Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 28 additions & 27 deletions codeflash/code_utils/code_extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,33 @@ def visit_Assign(self, node: cst.Assign) -> Optional[bool]:
return True


def find_insertion_index_after_imports(node: cst.Module) -> int:
"""Find the position of the last import statement in the top-level of the module."""
insert_index = 0
for i, stmt in enumerate(node.body):
is_top_level_import = isinstance(stmt, cst.SimpleStatementLine) and any(
isinstance(child, (cst.Import, cst.ImportFrom)) for child in stmt.body
)

is_conditional_import = isinstance(stmt, cst.If) and all(
isinstance(inner, cst.SimpleStatementLine)
and all(isinstance(child, (cst.Import, cst.ImportFrom)) for child in inner.body)
for inner in stmt.body.body
)

if is_top_level_import or is_conditional_import:
insert_index = i + 1

# Stop scanning once we reach a class or function definition.
# Imports are supposed to be at the top of the file, but they can technically appear anywhere, even at the bottom of the file.
# Without this check, a stray import later in the file
# would incorrectly shift our insertion index below actual code definitions.
if isinstance(stmt, (cst.ClassDef, cst.FunctionDef)):
break

return insert_index


class GlobalAssignmentTransformer(cst.CSTTransformer):
"""Transforms global assignments in the original file with those from the new file."""

Expand Down Expand Up @@ -122,32 +149,6 @@ def leave_Assign(self, original_node: cst.Assign, updated_node: cst.Assign) -> c

return updated_node

def _find_insertion_index(self, updated_node: cst.Module) -> int:
"""Find the position of the last import statement in the top-level of the module."""
insert_index = 0
for i, stmt in enumerate(updated_node.body):
is_top_level_import = isinstance(stmt, cst.SimpleStatementLine) and any(
isinstance(child, (cst.Import, cst.ImportFrom)) for child in stmt.body
)

is_conditional_import = isinstance(stmt, cst.If) and all(
isinstance(inner, cst.SimpleStatementLine)
and all(isinstance(child, (cst.Import, cst.ImportFrom)) for child in inner.body)
for inner in stmt.body.body
)

if is_top_level_import or is_conditional_import:
insert_index = i + 1

# Stop scanning once we reach a class or function definition.
# Imports are supposed to be at the top of the file, but they can technically appear anywhere, even at the bottom of the file.
# Without this check, a stray import later in the file
# would incorrectly shift our insertion index below actual code definitions.
if isinstance(stmt, (cst.ClassDef, cst.FunctionDef)):
break

return insert_index

def leave_Module(self, original_node: cst.Module, updated_node: cst.Module) -> cst.Module: # noqa: ARG002
# Add any new assignments that weren't in the original file
new_statements = list(updated_node.body)
Expand All @@ -161,7 +162,7 @@ def leave_Module(self, original_node: cst.Module, updated_node: cst.Module) -> c

if assignments_to_append:
# after last top-level imports
insert_index = self._find_insertion_index(updated_node)
insert_index = find_insertion_index_after_imports(updated_node)

