Skip to content

Commit def3a7a

Browse files
committed
Improve the branch selection in interactive go (now same as status)
1 parent cc50fcd commit def3a7a

2 files changed

Lines changed: 86 additions & 86 deletions

File tree

git_machete/client/go_interactive.py

Lines changed: 71 additions & 80 deletions
Original file line numberDiff line numberDiff line change
@@ -9,21 +9,35 @@
99
termios = None # type: ignore[assignment]
1010
tty = None # type: ignore[assignment]
1111

12-
from git_machete import utils
13-
from git_machete.client.base import MacheteClient
12+
from git_machete import git_config_keys, utils
13+
from git_machete.client.base import SquashMergeDetection
14+
from git_machete.client.status import (StatusData, StatusFlags,
15+
StatusMacheteClient)
1416
from git_machete.exceptions import UnexpectedMacheteException
1517
from git_machete.git_operations import LocalBranchShortName
1618
from git_machete.utils import AnsiEscapeCodes, bold, index_or_none, warn
1719

1820

19-
class GoInteractiveMacheteClient(MacheteClient):
21+
def _branch_line_indices(data: StatusData, branches: List[LocalBranchShortName]) -> List[int]:
22+
"""Line index in status output for each branch (status has an extra newline before non-first roots)."""
23+
roots = set(data.roots)
24+
first_root = next((b for b in branches if b in roots), None)
25+
result: List[int] = []
26+
for i, branch in enumerate(branches):
27+
blanks_before = sum(1 for j in range(i) if branches[j] in roots and branches[j] != first_root)
28+
result.append(i + blanks_before)
29+
return result
30+
31+
32+
class GoInteractiveMacheteClient(StatusMacheteClient):
2033
"""Client for interactive branch selection using curses-style interface (implemented without curses, just using ANSI sequences)."""
2134

2235
MAX_VISIBLE_BRANCHES_DEFAULT = 20
2336
MAX_VISIBLE_BRANCHES_LOWER = 2
2437
MAX_VISIBLE_BRANCHES_UPPER = 50
2538

26-
_managed_branches_with_depths: List[Tuple[LocalBranchShortName, int]]
39+
_status_data: StatusData
40+
_branch_line_indices: List[int] # line index in status output for each branch
2741
_current_branch: Optional[LocalBranchShortName]
2842
_max_visible_branches: int
2943

@@ -37,35 +51,10 @@ def _get_max_visible_branches(self) -> int:
3751
max_visible_branches = terminal_height - 2
3852
return max(self.MAX_VISIBLE_BRANCHES_LOWER, min(max_visible_branches, self.MAX_VISIBLE_BRANCHES_UPPER))
3953

40-
def _get_branch_list_with_depths(self) -> List[Tuple[LocalBranchShortName, int]]:
41-
"""Get a flat list of branches with their depths using DFS traversal."""
42-
result: List[Tuple[LocalBranchShortName, int]] = []
43-
44-
def add_branch_and_children(branch: LocalBranchShortName, depth: int) -> None:
45-
result.append((branch, depth))
46-
for child_branch in self.down_branches_for(branch) or []:
47-
add_branch_and_children(child_branch, depth + 1)
48-
49-
for root in self._state.roots:
50-
add_branch_and_children(root, depth=0)
51-
52-
return result
53-
54-
def _render_branch_line(self, branch: LocalBranchShortName, depth: int) -> str:
55-
"""Render a single branch line with indentation."""
56-
indent = " " * depth
57-
marker = " " if self._current_branch is None or branch != self._current_branch else "*"
58-
59-
line = f"{indent}{marker} {branch}"
60-
annotation = self.annotations.get(branch)
61-
if annotation and annotation.formatted_full_text:
62-
line += f" {annotation.formatted_full_text}"
63-
64-
return line
65-
6654
def _draw_screen(self, *, selected_idx: int, scroll_offset: int,
67-
num_lines_drawn: int, is_first_draw: bool) -> int:
68-
"""Draw the branch selection screen using ANSI escape codes."""
55+
num_lines_drawn: int, is_first_draw: bool) -> Tuple[int, int]:
56+
"""Draw the branch selection screen using status-style output (format_status_output).
57+
Returns (scroll_offset, actual_lines_drawn) so the caller can move the cursor up correctly on next redraw."""
6958
# Move cursor up to the start of our display area (if we've drawn before)
7059
if not is_first_draw and num_lines_drawn > 0:
7160
sys.stdout.write(f'{AnsiEscapeCodes.CSI}{num_lines_drawn}A')
@@ -79,27 +68,31 @@ def _draw_screen(self, *, selected_idx: int, scroll_offset: int,
7968
"Enter or Space: checkout, q or Ctrl+C: quit)")
8069
sys.stdout.write(bold(header_text) + '\n')
8170

