Skip to content

Commit f1d13f7

Browse files
committed
Improve the branch selection in interactive go (now same as status)
1 parent bd4a773 commit f1d13f7

3 files changed

Lines changed: 116 additions & 102 deletions

File tree

Lines changed: 53 additions & 81 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import sys
2-
from typing import List, Optional, Tuple
2+
from typing import Optional, Tuple
33

44
try:
55
import termios
@@ -10,20 +10,21 @@
1010
tty = None # type: ignore[assignment]
1111

1212
from git_machete import utils
13-
from git_machete.client.base import MacheteClient
13+
from git_machete.client.status import (StatusData, StatusFlags,
14+
StatusMacheteClient)
1415
from git_machete.exceptions import MacheteException, UnexpectedMacheteException
1516
from git_machete.git_operations import LocalBranchShortName
1617
from git_machete.utils import bold, index_or_none, warn
1718

1819

19-
class GoInteractiveMacheteClient(MacheteClient):
20+
class GoInteractiveMacheteClient(StatusMacheteClient):
2021
"""Client for interactive branch selection using curses-style interface (implemented without curses, just using ANSI sequences)."""
2122

2223
MAX_VISIBLE_BRANCHES_DEFAULT = 20
2324
MAX_VISIBLE_BRANCHES_LOWER = 2
2425
MAX_VISIBLE_BRANCHES_UPPER = 50
2526

26-
_managed_branches_with_depths: List[Tuple[LocalBranchShortName, int]]
27+
_status_data: StatusData
2728
_current_branch: Optional[LocalBranchShortName]
2829
_max_visible_branches: int
2930

@@ -37,35 +38,10 @@ def _get_max_visible_branches(self) -> int:
3738
max_visible_branches = terminal_height - 2
3839
return max(self.MAX_VISIBLE_BRANCHES_LOWER, min(max_visible_branches, self.MAX_VISIBLE_BRANCHES_UPPER))
3940

40-
def _get_branch_list_with_depths(self) -> List[Tuple[LocalBranchShortName, int]]:
41-
"""Get a flat list of branches with their depths using DFS traversal."""
42-
result: List[Tuple[LocalBranchShortName, int]] = []
43-
44-
def add_branch_and_children(branch: LocalBranchShortName, depth: int) -> None:
45-
result.append((branch, depth))
46-
for child_branch in self.down_branches_for(branch) or []:
47-
add_branch_and_children(child_branch, depth + 1)
48-
49-
for root in self._state.roots:
50-
add_branch_and_children(root, depth=0)
51-
52-
return result
53-
54-
def _render_branch_line(self, branch: LocalBranchShortName, depth: int) -> str:
55-
"""Render a single branch line with indentation."""
56-
indent = " " * depth
57-
marker = " " if self._current_branch is None or branch != self._current_branch else "*"
58-
59-
line = f"{indent}{marker} {branch}"
60-
annotation = self.annotations.get(branch)
61-
if annotation and annotation.formatted_full_text:
62-
line += f" {annotation.formatted_full_text}"
63-
64-
return line
65-
6641
def _draw_screen(self, *, selected_idx: int, scroll_offset: int,
67-
num_lines_drawn: int, is_first_draw: bool) -> int:
68-
"""Draw the branch selection screen using ANSI escape codes."""
42+
num_lines_drawn: int, is_first_draw: bool) -> Tuple[int, int]:
43+
"""Draw the branch selection screen using status-style output (format_status_output).
44+
Returns (scroll_offset, actual_lines_drawn) so the caller can move the cursor up correctly on next redraw."""
6945
# Move cursor up to the start of our display area (if we've drawn before)
7046
if not is_first_draw and num_lines_drawn > 0:
7147
sys.stdout.write(f'{utils.AE.CSI}{num_lines_drawn}A')
@@ -79,27 +55,30 @@ def _draw_screen(self, *, selected_idx: int, scroll_offset: int,
7955
"Enter or Space: checkout, q or Ctrl+C: quit)")
8056
sys.stdout.write(bold(header_text) + '\n')
8157

