Skip to content
42 changes: 23 additions & 19 deletions codeflash/cli_cmds/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,25 +249,29 @@ def project_root_from_module_root(module_root: Path, pyproject_file_path: Path)


def handle_optimize_all_arg_parsing(args: Namespace) -> Namespace:
if hasattr(args, "all"):
import git

from codeflash.code_utils.git_utils import check_and_push_branch, get_repo_owner_and_name
from codeflash.code_utils.github_utils import require_github_app_or_exit

# Ensure that the user can actually open PRs on the repo.
try:
git_repo = git.Repo(search_parent_directories=True)
except git.exc.InvalidGitRepositoryError:
logger.exception(
"I couldn't find a git repository in the current directory. "
"I need a git repository to run --all and open PRs for optimizations. Exiting..."
)
apologize_and_exit()
if not args.no_pr and not check_and_push_branch(git_repo, git_remote=args.git_remote):
exit_with_message("Branch is not pushed...", error_on_exit=True)
owner, repo = get_repo_owner_and_name(git_repo)
if not args.no_pr:
if hasattr(args, "all") or (hasattr(args, "file") and args.file):
no_pr = getattr(args, "no_pr", False)

if not no_pr:
import git

from codeflash.code_utils.git_utils import check_and_push_branch, get_repo_owner_and_name
from codeflash.code_utils.github_utils import require_github_app_or_exit

# Ensure that the user can actually open PRs on the repo.
try:
git_repo = git.Repo(search_parent_directories=True)
except git.exc.InvalidGitRepositoryError:
mode = "--all" if hasattr(args, "all") else "--file"
logger.exception(
f"I couldn't find a git repository in the current directory. "
f"I need a git repository to run {mode} and open PRs for optimizations. Exiting..."
)
apologize_and_exit()
git_remote = getattr(args, "git_remote", None)
if not check_and_push_branch(git_repo, git_remote=git_remote):
exit_with_message("Branch is not pushed...", error_on_exit=True)
owner, repo = get_repo_owner_and_name(git_repo)
require_github_app_or_exit(owner, repo)
if not hasattr(args, "all"):
args.all = None
Expand Down
2 changes: 2 additions & 0 deletions tests/test_worktree.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ def test_mirror_paths_for_worktree_mode(monkeypatch: pytest.MonkeyPatch):
args = Namespace()
args.benchmark = False
args.benchmarks_root = None
args.no_pr = True

args.config_file = project_root / "pyproject.toml"
args.file = project_root / "src" / "app" / "main.py"
Expand Down Expand Up @@ -42,6 +43,7 @@ def test_mirror_paths_for_worktree_mode(monkeypatch: pytest.MonkeyPatch):
args = Namespace()
args.benchmark = False
args.benchmarks_root = None
args.no_pr = True

args.config_file = repo_root / "pyproject.toml"
args.file = repo_root / "codeflash/optimization/optimizer.py"
Expand Down
Loading
Loading