82-
# Adjust scroll offset if needed
83-
visible_lines = min(self._max_visible_branches, len(self._managed_branches_with_depths))
84-
if selected_idx < scroll_offset:
85-
scroll_offset = selected_idx
86-
elif selected_idx >= scroll_offset + visible_lines:
87-
scroll_offset = selected_idx - visible_lines + 1
88-
89-
# Draw branches
71+
branches = self._status_data.branches_in_display_order
72+
line_indices = self._branch_line_indices
73+
selected_branch = branches[selected_idx] if 0 <= selected_idx < len(branches) else None
74+
status_str = self.format_status_output(
75+
self._status_data,
76+
selected_branch=selected_branch,
77+
)
78+
lines = status_str.splitlines()
79+
num_lines = len(lines)
80+
visible_lines = min(self._max_visible_branches, num_lines)
81+
selected_line_idx = line_indices[selected_idx] if 0 <= selected_idx < len(line_indices) else 0
82+
if selected_line_idx < scroll_offset:
83+
scroll_offset = selected_line_idx
84+
elif selected_line_idx >= scroll_offset + visible_lines:
85+
scroll_offset = selected_line_idx - visible_lines + 1
86+
87+
lines_drawn = 1 # header
9088
for i in range(visible_lines):
91-
branch_idx = scroll_offset + i
92-
branch, depth = self._managed_branches_with_depths[branch_idx]
93-
line = self._render_branch_line(branch, depth)
94-
95-
if branch_idx == selected_idx:
96-
# Highlight selected line (inverse video)
97-
sys.stdout.write(f'{AnsiEscapeCodes.REVERSE_VIDEO}{line}{AnsiEscapeCodes.ENDC}\n')
98-
else:
99-
sys.stdout.write(f'{line}\n')
89+
line_idx = scroll_offset + i
90+
if line_idx < num_lines:
91+
sys.stdout.write(lines[line_idx] + '\n')
92+
lines_drawn += 1
10093

10194
sys.stdout.flush()
102-
return scroll_offset
95+
return scroll_offset, lines_drawn
10396

10497
def _get_stdin_fd(self) -> int: # pragma: no cover; always mocked in tests
10598
return sys.stdin.fileno()
@@ -136,26 +129,36 @@ def go_interactive(self, *, current_branch: Optional[LocalBranchShortName]) -> O
136129
"""
137130
Launch interactive branch selection interface.
138131
Returns the selected branch or None if cancelled.
132+
Status data is computed once at start; only rendering (format_status_output) runs on each redraw.
139133
"""
140134
if termios is None or tty is None:
141135
raise UnexpectedMacheteException("Interactive mode is not supported on Windows yet")
142136

143-
# Get flat list of branches with depths from already-parsed state
144-
self._managed_branches_with_depths = self._get_branch_list_with_depths()
145-
146137
self._current_branch = current_branch
147-
148-
# Determine maximum visible branches from terminal height
149138
self._max_visible_branches = self._get_max_visible_branches()
150139

140+
# Compute status data once (no list-commits in TUI; squash merge detection defaults to simple)
141+
maybe_space = (
142+
' ' if self._git.get_boolean_config_attr(
143+
git_config_keys.STATUS_EXTRA_SPACE_BEFORE_BRANCH_NAME, default_value=False) else ''
144+
)
145+
flags = StatusFlags(
146+
maybe_space_before_branch_name=maybe_space,
147+
opt_list_commits=False,
148+
opt_list_commits_with_hashes=False,
149+
opt_squash_merge_detection=SquashMergeDetection.SIMPLE,
150+
)
151+
self._status_data = self.compute_status_data(flags=flags)
152+
branches_ordered = self._status_data.branches_in_display_order
153+
self._branch_line_indices = _branch_line_indices(self._status_data, branches_ordered)
154+
151155
# Find initial selection (current branch or first branch if detached HEAD)
152156
if current_branch is not None:
153-
selected_idx = index_or_none(self.managed_branches, self._current_branch)
157+
selected_idx = index_or_none(branches_ordered, self._current_branch)
154158
if selected_idx is None:
155159
selected_idx = 0
156160
warn(f"current branch {self._current_branch} is unmanaged\n")
157161
else:
158-
# Detached HEAD - start with first managed branch
159162
selected_idx = 0
160163

161164
scroll_offset = 0
@@ -166,55 +169,43 @@ def go_interactive(self, *, current_branch: Optional[LocalBranchShortName]) -> O
166169
sys.stdout.write(AnsiEscapeCodes.HIDE_CURSOR)
167170
sys.stdout.flush()
168171

