@@ -63,6 +63,13 @@ def git_get_current_branch() -> str:
63
63
return p .stdout .decode ().strip ()
64
64
65
65
66
+ def print_branch_level (branch : str , current_branch : str , depth : int ) -> None :
67
+ """Prints a branch line as part of a tree"""
68
+ branch_line = f"\u001b [32m{ branch } \u001b [0m" if branch == current_branch else branch
69
+ level_indicator = " " * (2 * (depth - 1 )) + "↳ "
70
+ print (f"{ level_indicator } { branch_line } " if depth > 0 else branch_line )
71
+
72
+
66
73
class NoValidTrunkError (Exception ):
67
74
"""Error for when none of the trunk candidates match"""
68
75
@@ -117,10 +124,9 @@ def wrapup(self) -> None:
117
124
118
125
def print_stack (self ):
119
126
"""Pretty print the entire stack"""
127
+ current_branch = git_get_current_branch ()
120
128
self ._traverse_stack (
121
- lambda branch , depth : print (" " * (2 * (depth - 1 )) + "↳ " + branch )
122
- if depth > 0
123
- else print (branch )
129
+ lambda branch , depth : print_branch_level (branch , current_branch , depth )
124
130
)
125
131
126
132
def create_branch (self , branch : str , parent ) -> None :
@@ -211,7 +217,15 @@ def switch_to_child(self) -> None:
211
217
212
218
def sync (self ):
213
219
"""Rebase all branches on top of current trunk"""
220
+ current_branch = git_get_current_branch ()
214
221
self ._traverse_stack (lambda branch , depth : self ._check_and_rebase (branch ))
222
+ # switch back to original branch once done
223
+ subprocess .run (
224
+ ["git" , "switch" , current_branch ],
225
+ check = True ,
226
+ stdout = sys .stdout .buffer ,
227
+ stderr = sys .stderr .buffer ,
228
+ )
215
229
216
230
def _traverse_stack (self , fn : Callable [[str , int ], None ]):
217
231
"""DFS through the gitstack from trunk, calling a function on each branch"""
0 commit comments