perf(scope): use async treesitter parsing when available

This commit is contained in:
Folke Lemaitre 2025-02-23 09:20:57 +01:00
parent dd15e3a05a
commit e0f882e6d6
No known key found for this signature in database
GPG key ID: 41F8B1FBACAE2040
2 changed files with 100 additions and 69 deletions

View file

@ -106,6 +106,9 @@ local defaults = {
}, },
} }
---@diagnostic disable-next-line: invisible
M.TS_ASYNC = vim.treesitter.languagetree._async_parse ~= nil
local id = 0 local id = 0
---@alias snacks.scope.scope {buf: number, from: number, to: number, indent?: number} ---@alias snacks.scope.scope {buf: number, from: number, to: number, indent?: number}
@ -395,14 +398,30 @@ function TSScope:with(opts)
end end
---@param opts snacks.scope.Opts ---@param opts snacks.scope.Opts
function TSScope:find(opts) function TSScope:parser(opts)
local lang = vim.bo[opts.buf].filetype local lang = vim.bo[opts.buf].filetype
local has_parser, parser = pcall(vim.treesitter.get_parser, opts.buf, lang, { error = false }) local has_parser, parser = pcall(vim.treesitter.get_parser, opts.buf, lang, { error = false })
if not has_parser or parser == nil then return has_parser and parser or nil
end
---@param cb fun()
---@param opts snacks.scope.Opts
function TSScope:init(cb, opts)
local parser = self:parser(opts)
if not parser then
return return
end end
if M.TS_ASYNC then
parser:parse(opts.treesitter.injections, cb)
else
parser:parse(opts.treesitter.injections) parser:parse(opts.treesitter.injections)
cb()
end
end
---@param opts snacks.scope.Opts
function TSScope:find(opts)
local lang = vim.treesitter.language.get_lang(vim.bo[opts.buf].filetype)
local line = vim.fn.nextnonblank(opts.pos[1]) local line = vim.fn.nextnonblank(opts.pos[1])
line = line == 0 and vim.fn.prevnonblank(opts.pos[1]) or line line = line == 0 and vim.fn.prevnonblank(opts.pos[1]) or line
-- FIXME: -- FIXME:
@ -475,9 +494,9 @@ function Scope:__tostring()
) )
end end
---@param opts? snacks.scope.Opts ---@param cb fun(scope?: snacks.scope.Scope)
---@return snacks.scope.Scope? ---@param opts? snacks.scope.Opts|{parse?:boolean}
function M.get(opts) function M.get(cb, opts)
opts = Snacks.config.get("scope", defaults, opts or {}) --[[ @as snacks.scope.Opts ]] opts = Snacks.config.get("scope", defaults, opts or {}) --[[ @as snacks.scope.Opts ]]
opts.buf = (opts.buf == nil or opts.buf == 0) and vim.api.nvim_get_current_buf() or opts.buf opts.buf = (opts.buf == nil or opts.buf == 0) and vim.api.nvim_get_current_buf() or opts.buf
if not opts.pos then if not opts.pos then
@ -487,32 +506,38 @@ function M.get(opts)
-- run in the context of the buffer if not current -- run in the context of the buffer if not current
if vim.api.nvim_get_current_buf() ~= opts.buf then if vim.api.nvim_get_current_buf() ~= opts.buf then
local ret ---@type snacks.scope.Scope?
vim.api.nvim_buf_call(opts.buf, function() vim.api.nvim_buf_call(opts.buf, function()
ret = M.get(opts) M.get(cb, opts)
end) end)
return ret return
end end
---@type snacks.scope.Scope ---@type snacks.scope.Scope
local Class = opts.treesitter.enabled and TSScope.has_ts(opts.buf) and TSScope or IndentScope local Class = opts.treesitter.enabled and TSScope.has_ts(opts.buf) and TSScope or IndentScope
local ret = Class:find(opts) --[[ @as snacks.scope.Scope? ]] if Class == TSScope and opts.parse ~= false then
TSScope:init(function()
opts.parse = false
M.get(cb, opts)
end, opts)
return
end
local scope = Class:find(opts) --[[ @as snacks.scope.Scope? ]]
-- fallback to indent based detection -- fallback to indent based detection
if not ret and Class == TSScope then if not scope and Class == TSScope then
Class = IndentScope Class = IndentScope
ret = Class:find(opts) scope = Class:find(opts)
end end
-- when end_pos is provided, get its scope and expand the current scope -- when end_pos is provided, get its scope and expand the current scope
-- to include it. -- to include it.
if ret and opts.end_pos and not vim.deep_equal(opts.pos, opts.end_pos) then if scope and opts.end_pos and not vim.deep_equal(opts.pos, opts.end_pos) then
local end_scope = Class:find(vim.tbl_extend("keep", { pos = opts.end_pos }, opts)) --[[ @as snacks.scope.Scope? ]] local end_scope = Class:find(vim.tbl_extend("keep", { pos = opts.end_pos }, opts)) --[[ @as snacks.scope.Scope? ]]
if end_scope and end_scope.from < ret.from then if end_scope and end_scope.from < scope.from then
ret = ret:expand(end_scope.from) or ret scope = scope:expand(end_scope.from) or scope
end end
if end_scope and end_scope.to > ret.to then if end_scope and end_scope.to > scope.to then
ret = ret:expand(end_scope.to) or ret scope = scope:expand(end_scope.to) or scope
end end
end end
@ -521,41 +546,40 @@ function M.get(opts)
-- expand block with ancestors until min_size is reached -- expand block with ancestors until min_size is reached
-- or max_size is reached -- or max_size is reached
if ret then if scope then
local s = ret --- @type snacks.scope.Scope? local s = scope --- @type snacks.scope.Scope?
while s do while s do
if opts.edge and ret:size_with_edge() >= min_size and s:size_with_edge() > max_size then if opts.edge and scope:size_with_edge() >= min_size and s:size_with_edge() > max_size then
break break
elseif not opts.edge and ret:size() >= min_size and s:size() > max_size then elseif not opts.edge and scope:size() >= min_size and s:size() > max_size then
break break
end end
ret, s = s, s:parent() scope, s = s, s:parent()
end end
-- expand with edge -- expand with edge
if opts.edge then if opts.edge then
ret = ret:with_edge() --[[@as snacks.scope.Scope]] scope = scope:with_edge() --[[@as snacks.scope.Scope]]
end end
end end
-- expand single line blocks with single line siblings -- expand single line blocks with single line siblings
if opts.siblings and ret and ret:size() == 1 then if opts.siblings and scope and scope:size() == 1 then
while ret and ret:size() < min_size do while scope and scope:size() < min_size do
local prev, next = vim.fn.prevnonblank(ret.from - 1), vim.fn.nextnonblank(ret.to + 1) ---@type number, number local prev, next = vim.fn.prevnonblank(scope.from - 1), vim.fn.nextnonblank(scope.to + 1) ---@type number, number
local prev_dist, next_dist = math.abs(opts.pos[1] - prev), math.abs(opts.pos[1] - next) local prev_dist, next_dist = math.abs(opts.pos[1] - prev), math.abs(opts.pos[1] - next)
local prev_s = prev > 0 and Class:find(vim.tbl_extend("keep", { pos = { prev, 0 } }, opts)) local prev_s = prev > 0 and Class:find(vim.tbl_extend("keep", { pos = { prev, 0 } }, opts))
local next_s = next > 0 and Class:find(vim.tbl_extend("keep", { pos = { next, 0 } }, opts)) local next_s = next > 0 and Class:find(vim.tbl_extend("keep", { pos = { next, 0 } }, opts))
prev_s = prev_s and prev_s:size() == 1 and prev_s prev_s = prev_s and prev_s:size() == 1 and prev_s
next_s = next_s and next_s:size() == 1 and next_s next_s = next_s and next_s:size() == 1 and next_s
local s = prev_dist < next_dist and prev_s or next_s or prev_s local s = prev_dist < next_dist and prev_s or next_s or prev_s
if s and (s.from < ret.from or s.to > ret.to) then if s and (s.from < scope.from or s.to > scope.to) then
ret = Scope.with(ret, { from = math.min(ret.from, s.from), to = math.max(ret.to, s.to) }) scope = Scope.with(scope, { from = math.min(scope.from, s.from), to = math.max(scope.to, s.to) })
else else
break break
end end
end end
end end
cb(scope)
return ret
end end
---@class snacks.scope.Listener ---@class snacks.scope.Listener
@ -591,16 +615,20 @@ function Listener:check(win)
return return
end end
local scope = M.get(vim.tbl_extend("keep", { M.get(
buf = buf, function(scope)
pos = vim.api.nvim_win_get_cursor(win),
}, self.opts))
local prev = self.active[win] local prev = self.active[win]
if prev == scope then if prev == scope then
return -- no change return -- no change
end end
self.active[win] = scope self.active[win] = scope
self.cb(win, buf, scope, prev) self.cb(win, buf, scope, prev)
end,
vim.tbl_extend("keep", {
buf = buf,
pos = vim.api.nvim_win_get_cursor(win),
}, self.opts)
)
end end
--- Get the active scope for a window --- Get the active scope for a window
@ -719,7 +747,7 @@ function M.textobject(opts)
local inner = not opts.edge local inner = not opts.edge
opts.edge = true -- always include the edge of the scope to make inner work opts.edge = true -- always include the edge of the scope to make inner work
local scope = M.get(opts) M.get(function(scope)
if not scope then if not scope then
return opts.notify ~= false and Snacks.notify.warn("No scope in range") return opts.notify ~= false and Snacks.notify.warn("No scope in range")
end end
@ -734,6 +762,7 @@ function M.textobject(opts)
vim.api.nvim_win_set_cursor(0, from) vim.api.nvim_win_set_cursor(0, from)
vim.cmd("normal! " .. (opts.linewise and "V" or "v")) vim.cmd("normal! " .. (opts.linewise and "V" or "v"))
vim.api.nvim_win_set_cursor(0, to) vim.api.nvim_win_set_cursor(0, to)
end, opts)
end end
--- Jump to the top or bottom of the scope --- Jump to the top or bottom of the scope
@ -741,7 +770,7 @@ end
---@param opts? snacks.scope.Jump ---@param opts? snacks.scope.Jump
function M.jump(opts) function M.jump(opts)
opts = Snacks.config.get("scope", defaults, opts or {}) --[[ @as snacks.scope.Jump ]] opts = Snacks.config.get("scope", defaults, opts or {}) --[[ @as snacks.scope.Jump ]]
local scope = M.get(opts) M.get(function(scope)
if not scope then if not scope then
return opts.notify ~= false and Snacks.notify.warn("No scope in range") return opts.notify ~= false and Snacks.notify.warn("No scope in range")
end end
@ -753,6 +782,7 @@ function M.jump(opts)
end end
scope = scope:parent() scope = scope:parent()
end end
end, opts)
end end
---@private ---@private

View file

@ -78,14 +78,15 @@ describe("scope", function()
ft = "lua", ft = "lua",
ts = ts, ts = ts,
}) })
local scope = Snacks.scope.get({ Snacks.scope.get(function(scope)
pos = { line, 0 },
treesitter = { enabled = ts },
})
assert(scope) assert(scope)
assert((scope.node == nil) == not ts) assert((scope.node == nil) == not ts)
assert.same(scope.from, s[1]) assert.same(scope.from, s[1])
assert.same(scope.to, s[2]) assert.same(scope.to, s[2])
end, {
pos = { line, 0 },
treesitter = { enabled = ts },
})
end) end)
end end
end end