diff --git a/lua/dropbar/bar.lua b/lua/dropbar/bar.lua index d00b46e4..25e1cba6 100644 --- a/lua/dropbar/bar.lua +++ b/lua/dropbar/bar.lua @@ -11,7 +11,7 @@ local function str_sanitize(str) return str and vim.gsplit(str, '\n')() end ----@alias dropbar_symbol_range_t lsp_range_t +---@alias dropbar_symbol_range_t dropbar_range_t ---Symbol in dropbar, basic element of `dropbar_t` and ---`dropbar_menu_entry_t` diff --git a/lua/dropbar/sources/lsp.lua b/lua/dropbar/sources/lsp.lua index 4c3fdd05..2f8d0baf 100644 --- a/lua/dropbar/sources/lsp.lua +++ b/lua/dropbar/sources/lsp.lua @@ -4,7 +4,7 @@ local utils = require('dropbar.utils') local groupid = vim.api.nvim_create_augroup('dropbar.sources.lsp', {}) local initialized = false ----@type table +---@type table local lsp_buf_symbols = {} setmetatable(lsp_buf_symbols, { __index = function(_, k) @@ -13,40 +13,36 @@ setmetatable(lsp_buf_symbols, { end, }) ----@alias lsp_client_t table +---@alias dropbar_lsp_client_t table ----@class lsp_range_t ----@field start {line: integer, character: integer} ----@field end {line: integer, character: integer} - ----@class lsp_location_t +---@class dropbar_lsp_location_t ---@field uri string ----@field range lsp_range_t +---@field range dropbar_range_t ----@class lsp_document_symbol_t +---@class dropbar_lsp_document_symbol_t ---@field name string ---@field kind integer ---@field tags? table ---@field deprecated? boolean ---@field detail? string ----@field range? lsp_range_t ----@field selectionRange? lsp_range_t ----@field children? lsp_document_symbol_t[] +---@field range? dropbar_range_t +---@field selectionRange? dropbar_range_t +---@field children? dropbar_lsp_document_symbol_t[] ----@class lsp_symbol_information_t +---@class dropbar_lsp_symbol_information_t ---@field name string ---@field kind integer ---@field tags? table ---@field deprecated? boolean ----@field location? lsp_location_t +---@field location? dropbar_lsp_location_t ---@field containerName? string ----@class lsp_symbol_information_tree_t: lsp_symbol_information_t ----@field parent? lsp_symbol_information_tree_t ----@field children? lsp_symbol_information_tree_t[] ----@field siblings? lsp_symbol_information_tree_t[] +---@class dropbar_lsp_symbol_information_tree_t: dropbar_lsp_symbol_information_t +---@field parent? dropbar_lsp_symbol_information_tree_t +---@field children? dropbar_lsp_symbol_information_tree_t[] +---@field siblings? dropbar_lsp_symbol_information_tree_t[] ----@alias lsp_symbol_t lsp_document_symbol_t|lsp_symbol_information_t +---@alias dropbar_lsp_symbol_t dropbar_lsp_document_symbol_t|dropbar_lsp_symbol_information_t -- Map symbol number to symbol kind -- stylua: ignore start @@ -87,7 +83,7 @@ local symbol_kind_names = setmetatable({ ---@alias lsp_symbol_type_t 'SymbolInformation'|'DocumentSymbol' ---Return type of the symbol table ----@param symbols lsp_symbol_t[] symbol table +---@param symbols dropbar_lsp_symbol_t[] symbol table ---@return lsp_symbol_type_t? type symbol type local function symbol_type(symbols) if symbols[1] and symbols[1].location then @@ -98,61 +94,11 @@ local function symbol_type(symbols) end end ----Check if cursor is in range ----@param cursor integer[] cursor position (line, character); (1, 0)-based ----@param range lsp_range_t 0-based range ----@return boolean -local function cursor_in_range(cursor, range) - local cursor0 = { cursor[1] - 1, cursor[2] } - -- stylua: ignore start - return ( - cursor0[1] > range.start.line - or (cursor0[1] == range.start.line - and cursor0[2] >= range.start.character) - ) - and ( - cursor0[1] < range['end'].line - or (cursor0[1] == range['end'].line - and cursor0[2] <= range['end'].character) - ) - -- stylua: ignore end -end - ----Check if range1 contains range2 ----Strict indexing -- if range1 == range2, return false ----@param range1 lsp_range_t 0-based range ----@param range2 lsp_range_t 0-based range ----@return boolean -local function range_contains(range1, range2) - -- stylua: ignore start - return ( - range2.start.line > range1.start.line - or (range2.start.line == range1.start.line - and range2.start.character > range1.start.character) - ) - and ( - range2.start.line < range1['end'].line - or (range2.start.line == range1['end'].line - and range2.start.character < range1['end'].character) - ) - and ( - range2['end'].line > range1.start.line - or (range2['end'].line == range1.start.line - and range2['end'].character > range1.start.character) - ) - and ( - range2['end'].line < range1['end'].line - or (range2['end'].line == range1['end'].line - and range2['end'].character < range1['end'].character) - ) - -- stylua: ignore end -end - ---Convert LSP DocumentSymbol into winbar symbol ----@param document_symbol lsp_document_symbol_t LSP DocumentSymbol +---@param document_symbol dropbar_lsp_document_symbol_t LSP DocumentSymbol ---@param buf integer buffer number ---@param win integer window number ----@param siblings lsp_document_symbol_t[]? siblings of the symbol +---@param siblings dropbar_lsp_document_symbol_t[]? siblings of the symbol ---@param idx integer? index of the symbol in siblings ---@return dropbar_symbol_t local function convert_document_symbol( @@ -200,7 +146,7 @@ end ---Convert LSP DocumentSymbol[] into a list of dropbar symbols ---Side effect: change dropbar_symbols ---LSP Specification document: https://microsoft.github.io/language-server-protocol/specifications/lsp/3.17/specification/ ----@param lsp_symbols lsp_document_symbol_t[] +---@param lsp_symbols dropbar_lsp_document_symbol_t[] ---@param dropbar_symbols dropbar_symbol_t[] (reference to) dropbar symbols ---@param buf integer buffer number ---@param win integer window number @@ -219,7 +165,7 @@ local function convert_document_symbol_list( -- Parse in reverse order so that the symbol with the largest start position -- is preferred for idx, symbol in vim.iter(lsp_symbols):enumerate():rev() do - if cursor_in_range(cursor, symbol.range) then + if utils.range.contains_cursor(cursor, symbol.range) then if vim.tbl_contains( configs.opts.sources.lsp.valid_symbols, @@ -246,8 +192,8 @@ local function convert_document_symbol_list( end ---Convert LSP SymbolInformation[] into DocumentSymbol[] ----@param symbols lsp_symbol_t LSP symbols ----@return lsp_document_symbol_t[] +---@param symbols dropbar_lsp_symbol_t LSP symbols +---@return dropbar_lsp_document_symbol_t[] local function unify(symbols) if symbol_type(symbols) == 'DocumentSymbol' or vim.tbl_isempty(symbols) then return symbols @@ -262,9 +208,9 @@ local function unify(symbols) -- symbol can only be a child or a sibling of the previous symbol in the -- same list for list_idx, sym in vim.iter(symbols):enumerate():skip(1) do - local prev = symbols[list_idx - 1] --[[@as lsp_symbol_information_tree_t]] + local prev = symbols[list_idx - 1] --[[@as dropbar_lsp_symbol_information_tree_t]] -- If the symbol is a child of the previous symbol - if range_contains(prev.location.range, sym.location.range) then + if utils.range.contains(prev.location.range, sym.location.range) then sym.parent = prev else -- Else the symbol is a sibling of the previous symbol sym.parent = prev.parent @@ -326,8 +272,8 @@ local function update_symbols(buf, ttl) -- responses can be disordered i.e. later symbols can appear first lsp_buf_symbols[buf] = unify(symbols) - ---@param s1 lsp_document_symbol_t - ---@param s2 lsp_document_symbol_t + ---@param s1 dropbar_lsp_document_symbol_t + ---@param s2 dropbar_lsp_document_symbol_t ---@return boolean precedes true if `s1` appears before `s2` table.sort(lsp_buf_symbols[buf], function(s1, s2) local l1, l2, c1, c2 = diff --git a/lua/dropbar/sources/treesitter.lua b/lua/dropbar/sources/treesitter.lua index b6be7937..27568f01 100644 --- a/lua/dropbar/sources/treesitter.lua +++ b/lua/dropbar/sources/treesitter.lua @@ -2,6 +2,11 @@ local configs = require('dropbar.configs') local bar = require('dropbar.bar') local utils = require('dropbar.utils') +-- Max character offset when comparing symbol boundaries during deduplication +-- If two boundaries are on the same line and within this offset, treat them as +-- equal range +local DEDUP_RANGE_MATCH_TOL = 2 + ---Convert a snake_case string to camelCase ---@param str string? ---@return string? @@ -14,21 +19,135 @@ local function snake_to_camel(str) ) end +---@class dropbar_ts_cache_t +---@field symbol_info table +---@field short_names table + +---@return dropbar_ts_cache_t +local function create_symbol_cache() + return { + symbol_info = setmetatable({}, { __mode = 'k' }), + short_names = setmetatable({}, { __mode = 'k' }), + } +end + ---Get short name of treesitter symbols in buffer buf ---@param node TSNode ----@param buf integer buffer handler ----@return string name -local function get_node_short_name(node, buf) - return ( - vim - .trim( - vim.fn.matchstr( - vim.treesitter.get_node_text(node, buf):gsub('\n', ' '), - configs.opts.sources.treesitter.name_regex - ) +---@param buf integer +---@param cache dropbar_ts_cache_t +---@return string? +local function get_node_short_name(node, buf, cache) + local cached = cache.short_names[node] + if cached ~= nil then + return cached or nil + end + + local name = vim + .trim( + vim.fn.matchstr( + vim.treesitter.get_node_text(node, buf):gsub('\n', ' '), + configs.opts.sources.treesitter.name_regex ) - :gsub('%s+', ' ') - ) + ) + :gsub('%s+', ' ') + if name == '' then + cache.short_names[node] = false + return nil + end + + cache.short_names[node] = name + return name +end + +---@param node TSNode +---@return dropbar_range_t +local function get_node_range(node) + local start_line, start_col, end_line, end_col = + vim.treesitter.get_node_range(node) + return { + start = { + line = start_line, + character = start_col, + }, + ['end'] = { + line = end_line, + character = end_col, + }, + } +end + +---Returns true if the node has at least one child and none of its children are +---named treesitter nodes i.e. all children are anonymous +--- +---By heuristic such nodes can be skipped when collecting symbols +---@param node TSNode +---@return boolean +local function has_only_anonymous_children(node) + return node:child_count() > 0 and node:named_child_count() == 0 +end + +---@param node TSNode +---@param buf integer +---@param cache dropbar_ts_cache_t +---@return { name: string, source_range?: dropbar_range_t }? +local function resolve_node_short_name(node, buf, cache) + local has_named_children = false + local named_children = {} ---@type TSNode[] + local node_start_line = vim.treesitter.get_node_range(node) + + for child, field_name in node:iter_children() do + if child:named() then + if has_only_anonymous_children(child) then + goto continue + end + + has_named_children = true + table.insert(named_children, child) + + if field_name then + local name = get_node_short_name(child, buf, cache) + if name then + return { + name = name, + source_range = get_node_range(child), + } + end + end + end + + ::continue:: + end + + for _, child in ipairs(named_children) do + local child_start_line = vim.treesitter.get_node_range(child) + if child_start_line ~= node_start_line then + goto continue + end + + local name = get_node_short_name(child, buf, cache) + if name then + return { + name = name, + source_range = get_node_range(child), + } + end + + ::continue:: + end + + if has_named_children then + return nil + end + + local name = get_node_short_name(node, buf, cache) + if not name then + return nil + end + + return { + name = name, + source_range = get_node_range(node), + } end ---Get valid treesitter node type name @@ -44,26 +163,189 @@ local function get_node_short_type(node) return '' end +---@class dropbar_ts_symbol_info +---@field short_type string +---@field kind string +---@field name_info { name: string, source_range?: dropbar_range_t } + +---@param node TSNode +---@param buf integer buffer handler +---@param cache dropbar_ts_cache_t +---@return dropbar_ts_symbol_info? +local function resolve_symbol_info(node, buf, cache) + local cached = cache.symbol_info[node] + if cached ~= nil then + return cached or nil + end + + local short_type = get_node_short_type(node) + if short_type == '' then + cache.symbol_info[node] = false + return nil + end + + local name_info = resolve_node_short_name(node, buf, cache) + if not name_info then + cache.symbol_info[node] = false + return nil + end + + local symbol_info = { + short_type = short_type, + kind = snake_to_camel(short_type), + name_info = name_info, + } + cache.symbol_info[node] = symbol_info + return symbol_info +end + ---Check if treesitter node is valid ---@param node TSNode ---@param buf integer buffer handler +---@param cache dropbar_ts_cache_t +---@return boolean +local function valid_node(node, buf, cache) + return resolve_symbol_info(node, buf, cache) ~= nil +end + +---@param a_pos dropbar_pos_t +---@param b_pos dropbar_pos_t +---@return integer +local function compare_pos(a_pos, b_pos) + if a_pos.line ~= b_pos.line then + return a_pos.line < b_pos.line and -1 or 1 + end + if a_pos.character ~= b_pos.character then + return a_pos.character < b_pos.character and -1 or 1 + end + return 0 +end + +---@param s1 dropbar_symbol_t +---@param s2 dropbar_symbol_t ---@return boolean -local function valid_node(node, buf) - return get_node_short_type(node) ~= '' - and get_node_short_name(node, buf) ~= '' +---@return boolean s1_contains_s2 +---@return boolean s2_contains_s1 +local function should_dedup_adjacent(s1, s2) + if s1.name ~= s2.name or s1.name == '' then + return false, false, false + end + + local s1_contains_s2, s2_contains_s1 + + if s1.name_source and s2.name_source then + if + utils.range.matches( + s1.name_source, + s2.name_source, + DEDUP_RANGE_MATCH_TOL + ) + then + s1_contains_s2 = utils.range.contains(s1.range, s2.range, false) + s2_contains_s1 = utils.range.contains(s2.range, s1.range, false) + return true, s1_contains_s2, s2_contains_s1 + end + end + + local same_start = compare_pos(s1.range.start, s2.range.start) == 0 + local same_end = compare_pos(s1.range['end'], s2.range['end']) == 0 + if not same_start and not same_end then + return false, false, false + end + + -- Equal ranges should still deduplicate; strict containment would return false + if same_start and same_end then + return true, false, false + end + + s1_contains_s2 = utils.range.contains(s1.range, s2.range, false) + s2_contains_s1 = utils.range.contains(s2.range, s1.range, false) + return s1_contains_s2 or s2_contains_s1, s1_contains_s2, s2_contains_s1 +end + +---@param symbols dropbar_symbol_t[] +---@return dropbar_symbol_t[] +local function dedup_adjacent_symbols(symbols) + if #symbols < 2 then + return symbols + end + + local deduped = { symbols[1] } + for i = 2, #symbols do + local current = symbols[i] + local previous = deduped[#deduped] + + if + previous.name_source + and current.name_source + and utils.range.contains( + previous.name_source, + current.name_source, + false + ) + then + local same_start = compare_pos( + previous.name_source.start, + current.name_source.start + ) == 0 + local current_ends_earlier = compare_pos( + current.name_source['end'], + previous.name_source['end'] + ) < 0 + if + same_start + and current_ends_earlier + and previous.name_source['end'].line + ~= current.name_source['end'].line + then + deduped[#deduped] = current + goto continue + end + + if + previous.name_source.start.line == current.name_source.start.line + and previous.name_source['end'].line + == current.name_source['end'].line + then + goto continue + end + end + + local should_dedup, previous_contains_current, current_contains_previous = + should_dedup_adjacent(previous, current) + if should_dedup then + if previous_contains_current and not current_contains_previous then + -- Keep narrower symbol when names overlap. + deduped[#deduped] = current + elseif current_contains_previous and not previous_contains_current then + -- Keep narrower symbol when names overlap. + deduped[#deduped] = previous + else + -- Equal ranges: keep the deeper (later) symbol. + deduped[#deduped] = current + end + else + table.insert(deduped, current) + end + + ::continue:: + end + + return deduped end ---Get treesitter node children ---@param node TSNode ---@param buf integer buffer handler +---@param cache dropbar_ts_cache_t ---@return TSNode[] children -local function get_node_children(node, buf) +local function get_node_children(node, buf, cache) local children = {} for child in node:iter_children() do - if valid_node(child, buf) then + if valid_node(child, buf, cache) then table.insert(children, child) else - vim.list_extend(children, get_node_children(child, buf)) + vim.list_extend(children, get_node_children(child, buf, cache)) end end return children @@ -72,17 +354,19 @@ end ---Get treesitter node siblings ---@param node TSNode ---@param buf integer buffer handler +---@param cache dropbar_ts_cache_t ---@return TSNode[] siblings ---@return integer idx index of the node in its siblings -local function get_node_siblings(node, buf) +local function get_node_siblings(node, buf, cache) local siblings = {} local current = node ---@type TSNode? while current do - if valid_node(current, buf) then + if valid_node(current, buf, cache) then table.insert(siblings, 1, current) else - siblings = vim.list_extend(get_node_children(current, buf), siblings) + siblings = + vim.list_extend(get_node_children(current, buf, cache), siblings) end current = current:prev_sibling() end @@ -90,10 +374,10 @@ local function get_node_siblings(node, buf) current = node:next_sibling() while current do - if valid_node(current, buf) then + if valid_node(current, buf, cache) then table.insert(siblings, current) else - vim.list_extend(siblings, get_node_children(current, buf)) + vim.list_extend(siblings, get_node_children(current, buf, cache)) end current = current:next_sibling() end @@ -105,45 +389,41 @@ end ---@param ts_node TSNode ---@param buf integer buffer handler ---@param win integer window handler +---@param cache dropbar_ts_cache_t +---@param symbol_info? dropbar_ts_symbol_info ---@return dropbar_symbol_t? -local function convert(ts_node, buf, win) - if not valid_node(ts_node, buf) then +local function convert(ts_node, buf, win, cache, symbol_info) + symbol_info = symbol_info or resolve_symbol_info(ts_node, buf, cache) + if not symbol_info then return nil end - local kind = snake_to_camel(get_node_short_type(ts_node)) - local range = { ts_node:range() } + return bar.dropbar_symbol_t:new(setmetatable({ buf = buf, win = win, - name = get_node_short_name(ts_node, buf), - icon = configs.opts.icons.kinds.symbols[kind], - name_hl = 'DropBarKind' .. kind, - icon_hl = 'DropBarIconKind' .. kind, - range = { - start = { - line = range[1], - character = range[2], - }, - ['end'] = { - line = range[3], - character = range[4], - }, - }, + ts_node = ts_node, + kind = symbol_info.kind, + name = symbol_info.name_info.name, + name_source = symbol_info.name_info.source_range, + icon = configs.opts.icons.kinds.symbols[symbol_info.kind], + name_hl = 'DropBarKind' .. symbol_info.kind, + icon_hl = 'DropBarIconKind' .. symbol_info.kind, + range = get_node_range(ts_node), }, { ---@param self dropbar_symbol_t ---@param k string|number __index = function(self, k) if k == 'children' then self.children = vim.tbl_map(function(child) - return convert(child, buf, win) - end, get_node_children(ts_node, buf)) + return convert(child, buf, win, cache) + end, get_node_children(ts_node, buf, cache)) return self.children end if k == 'siblings' or k == 'sibling_idx' then - local siblings, idx = get_node_siblings(ts_node, buf) + local siblings, idx = get_node_siblings(ts_node, buf, cache) self.siblings = vim.tbl_map(function(sibling) - return convert(sibling, buf, win) + return convert(sibling, buf, win, cache) end, siblings) self.sibling_idx = idx return self[k] @@ -171,25 +451,30 @@ local function get_symbols(buf, win, cursor) end local symbols = {} ---@type dropbar_symbol_t[] + local cache = create_symbol_cache() + local mode = vim.api.nvim_get_mode().mode + local is_insert_mode = mode:sub(1, 1) == 'i' + local col_offset = cursor[2] >= 1 and is_insert_mode and 1 or 0 -- Prevent errors when getting node from filetypes without a parser local node = vim.F.npcall(vim.treesitter.get_node, { - ft = vim.filetype.match({ buf = buf }), bufnr = buf, pos = { cursor[1] - 1, - cursor[2] - - (cursor[2] >= 1 and vim.startswith(vim.fn.mode(), 'i') and 1 or 0), + cursor[2] - col_offset, }, }) while node and #symbols < configs.opts.sources.treesitter.max_depth do - if valid_node(node, buf) then - table.insert(symbols, 1, convert(node, buf, win)) + local symbol_info = resolve_symbol_info(node, buf, cache) + if symbol_info then + table.insert(symbols, 1, convert(node, buf, win, cache, symbol_info)) end node = node:parent() end + symbols = dedup_adjacent_symbols(symbols) + utils.bar.set_min_widths(symbols, configs.opts.sources.treesitter.min_widths) return symbols end diff --git a/lua/dropbar/utils/init.lua b/lua/dropbar/utils/init.lua index f12a2725..1d38d7f3 100644 --- a/lua/dropbar/utils/init.lua +++ b/lua/dropbar/utils/init.lua @@ -2,6 +2,7 @@ return setmetatable({ bar = nil, ---@module 'dropbar.utils.bar' menu = nil, ---@module 'dropbar.utils.menu' source = nil, ---@module 'dropbar.utils.source' + range = nil, ---@module 'dropbar.utils.range' }, { __index = function(_, key) return require('dropbar.utils.' .. key) diff --git a/lua/dropbar/utils/pos.lua b/lua/dropbar/utils/pos.lua new file mode 100644 index 00000000..f459caed --- /dev/null +++ b/lua/dropbar/utils/pos.lua @@ -0,0 +1,18 @@ +local M = {} + +---@class dropbar_pos_t +---@field line integer +---@field character integer + +---Check if two positions are equal within a column tolerance +---Requires both positions to be on the same line; columns may differ by at most `tol`. +---@param pos1 dropbar_pos_t 0-based position +---@param pos2 dropbar_pos_t 0-based position +---@param tol integer maximum allowed column delta (>= 0) +---@return boolean +function M.matches(pos1, pos2, tol) + return pos1.line == pos2.line + and math.abs(pos1.character - pos2.character) <= tol +end + +return M diff --git a/lua/dropbar/utils/range.lua b/lua/dropbar/utils/range.lua new file mode 100644 index 00000000..9d1b65a0 --- /dev/null +++ b/lua/dropbar/utils/range.lua @@ -0,0 +1,83 @@ +local M = {} + +local utils = require('dropbar.utils') + +---@class dropbar_range_t +---@field start dropbar_pos_t +---@field end dropbar_pos_t + +---Check if r1 contains r2 +---Strict indexing -- if r1 == r2, return false +---@param r1 dropbar_range_t 0-based range +---@param r2 dropbar_range_t 0-based range +---@param strict boolean? only return true if `range1` fully contains `range2` (no overlapping boundaries), default false +---@return boolean +function M.contains(r1, r2, strict) + return ( + r2.start.line > r1.start.line + or ( + r2.start.line == r1.start.line + and ( + r2.start.character > r1.start.character + or not strict and r2.start.character == r1.start.character + ) + ) + ) + and (r2.start.line < r1['end'].line or (r2.start.line == r1['end'].line and (r2.start.character < r1['end'].character or not strict and r2.start.character == r1['end'].character))) + and (r2['end'].line > r1.start.line or (r2['end'].line == r1.start.line and (r2['end'].character > r1.start.character or not strict and r2['end'].character == r1.start.character))) + and ( + r2['end'].line < r1['end'].line + or ( + r2['end'].line == r1['end'].line + and ( + r2['end'].character < r1['end'].character + or not strict and r2['end'].character == r1['end'].character + ) + ) + ) +end + +---Check if cursor is in range +---@param cursor integer[] cursor position (line, character); (1, 0)-based +---@param range dropbar_range_t 0-based range +---@param strict boolean? only return true if `cursor` is fully contained in `range` (not on the boundary), default false +---@return boolean +function M.contains_cursor(cursor, range, strict) + cursor = cursor or vim.api.nvim_win_get_cursor(0) + local line = cursor[1] - 1 + local char = cursor[2] + return ( + line > range.start.line + or ( + line == range.start.line + and ( + char > range.start.character + or not strict and char == range.start.character + ) + ) + ) + and ( + line < range['end'].line + or ( + line == range['end'].line + and ( + char < range['end'].character + or not strict and char == range['end'].character + ) + ) + ) +end + +---Check if two ranges match at either boundary within a tolerance +---Two ranges 'match' when their starts are on the same line and within `tol` columns, +---or their ends are on the same line and within `tol` columns. +---@param r1 dropbar_range_t 0-based range +---@param r2 dropbar_range_t 0-based range +---@param tol integer maximum allowed column delta for boundary comparison (>= 0) +---@return boolean +function M.matches(r1, r2, tol) + return utils.pos.matches(r1.start, r2.start, tol) + or utils.pos.matches(r1['end'], r2['end'], tol) +end + +return M diff --git a/tests/sources/treesitter_spec.lua b/tests/sources/treesitter_spec.lua new file mode 100644 index 00000000..854930ae --- /dev/null +++ b/tests/sources/treesitter_spec.lua @@ -0,0 +1,385 @@ +---@diagnostic disable: undefined-field + +local dropbar = require('dropbar') +local source_treesitter = require('dropbar.sources.treesitter') +local stub = require('luassert.stub') + +---@param opts? { +--- type_name?: string, +--- text?: string, +--- range?: integer[], +--- named?: boolean, +--- children?: TSNode[], +--- fields?: (string|nil)[], +---} +---@return TSNode +local function ts_node(opts) + opts = opts or {} + local children = opts.children or {} + local fields = opts.fields or {} + local ts = { + _type = opts.type_name or 'identifier', + _text = opts.text or '', + _range = opts.range or { 0, 0, 0, 0 }, + _named = opts.named ~= false, + _children = children, + _fields = fields, + _parent = nil, + _index = nil, + } + + for i, child in ipairs(children) do + child._parent = ts + child._index = i + end + + ts.type = function(self) + return self._type + end + ts.range = function(self) + return unpack(self._range) + end + ts.parent = function(self) + return self._parent + end + ts.named = function(self) + return self._named + end + ts.child_count = function(self) + return #self._children + end + ts.named_child_count = function(self) + local count = 0 + for _, child in ipairs(self._children) do + if child:named() then + count = count + 1 + end + end + return count + end + ts.iter_children = function(self) + local i = 0 + return function() + i = i + 1 + local child = self._children[i] + if not child then + return nil + end + return child, self._fields[i] + end + end + ts.prev_sibling = function(self) + local parent = self._parent + if not parent or not self._index or self._index <= 1 then + return nil + end + return parent._children[self._index - 1] + end + ts.next_sibling = function(self) + local parent = self._parent + if not parent or not self._index then + return nil + end + return parent._children[self._index + 1] + end + + return ts +end + +---@param cursor_node TSNode +---@param stubs luassert.stub[] +local function stub_treesitter(cursor_node, stubs) + table.insert( + stubs, + stub(vim.treesitter, 'get_parser', function() + return true + end) + ) + table.insert( + stubs, + stub(vim.filetype, 'match', function() + return 'nickel' + end) + ) + table.insert( + stubs, + stub(vim.treesitter, 'get_node', function() + return cursor_node + end) + ) + table.insert( + stubs, + stub(vim.treesitter, 'get_node_text', function(node) + return node._text + end) + ) + table.insert( + stubs, + stub(vim.treesitter, 'get_node_range', function(node_or_range) + if type(node_or_range) == 'table' and node_or_range._range then + return unpack(node_or_range._range) + end + return unpack(node_or_range) + end) + ) + table.insert( + stubs, + stub(vim.treesitter, 'node_contains', function(node, range) + local node_start_row, node_start_col, node_end_row, node_end_col = + unpack(node._range) + local range_start_row, range_start_col, range_end_row, range_end_col = + unpack(range) + local node_starts_before_range = node_start_row < range_start_row + or ( + node_start_row == range_start_row + and node_start_col <= range_start_col + ) + local node_ends_after_range = node_end_row > range_end_row + or (node_end_row == range_end_row and node_end_col >= range_end_col) + return node_starts_before_range and node_ends_after_range + end) + ) +end + +---@param symbols dropbar_symbol_t[] +---@return string[] +local function symbol_names(symbols) + return vim.tbl_map(function(symbol) + return symbol.name + end, symbols) +end + +describe('[source][treesitter]', function() + local stubs = {} + + before_each(function() + dropbar.setup({ + bar = { + sources = { + source_treesitter, + }, + }, + sources = { + treesitter = { + valid_types = { + 'classProperty', + 'modifier', + 'pair', + 'identifier', + 'class', + 'function', + }, + name_regex = '[A-Za-z_][A-Za-z0-9_.]*', + }, + }, + }) + end) + + after_each(function() + for _, s in ipairs(stubs) do + s:revert() + end + stubs = {} + end) + + it('resolves names from named children, not anonymous tokens', function() + local local_token = ts_node({ + type_name = 'string', + text = 'local', + range = { 0, 0, 0, 5 }, + named = false, + }) + local modifier = ts_node({ + type_name = 'modifier', + text = 'local', + range = { 0, 0, 0, 5 }, + children = { local_token }, + fields = { nil }, + }) + local identifier = ts_node({ + type_name = 'identifier', + text = 'myField', + range = { 0, 6, 0, 13 }, + }) + local class_property = ts_node({ + type_name = 'classProperty', + text = 'local myField', + range = { 0, 0, 0, 13 }, + children = { modifier, identifier }, + fields = { nil, 'name' }, + }) + + stub_treesitter(class_property, stubs) + + local symbols = source_treesitter.get_symbols( + vim.api.nvim_get_current_buf(), + vim.api.nvim_get_current_win(), + { 1, 7 } + ) + + assert.are.same({ 'myField' }, symbol_names(symbols)) + end) + + it( + 'collapses same-line contained path segments into parent breadcrumb', + function() + local path = ts_node({ + type_name = 'identifier', + text = 'grammar.source.git', + range = { 1, 2, 1, 20 }, + }) + local grammar = ts_node({ + type_name = 'identifier', + text = 'grammar', + range = { 1, 2, 1, 9 }, + }) + local source = ts_node({ + type_name = 'identifier', + text = 'source', + range = { 1, 10, 1, 16 }, + }) + grammar._parent = path + source._parent = grammar + + stub_treesitter(source, stubs) + + local symbols = source_treesitter.get_symbols( + vim.api.nvim_get_current_buf(), + vim.api.nvim_get_current_win(), + { 2, 12 } + ) + + assert.are.same({ 'grammar.source.git' }, symbol_names(symbols)) + end + ) + + it( + 'keeps child symbol when parent and child names are on different lines', + function() + local path = ts_node({ + type_name = 'identifier', + text = 'grammar.source.git', + range = { 1, 2, 1, 20 }, + }) + local rev = ts_node({ + type_name = 'identifier', + text = 'rev', + range = { 2, 4, 2, 7 }, + }) + rev._parent = path + + stub_treesitter(rev, stubs) + + local symbols = source_treesitter.get_symbols( + vim.api.nvim_get_current_buf(), + vim.api.nvim_get_current_win(), + { 3, 5 } + ) + + assert.are.same({ 'grammar.source.git', 'rev' }, symbol_names(symbols)) + end + ) + + it( + 'prefers narrower name when broader parent starts same place across lines', + function() + local broad = ts_node({ + type_name = 'identifier', + text = 'self.get_base_types_for_class', + range = { 0, 17, 2, 33 }, + }) + local self_symbol = ts_node({ + type_name = 'identifier', + text = 'self', + range = { 0, 17, 1, 29 }, + }) + self_symbol._parent = broad + + stub_treesitter(self_symbol, stubs) + + local symbols = source_treesitter.get_symbols( + vim.api.nvim_get_current_buf(), + vim.api.nvim_get_current_win(), + { 1, 18 } + ) + + assert.are.same({ 'self' }, symbol_names(symbols)) + end + ) + + it('deduplicates wrapper symbols with identical names', function() + local settings_pair = ts_node({ + type_name = 'pair', + text = 'settings = { ... }', + range = { 0, 2, 8, 0 }, + }) + local settings_id = ts_node({ + type_name = 'identifier', + text = 'settings', + range = { 0, 2, 0, 10 }, + }) + local server_pair = ts_node({ + type_name = 'pair', + text = 'server = { ... }', + range = { 1, 4, 7, 2 }, + }) + local server_id = ts_node({ + type_name = 'identifier', + text = 'server', + range = { 1, 4, 1, 10 }, + }) + local host_id = ts_node({ + type_name = 'identifier', + text = 'host', + range = { 2, 6, 2, 10 }, + }) + + settings_id._parent = settings_pair + server_pair._parent = settings_id + server_id._parent = server_pair + host_id._parent = server_id + + stub_treesitter(host_id, stubs) + + local symbols = source_treesitter.get_symbols( + vim.api.nvim_get_current_buf(), + vim.api.nvim_get_current_win(), + { 3, 8 } + ) + + assert.are.same({ 'settings', 'server', 'host' }, symbol_names(symbols)) + end) + + it( + 'does not deduplicate non-wrapper symbols that just share names', + function() + local class_foo = ts_node({ + type_name = 'class', + text = 'Foo', + range = { 0, 0, 10, 0 }, + }) + local function_foo = ts_node({ + type_name = 'function', + text = 'Foo', + range = { 2, 2, 4, 2 }, + }) + local identifier = ts_node({ + type_name = 'identifier', + text = 'x', + range = { 3, 4, 3, 5 }, + }) + function_foo._parent = class_foo + identifier._parent = function_foo + + stub_treesitter(identifier, stubs) + + local symbols = source_treesitter.get_symbols( + vim.api.nvim_get_current_buf(), + vim.api.nvim_get_current_win(), + { 4, 5 } + ) + + assert.are.same({ 'Foo', 'Foo', 'x' }, symbol_names(symbols)) + end + ) +end)