Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions codeflash/benchmarking/instrument_codeflash_trace.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,10 @@

from typing import TYPE_CHECKING, Optional, Union

import isort
import libcst as cst

from codeflash.code_utils.formatter import sort_imports

if TYPE_CHECKING:
from pathlib import Path

Expand Down Expand Up @@ -107,7 +108,7 @@ def instrument_codeflash_trace_decorator(file_to_funcs_to_optimize: dict[Path, l
original_code = file_path.read_text(encoding="utf-8")
new_code = add_codeflash_decorator_to_code(original_code, functions_to_optimize)
# Modify the code
modified_code = isort.code(code=new_code, float_to_top=True)
modified_code = sort_imports(code=new_code, float_to_top=True)

# Write the modified code back to the file
file_path.write_text(modified_code, encoding="utf-8")
5 changes: 2 additions & 3 deletions codeflash/benchmarking/replay_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,8 @@
from pathlib import Path
from typing import TYPE_CHECKING, Any

import isort

from codeflash.cli_cmds.console import logger
from codeflash.code_utils.formatter import sort_imports
from codeflash.discovery.functions_to_optimize import inspect_top_level_functions_or_methods
from codeflash.verification.verification_utils import get_test_file_path

Expand Down Expand Up @@ -299,7 +298,7 @@ def generate_replay_test(
test_framework=test_framework,
max_run_count=max_run_count,
)
test_code = isort.code(test_code)
test_code = sort_imports(code=test_code)
output_file = get_test_file_path(
test_dir=Path(output_dir), function_name=benchmark_module_path, test_type="replay"
)
Expand Down
4 changes: 2 additions & 2 deletions codeflash/code_utils/code_replacer.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,13 @@
from functools import lru_cache
from typing import TYPE_CHECKING, Optional, TypeVar

import isort
import libcst as cst
from libcst.metadata import PositionProvider

from codeflash.cli_cmds.console import logger
from codeflash.code_utils.code_extractor import add_global_assignments, add_needed_imports_from_module
from codeflash.code_utils.config_parser import find_conftest_files
from codeflash.code_utils.formatter import sort_imports
from codeflash.code_utils.line_profile_utils import ImportAdder
from codeflash.models.models import FunctionParent

Expand Down Expand Up @@ -226,7 +226,7 @@ def add_custom_marker_to_all_tests(test_paths: list[Path]) -> None:
module = cst.parse_module(file_content)
importadder = ImportAdder("import pytest")
modified_module = module.visit(importadder)
modified_module = cst.parse_module(isort.code(modified_module.code, float_to_top=True))
modified_module = cst.parse_module(sort_imports(code=modified_module.code, float_to_top=True))
pytest_mark_adder = PytestMarkAdder("codeflash_no_autouse")
modified_module = modified_module.visit(pytest_mark_adder)
test_path.write_text(modified_module.code, encoding="utf-8")
Expand Down
6 changes: 3 additions & 3 deletions codeflash/code_utils/formatter.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,11 +166,11 @@ def format_code(
return formatted_code


def sort_imports(code: str) -> str:
def sort_imports(code: str, *, float_to_top: bool = False) -> str:
try:
# Deduplicate and sort imports, modify the code in memory, not on disk
sorted_code = isort.code(code)
except Exception:
sorted_code = isort.code(code=code, float_to_top=float_to_top)
except Exception: # this will also catch the FileSkipComment exception, use this fn everywhere
logger.exception("Failed to sort imports with isort.")
return code # Fall back to original code if isort fails

Expand Down
3 changes: 2 additions & 1 deletion codeflash/code_utils/instrument_existing_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

from codeflash.cli_cmds.console import logger
from codeflash.code_utils.code_utils import get_run_tmp_file, module_name_from_file_path
from codeflash.code_utils.formatter import sort_imports
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
from codeflash.models.models import FunctionParent, TestingMode, VerificationType

Expand Down Expand Up @@ -1129,7 +1130,7 @@ def add_async_decorator_to_function(
import_transformer = AsyncDecoratorImportAdder(mode)
module = module.visit(import_transformer)

return isort.code(module.code, float_to_top=True), decorator_transformer.added_decorator
return sort_imports(code=module.code, float_to_top=True), decorator_transformer.added_decorator
except Exception as e:
logger.exception(f"Error adding async decorator to function {function.qualified_name}: {e}")
return source_code, False
Expand Down
4 changes: 2 additions & 2 deletions codeflash/code_utils/line_profile_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,10 @@
from pathlib import Path
from typing import TYPE_CHECKING, Union

import isort
import libcst as cst

from codeflash.code_utils.code_utils import get_run_tmp_file
from codeflash.code_utils.formatter import sort_imports

if TYPE_CHECKING:
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
Expand Down Expand Up @@ -213,7 +213,7 @@ def add_decorator_imports(function_to_optimize: FunctionToOptimize, code_context
transformer = ImportAdder("from line_profiler import profile as codeflash_line_profile")
# Apply the transformer to add the import
module_node = module_node.visit(transformer)
modified_code = isort.code(module_node.code, float_to_top=True)
modified_code = sort_imports(code=module_node.code, float_to_top=True)
# write to file
with file_path.open("w", encoding="utf-8") as file:
file.write(modified_code)
Expand Down
3 changes: 1 addition & 2 deletions codeflash/optimization/function_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
from pathlib import Path
from typing import TYPE_CHECKING

import isort
import libcst as cst
from rich.console import Group
from rich.panel import Panel
Expand Down Expand Up @@ -900,7 +899,7 @@ def reformat_code_and_helpers(
optimized_context: CodeStringsMarkdown,
) -> tuple[str, dict[Path, str]]:
should_sort_imports = not self.args.disable_imports_sorting
if should_sort_imports and isort.code(original_code) != original_code:
if should_sort_imports and sort_imports(code=original_code) != original_code:
should_sort_imports = False

optimized_code = ""
Expand Down
5 changes: 4 additions & 1 deletion codeflash/tracing/tracing_new_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,6 +265,8 @@ def __exit__(

# These modules have been imported here now the tracer is done. It is safe to import codeflash and external modules here

from contextlib import suppress

import isort

from codeflash.tracing.replay_test import create_trace_replay_test
Expand All @@ -280,7 +282,8 @@ def __exit__(
test_file_path = get_test_file_path(
test_dir=Path(self.config["tests_root"]), function_name=function_path, test_type="replay"
)
replay_test = isort.code(replay_test)
with suppress(Exception):
replay_test = isort.code(replay_test)

with Path(test_file_path).open("w", encoding="utf8") as file:
file.write(replay_test)
Expand Down
5 changes: 2 additions & 3 deletions codeflash/verification/instrument_codeflash_capture.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,8 @@
from pathlib import Path
from typing import TYPE_CHECKING

import isort

from codeflash.code_utils.code_utils import get_run_tmp_file
from codeflash.code_utils.formatter import sort_imports

if TYPE_CHECKING:
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
Expand Down Expand Up @@ -70,7 +69,7 @@ def add_codeflash_capture_to_init(
ast.fix_missing_locations(modified_tree)

# Convert back to source code
return isort.code(code=ast.unparse(modified_tree), float_to_top=True)
return sort_imports(code=ast.unparse(modified_tree), float_to_top=True)


class InitDecorator(ast.NodeTransformer):
Expand Down
10 changes: 9 additions & 1 deletion tests/test_formatter.py
Original file line number Diff line number Diff line change
Expand Up @@ -796,4 +796,12 @@ def _is_valid(self, item):
optimization_function = """def process(self,data):
'''Single quote docstring with formatting issues.'''
return{'result':[item for item in data if self._is_valid(item)]}"""
_run_formatting_test(source_code, True, optimized_function=optimization_function, expected=expected)
_run_formatting_test(source_code, True, optimized_function=optimization_function, expected=expected)

def test_sort_imports_skip_file():
"""Test that isort skips files with # isort:skip_file."""
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is skip file test? Whats the use of this. We should have a test where we are sorting the packages and asserting on them I think,

code = """# isort:skip_file

import sys, os, json # isort will ignore this file completely"""
new_code = sort_imports(code)
assert new_code == code
Loading