82-
# Adjust scroll offset if needed
83-
visible_lines = min(self._max_visible_branches, len(self._managed_branches_with_depths))
84-
if selected_idx < scroll_offset:
85-
scroll_offset = selected_idx
86-
elif selected_idx >= scroll_offset + visible_lines:
87-
scroll_offset = selected_idx - visible_lines + 1
88-
89-
# Draw branches
90-
for i in range(visible_lines):
91-
branch_idx = scroll_offset + i
92-
branch, depth = self._managed_branches_with_depths[branch_idx]
93-
line = self._render_branch_line(branch, depth)
94-
95-
if branch_idx == selected_idx:
96-
# Highlight selected line (inverse video)
97-
sys.stdout.write(f'{utils.AE.REVERSE_VIDEO}{line}{utils.AE.ENDC}\n')
98-
else:
99-
sys.stdout.write(f'{line}\n')
58+
branches = self._status_data.branches_in_display_order
59+
selected_branch = branches[selected_idx] if 0 <= selected_idx < len(branches) else None
60+
formatted = self.format_status_output(
61+
self._status_data,
62+
selected_branch=selected_branch,
63+
)
64+
lines = formatted.result.splitlines()
65+
num_lines = len(lines)
66+
visible_lines = min(self._max_visible_branches, num_lines)
67+
# line_for_branch maps branch -> 0-based line index in result
68+
selected_line_idx = formatted.line_for_branch.get(selected_branch, 0) if selected_branch else 0
69+
if selected_line_idx < scroll_offset:
70+
scroll_offset = selected_line_idx
71+
elif selected_line_idx >= scroll_offset + visible_lines:
72+
scroll_offset = selected_line_idx - visible_lines + 1
73+
74+
lines_drawn = 1 # header
75+
end = min(scroll_offset + visible_lines, num_lines)
76+
for line_idx in range(scroll_offset, end):
77+
sys.stdout.write(lines[line_idx] + '\n')
78+
lines_drawn += 1
10079

10180
sys.stdout.flush()
102-
return scroll_offset
81+
return scroll_offset, lines_drawn
10382

10483
def _get_stdin_fd(self) -> int: # pragma: no cover; always mocked in tests
10584
return sys.stdin.fileno()
@@ -136,28 +115,33 @@ def go_interactive(self, *, current_branch: Optional[LocalBranchShortName]) -> O
136115
"""
137116
Launch interactive branch selection interface.
138117
Returns the selected branch or None if cancelled.
118+
Status data is computed once at start; only rendering (format_status_output) runs on each redraw.
139119
"""
140120
if termios is None or tty is None:
141121
raise UnexpectedMacheteException("Interactive mode is not supported on Windows yet")
142122
if not utils.is_stdout_a_tty():
143123
raise MacheteException("Interactive `git machete go` requires stdout to be a TTY.")
144124

145-
# Get flat list of branches with depths from already-parsed state
146-
self._managed_branches_with_depths = self._get_branch_list_with_depths()
147-
148125
self._current_branch = current_branch
149-
150-
# Determine maximum visible branches from terminal height
151126
self._max_visible_branches = self._get_max_visible_branches()
152127

128+
# Compute status data once (no list-commits in TUI; config same as status)
129+
flags = StatusFlags(
130+
maybe_space_before_branch_name=(' ' if self._config.status_extra_space_before_branch_name() else ''),
131+
opt_list_commits=False,
132+
opt_list_commits_with_hashes=False,
133+
opt_squash_merge_detection=self._config.squash_merge_detection(),
134+
)
135+
self._status_data = self.compute_status_data(flags=flags)
136+
branches_ordered = self._status_data.branches_in_display_order
137+
153138
# Find initial selection (current branch or first branch if detached HEAD)
154139
if current_branch is not None:
155-
selected_idx = index_or_none(self.managed_branches, self._current_branch)
140+
selected_idx = index_or_none(branches_ordered, self._current_branch)
156141
if selected_idx is None:
157142
selected_idx = 0
158143
warn(f"current branch {self._current_branch} is unmanaged\n")
159144
else:
160-
# Detached HEAD - start with first managed branch
161145
selected_idx = 0
162146

163147
scroll_offset = 0
@@ -168,55 +152,43 @@ def go_interactive(self, *, current_branch: Optional[LocalBranchShortName]) -> O
168152
sys.stdout.write(utils.AE.HIDE_CURSOR)
169153
sys.stdout.flush()
170154

