@@ -20,12 +20,13 @@ def get_slug_from_remote_url(remote_url: str) -> str:
2020 return potential_slug .removesuffix (".git" )
2121
2222
23- @contextlib .contextmanager
24- def transitioning_branches (
25- repo : Repo , branch_prefix : str , branch_suffix : str = "" , force : bool = True
26- ) -> Generator [tuple [Head , Head ], None , None ]:
23+ def get_current_branch (repo : Repo ) -> Head :
24+ remote = repo .remote ("origin" )
2725 if repo .head .is_detached :
28- from_branch = next ((branch for branch in repo .branches if branch .commit == repo .head .commit ), None )
26+ from_branch = next (
27+ (branch for branch in remote .refs if branch .commit == repo .head .commit and branch .remote_head != "HEAD" ),
28+ None ,
29+ )
2930 else :
3031 from_branch = repo .active_branch
3132
@@ -35,16 +36,27 @@ def transitioning_branches(
3536 "Make sure repository is not in a detached HEAD state with additional commits."
3637 )
3738
38- next_branch_name = f"{ branch_prefix } { from_branch .name } { branch_suffix } "
39+ return from_branch
40+
41+
42+ @contextlib .contextmanager
43+ def transitioning_branches (
44+ repo : Repo , branch_prefix : str , branch_suffix : str = "" , force : bool = True
45+ ) -> Generator [tuple [str , str ], None , None ]:
46+ from_branch = get_current_branch (repo )
47+ from_branch_name = from_branch .name if not from_branch .is_remote () else from_branch .remote_head
48+ next_branch_name = f"{ branch_prefix } { from_branch_name } { branch_suffix } "
3949 if next_branch_name in repo .heads and not force :
40- raise ValueError (f'Branch "{ next_branch_name } " already exists.' )
50+ raise ValueError (f'Local Branch "{ next_branch_name } " already exists.' )
51+ if next_branch_name in repo .remote ("origin" ).refs and not force :
52+ raise ValueError (f'Remote Branch "{ next_branch_name } " already exists.' )
4153
4254 logger .info (f'Creating new branch "{ next_branch_name } ".' )
4355 to_branch = repo .create_head (next_branch_name , force = force )
4456
4557 try :
4658 to_branch .checkout ()
47- yield from_branch , to_branch
59+ yield from_branch_name , next_branch_name
4860 finally :
4961 from_branch .checkout ()
5062
@@ -137,7 +149,9 @@ def run(self) -> dict:
137149 repo = git .Repo (Path .cwd ())
138150 if not self .enabled :
139151 logger .debug ("Branch creation is disabled." )
140- return dict (target_branch = repo .active_branch .name )
152+ from_branch = get_current_branch (repo )
153+ from_branch_name = from_branch .name if not from_branch .is_remote () else from_branch .remote_head
154+ return dict (target_branch = from_branch_name )
141155
142156 modified_files = {modified_code_file ["path" ] for modified_code_file in self .modified_code_files }
143157
@@ -153,6 +167,6 @@ def run(self) -> dict:
153167
154168 logger .info (f"Run completed { self .__class__ .__name__ } " )
155169 return dict (
156- base_branch = from_branch . name ,
157- target_branch = to_branch . name ,
170+ base_branch = from_branch ,
171+ target_branch = to_branch ,
158172 )
0 commit comments