From 1c8e8de7434c7dce364db07ca7ffcfb830fa3bb6 Mon Sep 17 00:00:00 2001 From: Riley Bruins Date: Tue, 6 May 2025 12:03:12 -0700 Subject: [PATCH] feat: hash patterns by language name This required a slight refactor. Now documents store a language information object which contains the actual language object and the name of the language. When ABI 15 is prevalent enough, we can eventually derive the name from the language object itself. --- src/cli/check.rs | 6 ++++- src/cli/lint.rs | 2 +- src/handlers/diagnostic.rs | 32 ++++++++--------------- src/handlers/did_open.rs | 45 ++++++++++++++++++++++----------- src/handlers/execute_command.rs | 6 ++++- src/main.rs | 10 +++++++- src/test_helpers.rs | 1 + src/util.rs | 39 +++++++++++++--------------- 8 files changed, 80 insertions(+), 61 deletions(-) diff --git a/src/cli/check.rs b/src/cli/check.rs index 41774ae..68f3236 100644 --- a/src/cli/check.rs +++ b/src/cli/check.rs @@ -23,7 +23,11 @@ pub fn check_directories(directories: &[PathBuf], config: String, format: bool, }; scm_files.par_iter().for_each(|path| { let uri = Url::from_file_path(path.canonicalize().unwrap()).unwrap(); - if let Some(lang) = util::get_language(&uri, &options) { + let language = (|| { + let name = util::get_language_name(&uri, &options)?; + util::get_language(&name, &options) + })(); + if let Some(lang) = language { if let Ok(source) = fs::read_to_string(path) { if let Err(err) = Query::new(&lang, source.as_str()) { match err.kind { diff --git a/src/cli/lint.rs b/src/cli/lint.rs index 1c303b5..ac83481 100644 --- a/src/cli/lint.rs +++ b/src/cli/lint.rs @@ -37,7 +37,7 @@ pub(super) fn lint_file( fields_vec: Default::default(), supertype_map: Default::default(), version: Default::default(), - language: Default::default(), + language_data: Default::default(), }; let provider = &util::TextProviderRope(&doc.rope); let diagnostics = get_diagnostics(uri, &doc, options, provider); diff --git a/src/handlers/diagnostic.rs b/src/handlers/diagnostic.rs index ce3c4aa..ec35817 100644 --- a/src/handlers/diagnostic.rs +++ b/src/handlers/diagnostic.rs @@ -11,13 +11,12 @@ use tower_lsp::{ }, }; use tree_sitter::{ - Language, Node, Query, QueryCursor, QueryError, QueryErrorKind, StreamingIterator as _, - TreeCursor, + Node, Query, QueryCursor, QueryError, QueryErrorKind, StreamingIterator as _, TreeCursor, }; use ts_query_ls::{Options, PredicateParameter, PredicateParameterArity, PredicateParameterType}; use crate::{ - Backend, DocumentData, QUERY_LANGUAGE, SymbolInfo, + Backend, DocumentData, LanguageData, QUERY_LANGUAGE, SymbolInfo, util::{CAPTURES_QUERY, NodeUtil as _, TextProviderRope, uri_to_basename}, }; @@ -36,27 +35,18 @@ static QUERY_STRUCTURE_RESULTS: LazyLock LazyLock::new(DashMap::new); fn get_pattern_diagnostic( - uri: Url, pattern_node: &Node, rope: &Rope, - language: Language, + language_data: LanguageData, ) -> Option { - // Assume all sibling queries are of the same language - // TODO: Once many parsers have been upgraded to ABI 15, hash by language name rather than - // directory - let parent = uri - .to_file_path() - .ok()? - .parent()? - .to_string_lossy() - .into_owned(); + let LanguageData { language, name } = language_data; // Format patterns to ignore syntactically insignificant diffs let pattern_text = pattern_node.text(rope); - if let Some(cached_diag) = QUERY_STRUCTURE_RESULTS.get(&(parent.clone(), pattern_text.clone())) - { + let pattern_key = (name, pattern_text); + if let Some(cached_diag) = QUERY_STRUCTURE_RESULTS.get(&pattern_key) { return *cached_diag.deref(); } - match Query::new(&language, &pattern_text) { + match Query::new(&language, &pattern_key.1) { Err(QueryError { kind: QueryErrorKind::Structure, offset, @@ -65,11 +55,11 @@ fn get_pattern_diagnostic( message: _, }) => { let offset = Some(offset); - QUERY_STRUCTURE_RESULTS.insert((parent, pattern_text), offset); + QUERY_STRUCTURE_RESULTS.insert(pattern_key, offset); offset } _ => { - QUERY_STRUCTURE_RESULTS.insert((parent, pattern_text), None); + QUERY_STRUCTURE_RESULTS.insert(pattern_key, None); None } } @@ -293,9 +283,9 @@ pub fn get_diagnostics( } } "definition" => { - if let Some(language) = document.language.clone() { + if let Some(language_data) = &document.language_data { if let Some(offset) = - get_pattern_diagnostic(uri.clone(), &capture.node, rope, language) + get_pattern_diagnostic(&capture.node, rope, language_data.clone()) { let true_offset = offset + capture.node.start_byte(); diagnostics.push(Diagnostic { diff --git a/src/handlers/did_open.rs b/src/handlers/did_open.rs index 2030766..91b96e0 100644 --- a/src/handlers/did_open.rs +++ b/src/handlers/did_open.rs @@ -5,7 +5,10 @@ use tower_lsp::lsp_types::DidOpenTextDocumentParams; use tracing::info; use tree_sitter::Parser; -use crate::{Backend, DocumentData, QUERY_LANGUAGE, SymbolInfo, util::get_language}; +use crate::{ + Backend, DocumentData, LanguageData, QUERY_LANGUAGE, SymbolInfo, + util::{get_language, get_language_name}, +}; pub async fn did_open(backend: &Backend, params: DidOpenTextDocumentParams) { let uri = ¶ms.text_document.uri; @@ -24,21 +27,29 @@ pub async fn did_open(backend: &Backend, params: DidOpenTextDocumentParams) { let mut fields_vec: Vec = vec![]; let mut fields_set: HashSet = HashSet::new(); let mut supertype_map: HashMap> = HashMap::new(); - let language = get_language(uri, &*backend.options.read().await); - if let Some(lang) = &language { + let language_data = async { + let options = backend.options.read().await; + let name = get_language_name(uri, &options)?; + let language = get_language(&name, &options)?; + Some(LanguageData { name, language }) + } + .await; + + if let Some(LanguageData { language, name: _ }) = &language_data { let error_symbol = SymbolInfo { label: "ERROR".to_owned(), named: true, }; symbols_set.insert(error_symbol.clone()); symbols_vec.push(error_symbol); - for i in 0..lang.node_kind_count() as u16 { - let supertype = lang.node_kind_is_supertype(i); - let named = lang.node_kind_is_named(i) || supertype; + for i in 0..language.node_kind_count() as u16 { + let supertype = language.node_kind_is_supertype(i); + let named = language.node_kind_is_named(i) || supertype; let label = if named { - lang.node_kind_for_id(i).unwrap().to_owned() + language.node_kind_for_id(i).unwrap().to_owned() } else { - lang.node_kind_for_id(i) + language + .node_kind_for_id(i) .unwrap() .replace('\\', r"\\") .replace('"', r#"\""#) @@ -48,24 +59,28 @@ pub async fn did_open(backend: &Backend, params: DidOpenTextDocumentParams) { if supertype { supertype_map.insert( symbol_info.clone(), - lang.subtypes_for_supertype(i) + language + .subtypes_for_supertype(i) .iter() .map(|s| SymbolInfo { - label: lang.node_kind_for_id(*s).unwrap().to_string(), - named: lang.node_kind_is_named(*s) || lang.node_kind_is_supertype(*s), + label: language.node_kind_for_id(*s).unwrap().to_string(), + named: language.node_kind_is_named(*s) + || language.node_kind_is_supertype(*s), }) .collect(), ); } - if symbols_set.contains(&symbol_info) || !(lang.node_kind_is_visible(i) || supertype) { + if symbols_set.contains(&symbol_info) + || !(language.node_kind_is_visible(i) || supertype) + { continue; } symbols_set.insert(symbol_info.clone()); symbols_vec.push(symbol_info); } // Field IDs go from 1 to nfields inclusive (extra index 0 maps to NULL) - for i in 1..=lang.field_count() as u16 { - let field_name = lang.field_name_for_id(i).unwrap().to_owned(); + for i in 1..=language.field_count() as u16 { + let field_name = language.field_name_for_id(i).unwrap().to_owned(); if !fields_set.contains(&field_name) { fields_set.insert(field_name.clone()); fields_vec.push(field_name); @@ -86,7 +101,7 @@ pub async fn did_open(backend: &Backend, params: DidOpenTextDocumentParams) { fields_vec, supertype_map, version, - language, + language_data, }, ); } diff --git a/src/handlers/execute_command.rs b/src/handlers/execute_command.rs index 422b43a..1f6aa51 100644 --- a/src/handlers/execute_command.rs +++ b/src/handlers/execute_command.rs @@ -42,7 +42,11 @@ async fn check_impossible_patterns(backend: &Backend, params: ExecuteCommandPara let rope = &doc.rope; let tree = &doc.tree; let options = &backend.options.read().await; - let Some(lang) = util::get_language(&uri, options) else { + let language = (|| { + let name = util::get_language_name(&uri, options)?; + util::get_language(&name, options) + })(); + let Some(lang) = language else { warn!("Could not retrieve language for path: '{}'", uri.path()); return; }; diff --git a/src/main.rs b/src/main.rs index 7e0c7dd..bfbfe27 100644 --- a/src/main.rs +++ b/src/main.rs @@ -102,6 +102,14 @@ impl fmt::Display for SymbolInfo { } } +#[derive(Clone)] +pub struct LanguageData { + language: Language, + // TODO: Once most parsers are upgraded to ABI 15, just get the name from the language object + // itself + name: String, +} + struct DocumentData { symbols_set: HashSet, symbols_vec: Vec, @@ -111,7 +119,7 @@ struct DocumentData { rope: Rope, tree: Tree, version: i32, - language: Option, + language_data: Option, } struct Backend { diff --git a/src/test_helpers.rs b/src/test_helpers.rs index 4303f09..a2605e0 100644 --- a/src/test_helpers.rs +++ b/src/test_helpers.rs @@ -84,6 +84,7 @@ pub mod helpers { ]), ) })), + language_data: None, }, ) }, diff --git a/src/util.rs b/src/util.rs index 91945e0..a54325a 100644 --- a/src/util.rs +++ b/src/util.rs @@ -181,7 +181,8 @@ pub fn lsp_textdocchange_to_ts_inputedit( const DYLIB_EXTENSIONS: [&str; 3] = [".so", ".dll", ".dylib"]; -pub fn get_language(uri: &Url, options: &Options) -> Option { +/// Get the language name associated with the given URI, after aliases. +pub fn get_language_name(uri: &Url, options: &Options) -> Option { let mut language_retrieval_regexes: Vec = options .language_retrieval_patterns .clone() @@ -197,28 +198,24 @@ pub fn get_language(uri: &Url, options: &Options) -> Option { break; } } - let lang = captures - .and_then(|captures| captures.get(1)) - .and_then(|cap| { - let cap_str = cap.as_str(); - get_language_object( - options - .parser_aliases - .get(cap_str) - .unwrap_or(&cap_str.to_owned()) - .as_str(), - &options.parser_install_directories, - &ENGINE, - ) - }); - lang + let raw_name = captures + .and_then(|caps| caps.get(1)) + .map(|cap| cap.as_str().to_owned())?; + Some( + options + .parser_aliases + .get(&raw_name) + .cloned() + .unwrap_or(raw_name), + ) } -pub fn get_language_object( - name: &str, - directories: &Vec, - engine: &Engine, -) -> Option { +/// Get the TSLanguage object for a given name and configuration. +pub fn get_language(name: &str, options: &Options) -> Option { + get_language_object(name, &options.parser_install_directories, &ENGINE) +} + +fn get_language_object(name: &str, directories: &Vec, engine: &Engine) -> Option { let name = name.replace('-', "_"); let language_fn_name = format!("tree_sitter_{name}");