155+
branches = self._status_data.branches_in_display_order
171156
try:
172157
while True:
173-
# Calculate how many lines we'll draw (header + visible branches)
174-
visible_lines = min(self._max_visible_branches, len(self._managed_branches_with_depths))
175-
num_lines_drawn = visible_lines + 1 # +1 for header
176-
177-
scroll_offset = self._draw_screen(
158+
scroll_offset, num_lines_drawn = self._draw_screen(
178159
selected_idx=selected_idx,
179160
scroll_offset=scroll_offset,
180161
num_lines_drawn=num_lines_drawn,
181162
is_first_draw=is_first_draw
182163
)
183164
is_first_draw = False
184165

185-
# Read key
186166
key = self._getch()
187167

188168
if key == utils.AE.KEY_UP:
189-
# Wrap around from first to last
190-
selected_idx = (selected_idx - 1) % len(self._managed_branches_with_depths)
169+
selected_idx = (selected_idx - 1) % len(branches)
191170
elif key == utils.AE.KEY_DOWN:
192-
# Wrap around from last to first
193-
selected_idx = (selected_idx + 1) % len(self._managed_branches_with_depths)
171+
selected_idx = (selected_idx + 1) % len(branches)
194172
elif key == utils.AE.KEY_SHIFT_UP:
195-
# Jump to first branch
196173
selected_idx = 0
197174
elif key == utils.AE.KEY_SHIFT_DOWN:
198-
# Jump to last branch
199-
selected_idx = len(self._managed_branches_with_depths) - 1
175+
selected_idx = len(branches) - 1
200176
elif key == utils.AE.KEY_LEFT:
201-
# Go to parent
202-
selected_branch, _ = self._managed_branches_with_depths[selected_idx]
177+
selected_branch = branches[selected_idx]
203178
parent_branch = self.up_branch_for(selected_branch)
204-
if parent_branch:
205-
selected_idx = self.managed_branches.index(parent_branch)
179+
if parent_branch is not None:
180+
selected_idx = branches.index(parent_branch)
206181
elif key == utils.AE.KEY_RIGHT:
207-
# Go to first child
208-
selected_branch, _ = self._managed_branches_with_depths[selected_idx]
182+
selected_branch = branches[selected_idx]
209183
child_branches = self.down_branches_for(selected_branch)
210184
if child_branches:
211-
selected_idx = self.managed_branches.index(child_branches[0])
185+
selected_idx = branches.index(child_branches[0])
212186
elif key in utils.AE.KEYS_ENTER or key == utils.AE.KEY_SPACE:
213-
selected_branch, _ = self._managed_branches_with_depths[selected_idx]
214-
return selected_branch
187+
return branches[selected_idx]
215188
elif key in ('q', 'Q'):
216189
return None
217190
elif key == utils.AE.KEY_CTRL_C:
218191
return None
219192
finally:
220-
# Show cursor again and move past our interface
221193
sys.stdout.write(utils.AE.SHOW_CURSOR)
222194
sys.stdout.flush()

git_machete/client/status.py

Lines changed: 33 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -69,12 +69,26 @@ class StatusData(NamedTuple):
6969
ongoing_operation: StatusOngoingOperation
7070

7171

72+
class StatusFormatOutput(NamedTuple):
73+
"""Result of formatting status output. Returned by format_status_output."""
74+
75+
result: str
76+
# Maps each branch to the 0-based index of the line in result (when split by newlines) where it appears.
77+
line_for_branch: Dict[LocalBranchShortName, int]
78+
79+
7280
class StatusMacheteClient(MacheteClient):
7381
"""Client for the status command. Exposes status() and can be used as a mixin for other clients."""
7482

7583
@staticmethod
76-
def _format_status_output(data: StatusData) -> str:
77-
"""Pure function: given StatusData, returns the formatted status tree string."""
84+
def format_status_output(
85+
data: StatusData,
86+
*,
87+
selected_branch: Optional[LocalBranchShortName] = None,
88+
) -> StatusFormatOutput:
89+
"""Pure function: given StatusData, returns StatusFormatOutput (result string and line_for_branch).
90+
When selected_branch is set, that branch's name (not annotation/sync status) is wrapped in reverse video.
91+
line_for_branch maps each branch to the 0-based line index in result where it appears."""
7892

7993
# These maps need to be defined in a local scope to avoid for mocking the color palette more easily.
8094
sync_to_parent_status_to_edge_color_map: Dict[SyncToParentStatus, str] = {
@@ -89,9 +103,10 @@ def _format_status_output(data: StatusData) -> str:
89103
SyncToParentStatus.OUT_OF_SYNC: "x-",
90104
SyncToParentStatus.MERGED_TO_PARENT: "m-"
91105
}
92-
93106
out = io.StringIO()
94107
space = data.flags.maybe_space_before_branch_name
108+
line_for_branch: Dict[LocalBranchShortName, int] = {}
109+
line_index = 0
95110

96111
next_sibling_of_ancestor_by_branch: Dict[LocalBranchShortName, List[Optional[LocalBranchShortName]]] = {}
97112

@@ -127,12 +142,14 @@ def write_line_prefix(
127142
next_sibling_of_ancestor = next_sibling_of_ancestor_by_branch[branch]
128143
if b.up_branch is not None:
129144
write_line_prefix(branch, next_sibling_of_ancestor, f"{utils.get_vertical_bar()}\n")
145+
line_index += 1
130146
for commit, fp_suffix in b.commits:
131147
write_line_prefix(branch, next_sibling_of_ancestor, utils.get_vertical_bar())
132148
out.write(
133149
f' {f"{dim(commit.short_hash)} " if data.flags.opt_list_commits_with_hashes else ""}'
134150
f'{dim(commit.subject)}{fp_suffix}\n'
135151
)
152+
line_index += 1
136153
if utils.ascii_only:
137154
junction = sync_to_parent_status_to_junction_ascii_only_map[b.sync_to_parent_status]
138155
else:
@@ -145,8 +162,10 @@ def write_line_prefix(
145162
else:
146163
if branch != data.roots[0]:
147164
out.write("\n")
165+
line_index += 1
148166
out.write(" " + space)
149167

168+
line_for_branch[branch] = line_index
150169
op = data.ongoing_operation
151170
if branch in (op.currently_checked_out_branch, op.currently_rebased_branch, op.currently_bisected_branch):
152171
if branch == op.currently_rebased_branch:
@@ -171,9 +190,14 @@ def write_line_prefix(
171190
if b.annotation is not None and b.annotation.formatted_full_text:
172191
anno = ' ' + b.annotation.formatted_full_text
173192

174-
out.write(current + anno + b.sync_status + b.hook_output + "\n")
193+
if selected_branch is not None and branch == selected_branch:
194+
current_part = f"{utils.AE.REVERSE_VIDEO}{current}{utils.AE.ENDC}"
195+
else:
196+
current_part = current
197+
out.write(f"{current_part}{anno}{b.sync_status}{b.hook_output}\n")
198+
line_index += 1
175199

176-
return out.getvalue()
200+
return StatusFormatOutput(result=out.getvalue(), line_for_branch=line_for_branch)
177201

178202
@staticmethod
179203
def _status_warning_message(data: StatusData) -> Optional[str]:
@@ -216,7 +240,7 @@ def _status_warning_message(data: StatusData) -> Optional[str]:
216240
)
217241
return f"{first_part}.\n\n{second_part}."
218242

219-
def _compute_status_data(self, *, flags: StatusFlags) -> StatusData:
243+
def compute_status_data(self, *, flags: StatusFlags) -> StatusData:
220244
managed_branches: List[LocalBranchShortName] = list(self._state.managed_branches)
221245

222246
sync_to_parent_status: Dict[LocalBranchShortName, SyncToParentStatus] = {}
@@ -359,9 +383,9 @@ def status(
359383
opt_list_commits_with_hashes=opt_list_commits_with_hashes,
360384
opt_squash_merge_detection=opt_squash_merge_detection,
361385
)
362-
data = self._compute_status_data(flags=flags)
363-
status_str = self._format_status_output(data)
364-
sys.stdout.write(status_str)
386+
data = self.compute_status_data(flags=flags)
387+
format_out = self.format_status_output(data)
388+
sys.stdout.write(format_out.result)
365389
if warn_when_branch_in_sync_but_fork_point_off:
366390
warning_msg = self._status_warning_message(data)
367391
if warning_msg is not None:

0 commit comments

Comments
 (0)