assignment_lines = [
cst.SimpleStatementLine([assignment], leading_lines=[cst.EmptyLine()])
Expand Down
44 changes: 36 additions & 8 deletions codeflash/code_utils/code_replacer.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,18 @@
import ast
from collections import defaultdict
from functools import lru_cache
from itertools import chain
from typing import TYPE_CHECKING, Optional, TypeVar

import libcst as cst
from libcst.metadata import PositionProvider

from codeflash.cli_cmds.console import logger
from codeflash.code_utils.code_extractor import add_global_assignments, add_needed_imports_from_module
from codeflash.code_utils.code_extractor import (
add_global_assignments,
add_needed_imports_from_module,
find_insertion_index_after_imports,
)
from codeflash.code_utils.config_parser import find_conftest_files
from codeflash.code_utils.formatter import sort_imports
from codeflash.code_utils.line_profile_utils import ImportAdder
Expand Down Expand Up @@ -249,6 +254,7 @@ def __init__(
] = {} # keys are (class_name, function_name)
self.new_functions: list[cst.FunctionDef] = []
self.new_class_functions: dict[str, list[cst.FunctionDef]] = defaultdict(list)
self.new_classes: list[cst.ClassDef] = []
self.current_class = None
self.modified_init_functions: dict[str, cst.FunctionDef] = {}

Expand All @@ -271,6 +277,10 @@ def visit_ClassDef(self, node: cst.ClassDef) -> bool:
self.current_class = node.name.value

parents = (FunctionParent(name=node.name.value, type="ClassDef"),)

if (node.name.value, ()) not in self.preexisting_objects:
self.new_classes.append(node)

for child_node in node.body.body:
if (
self.preexisting_objects
Expand All @@ -290,13 +300,15 @@ class OptimFunctionReplacer(cst.CSTTransformer):
def __init__(
self,
modified_functions: Optional[dict[tuple[str | None, str], cst.FunctionDef]] = None,
new_classes: Optional[list[cst.ClassDef]] = None,
new_functions: Optional[list[cst.FunctionDef]] = None,
new_class_functions: Optional[dict[str, list[cst.FunctionDef]]] = None,
modified_init_functions: Optional[dict[str, cst.FunctionDef]] = None,
) -> None:
super().__init__()
self.modified_functions = modified_functions if modified_functions is not None else {}
self.new_functions = new_functions if new_functions is not None else []
self.new_classes = new_classes if new_classes is not None else []
self.new_class_functions = new_class_functions if new_class_functions is not None else defaultdict(list)
self.modified_init_functions: dict[str, cst.FunctionDef] = (
modified_init_functions if modified_init_functions is not None else {}
Expand Down Expand Up @@ -335,19 +347,33 @@ def leave_ClassDef(self, original_node: cst.ClassDef, updated_node: cst.ClassDef
def leave_Module(self, original_node: cst.Module, updated_node: cst.Module) -> cst.Module: # noqa: ARG002
node = updated_node
max_function_index = None
class_index = None
max_class_index = None
for index, _node in enumerate(node.body):
if isinstance(_node, cst.FunctionDef):
max_function_index = index
if isinstance(_node, cst.ClassDef):
class_index = index
max_class_index = index

if self.new_classes:
existing_class_names = {_node.name.value for _node in node.body if isinstance(_node, cst.ClassDef)}

unique_classes = [
new_class for new_class in self.new_classes if new_class.name.value not in existing_class_names
]
if unique_classes:
new_classes_insertion_idx = max_class_index or find_insertion_index_after_imports(node)
new_body = list(
chain(node.body[:new_classes_insertion_idx], unique_classes, node.body[new_classes_insertion_idx:])
)
node = node.with_changes(body=new_body)

if max_function_index is not None:
node = node.with_changes(
body=(*node.body[: max_function_index + 1], *self.new_functions, *node.body[max_function_index + 1 :])
)
elif class_index is not None:
elif max_class_index is not None:
node = node.with_changes(
body=(*node.body[: class_index + 1], *self.new_functions, *node.body[class_index + 1 :])
body=(*node.body[: max_class_index + 1], *self.new_functions, *node.body[max_class_index + 1 :])
)
else:
node = node.with_changes(body=(*self.new_functions, *node.body))
Expand All @@ -373,18 +399,20 @@ def replace_functions_in_file(
parsed_function_names.append((class_name, function_name))

# Collect functions we want to modify from the optimized code
module = cst.metadata.MetadataWrapper(cst.parse_module(optimized_code))
optimized_module = cst.metadata.MetadataWrapper(cst.parse_module(optimized_code))
original_module = cst.parse_module(source_code)

visitor = OptimFunctionCollector(preexisting_objects, set(parsed_function_names))
module.visit(visitor)
optimized_module.visit(visitor)

# Replace these functions in the original code
transformer = OptimFunctionReplacer(
modified_functions=visitor.modified_functions,
new_classes=visitor.new_classes,
new_functions=visitor.new_functions,
new_class_functions=visitor.new_class_functions,
modified_init_functions=visitor.modified_init_functions,
)
original_module = cst.parse_module(source_code)
modified_tree = original_module.visit(transformer)
return modified_tree.code

Expand Down
143 changes: 133 additions & 10 deletions tests/test_code_replacement.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,7 @@ def new_function(self, value: cst.Name):
return other_function(self.name)
def new_function2(value):
return value
"""
"""

original_code = """import libcst as cst
from typing import Mandatory
Expand All @@ -230,19 +230,28 @@ def other_function(st):

print("Salut monde")
"""
expected = """from typing import Mandatory
expected = """import libcst as cst
from typing import Mandatory

class NewClass:
def __init__(self, name):
self.name = name
def new_function(self, value: cst.Name):
return other_function(self.name)
def new_function2(value):
return value

print("Au revoir")

def yet_another_function(values):
return len(values)

def other_function(st):
return(st * 2)

def totally_new_function(value):
return value

def other_function(st):
return(st * 2)

print("Salut monde")
"""

Expand Down Expand Up @@ -279,7 +288,7 @@ def new_function(self, value):
return other_function(self.name)
def new_function2(value):
return value
"""
"""

original_code = """import libcst as cst
from typing import Mandatory
Expand All @@ -296,17 +305,25 @@ def other_function(st):
"""
expected = """from typing import Mandatory

class NewClass:
def __init__(self, name):
self.name = name
def new_function(self, value):
return other_function(self.name)
def new_function2(value):
return value

print("Au revoir")

def yet_another_function(values):
return len(values) + 2

def other_function(st):
return(st * 2)

def totally_new_function(value):
return value

def other_function(st):
return(st * 2)

print("Salut monde")
"""

Expand Down Expand Up @@ -3619,4 +3636,110 @@ async def task():
await asyncio.sleep(1)
return "done"
'''
assert is_zero_diff(original_code, optimized_code)
assert is_zero_diff(original_code, optimized_code)



def test_code_replacement_with_new_helper_class() -> None:
optim_code = """from __future__ import annotations

import itertools
import re
from dataclasses import dataclass
from typing import Any, Callable, Iterator, Sequence

from bokeh.models import HoverTool, Plot, Tool


# Move the Item dataclass to module-level to avoid redefining it on every function call
@dataclass(frozen=True)
class _RepeatedToolItem:
obj: Tool
properties: dict[str, Any]

def _collect_repeated_tools(tool_objs: list[Tool]) -> Iterator[Tool]:
key: Callable[[Tool], str] = lambda obj: obj.__class__.__name__
# Pre-collect properties for all objects by group to avoid repeated calls
for _, group in itertools.groupby(sorted(tool_objs, key=key), key=key):
grouped = list(group)
n = len(grouped)
if n > 1:
# Precompute all properties once for this group
props = [_RepeatedToolItem(obj, obj.properties_with_values()) for obj in grouped]
i = 0
while i < len(props) - 1:
head = props[i]
for j in range(i+1, len(props)):
item = props[j]
if item.properties == head.properties:
yield item.obj
i += 1
"""

original_code = """from __future__ import annotations
import itertools
import re
from bokeh.models import HoverTool, Plot, Tool
from dataclasses import dataclass
from typing import Any, Callable, Iterator, Sequence

def _collect_repeated_tools(tool_objs: list[Tool]) -> Iterator[Tool]:
@dataclass(frozen=True)
class Item:
obj: Tool
properties: dict[str, Any]

key: Callable[[Tool], str] = lambda obj: obj.__class__.__name__

for _, group in itertools.groupby(sorted(tool_objs, key=key), key=key):
rest = [ Item(obj, obj.properties_with_values()) for obj in group ]
while len(rest) > 1:
head, *rest = rest
for item in rest:
if item.properties == head.properties:
yield item.obj
"""

expected = """from __future__ import annotations
import itertools
from bokeh.models import Tool
from dataclasses import dataclass
from typing import Any, Callable, Iterator


# Move the Item dataclass to module-level to avoid redefining it on every function call
@dataclass(frozen=True)
class _RepeatedToolItem:
obj: Tool
properties: dict[str, Any]

def _collect_repeated_tools(tool_objs: list[Tool]) -> Iterator[Tool]:
key: Callable[[Tool], str] = lambda obj: obj.__class__.__name__
# Pre-collect properties for all objects by group to avoid repeated calls
for _, group in itertools.groupby(sorted(tool_objs, key=key), key=key):
grouped = list(group)
n = len(grouped)
if n > 1:
# Precompute all properties once for this group
props = [_RepeatedToolItem(obj, obj.properties_with_values()) for obj in grouped]
i = 0
while i < len(props) - 1:
head = props[i]
for j in range(i+1, len(props)):
item = props[j]
if item.properties == head.properties:
yield item.obj
i += 1
"""

function_names: list[str] = ["_collect_repeated_tools"]
preexisting_objects: set[tuple[str, tuple[FunctionParent, ...]]] = find_preexisting_objects(original_code)
new_code: str = replace_functions_and_add_imports(
source_code=original_code,
function_names=function_names,
optimized_code=optim_code,
module_abspath=Path(__file__).resolve(),
preexisting_objects=preexisting_objects,
project_root_path=Path(__file__).resolve().parent.resolve(),
)
assert new_code == expected
Loading