diff --git a/src/cli/check.rs b/src/cli/check.rs index 41774ae..68f3236 100644 --- a/src/cli/check.rs +++ b/src/cli/check.rs @@ -23,7 +23,11 @@ pub fn check_directories(directories: &[PathBuf], config: String, format: bool, }; scm_files.par_iter().for_each(|path| { let uri = Url::from_file_path(path.canonicalize().unwrap()).unwrap(); - if let Some(lang) = util::get_language(&uri, &options) { + let language = (|| { + let name = util::get_language_name(&uri, &options)?; + util::get_language(&name, &options) + })(); + if let Some(lang) = language { if let Ok(source) = fs::read_to_string(path) { if let Err(err) = Query::new(&lang, source.as_str()) { match err.kind { diff --git a/src/cli/lint.rs b/src/cli/lint.rs index 1c303b5..ac83481 100644 --- a/src/cli/lint.rs +++ b/src/cli/lint.rs @@ -37,7 +37,7 @@ pub(super) fn lint_file( fields_vec: Default::default(), supertype_map: Default::default(), version: Default::default(), - language: Default::default(), + language_data: Default::default(), }; let provider = &util::TextProviderRope(&doc.rope); let diagnostics = get_diagnostics(uri, &doc, options, provider); diff --git a/src/handlers/diagnostic.rs b/src/handlers/diagnostic.rs index ce3c4aa..ec35817 100644 --- a/src/handlers/diagnostic.rs +++ b/src/handlers/diagnostic.rs @@ -11,13 +11,12 @@ use tower_lsp::{ }, }; use tree_sitter::{ - Language, Node, Query, QueryCursor, QueryError, QueryErrorKind, StreamingIterator as _, - TreeCursor, + Node, Query, QueryCursor, QueryError, QueryErrorKind, StreamingIterator as _, TreeCursor, }; use ts_query_ls::{Options, PredicateParameter, PredicateParameterArity, PredicateParameterType}; use crate::{ - Backend, DocumentData, QUERY_LANGUAGE, SymbolInfo, + Backend, DocumentData, LanguageData, QUERY_LANGUAGE, SymbolInfo, util::{CAPTURES_QUERY, NodeUtil as _, TextProviderRope, uri_to_basename}, }; @@ -36,27 +35,18 @@ static QUERY_STRUCTURE_RESULTS: LazyLock LazyLock::new(DashMap::new); fn get_pattern_diagnostic( - uri: Url, pattern_node: &Node, rope: &Rope, - language: Language, + language_data: LanguageData, ) -> Option { - // Assume all sibling queries are of the same language - // TODO: Once many parsers have been upgraded to ABI 15, hash by language name rather than - // directory - let parent = uri - .to_file_path() - .ok()? - .parent()? - .to_string_lossy() - .into_owned(); + let LanguageData { language, name } = language_data; // Format patterns to ignore syntactically insignificant diffs let pattern_text = pattern_node.text(rope); - if let Some(cached_diag) = QUERY_STRUCTURE_RESULTS.get(&(parent.clone(), pattern_text.clone())) - { + let pattern_key = (name, pattern_text); + if let Some(cached_diag) = QUERY_STRUCTURE_RESULTS.get(&pattern_key) { return *cached_diag.deref(); } - match Query::new(&language, &pattern_text) { + match Query::new(&language, &pattern_key.1) { Err(QueryError { kind: QueryErrorKind::Structure, offset, @@ -65,11 +55,11 @@ fn get_pattern_diagnostic( message: _, }) => { let offset = Some(offset); - QUERY_STRUCTURE_RESULTS.insert((parent, pattern_text), offset); + QUERY_STRUCTURE_RESULTS.insert(pattern_key, offset); offset } _ => { - QUERY_STRUCTURE_RESULTS.insert((parent, pattern_text), None); + QUERY_STRUCTURE_RESULTS.insert(pattern_key, None); None } } @@ -293,9 +283,9 @@ pub fn get_diagnostics( } } "definition" => { - if let Some(language) = document.language.clone() { + if let Some(language_data) = &document.language_data { if let Some(offset) = - get_pattern_diagnostic(uri.clone(), &capture.node, rope, language) + get_pattern_diagnostic(&capture.node, rope, language_data.clone()) { let true_offset = offset + capture.node.start_byte(); diagnostics.push(Diagnostic { diff --git a/src/handlers/did_open.rs b/src/handlers/did_open.rs index 2030766..91b96e0 100644 --- a/src/handlers/did_open.rs +++ b/src/handlers/did_open.rs @@ -5,7 +5,10 @@ use tower_lsp::lsp_types::DidOpenTextDocumentParams; use tracing::info; use tree_sitter::Parser; -use crate::{Backend, DocumentData, QUERY_LANGUAGE, SymbolInfo, util::get_language}; +use crate::{ + Backend, DocumentData, LanguageData, QUERY_LANGUAGE, SymbolInfo, + util::{get_language, get_language_name}, +}; pub async fn did_open(backend: &Backend, params: DidOpenTextDocumentParams) { let uri = ¶ms.text_document.uri; @@ -24,21 +27,29 @@ pub async fn did_open(backend: &Backend, params: DidOpenTextDocumentParams) { let mut fields_vec: Vec = vec![]; let mut fields_set: HashSet = HashSet::new(); let mut supertype_map: HashMap> = HashMap::new(); - let language = get_language(uri, &*backend.options.read().await); - if let Some(lang) = &language { + let language_data = async { + let options = backend.options.read().await; + let name = get_language_name(uri, &options)?; + let language = get_language(&name, &options)?; + Some(LanguageData { name, language }) + } + .await; + + if let Some(LanguageData { language, name: _ }) = &language_data { let error_symbol = SymbolInfo { label: "ERROR".to_owned(), named: true, }; symbols_set.insert(error_symbol.clone()); symbols_vec.push(error_symbol); - for i in 0..lang.node_kind_count() as u16 { - let supertype = lang.node_kind_is_supertype(i); - let named = lang.node_kind_is_named(i) || supertype; + for i in 0..language.node_kind_count() as u16 { + let supertype = language.node_kind_is_supertype(i); + let named = language.node_kind_is_named(i) || supertype; let label = if named { - lang.node_kind_for_id(i).unwrap().to_owned() + language.node_kind_for_id(i).unwrap().to_owned() } else { - lang.node_kind_for_id(i) + language + .node_kind_for_id(i) .unwrap() .replace('\\', r"\\") .replace('"', r#"\""#) @@ -48,24 +59,28 @@ pub async fn did_open(backend: &Backend, params: DidOpenTextDocumentParams) { if supertype { supertype_map.insert( symbol_info.clone(), - lang.subtypes_for_supertype(i) + language + .subtypes_for_supertype(i) .iter() .map(|s| SymbolInfo { - label: lang.node_kind_for_id(*s).unwrap().to_string(), - named: lang.node_kind_is_named(*s) || lang.node_kind_is_supertype(*s), + label: language.node_kind_for_id(*s).unwrap().to_string(), + named: language.node_kind_is_named(*s) + || language.node_kind_is_supertype(*s), }) .collect(), ); } - if symbols_set.contains(&symbol_info) || !(lang.node_kind_is_visible(i) || supertype) { + if symbols_set.contains(&symbol_info) + || !(language.node_kind_is_visible(i) || supertype) + { continue; } symbols_set.insert(symbol_info.clone()); symbols_vec.push(symbol_info); } // Field IDs go from 1 to nfields inclusive (extra index 0 maps to NULL) - for i in 1..=lang.field_count() as u16 { - let field_name = lang.field_name_for_id(i).unwrap().to_owned(); + for i in 1..=language.field_count() as u16 { + let field_name = language.field_name_for_id(i).unwrap().to_owned(); if !fields_set.contains(&field_name) { fields_set.insert(field_name.clone()); fields_vec.push(field_name); @@ -86,7 +101,7 @@ pub async fn did_open(backend: &Backend, params: DidOpenTextDocumentParams) { fields_vec, supertype_map, version, - language, + language_data, }, ); } diff --git a/src/handlers/execute_command.rs b/src/handlers/execute_command.rs index 422b43a..1f6aa51 100644 --- a/src/handlers/execute_command.rs +++ b/src/handlers/execute_command.rs @@ -42,7 +42,11 @@ async fn check_impossible_patterns(backend: &Backend, params: ExecuteCommandPara let rope = &doc.rope; let tree = &doc.tree; let options = &backend.options.read().await; - let Some(lang) = util::get_language(&uri, options) else { + let language = (|| { + let name = util::get_language_name(&uri, options)?; + util::get_language(&name, options) + })(); + let Some(lang) = language else { warn!("Could not retrieve language for path: '{}'", uri.path()); return; }; diff --git a/src/main.rs b/src/main.rs index 7e0c7dd..bfbfe27 100644 --- a/src/main.rs +++ b/src/main.rs @@ -102,6 +102,14 @@ impl fmt::Display for SymbolInfo { } } +#[derive(Clone)] +pub struct LanguageData { + language: Language, + // TODO: Once most parsers are upgraded to ABI 15, just get the name from the language object + // itself + name: String, +} + struct DocumentData { symbols_set: HashSet, symbols_vec: Vec, @@ -111,7 +119,7 @@ struct DocumentData { rope: Rope, tree: Tree, version: i32, - language: Option, + language_data: Option, } struct Backend { diff --git a/src/test_helpers.rs b/src/test_helpers.rs index 4303f09..a2605e0 100644 --- a/src/test_helpers.rs +++ b/src/test_helpers.rs @@ -84,6 +84,7 @@ pub mod helpers { ]), ) })), + language_data: None, }, ) }, diff --git a/src/util.rs b/src/util.rs index 91945e0..a54325a 100644 --- a/src/util.rs +++ b/src/util.rs @@ -181,7 +181,8 @@ pub fn lsp_textdocchange_to_ts_inputedit( const DYLIB_EXTENSIONS: [&str; 3] = [".so", ".dll", ".dylib"]; -pub fn get_language(uri: &Url, options: &Options) -> Option { +/// Get the language name associated with the given URI, after aliases. +pub fn get_language_name(uri: &Url, options: &Options) -> Option { let mut language_retrieval_regexes: Vec = options .language_retrieval_patterns .clone() @@ -197,28 +198,24 @@ pub fn get_language(uri: &Url, options: &Options) -> Option { break; } } - let lang = captures - .and_then(|captures| captures.get(1)) - .and_then(|cap| { - let cap_str = cap.as_str(); - get_language_object( - options - .parser_aliases - .get(cap_str) - .unwrap_or(&cap_str.to_owned()) - .as_str(), - &options.parser_install_directories, - &ENGINE, - ) - }); - lang + let raw_name = captures + .and_then(|caps| caps.get(1)) + .map(|cap| cap.as_str().to_owned())?; + Some( + options + .parser_aliases + .get(&raw_name) + .cloned() + .unwrap_or(raw_name), + ) } -pub fn get_language_object( - name: &str, - directories: &Vec, - engine: &Engine, -) -> Option { +/// Get the TSLanguage object for a given name and configuration. +pub fn get_language(name: &str, options: &Options) -> Option { + get_language_object(name, &options.parser_install_directories, &ENGINE) +} + +fn get_language_object(name: &str, directories: &Vec, engine: &Engine) -> Option { let name = name.replace('-', "_"); let language_fn_name = format!("tree_sitter_{name}");