172+
branches = self._status_data.branches_in_display_order
169173
try:
170174
while True:
171-
# Calculate how many lines we'll draw (header + visible branches)
172-
visible_lines = min(self._max_visible_branches, len(self._managed_branches_with_depths))
173-
num_lines_drawn = visible_lines + 1 # +1 for header
174-
175-
scroll_offset = self._draw_screen(
175+
scroll_offset, num_lines_drawn = self._draw_screen(
176176
selected_idx=selected_idx,
177177
scroll_offset=scroll_offset,
178178
num_lines_drawn=num_lines_drawn,
179179
is_first_draw=is_first_draw
180180
)
181181
is_first_draw = False
182182

183-
# Read key
184183
key = self._getch()
185184

186185
if key == AnsiEscapeCodes.KEY_UP:
187-
# Wrap around from first to last
188-
selected_idx = (selected_idx - 1) % len(self._managed_branches_with_depths)
186+
selected_idx = (selected_idx - 1) % len(branches)
189187
elif key == AnsiEscapeCodes.KEY_DOWN:
190-
# Wrap around from last to first
191-
selected_idx = (selected_idx + 1) % len(self._managed_branches_with_depths)
188+
selected_idx = (selected_idx + 1) % len(branches)
192189
elif key == AnsiEscapeCodes.KEY_SHIFT_UP:
193-
# Jump to first branch
194190
selected_idx = 0
195191
elif key == AnsiEscapeCodes.KEY_SHIFT_DOWN:
196-
# Jump to last branch
197-
selected_idx = len(self._managed_branches_with_depths) - 1
192+
selected_idx = len(branches) - 1
198193
elif key == AnsiEscapeCodes.KEY_LEFT:
199-
# Go to parent
200-
selected_branch, _ = self._managed_branches_with_depths[selected_idx]
194+
selected_branch = branches[selected_idx]
201195
parent_branch = self.up_branch_for(selected_branch)
202-
if parent_branch:
203-
selected_idx = self.managed_branches.index(parent_branch)
196+
if parent_branch is not None:
197+
selected_idx = branches.index(parent_branch)
204198
elif key == AnsiEscapeCodes.KEY_RIGHT:
205-
# Go to first child
206-
selected_branch, _ = self._managed_branches_with_depths[selected_idx]
199+
selected_branch = branches[selected_idx]
207200
child_branches = self.down_branches_for(selected_branch)
208201
if child_branches:
209-
selected_idx = self.managed_branches.index(child_branches[0])
202+
selected_idx = branches.index(child_branches[0])
210203
elif key in AnsiEscapeCodes.KEYS_ENTER or key == AnsiEscapeCodes.KEY_SPACE:
211-
selected_branch, _ = self._managed_branches_with_depths[selected_idx]
212-
return selected_branch
204+
return branches[selected_idx]
213205
elif key in ('q', 'Q'):
214206
return None
215207
elif key == AnsiEscapeCodes.KEY_CTRL_C:
216208
return None
217209
finally:
218-
# Show cursor again and move past our interface
219210
sys.stdout.write(AnsiEscapeCodes.SHOW_CURSOR)
220211
sys.stdout.flush()

git_machete/client/status.py

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -87,8 +87,13 @@ class StatusMacheteClient(MacheteClient):
8787
"""Client for the status command. Exposes status() and can be used as a mixin for other clients."""
8888

8989
@staticmethod
90-
def _format_status_output(data: StatusData) -> str:
91-
"""Pure function: given StatusData, returns the formatted status tree string."""
90+
def format_status_output(
91+
data: StatusData,
92+
*,
93+
selected_branch: Optional[LocalBranchShortName] = None,
94+
) -> str:
95+
"""Pure function: given StatusData, returns the formatted status tree string.
96+
When selected_branch is set, that branch's name (not annotation/sync status) is wrapped in reverse video."""
9297
out = io.StringIO()
9398
space = data.flags.maybe_space_before_branch_name
9499

@@ -169,7 +174,11 @@ def write_line_prefix(
169174
if b.annotation is not None and b.annotation.formatted_full_text:
170175
anno = ' ' + b.annotation.formatted_full_text
171176

172-
out.write(current + anno + b.sync_status + b.hook_output + "\n")
177+
if selected_branch is not None and branch == selected_branch:
178+
current_part = f"{AnsiEscapeCodes.REVERSE_VIDEO}{current}{AnsiEscapeCodes.ENDC}"
179+
else:
180+
current_part = current
181+
out.write(f"{current_part}{anno}{b.sync_status}{b.hook_output}\n")
173182

174183
return out.getvalue()
175184

@@ -214,7 +223,7 @@ def _status_warning_message(data: StatusData) -> Optional[str]:
214223
)
215224
return f"{first_part}.\n\n{second_part}."
216225

217-
def _compute_status_data(self, *, flags: StatusFlags) -> StatusData:
226+
def compute_status_data(self, *, flags: StatusFlags) -> StatusData:
218227
managed_branches: List[LocalBranchShortName] = list(self._state.managed_branches)
219228

220229
sync_to_parent_status: Dict[LocalBranchShortName, SyncToParentStatus] = {}
@@ -360,8 +369,8 @@ def status(
360369
opt_list_commits_with_hashes=opt_list_commits_with_hashes,
361370
opt_squash_merge_detection=opt_squash_merge_detection,
362371
)
363-
data = self._compute_status_data(flags=flags)
364-
status_str = self._format_status_output(data)
372+
data = self.compute_status_data(flags=flags)
373+
status_str = self.format_status_output(data)
365374
sys.stdout.write(status_str)
366375
if warn_when_branch_in_sync_but_fork_point_off:
367376
warning_msg = self._status_warning_message(data)

0 commit comments

Comments
 (0)