Skip to content

Commit 3406b05

Browse files
committed
Improve the branch selection in interactive go (now same as status)
1 parent cc50fcd commit 3406b05

2 files changed

Lines changed: 92 additions & 89 deletions

File tree

Lines changed: 59 additions & 81 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import sys
2-
from typing import List, Optional, Tuple
2+
from typing import Optional, Tuple
33

44
try:
55
import termios
@@ -9,21 +9,23 @@
99
termios = None # type: ignore[assignment]
1010
tty = None # type: ignore[assignment]
1111

12-
from git_machete import utils
13-
from git_machete.client.base import MacheteClient
12+
from git_machete import git_config_keys, utils
13+
from git_machete.client.base import SquashMergeDetection
14+
from git_machete.client.status import (StatusData, StatusFlags,
15+
StatusMacheteClient)
1416
from git_machete.exceptions import UnexpectedMacheteException
1517
from git_machete.git_operations import LocalBranchShortName
1618
from git_machete.utils import AnsiEscapeCodes, bold, index_or_none, warn
1719

1820

19-
class GoInteractiveMacheteClient(MacheteClient):
21+
class GoInteractiveMacheteClient(StatusMacheteClient):
2022
"""Client for interactive branch selection using curses-style interface (implemented without curses, just using ANSI sequences)."""
2123

2224
MAX_VISIBLE_BRANCHES_DEFAULT = 20
2325
MAX_VISIBLE_BRANCHES_LOWER = 2
2426
MAX_VISIBLE_BRANCHES_UPPER = 50
2527

26-
_managed_branches_with_depths: List[Tuple[LocalBranchShortName, int]]
28+
_status_data: StatusData
2729
_current_branch: Optional[LocalBranchShortName]
2830
_max_visible_branches: int
2931

@@ -37,35 +39,10 @@ def _get_max_visible_branches(self) -> int:
3739
max_visible_branches = terminal_height - 2
3840
return max(self.MAX_VISIBLE_BRANCHES_LOWER, min(max_visible_branches, self.MAX_VISIBLE_BRANCHES_UPPER))
3941

40-
def _get_branch_list_with_depths(self) -> List[Tuple[LocalBranchShortName, int]]:
41-
"""Get a flat list of branches with their depths using DFS traversal."""
42-
result: List[Tuple[LocalBranchShortName, int]] = []
43-
44-
def add_branch_and_children(branch: LocalBranchShortName, depth: int) -> None:
45-
result.append((branch, depth))
46-
for child_branch in self.down_branches_for(branch) or []:
47-
add_branch_and_children(child_branch, depth + 1)
48-
49-
for root in self._state.roots:
50-
add_branch_and_children(root, depth=0)
51-
52-
return result
53-
54-
def _render_branch_line(self, branch: LocalBranchShortName, depth: int) -> str:
55-
"""Render a single branch line with indentation."""
56-
indent = " " * depth
57-
marker = " " if self._current_branch is None or branch != self._current_branch else "*"
58-
59-
line = f"{indent}{marker} {branch}"
60-
annotation = self.annotations.get(branch)
61-
if annotation and annotation.formatted_full_text:
62-
line += f" {annotation.formatted_full_text}"
63-
64-
return line
65-
6642
def _draw_screen(self, *, selected_idx: int, scroll_offset: int,
67-
num_lines_drawn: int, is_first_draw: bool) -> int:
68-
"""Draw the branch selection screen using ANSI escape codes."""
43+
num_lines_drawn: int, is_first_draw: bool) -> Tuple[int, int]:
44+
"""Draw the branch selection screen using status-style output (format_status_output).
45+
Returns (scroll_offset, actual_lines_drawn) so the caller can move the cursor up correctly on next redraw."""
6946
# Move cursor up to the start of our display area (if we've drawn before)
7047
if not is_first_draw and num_lines_drawn > 0:
7148
sys.stdout.write(f'{AnsiEscapeCodes.CSI}{num_lines_drawn}A')
@@ -79,27 +56,31 @@ def _draw_screen(self, *, selected_idx: int, scroll_offset: int,
7956
"Enter or Space: checkout, q or Ctrl+C: quit)")
8057
sys.stdout.write(bold(header_text) + '\n')
8158

