refactor: consolidate document data in one object (#96)

This commit is contained in:
Riley Bruins 2025-04-29 20:04:28 -07:00 committed by GitHub
parent 0a205a9682
commit 28ffc04c44
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
16 changed files with 367 additions and 458 deletions

View file

@ -22,14 +22,12 @@ pub async fn completion(
) -> Result<Option<CompletionResponse>> {
let uri = &params.text_document_position.text_document.uri;
let Some(tree) = backend.cst_map.get(uri) else {
warn!("No CST built for URI: {uri:?}");
return Ok(None);
};
let Some(rope) = &backend.document_map.get(uri) else {
warn!("No document built for URI: {uri:?}");
let Some(doc) = backend.document_map.get(uri) else {
warn!("No document for URI: {uri:?}");
return Ok(None);
};
let rope = &doc.rope;
let tree = &doc.tree;
let mut position = params.text_document_position.position;
if position.character > 0 {
@ -59,8 +57,8 @@ pub async fn completion(
{
let response = || {
let supertype = current_node.prev_named_sibling()?;
let supertype_map_map = backend.supertype_map_map.get(uri)?;
let subtypes = supertype_map_map.get(&SymbolInfo {
let supertypes = &doc.supertype_map;
let subtypes = supertypes.get(&SymbolInfo {
label: supertype.text(rope),
named: true,
})?;
@ -160,26 +158,23 @@ pub async fn completion(
let in_anon = node_is_or_has_ancestor(root, current_node, "string") && !in_predicate;
let top_level = current_node.kind() == "program";
if !top_level {
if let (Some(symbols), Some(supertypes)) = (
backend.symbols_vec_map.get(uri),
backend.supertype_map_map.get(uri),
) {
for symbol in symbols.iter() {
if (in_anon && !symbol.named) || (!in_anon && symbol.named) {
completion_items.push(CompletionItem {
label: symbol.label.clone(),
kind: if symbol.named {
if !supertypes.contains_key(symbol) {
Some(CompletionItemKind::CLASS)
} else {
Some(CompletionItemKind::INTERFACE)
}
let symbols = &doc.symbols_vec;
let supertypes = &doc.supertype_map;
for symbol in symbols.iter() {
if (in_anon && !symbol.named) || (!in_anon && symbol.named) {
completion_items.push(CompletionItem {
label: symbol.label.clone(),
kind: if symbol.named {
if !supertypes.contains_key(symbol) {
Some(CompletionItemKind::CLASS)
} else {
Some(CompletionItemKind::CONSTANT)
},
..Default::default()
});
}
Some(CompletionItemKind::INTERFACE)
}
} else {
Some(CompletionItemKind::CONSTANT)
},
..Default::default()
});
}
}
}
@ -191,14 +186,12 @@ pub async fn completion(
..Default::default()
});
}
if let Some(fields) = backend.fields_vec_map.get(uri) {
for field in fields.iter() {
completion_items.push(CompletionItem {
label: format!("{field}: "),
kind: Some(CompletionItemKind::FIELD),
..Default::default()
});
}
for field in doc.fields_vec.iter() {
completion_items.push(CompletionItem {
label: format!("{field}: "),
kind: Some(CompletionItemKind::FIELD),
..Default::default()
});
}
}
}

View file

@ -786,14 +786,15 @@ mod test {
&Default::default(),
)
.await;
let rope = &service.inner().document_map.get(&TEST_URI).unwrap();
let doc = &service.inner().document_map.get(&TEST_URI).unwrap();
let rope = &doc.rope;
let provider = &TextProviderRope(rope);
let symbols = &HashSet::from_iter(symbols.iter().cloned());
let fields = &HashSet::from_iter(fields.iter().map(|s| s.to_string()));
// Act
let diagnostics = get_diagnostics(
&service.inner().cst_map.get(&TEST_URI).unwrap(),
&doc.tree,
rope,
provider,
symbols,

View file

@ -1,4 +1,5 @@
use tower_lsp::lsp_types::{DidChangeTextDocumentParams, Position, Range};
use tracing::warn;
use tree_sitter::Parser;
use crate::{
@ -10,7 +11,10 @@ use super::diagnostic::get_diagnostics;
pub async fn did_change(backend: &Backend, params: DidChangeTextDocumentParams) {
let uri = &params.text_document.uri;
let mut rope = backend.document_map.get_mut(uri).unwrap();
let Some(mut document) = backend.document_map.get_mut(uri) else {
return;
};
let rope = &mut document.rope;
let mut parser = Parser::new();
parser
.set_language(&QUERY_LANGUAGE)
@ -33,7 +37,7 @@ pub async fn did_change(backend: &Backend, params: DidChangeTextDocumentParams)
Range { start, end }
};
edits.push(lsp_textdocchange_to_ts_inputedit(&rope, change).unwrap());
edits.push(lsp_textdocchange_to_ts_inputedit(rope, change).unwrap());
let start_row_char_idx = rope.line_to_char(range.start.line as usize);
let start_row_cu = rope.char_to_utf16_cu(start_row_char_idx);
@ -54,45 +58,43 @@ pub async fn did_change(backend: &Backend, params: DidChangeTextDocumentParams)
}
}
let contents = rope.to_string();
let result = {
let mut old_tree = backend.cst_map.get_mut(uri).unwrap();
let mut old_tree = document.tree.clone();
for edit in edits {
old_tree.edit(&edit);
}
for edit in edits {
old_tree.edit(&edit);
}
parser.parse(&contents, Some(&old_tree))
let Some(tree) = parser.parse(&contents, Some(&old_tree)) else {
warn!("Failure during tree parse");
return;
};
if let Some(tree) = result {
*backend.cst_map.get_mut(uri).unwrap() = tree.clone();
// Update diagnostics
if let (Some(symbols), Some(fields), Some(supertypes), options) = (
backend.symbols_set_map.get(uri),
backend.fields_set_map.get(uri),
backend.supertype_map_map.get(uri),
backend.options.read().await,
) {
let provider = TextProviderRope(&rope);
backend
.client
.publish_diagnostics(
uri.clone(),
get_diagnostics(
&tree,
&rope,
&provider,
&symbols,
&fields,
&supertypes,
&options,
uri,
),
None,
)
.await;
}
}
document.tree = tree;
let document = &*document;
let options = &backend.options.read().await;
let symbols = &document.symbols_set;
let fields = &document.fields_set;
let supertypes = &document.supertype_map;
let provider = TextProviderRope(&document.rope);
// Update diagnostics
backend
.client
.publish_diagnostics(
uri.clone(),
get_diagnostics(
&document.tree,
&document.rope,
&provider,
symbols,
fields,
supertypes,
options,
uri,
),
None,
)
.await;
}
#[cfg(test)]
@ -176,13 +178,10 @@ mod test {
.unwrap();
// Assert
let doc = service.inner().document_map.get(&TEST_URI);
let tree = service.inner().cst_map.get(&TEST_URI);
assert!(doc.is_some());
let doc = doc.unwrap();
assert_eq!(doc.to_string(), expected);
assert!(tree.is_some());
let tree = tree.unwrap();
let doc = service.inner().document_map.get(&TEST_URI).unwrap();
let rope = &doc.rope;
assert_eq!(rope.to_string(), expected);
let tree = &doc.tree;
assert_eq!(
tree.root_node().utf8_text(expected.as_bytes()).unwrap(),
expected

View file

@ -6,7 +6,7 @@ use tracing::info;
use tree_sitter::Parser;
use crate::{
Backend, QUERY_LANGUAGE, SymbolInfo,
Backend, DocumentData, QUERY_LANGUAGE, SymbolInfo,
handlers::diagnostic::get_diagnostics,
util::{TextProviderRope, get_language},
};
@ -20,10 +20,7 @@ pub async fn did_open(backend: &Backend, params: DidOpenTextDocumentParams) {
parser
.set_language(&QUERY_LANGUAGE)
.expect("Error loading Query grammar");
backend.document_map.insert(uri.clone(), rope.clone());
backend
.cst_map
.insert(uri.clone(), parser.parse(&contents, None).unwrap());
let tree = parser.parse(&contents, None).unwrap();
// Initialize language info
let mut symbols_vec: Vec<SymbolInfo> = vec![];
@ -63,9 +60,7 @@ pub async fn did_open(backend: &Backend, params: DidOpenTextDocumentParams) {
.collect(),
);
}
if symbols_set.contains(&symbol_info)
|| !(lang.node_kind_is_visible(i) || lang.node_kind_is_supertype(i))
{
if symbols_set.contains(&symbol_info) || !(lang.node_kind_is_visible(i) || supertype) {
continue;
}
symbols_set.insert(symbol_info.clone());
@ -80,34 +75,48 @@ pub async fn did_open(backend: &Backend, params: DidOpenTextDocumentParams) {
}
}
}
backend.symbols_vec_map.insert(uri.to_owned(), symbols_vec);
backend.symbols_set_map.insert(uri.to_owned(), symbols_set);
backend.fields_vec_map.insert(uri.to_owned(), fields_vec);
backend.fields_set_map.insert(uri.to_owned(), fields_set);
backend
.supertype_map_map
.insert(uri.to_owned(), supertype_map);
backend.document_map.insert(
uri.clone(),
DocumentData {
rope,
tree,
symbols_set,
symbols_vec,
fields_set,
fields_vec,
supertype_map,
},
);
// Publish diagnostics
if let (Some(tree), Some(symbols), Some(fields), Some(supertypes), options) = (
backend.cst_map.get(uri),
backend.symbols_set_map.get(uri),
backend.fields_set_map.get(uri),
backend.supertype_map_map.get(uri),
if let (
Some(DocumentData {
symbols_set,
fields_set,
supertype_map,
rope,
tree,
fields_vec: _,
symbols_vec: _,
}),
options,
) = (
backend.document_map.get(uri).as_deref(),
backend.options.read().await,
) {
let provider = TextProviderRope(&rope);
let provider = TextProviderRope(rope);
backend
.client
.publish_diagnostics(
uri.clone(),
get_diagnostics(
&tree,
&rope,
tree,
rope,
&provider,
&symbols,
&fields,
&supertypes,
symbols_set,
fields_set,
supertype_map,
&options,
uri,
),
@ -154,13 +163,10 @@ mod test {
.unwrap();
// Assert
let doc_rope = service.inner().document_map.get(&TEST_URI);
assert!(doc_rope.is_some());
let doc_rope = doc_rope.unwrap();
let doc = service.inner().document_map.get(&TEST_URI).unwrap();
let doc_rope = &doc.rope;
assert_eq!(doc_rope.to_string(), source);
let tree = service.inner().cst_map.get(&TEST_URI);
assert!(tree.is_some());
let tree = tree.unwrap();
let tree = &doc.tree;
assert_eq!(
tree.root_node().utf8_text(source.as_bytes()).unwrap(),
doc_rope.to_string()

View file

@ -18,14 +18,12 @@ pub async fn document_highlight(
) -> Result<Option<Vec<DocumentHighlight>>> {
let uri = &params.text_document_position_params.text_document.uri;
let Some(tree) = backend.cst_map.get(uri) else {
warn!("No CST built for URI: {uri:?}");
return Ok(None);
};
let Some(rope) = &backend.document_map.get(uri) else {
warn!("No document built for URI: {uri:?}");
let Some(doc) = backend.document_map.get(uri) else {
warn!("No document for URI: {uri:?}");
return Ok(None);
};
let rope = &doc.rope;
let tree = &doc.tree;
let cur_pos = params
.text_document_position_params
.position

View file

@ -17,11 +17,12 @@ pub async fn document_symbol(
let uri = &params.text_document.uri;
let mut document_symbols = vec![];
let (Some(rope), Some(tree)) = (&backend.document_map.get(uri), backend.cst_map.get(uri))
else {
let Some(doc) = backend.document_map.get(uri) else {
warn!("No document found for URI: {uri} when searching for document symbols.");
return Ok(None);
};
let rope = &doc.rope;
let tree = &doc.tree;
let provider = &TextProviderRope(rope);
let mut cursor = QueryCursor::new();

View file

@ -35,11 +35,12 @@ async fn check_impossible_patterns(backend: &Backend, params: ExecuteCommandPara
}
};
let (Some(rope), Some(tree)) = (&backend.document_map.get(&uri), backend.cst_map.get(&uri))
else {
let Some(doc) = &backend.document_map.get(&uri) else {
warn!("No document built for URI '{uri}' when executing check impossible patterns command");
return;
};
let rope = &doc.rope;
let tree = &doc.tree;
let options = &backend.options.read().await;
let Some(lang) = util::get_language(&uri, options) else {
error!("Could not retrieve language for path: '{}'", uri.path());
@ -83,23 +84,13 @@ async fn check_impossible_patterns(backend: &Backend, params: ExecuteCommandPara
}
}
}
if let (Some(symbols), Some(fields), Some(supertypes)) = (
backend.symbols_set_map.get(&uri),
backend.fields_set_map.get(&uri),
backend.supertype_map_map.get(&uri),
) {
let provider = TextProviderRope(rope);
diagnostics.append(&mut get_diagnostics(
&tree,
rope,
&provider,
&symbols,
&fields,
&supertypes,
options,
&uri,
));
}
let symbols = &doc.symbols_set;
let fields = &doc.fields_set;
let supertypes = &doc.supertype_map;
let provider = TextProviderRope(rope);
diagnostics.append(&mut get_diagnostics(
tree, rope, &provider, symbols, fields, supertypes, options, &uri,
));
backend
.client

View file

@ -6,6 +6,7 @@ use regex::Regex;
use ropey::Rope;
use tower_lsp::jsonrpc::Result;
use tower_lsp::lsp_types::{DocumentFormattingParams, Range, TextEdit};
use tracing::warn;
use tree_sitter::{
Node, Query, QueryCursor, QueryMatch, QueryPredicateArg, StreamingIterator as _, Tree,
TreeCursor,
@ -19,18 +20,16 @@ pub async fn formatting(
backend: &Backend,
params: DocumentFormattingParams,
) -> Result<Option<Vec<TextEdit>>> {
let uri = params.text_document.uri;
let tree = match backend.cst_map.get(&uri) {
None => return Ok(None),
Some(val) => val,
};
let rope = match backend.document_map.get(&uri) {
None => return Ok(None),
Some(val) => val,
let uri = &params.text_document.uri;
let Some(doc) = backend.document_map.get(uri) else {
warn!("No document for URI: {uri}");
return Ok(None);
};
let rope = &doc.rope;
let tree = &doc.tree;
if let Some(formatted_doc) = format_document(&rope, &tree) {
Ok(Some(diff(rope.to_string().as_str(), &formatted_doc, &rope)))
if let Some(formatted_doc) = format_document(rope, tree) {
Ok(Some(diff(rope.to_string().as_str(), &formatted_doc, rope)))
} else {
Ok(None)
}
@ -413,6 +412,6 @@ mod test {
// Assert
let doc = service.inner().document_map.get(&TEST_URI).unwrap();
assert_eq!(doc.to_string(), String::from(after));
assert_eq!(doc.rope.to_string(), String::from(after));
}
}

View file

@ -19,14 +19,12 @@ pub async fn goto_definition(
) -> Result<Option<GotoDefinitionResponse>> {
info!("ts_query_ls goto_definition: {params:?}");
let uri = &params.text_document_position_params.text_document.uri;
let Some(tree) = backend.cst_map.get(uri) else {
warn!("No CST built for URI: {uri:?}");
return Ok(None);
};
let Some(rope) = &backend.document_map.get(uri) else {
warn!("No document built for URI: {uri:?}");
let Some(doc) = backend.document_map.get(uri) else {
warn!("No document for URI: {uri:?}");
return Ok(None);
};
let rope = &doc.rope;
let tree = &doc.tree;
let cur_pos = params.text_document_position_params.position;
let Some(current_node) = get_current_capture_node(tree.root_node(), cur_pos.to_ts_point(rope))
else {

View file

@ -2,6 +2,7 @@ use tower_lsp::{
jsonrpc::Result,
lsp_types::{Hover, HoverContents, HoverParams, MarkupContent, MarkupKind},
};
use tracing::warn;
use crate::{
Backend, SymbolInfo,
@ -13,80 +14,83 @@ pub async fn hover(backend: &Backend, params: HoverParams) -> Result<Option<Hove
let position = params.text_document_position_params.position;
let options = backend.options.read().await;
if let (Some(tree), Some(rope), Some(supertypes)) = (
backend.cst_map.get(uri),
&backend.document_map.get(uri),
backend.supertype_map_map.get(uri),
) {
let Some(node) = tree
.root_node()
.descendant_for_point_range(position.to_ts_point(rope), position.to_ts_point(rope))
else {
return Ok(None);
};
let node_text = node.text(rope);
let node_range = node.lsp_range(rope);
let sym = SymbolInfo {
label: node_text.clone(),
named: true,
};
let Some(doc) = backend.document_map.get(uri) else {
warn!("No document for uri: {uri}");
return Ok(None);
};
let node_parent = node.parent();
if node.kind() == "identifier"
&& node_parent.is_some_and(|p| {
p.kind() == "named_node" || p.kind() == "missing_node" || p.kind() == "predicate"
})
{
let node_parent = node_parent.unwrap();
if node_parent.kind() == "predicate" {
let is_predicate = node_parent
.named_child(1)
.is_some_and(|c| c.text(rope) == "?");
let validator = if is_predicate {
&options.valid_predicates
} else {
&options.valid_directives
};
if let Some(predicate) = validator.get(&node_text) {
let mut value =
format!("{}\n\n---\n\n## Parameters:\n\n", predicate.description);
for param in &predicate.parameters {
value += format!("- Type: `{}` ({})\n", param.type_, param.arity).as_str();
if let Some(desc) = &param.description {
value += format!(" - {}\n", desc).as_str();
}
let tree = &doc.tree;
let rope = &doc.rope;
let supertypes = &doc.supertype_map;
let Some(node) = tree
.root_node()
.descendant_for_point_range(position.to_ts_point(rope), position.to_ts_point(rope))
else {
return Ok(None);
};
let node_text = node.text(rope);
let node_range = node.lsp_range(rope);
let sym = SymbolInfo {
label: node_text.clone(),
named: true,
};
let node_parent = node.parent();
if node.kind() == "identifier"
&& node_parent.is_some_and(|p| {
p.kind() == "named_node" || p.kind() == "missing_node" || p.kind() == "predicate"
})
{
let node_parent = node_parent.unwrap();
if node_parent.kind() == "predicate" {
let is_predicate = node_parent
.named_child(1)
.is_some_and(|c| c.text(rope) == "?");
let validator = if is_predicate {
&options.valid_predicates
} else {
&options.valid_directives
};
if let Some(predicate) = validator.get(&node_text) {
let mut value = format!("{}\n\n---\n\n## Parameters:\n\n", predicate.description);
for param in &predicate.parameters {
value += format!("- Type: `{}` ({})\n", param.type_, param.arity).as_str();
if let Some(desc) = &param.description {
value += format!(" - {}\n", desc).as_str();
}
return Ok(Some(Hover {
range: Some(node_range),
contents: HoverContents::Markup(MarkupContent {
kind: MarkupKind::Markdown,
value,
}),
}));
}
return Ok(None);
}
if let Some(subtypes) = supertypes.get(&sym).and_then(|subtypes| {
(subtypes.iter().fold(
format!("Subtypes of `({node_text})`:\n\n```query"),
|acc, subtype| format!("{acc}\n{}", subtype),
) + "\n```")
.into()
}) {
return Ok(Some(Hover {
range: Some(node_range),
contents: HoverContents::Markup(MarkupContent {
kind: MarkupKind::Markdown,
value: subtypes,
value,
}),
}));
} else if node_text == "ERROR" {
return Ok(Some(Hover {
range: Some(node_range),
contents: HoverContents::Markup(MarkupContent {
kind: MarkupKind::Markdown,
value: String::from(
r"### The `ERROR` Node
}
return Ok(None);
}
if let Some(subtypes) = supertypes.get(&sym).and_then(|subtypes| {
(subtypes.iter().fold(
format!("Subtypes of `({node_text})`:\n\n```query"),
|acc, subtype| format!("{acc}\n{}", subtype),
) + "\n```")
.into()
}) {
return Ok(Some(Hover {
range: Some(node_range),
contents: HoverContents::Markup(MarkupContent {
kind: MarkupKind::Markdown,
value: subtypes,
}),
}));
} else if node_text == "ERROR" {
return Ok(Some(Hover {
range: Some(node_range),
contents: HoverContents::Markup(MarkupContent {
kind: MarkupKind::Markdown,
value: String::from(
r"### The `ERROR` Node
When the parser encounters text it does not recognize, it represents this node
as `(ERROR)` in the syntax tree. These error nodes can be queried just like
@ -95,17 +99,17 @@ normal nodes:
```query
(ERROR) @error-node
```",
),
}),
}));
}
} else if node.kind() == "MISSING" {
return Ok(Some(Hover {
range: Some(node_range),
contents: HoverContents::Markup(MarkupContent {
kind: MarkupKind::Markdown,
value: String::from(
r"### The `MISSING` Node
),
}),
}));
}
} else if node.kind() == "MISSING" {
return Ok(Some(Hover {
range: Some(node_range),
contents: HoverContents::Markup(MarkupContent {
kind: MarkupKind::Markdown,
value: String::from(
r"### The `MISSING` Node
If the parser is able to recover from erroneous text by inserting a missing token and then reducing, it will insert that
missing node in the final tree so long as that tree has the lowest error cost. These missing nodes appear as seemingly normal
@ -116,16 +120,16 @@ using `(MISSING)`:
```query
(MISSING) @missing-node
```",
),
}),
}));
} else if node.kind() == "_" {
return Ok(Some(Hover {
range: Some(node_range),
contents: HoverContents::Markup(MarkupContent {
kind: MarkupKind::Markdown,
value: String::from(
r"### The Wildcard Node
),
}),
}));
} else if node.kind() == "_" {
return Ok(Some(Hover {
range: Some(node_range),
contents: HoverContents::Markup(MarkupContent {
kind: MarkupKind::Markdown,
value: String::from(
r"### The Wildcard Node
A wildcard node is represented with an underscore (`_`), it matches any node.
This is similar to `.` in regular expressions.
@ -137,28 +141,27 @@ For example, this pattern would match any node inside a call:
```query
(call (_) @call.inner)
```",
),
),
}),
}));
} else if let Some(capture) =
get_current_capture_node(tree.root_node(), position.to_ts_point(rope))
{
let options = backend.options.read().await;
if let Some(description) = uri_to_basename(uri).and_then(|base| {
options
.valid_captures
.get(&base)
.and_then(|c| c.get(&capture.text(rope)[1..].to_string()))
}) {
let value = format!("## `{}`\n\n{}", capture.text(rope), description);
return Ok(Some(Hover {
range: Some(capture.lsp_range(rope)),
contents: HoverContents::Markup(MarkupContent {
kind: MarkupKind::Markdown,
value,
}),
}));
} else if let Some(capture) =
get_current_capture_node(tree.root_node(), position.to_ts_point(rope))
{
let options = backend.options.read().await;
if let Some(description) = uri_to_basename(uri).and_then(|base| {
options
.valid_captures
.get(&base)
.and_then(|c| c.get(&capture.text(rope)[1..].to_string()))
}) {
let value = format!("## `{}`\n\n{}", capture.text(rope), description);
return Ok(Some(Hover {
range: Some(capture.lsp_range(rope)),
contents: HoverContents::Markup(MarkupContent {
kind: MarkupKind::Markdown,
value,
}),
}));
}
}
}

View file

@ -66,12 +66,6 @@ mod test {
let (mut service, _socket) = LspService::build(|client| Backend {
client,
document_map: Default::default(),
cst_map: Default::default(),
symbols_set_map: Default::default(),
symbols_vec_map: Default::default(),
fields_set_map: Default::default(),
fields_vec_map: Default::default(),
supertype_map_map: Default::default(),
workspace_uris: Default::default(),
options: Default::default(),
})

View file

@ -15,14 +15,12 @@ pub async fn references(
) -> Result<Option<Vec<Location>>> {
let uri = &params.text_document_position.text_document.uri;
let Some(tree) = backend.cst_map.get(uri) else {
warn!("No CST built for URI: {uri:?}");
return Ok(None);
};
let Some(rope) = &backend.document_map.get(uri) else {
warn!("No document built for URI: {uri:?}");
let Some(doc) = backend.document_map.get(uri) else {
warn!("No document for URI: {uri:?}");
return Ok(None);
};
let rope = &doc.rope;
let tree = &doc.tree;
let cur_pos = params.text_document_position.position.to_ts_point(rope);
let current_node = match get_current_capture_node(tree.root_node(), cur_pos) {
None => return Ok(None),

View file

@ -24,14 +24,12 @@ static IDENTIFIER_PATTERN: LazyLock<Regex> =
pub async fn rename(backend: &Backend, params: RenameParams) -> Result<Option<WorkspaceEdit>> {
let uri = &params.text_document_position.text_document.uri;
let Some(tree) = backend.cst_map.get(uri) else {
warn!("No CST built for URI: {uri:?}");
return Ok(None);
};
let Some(rope) = &backend.document_map.get(uri) else {
let Some(doc) = backend.document_map.get(uri) else {
warn!("No document built for URI: {uri:?}");
return Ok(None);
};
let rope = &doc.rope;
let tree = &doc.tree;
let current_node = match get_current_capture_node(
tree.root_node(),
params.text_document_position.position.to_ts_point(rope),

View file

@ -4,6 +4,7 @@ use tower_lsp::{
jsonrpc::Result,
lsp_types::{SemanticToken, SemanticTokens, SemanticTokensParams, SemanticTokensResult},
};
use tracing::warn;
use tree_sitter::{Query, QueryCursor, StreamingIterator};
use crate::{
@ -24,58 +25,61 @@ pub async fn semantic_tokens_full(
params: SemanticTokensParams,
) -> Result<Option<SemanticTokensResult>> {
let uri = &params.text_document.uri;
let Some(doc) = backend.document_map.get(uri) else {
warn!("Could not find document for which to retrieve semantic tokens");
return Ok(None);
};
let mut tokens = Vec::new();
if let (Some(tree), Some(rope), Some(supertypes)) = (
backend.cst_map.get(uri),
&backend.document_map.get(uri),
backend.supertype_map_map.get(uri),
) {
let query = &SEM_TOK_QUERY;
let mut cursor = QueryCursor::new();
let provider = TextProviderRope(rope);
let mut matches = cursor.matches(query, tree.root_node(), &provider);
let mut prev_line = 0;
let mut prev_col = 0;
while let Some(match_) = matches.next() {
for cap in match_.captures.iter() {
let node = &cap.node;
let node_text = node.text(rope);
let start_row = node.start_position().row as u32;
let start_col = node.start_position().column as u32;
let delta_line = start_row - prev_line;
let length = node.byte_range().len() as u32;
let delta_start = if start_row - prev_line == 0 {
start_col - prev_col
} else {
start_col
};
if node_text == "ERROR" {
tokens.push(SemanticToken {
delta_line,
delta_start,
length,
token_type: 1,
token_modifiers_bitset: 1,
});
prev_line = start_row;
prev_col = start_col;
} else if supertypes.contains_key(&SymbolInfo {
label: node_text,
named: true,
}) {
tokens.push(SemanticToken {
delta_line,
delta_start,
length,
token_type: 0,
token_modifiers_bitset: 0,
});
prev_line = start_row;
prev_col = start_col;
}
let tree = &doc.tree;
let rope = &doc.rope;
let supertypes = &doc.supertype_map;
let query = &SEM_TOK_QUERY;
let mut cursor = QueryCursor::new();
let provider = TextProviderRope(rope);
let mut matches = cursor.matches(query, tree.root_node(), &provider);
let mut prev_line = 0;
let mut prev_col = 0;
while let Some(match_) = matches.next() {
for cap in match_.captures.iter() {
let node = &cap.node;
let node_text = node.text(rope);
let start_row = node.start_position().row as u32;
let start_col = node.start_position().column as u32;
let delta_line = start_row - prev_line;
let length = node.byte_range().len() as u32;
let delta_start = if start_row - prev_line == 0 {
start_col - prev_col
} else {
start_col
};
if node_text == "ERROR" {
tokens.push(SemanticToken {
delta_line,
delta_start,
length,
token_type: 1,
token_modifiers_bitset: 1,
});
prev_line = start_row;
prev_col = start_col;
} else if supertypes.contains_key(&SymbolInfo {
label: node_text,
named: true,
}) {
tokens.push(SemanticToken {
delta_line,
delta_start,
length,
token_type: 0,
token_modifiers_bitset: 0,
});
prev_line = start_row;
prev_col = start_col;
}
}
}
Ok(Some(SemanticTokensResult::Tokens(SemanticTokens {
result_id: None,
data: tokens,

View file

@ -6,6 +6,7 @@ use std::{
env,
fs::{self},
path::{Path, PathBuf},
str,
sync::{Arc, LazyLock, RwLock, atomic::AtomicI32},
};
use ts_query_ls::Options;
@ -93,15 +94,19 @@ impl fmt::Display for SymbolInfo {
}
}
struct DocumentData {
symbols_set: HashSet<SymbolInfo>,
symbols_vec: Vec<SymbolInfo>,
fields_set: HashSet<String>,
fields_vec: Vec<String>,
supertype_map: HashMap<SymbolInfo, BTreeSet<SymbolInfo>>,
rope: Rope,
tree: Tree,
}
struct Backend {
client: Client,
document_map: DashMap<Url, Rope>,
cst_map: DashMap<Url, Tree>,
symbols_set_map: DashMap<Url, HashSet<SymbolInfo>>,
symbols_vec_map: DashMap<Url, Vec<SymbolInfo>>,
fields_set_map: DashMap<Url, HashSet<String>>,
fields_vec_map: DashMap<Url, Vec<String>>,
supertype_map_map: DashMap<Url, HashMap<SymbolInfo, BTreeSet<SymbolInfo>>>,
document_map: DashMap<Url, DocumentData>,
options: Arc<tokio::sync::RwLock<Options>>,
workspace_uris: Arc<RwLock<Vec<Url>>>,
}
@ -471,12 +476,6 @@ async fn main() {
let (service, socket) = LspService::build(|client| Backend {
client,
document_map: Default::default(),
cst_map: Default::default(),
symbols_set_map: Default::default(),
symbols_vec_map: Default::default(),
fields_set_map: Default::default(),
fields_vec_map: Default::default(),
supertype_map_map: Default::default(),
workspace_uris: Default::default(),
options,
})

View file

@ -20,7 +20,7 @@ pub mod helpers {
},
};
use crate::{Backend, Options, QUERY_LANGUAGE, SymbolInfo};
use crate::{Backend, DocumentData, Options, QUERY_LANGUAGE, SymbolInfo};
pub static TEST_URI: LazyLock<Url> =
LazyLock::new(|| Url::parse("file:///tmp/test.scm").unwrap());
@ -54,60 +54,36 @@ pub mod helpers {
let options = Arc::new(tokio::sync::RwLock::new(options.clone()));
let (mut service, _socket) = LspService::build(|client| Backend {
client,
document_map: DashMap::from_iter(
documents
.iter()
.map(|(uri, source, _, _, _)| (uri.clone(), Rope::from(*source))),
),
cst_map: DashMap::from_iter(
documents.iter().map(|(uri, source, _, _, _)| {
(uri.clone(), parser.parse(*source, None).unwrap())
}),
),
symbols_set_map: DashMap::from_iter(
documents.iter().map(|(uri, _, symbols, _, _)| {
(uri.clone(), HashSet::from_iter(symbols.clone()))
}),
),
symbols_vec_map: DashMap::from_iter(
documents
.iter()
.map(|(uri, _, symbols, _, _)| (uri.clone(), symbols.clone())),
),
fields_set_map: DashMap::from_iter(documents.iter().map(|(uri, _, _, fields, _)| {
(
uri.clone(),
HashSet::from_iter(fields.iter().map(ToString::to_string)),
)
})),
fields_vec_map: DashMap::from_iter(documents.iter().map(|(uri, _, _, fields, _)| {
(
uri.clone(),
fields.clone().iter().map(ToString::to_string).collect(),
)
})),
supertype_map_map: DashMap::from_iter(documents.iter().map(
|(uri, _, _, _, supertypes)| {
document_map: DashMap::from_iter(documents.iter().map(
|(uri, source, symbols, fields, supertypes)| {
(
uri.clone(),
HashMap::from_iter(supertypes.iter().map(|supertype| {
(
SymbolInfo {
named: true,
label: String::from(*supertype),
},
BTreeSet::from([
DocumentData {
rope: Rope::from(*source),
tree: parser.parse(*source, None).unwrap(),
symbols_set: HashSet::from_iter(symbols.clone()),
symbols_vec: symbols.clone(),
fields_set: HashSet::from_iter(fields.iter().map(ToString::to_string)),
fields_vec: fields.clone().iter().map(ToString::to_string).collect(),
supertype_map: HashMap::from_iter(supertypes.iter().map(|supertype| {
(
SymbolInfo {
named: true,
label: String::from("test"),
label: String::from(*supertype),
},
SymbolInfo {
named: true,
label: String::from("test2"),
},
]),
)
})),
BTreeSet::from([
SymbolInfo {
named: true,
label: String::from("test"),
},
SymbolInfo {
named: true,
label: String::from("test2"),
},
]),
)
})),
},
)
},
)),
@ -271,82 +247,33 @@ mod test {
let actual_options = backend.options.read().await;
assert_eq!(actual_options.deref(), options);
assert_eq!(backend.document_map.len(), documents.len());
assert_eq!(backend.cst_map.len(), documents.len());
assert_eq!(backend.symbols_vec_map.len(), documents.len());
assert_eq!(backend.symbols_set_map.len(), documents.len());
assert_eq!(backend.fields_vec_map.len(), documents.len());
assert_eq!(backend.fields_set_map.len(), documents.len());
for (uri, doc, symbols, fields, supertypes) in documents {
for (uri, source, symbols, fields, supertypes) in documents {
let doc = backend.document_map.get(uri).unwrap();
assert_eq!(doc.rope.to_string(), (*source).to_string());
assert_eq!(
backend.document_map.get(uri).unwrap().to_string(),
(*doc).to_string()
);
assert_eq!(
backend
.cst_map
.get(uri)
.unwrap()
doc.tree
.root_node()
.utf8_text((*doc).to_string().as_bytes())
.utf8_text((*source).to_string().as_bytes())
.unwrap(),
(*doc).to_string()
);
assert!(
backend
.symbols_vec_map
.get(uri)
.is_some_and(|v| v.len() == symbols.len())
);
assert!(
backend
.symbols_set_map
.get(uri)
.is_some_and(|v| v.len() == symbols.len())
(*source).to_string()
);
assert!(doc.symbols_vec.len() == symbols.len());
assert!(doc.symbols_set.len() == symbols.len());
for symbol in symbols {
assert!(backend.symbols_vec_map.get(uri).unwrap().contains(symbol));
assert!(backend.symbols_set_map.get(uri).unwrap().contains(symbol));
assert!(doc.symbols_vec.contains(symbol));
assert!(doc.symbols_set.contains(symbol));
}
assert!(
backend
.fields_vec_map
.get(uri)
.is_some_and(|v| v.len() == fields.len())
);
assert!(
backend
.fields_set_map
.get(uri)
.is_some_and(|v| v.len() == fields.len())
);
assert!(doc.fields_vec.len() == fields.len());
assert!(doc.fields_set.len() == fields.len());
for field in fields {
assert!(
backend
.fields_vec_map
.get(uri)
.unwrap()
.contains(&field.to_string())
);
assert!(
backend
.fields_set_map
.get(uri)
.unwrap()
.contains(&field.to_string())
);
assert!(doc.fields_vec.contains(&field.to_string()));
assert!(doc.fields_set.contains(*field));
}
assert!(backend.supertype_map_map.get(uri).is_some());
for supertype in supertypes {
assert!(
backend
.supertype_map_map
.get(uri)
.unwrap()
.contains_key(&SymbolInfo {
named: true,
label: String::from(*supertype)
})
)
assert!(doc.supertype_map.contains_key(&SymbolInfo {
named: true,
label: String::from(*supertype)
}))
}
}
}