feat(picker.matcher): better scoring algorithm based on fzf. Closes #512. Fixes #513

This commit is contained in:
Folke Lemaitre 2025-01-16 08:11:38 +01:00
parent 821e23101f
commit e4e2e88c76
No known key found for this signature in database
GPG key ID: 41F8B1FBACAE2040
2 changed files with 218 additions and 51 deletions

View file

@ -8,6 +8,7 @@ local Async = require("snacks.picker.util.async")
---@field tick number
---@field task snacks.picker.Async
---@field live? boolean
---@field score snacks.picker.Score
local M = {}
M.__index = M
M.DEFAULT_SCORE = 1000
@ -39,6 +40,7 @@ function M.new(opts)
self.task = Async.nop()
self.mods = {}
self.tick = 0
self.score = require("snacks.picker.core.score").new()
return self
end
@ -282,13 +284,6 @@ function M:fuzzy_positions(str, pattern, from)
return ret
end
---@param str string
---@param c number
function M.is_alpha(str, c)
local b = str:byte(c, c)
return (b >= 65 and b <= 90) or (b >= 97 and b <= 122)
end
---@param item snacks.picker.Item
---@param mods snacks.picker.matcher.Mods
---@return number? score, number? from, number? to, string? str
@ -304,53 +299,44 @@ function M:_match(item, mods)
str = tostring(item[mods.field])
end
local str_orig = str
str = mods.ignorecase and str:lower() or str
local from, to ---@type number?, number?
if mods.fuzzy then
from, to = self:fuzzy(str, mods.chars)
return self:fuzzy(str, mods.chars)
end
if mods.exact_prefix then
if str:sub(1, #mods.pattern) == mods.pattern then
from, to = 1, #mods.pattern
end
elseif mods.exact_suffix then
if str:sub(-#mods.pattern) == mods.pattern then
from, to = #str - #mods.pattern + 1, #str
end
else
if mods.exact_prefix then
if str:sub(1, #mods.pattern) == mods.pattern then
from, to = 1, #mods.pattern
end
elseif mods.exact_suffix then
if str:sub(-#mods.pattern) == mods.pattern then
from, to = #str - #mods.pattern + 1, #str
end
else
from, to = str:find(mods.pattern, 1, true)
-- word match
while mods.word and from and to do
local bound_left = from == 1 or not M.is_alpha(str, from - 1)
local bound_right = to == #str or not M.is_alpha(str, to + 1)
if bound_left and bound_right then
break
end
from, to = str:find(mods.pattern, to + 1, true)
from, to = str:find(mods.pattern, 1, true)
-- word match
while mods.word and from and to do
local bound_left = self.score:is_left_boundary(str, from)
local bound_right = self.score:is_right_boundary(str, to)
if bound_left and bound_right then
break
end
from, to = str:find(mods.pattern, to + 1, true)
end
if mods.inverse then
if not from then
return M.INVERSE_SCORE
end
return
end
if mods.inverse then
if not from then
return M.INVERSE_SCORE
end
return
end
if from then
---@cast to number
return M.score(from, to, #str), from, to, str
return self.score:get(str_orig, from, to), from, to, str
end
end
---@param from number
---@param to number
---@param len number
function M.score(from, to, len)
return 1000 / (to - from + 1) -- calculate compactness score (distance between first and last match)
+ (100 / from) -- add bonus for early match
+ (100 / (len + 1)) -- add bonus for string length
end
---@param str string
---@param pattern string[]
---@param init? number
@ -360,11 +346,14 @@ function M:fuzzy_find(str, pattern, init)
if not from then
return
end
self.score:init(str, from)
---@type number?, number
local last, n = from, #pattern
for i = 2, n do
last = string.find(str, pattern[i], last + 1, true)
if not last then
if last then
self.score:update(last)
else
return
end
end
@ -375,24 +364,22 @@ end
--- to find the best match.
---@param str string
---@param pattern string[]
---@return number? from, number? to
---@return number? score, number? from, number? to, string? str
function M:fuzzy(str, pattern)
local from, to = self:fuzzy_find(str, pattern)
if not from then
return
end
---@cast to number
local best_from, best_to, best_width = from, to, to - from + 1
local n, width = #pattern, 0
-- short circuit if we have a perfect match
while from and best_width > n do
width = to - from + 1
if width < best_width then
best_from, best_to, best_width = from, to, width
local best_from, best_to, best_score = from, to, self.score.score
while from do
if self.score.score > best_score then
best_from, best_to, best_score = from, to, self.score.score
end
from, to = self:fuzzy_find(str, pattern, from + 1)
end
return best_from, best_to
return best_score, best_from, best_to, str
end
return M

View file

@ -0,0 +1,180 @@
--- This is a port of the scoring logic from fzf. See:
--- https://github.com/junegunn/fzf/blob/master/src/algo/algo.go
---@class snacks.picker.Score
---@field score number
---@field consecutive number
---@field prev? number
---@field prev_class number
---@field in_gap boolean
---@field str string
local M = {}
M.__index = M
-- Scoring constants. Same as fzf:
local SCORE_MATCH = 16
local SCORE_GAP_START = -3
local SCORE_GAP_EXTENSION = -1
local SCORE_LEN = -0.01
local BONUS_BOUNDARY = SCORE_MATCH / 2 -- 8
local BONUS_NONWORD = SCORE_MATCH / 2 -- 8
local BONUS_CAMEL_123 = BONUS_BOUNDARY - 1 -- 7
local BONUS_CONSECUTIVE = -(SCORE_GAP_START + SCORE_GAP_EXTENSION) -- 4
local BONUS_FIRST_CHAR_MULTIPLIER = 2
-- ASCII char classes (simplified); adapt as needed:
local CHAR_WHITE = 0
local CHAR_NONWORD = 1
local CHAR_DELIMITER = 2
local CHAR_LOWER = 3
local CHAR_UPPER = 4
local CHAR_LETTER = 5
local CHAR_NUMBER = 6
-- Table to classify ASCII bytes quickly:
local CHAR_CLASS = {} ---@type number[]
for b = 0, 255 do
local c = CHAR_NONWORD
local char = string.char(b)
if char:match("%s") then
c = CHAR_WHITE
elseif char:match("[/\\,:;|]") then
c = CHAR_DELIMITER
elseif b >= 48 and b <= 57 then -- '0'..'9'
c = CHAR_NUMBER
elseif b >= 65 and b <= 90 then -- 'A'..'Z'
c = CHAR_UPPER
elseif b >= 97 and b <= 122 then -- 'a'..'z'
c = CHAR_LOWER
end
CHAR_CLASS[b] = c
end
-- A bonus matrix that returns extra points for transitions from prevClass->currClass
local BONUS_MATRIX = {} ---@type number[][]
for i = 0, 6 do
BONUS_MATRIX[i] = {}
for j = 0, 6 do
BONUS_MATRIX[i][j] = 0
end
end
-- Helper to compute boundary/camelCase bonuses (mimics fzf approach)
local function computeBonus(prevC, currC)
-- If transitioning from whitespace/delimiter/nonword to letter => boundary bonus
if currC > CHAR_NONWORD then
if prevC == CHAR_WHITE then
return BONUS_BOUNDARY + 2 -- e.g. bonusBoundaryWhite
elseif prevC == CHAR_DELIMITER then
return BONUS_BOUNDARY + 1 -- e.g. bonusBoundaryDelimiter
elseif prevC == CHAR_NONWORD then
return BONUS_BOUNDARY
end
end
-- camelCase transitions or letter->number transitions
if (prevC == CHAR_LOWER and currC == CHAR_UPPER) or (prevC ~= CHAR_NUMBER and currC == CHAR_NUMBER) then
return BONUS_CAMEL_123
end
if currC == CHAR_NONWORD or currC == CHAR_DELIMITER then
return BONUS_NONWORD
elseif currC == CHAR_WHITE then
return BONUS_BOUNDARY + 2
end
return 0
end
-- Fill in the matrix
for prev = 0, 6 do
for curr = 0, 6 do
BONUS_MATRIX[prev][curr] = computeBonus(prev, curr)
end
end
function M.new()
local self = setmetatable({}, M)
self.score = 0
self.consecutive = 0
self.prev_class = CHAR_WHITE
self.in_gap = false
self.str = ""
return self
end
---@param str string
---@param pos number
function M:is_left_boundary(str, pos)
return pos == 1 or CHAR_CLASS[str:byte(pos - 1)] < CHAR_LOWER
end
---@param str string
---@param pos number
function M:is_right_boundary(str, pos)
return pos == #str or CHAR_CLASS[str:byte(pos + 1)] < CHAR_LOWER
end
---@param str string
---@param first number
function M:init(str, first)
self.str = str
self.score = #str * SCORE_LEN -- tiebreak by length
self.consecutive = 0
self.prev_class = CHAR_WHITE
self.prev = nil
if first > 1 then
self.prev_class = CHAR_CLASS[str:byte(first - 1)] or CHAR_NONWORD
end
self.in_gap = false
self:update(first)
end
---@param pos number
function M:update(pos)
local b = self.str:byte(pos)
local class = CHAR_CLASS[b] or CHAR_NONWORD
-- Calculate boundary bonus for transitioning from prevClass->currClass
local bonus = BONUS_MATRIX[self.prev_class][class] or 0
-- Handle gap vs consecutive logic
if self.prev then
local gap = pos - self.prev - 1
if gap > 0 then
-- We have a gap. If we were already in a gap, only extension penalty
-- otherwise we do a gap start penalty
self.score = self.score
+ (self.in_gap and (gap * SCORE_GAP_EXTENSION) or (SCORE_GAP_START + (gap - 1) * SCORE_GAP_EXTENSION))
self.consecutive = 0
self.in_gap = true
else
-- consecutive match => reward
self.consecutive = self.consecutive + 1
self.score = self.score + (BONUS_CONSECUTIVE * self.consecutive)
self.in_gap = false
end
else
bonus = bonus * BONUS_FIRST_CHAR_MULTIPLIER
end
-- Add base match + boundary/camel bonus
self.score = self.score + SCORE_MATCH + bonus
-- Update for next iteration
self.prev_class = class
self.prev = pos
end
---@param str string
---@param from number
---@param to number
function M:get(str, from, to)
self:init(str, from)
for i = from + 1, to do
self:update(i)
end
return self.score
end
return M