82-
# Adjust scroll offset if needed
83-
visible_lines = min(self._max_visible_branches, len(self._managed_branches_with_depths))
84-
if selected_idx < scroll_offset:
85-
scroll_offset = selected_idx
86-
elif selected_idx >= scroll_offset + visible_lines:
87-
scroll_offset = selected_idx - visible_lines + 1
88-
89-
# Draw branches
59+
branches = self._status_data.branches_in_display_order
60+
selected_branch = branches[selected_idx] if 0 <= selected_idx < len(branches) else None
61+
formatted = self.format_status_output(
62+
self._status_data,
63+
selected_branch=selected_branch,
64+
)
65+
lines = formatted.result.splitlines()
66+
num_lines = len(lines)
67+
visible_lines = min(self._max_visible_branches, num_lines)
68+
# line_for_branch maps branch -> 0-based line index in result
69+
selected_line_idx = formatted.line_for_branch.get(selected_branch, 0) if selected_branch else 0
70+
if selected_line_idx < scroll_offset:
71+
scroll_offset = selected_line_idx
72+
elif selected_line_idx >= scroll_offset + visible_lines:
73+
scroll_offset = selected_line_idx - visible_lines + 1
74+
75+
lines_drawn = 1 # header
9076
for i in range(visible_lines):
91-
branch_idx = scroll_offset + i
92-
branch, depth = self._managed_branches_with_depths[branch_idx]
93-
line = self._render_branch_line(branch, depth)
94-
95-
if branch_idx == selected_idx:
96-
# Highlight selected line (inverse video)
97-
sys.stdout.write(f'{AnsiEscapeCodes.REVERSE_VIDEO}{line}{AnsiEscapeCodes.ENDC}\n')
98-
else:
99-
sys.stdout.write(f'{line}\n')
77+
line_idx = scroll_offset + i
78+
if line_idx < num_lines:
79+
sys.stdout.write(lines[line_idx] + '\n')
80+
lines_drawn += 1
10081

10182
sys.stdout.flush()
102-
return scroll_offset
83+
return scroll_offset, lines_drawn
10384

10485
def _get_stdin_fd(self) -> int: # pragma: no cover; always mocked in tests
10586
return sys.stdin.fileno()
@@ -136,26 +117,35 @@ def go_interactive(self, *, current_branch: Optional[LocalBranchShortName]) -> O
136117
"""
137118
Launch interactive branch selection interface.
138119
Returns the selected branch or None if cancelled.
120+
Status data is computed once at start; only rendering (format_status_output) runs on each redraw.
139121
"""
140122
if termios is None or tty is None:
141123
raise UnexpectedMacheteException("Interactive mode is not supported on Windows yet")
142124

143-
# Get flat list of branches with depths from already-parsed state
144-
self._managed_branches_with_depths = self._get_branch_list_with_depths()
145-
146125
self._current_branch = current_branch
147-
148-
# Determine maximum visible branches from terminal height
149126
self._max_visible_branches = self._get_max_visible_branches()
150127

128+
# Compute status data once (no list-commits in TUI; squash merge detection defaults to simple)
129+
maybe_space = (
130+
' ' if self._git.get_boolean_config_attr(
131+
git_config_keys.STATUS_EXTRA_SPACE_BEFORE_BRANCH_NAME, default_value=False) else ''
132+
)
133+
flags = StatusFlags(
134+
maybe_space_before_branch_name=maybe_space,
135+
opt_list_commits=False,
136+
opt_list_commits_with_hashes=False,
137+
opt_squash_merge_detection=SquashMergeDetection.SIMPLE,
138+
)
139+
self._status_data = self.compute_status_data(flags=flags)
140+
branches_ordered = self._status_data.branches_in_display_order
141+
151142
# Find initial selection (current branch or first branch if detached HEAD)
152143
if current_branch is not None:
153-
selected_idx = index_or_none(self.managed_branches, self._current_branch)
144+
selected_idx = index_or_none(branches_ordered, self._current_branch)
154145
if selected_idx is None:
155146
selected_idx = 0
156147
warn(f"current branch {self._current_branch} is unmanaged\n")
157148
else:
158-
# Detached HEAD - start with first managed branch
159149
selected_idx = 0
160150

161151
scroll_offset = 0
@@ -166,55 +156,43 @@ def go_interactive(self, *, current_branch: Optional[LocalBranchShortName]) -> O
166156
sys.stdout.write(AnsiEscapeCodes.HIDE_CURSOR)
167157
sys.stdout.flush()
168158

159+
branches = self._status_data.branches_in_display_order
169160
try:
170161
while True:
171-
# Calculate how many lines we'll draw (header + visible branches)
172-
visible_lines = min(self._max_visible_branches, len(self._managed_branches_with_depths))
173-
num_lines_drawn = visible_lines + 1 # +1 for header
174-
175-
scroll_offset = self._draw_screen(
162+
scroll_offset, num_lines_drawn = self._draw_screen(
176163
selected_idx=selected_idx,
177164
scroll_offset=scroll_offset,
178165
num_lines_drawn=num_lines_drawn,
179166
is_first_draw=is_first_draw
180167
)
181168
is_first_draw = False
182169

183-
# Read key
184170
key = self._getch()
185171

186172
if key == AnsiEscapeCodes.KEY_UP:
187-
# Wrap around from first to last
188-
selected_idx = (selected_idx - 1) % len(self._managed_branches_with_depths)
173+
selected_idx = (selected_idx - 1) % len(branches)
189174
elif key == AnsiEscapeCodes.KEY_DOWN:
190-
# Wrap around from last to first
191-
selected_idx = (selected_idx + 1) % len(self._managed_branches_with_depths)
175+
selected_idx = (selected_idx + 1) % len(branches)
192176
elif key == AnsiEscapeCodes.KEY_SHIFT_UP:
193-
# Jump to first branch
194177
selected_idx = 0
195178
elif key == AnsiEscapeCodes.KEY_SHIFT_DOWN:
196-
# Jump to last branch
197-
selected_idx = len(self._managed_branches_with_depths) - 1
179+
selected_idx = len(branches) - 1
198180
elif key == AnsiEscapeCodes.KEY_LEFT:
199-
# Go to parent
200-
selected_branch, _ = self._managed_branches_with_depths[selected_idx]
181+
selected_branch = branches[selected_idx]
201182
parent_branch = self.up_branch_for(selected_branch)
202-
if parent_branch:
203-
selected_idx = self.managed_branches.index(parent_branch)
183+
if parent_branch is not None:
184+
selected_idx = branches.index(parent_branch)
204185
elif key == AnsiEscapeCodes.KEY_RIGHT:
205-
# Go to first child
206-
selected_branch, _ = self._managed_branches_with_depths[selected_idx]
186+
selected_branch = branches[selected_idx]
207187
child_branches = self.down_branches_for(selected_branch)
208188
if child_branches:
209-
selected_idx = self.managed_branches.index(child_branches[0])
189+
selected_idx = branches.index(child_branches[0])
210190
elif key in AnsiEscapeCodes.KEYS_ENTER or key == AnsiEscapeCodes.KEY_SPACE:
211-
selected_branch, _ = self._managed_branches_with_depths[selected_idx]
212-
return selected_branch
191+
return branches[selected_idx]
213192
elif key in ('q', 'Q'):
214193
return None
215194
elif key == AnsiEscapeCodes.KEY_CTRL_C:
216195
return None
217196
finally:
218-
# Show cursor again and move past our interface
219197
sys.stdout.write(AnsiEscapeCodes.SHOW_CURSOR)
220198
sys.stdout.flush()

git_machete/client/status.py

Lines changed: 33 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -83,14 +83,30 @@ class StatusData(NamedTuple):
8383
ongoing_operation: StatusOngoingOperation
8484

8585

86+
class StatusFormatOutput(NamedTuple):
87+
"""Result of formatting status output. Returned by format_status_output."""
88+
89+
result: str
90+
# Maps each branch to the 0-based index of the line in result (when split by newlines) where it appears.
91+
line_for_branch: Dict[LocalBranchShortName, int]
92+
93+
8694
class StatusMacheteClient(MacheteClient):
8795
"""Client for the status command. Exposes status() and can be used as a mixin for other clients."""
8896

8997
@staticmethod
90-
def _format_status_output(data: StatusData) -> str:
91-
"""Pure function: given StatusData, returns the formatted status tree string."""
98+
def format_status_output(
99+
data: StatusData,
100+
*,
101+
selected_branch: Optional[LocalBranchShortName] = None,
102+
) -> StatusFormatOutput:
103+
"""Pure function: given StatusData, returns StatusFormatOutput (result string and line_for_branch).
104+
When selected_branch is set, that branch's name (not annotation/sync status) is wrapped in reverse video.
105+
line_for_branch maps each branch to the 0-based line index in result where it appears."""
92106
out = io.StringIO()
93107
space = data.flags.maybe_space_before_branch_name
108+
line_for_branch: Dict[LocalBranchShortName, int] = {}
109+
line_index = 0
94110

95111
next_sibling_of_ancestor_by_branch: Dict[LocalBranchShortName, List[Optional[LocalBranchShortName]]] = {}
96112

@@ -125,12 +141,14 @@ def write_line_prefix(
125141
next_sibling_of_ancestor = next_sibling_of_ancestor_by_branch[branch]
126142
if b.up_branch is not None:
127143
write_line_prefix(branch, next_sibling_of_ancestor, f"{utils.get_vertical_bar()}\n")
144+
line_index += 1
128145
for commit, fp_suffix in b.commits:
129146
write_line_prefix(branch, next_sibling_of_ancestor, utils.get_vertical_bar())
130147
out.write(
131148
f' {f"{dim(commit.short_hash)} " if data.flags.opt_list_commits_with_hashes else ""}'
132149
f'{dim(commit.subject)}{fp_suffix}\n'
133150
)
151+
line_index += 1
134152
if utils.ascii_only:
135153
junction = sync_to_parent_status_to_junction_ascii_only_map[b.sync_to_parent_status]
136154
else:
@@ -143,8 +161,10 @@ def write_line_prefix(
143161
else:
144162
if branch != data.roots[0]:
145163
out.write("\n")
164+
line_index += 1
146165
out.write(" " + space)
147166

167+
line_for_branch[branch] = line_index
148168
op = data.ongoing_operation
149169
if branch in (op.currently_checked_out_branch, op.currently_rebased_branch, op.currently_bisected_branch):
150170
if branch == op.currently_rebased_branch:
@@ -169,9 +189,14 @@ def write_line_prefix(
169189
if b.annotation is not None and b.annotation.formatted_full_text:
170190
anno = ' ' + b.annotation.formatted_full_text
171191

172-
out.write(current + anno + b.sync_status + b.hook_output + "\n")
192+
if selected_branch is not None and branch == selected_branch:
193+
current_part = f"{AnsiEscapeCodes.REVERSE_VIDEO}{current}{AnsiEscapeCodes.ENDC}"
194+
else:
195+
current_part = current
196+
out.write(f"{current_part}{anno}{b.sync_status}{b.hook_output}\n")
197+
line_index += 1
173198

174-
return out.getvalue()
199+
return StatusFormatOutput(result=out.getvalue(), line_for_branch=line_for_branch)
175200

176201
@staticmethod
177202
def _status_warning_message(data: StatusData) -> Optional[str]:
@@ -214,7 +239,7 @@ def _status_warning_message(data: StatusData) -> Optional[str]:
214239
)
215240
return f"{first_part}.\n\n{second_part}."
216241

217-
def _compute_status_data(self, *, flags: StatusFlags) -> StatusData:
242+
def compute_status_data(self, *, flags: StatusFlags) -> StatusData:
218243
managed_branches: List[LocalBranchShortName] = list(self._state.managed_branches)
219244

220245
sync_to_parent_status: Dict[LocalBranchShortName, SyncToParentStatus] = {}
@@ -360,9 +385,9 @@ def status(
360385
opt_list_commits_with_hashes=opt_list_commits_with_hashes,
361386
opt_squash_merge_detection=opt_squash_merge_detection,
362387
)
363-
data = self._compute_status_data(flags=flags)
364-
status_str = self._format_status_output(data)
365-
sys.stdout.write(status_str)
388+
data = self.compute_status_data(flags=flags)
389+
format_out = self.format_status_output(data)
390+
sys.stdout.write(format_out.result)
366391
if warn_when_branch_in_sync_but_fork_point_off:
367392
warning_msg = self._status_warning_message(data)
368393
if warning_msg is not None:

0 commit comments

Comments
 (0)