From e9d2ff532b23d2ed9205e231ab7ff4180fcbe835 Mon Sep 17 00:00:00 2001 From: Shuhei Takahashi Date: Tue, 23 Dec 2025 01:44:31 +0900 Subject: [PATCH] Improve context-aware completion --- src/server/providers/completion.rs | 276 +++++++++++++++++++---------- 1 file changed, 183 insertions(+), 93 deletions(-) diff --git a/src/server/providers/completion.rs b/src/server/providers/completion.rs index a8efc19..c7f7815 100644 --- a/src/server/providers/completion.rs +++ b/src/server/providers/completion.rs @@ -145,35 +145,65 @@ impl Template<'_> { } } -fn is_statement_context(parsed_root: &Block<'_>, offset: usize) -> bool { +#[derive(Clone, Copy, Debug, Eq, PartialEq)] +enum CompletionContext { + TopLevel, + Target, + Expression, +} + +fn compute_completion_context(parsed_root: &Block<'_>, offset: usize) -> CompletionContext { let parents: Vec<_> = parsed_root .walk() .filter(|node| node.span().start() <= offset && offset <= node.span().end()) .collect(); + let in_target = parents.iter().any(|node| { + matches!( + node.as_statement(), + Some(Statement::Call(call)) if + matches!( + &call.block, + Some(block) + if block.span().start() <= offset && offset <= block.span().end()) + ) + }); + let statement_context = if in_target { + CompletionContext::Target + } else { + CompletionContext::TopLevel + }; for node in parents.into_iter().rev() { if node.as_block().is_some() { - return true; + return statement_context; } if let Some(statement) = node.as_statement() { match statement { Statement::Assignment(assignment) => { let primary_span = assignment.lvalue.primary_identifier().span; - return offset <= primary_span.end(); + return if offset <= primary_span.end() { + statement_context + } else { + CompletionContext::Expression + }; } Statement::Call(call) => { let function_span = call.function.span; - return offset <= function_span.end(); + return if offset <= function_span.end() { + statement_context + } else { + CompletionContext::Expression + }; } Statement::Condition(_) => { - return false; + return CompletionContext::Expression; } Statement::Error(_) => { - return true; + return statement_context; } } } } - true + statement_context } async fn build_identifier_completions( @@ -256,33 +286,49 @@ async fn build_identifier_completions( }; // Enumerate builtins. - let builtin_function_items = BUILTINS - .functions + let builtin_function_items = BUILTINS.functions.iter().map(|symbol| CompletionItem { + label: symbol.name.to_string(), + kind: Some(CompletionItemKind::KEYWORD), + documentation: Some(Documentation::MarkupContent(MarkupContent { + kind: MarkupKind::Markdown, + value: symbol.doc.to_string(), + })), + ..Default::default() + }); + let builtin_target_items = BUILTINS.targets.iter().map(|symbol| CompletionItem { + label: symbol.name.to_string(), + kind: Some(CompletionItemKind::FUNCTION), + documentation: Some(Documentation::MarkupContent(MarkupContent { + kind: MarkupKind::Markdown, + value: symbol.doc.to_string(), + })), + ..Default::default() + }); + let predefined_variable_items = + BUILTINS + .predefined_variables + .iter() + .map(|symbol| CompletionItem { + label: symbol.name.to_string(), + kind: Some(CompletionItemKind::KEYWORD), + documentation: Some(Documentation::MarkupContent(MarkupContent { + kind: MarkupKind::Markdown, + value: symbol.doc.to_string(), + })), + ..Default::default() + }); + let target_variable_items = BUILTINS + .target_variables .iter() - .chain(BUILTINS.targets.iter()) .map(|symbol| CompletionItem { label: symbol.name.to_string(), - kind: Some(CompletionItemKind::FUNCTION), + kind: Some(CompletionItemKind::KEYWORD), documentation: Some(Documentation::MarkupContent(MarkupContent { kind: MarkupKind::Markdown, value: symbol.doc.to_string(), })), ..Default::default() }); - let builtin_variable_items = BUILTINS - .predefined_variables - .iter() - .chain(BUILTINS.target_variables.iter()) - .map(|symbol| CompletionItem { - label: symbol.name.to_string(), - kind: Some(CompletionItemKind::VARIABLE), - documentation: Some(Documentation::MarkupContent(MarkupContent { - kind: MarkupKind::Markdown, - value: symbol.doc.to_string(), - })), - ..Default::default() - }); - let builtin_items = builtin_variable_items.chain(builtin_function_items); // Keywords. let literal_items = ["true", "false"].map(|name| CompletionItem { @@ -296,25 +342,44 @@ async fn build_identifier_completions( ..Default::default() }); - if is_statement_context(current_file.parsed_root.get(), offset) { - // No external variables. - Ok(conditional_items - .into_iter() - .chain(builtin_items) - .chain(local_variable_items) - .chain(local_template_items) - .chain(imported_template_items) - .chain(workspace_template_items) - .collect()) - } else { - // No templates. - Ok(literal_items - .into_iter() - .chain(builtin_items) - .chain(local_variable_items) - .chain(imported_variable_items) - .chain(workspace_variable_items) - .collect()) + match compute_completion_context(current_file.parsed_root.get(), offset) { + CompletionContext::TopLevel => { + // No external variables and builtin variables. + Ok(conditional_items + .into_iter() + .chain(builtin_function_items) + .chain(builtin_target_items) + .chain(local_variable_items) + .chain(local_template_items) + .chain(imported_template_items) + .chain(workspace_template_items) + .collect()) + } + CompletionContext::Target => { + // No external variables. + Ok(conditional_items + .into_iter() + .chain(builtin_function_items) + .chain(builtin_target_items) + .chain(target_variable_items) + .chain(local_variable_items) + .chain(local_template_items) + .chain(imported_template_items) + .chain(workspace_template_items) + .collect()) + } + CompletionContext::Expression => { + // No templates. + Ok(literal_items + .into_iter() + .chain(builtin_function_items) + .chain(predefined_variable_items) + .chain(target_variable_items) + .chain(local_variable_items) + .chain(imported_variable_items) + .chain(workspace_variable_items) + .collect()) + } } } @@ -370,19 +435,15 @@ mod tests { use super::*; - #[tokio::test] - async fn test_smoke_statement_context() { + async fn run_completion(path: &Path, position: Position) -> impl Iterator { let response = completion( - &RequestContext::new_for_testing(Some(&testdata("workspaces/completion"))), + &RequestContext::new_for_testing(Some(path)), CompletionParams { text_document_position: TextDocumentPositionParams { text_document: TextDocumentIdentifier { - uri: Url::from_file_path(testdata("workspaces/completion/BUILD.gn")) - .unwrap(), + uri: Url::from_file_path(path).unwrap(), }, - // assert(true) - // ^ - position: Position::new(36, 4), + position, }, work_done_progress_params: Default::default(), partial_result_params: Default::default(), @@ -410,10 +471,69 @@ mod tests { duplicates.iter().sorted().join(", ") ); - // Check items. - let names: HashSet<_> = items.iter().map(|item| item.label.as_str()).collect(); + // Return names. + items.into_iter().map(|item| item.label) + } + + #[tokio::test] + async fn test_smoke_top_level_context() { + let names: HashSet<_> = run_completion( + &testdata("workspaces/completion/BUILD.gn"), + Position::new(38, 0), + ) + .await + .collect(); let expectation = [ + ("assert", true), + ("source_set", true), + ("current_cpu", false), + ("sources", false), + ("_config_variable", false), + ("config_template", true), + ("_config_template", false), + ("import_variable", false), + ("_import_variable", false), + ("import_template", true), + ("_import_template", false), + ("indirect_variable", false), + ("_indirect_variable", false), + ("indirect_template", true), + ("_indirect_template", false), + ("outer_variable", true), + ("_outer_variable", true), + ("outer_template", true), + ("_outer_template", true), + ("inner_variable", false), + ("_inner_variable", false), + ("inner_template", false), + ("_inner_template", false), + ("child_variable", false), + ("_child_variable", false), + ("child_template", false), + ("_child_template", false), + ]; + + for (name, want) in expectation { + let got = names.contains(name); + assert_eq!(got, want, "{name}: got {got}, want {want}"); + } + } + + #[tokio::test] + async fn test_smoke_template_context() { + let names: HashSet<_> = run_completion( + &testdata("workspaces/completion/BUILD.gn"), + Position::new(36, 4), + ) + .await + .collect(); + + let expectation = [ + ("assert", true), + ("source_set", true), + ("current_cpu", false), + ("sources", true), ("config_variable", false), ("_config_variable", false), ("config_template", true), @@ -447,49 +567,19 @@ mod tests { } #[tokio::test] - async fn test_smoke_non_statement_context() { - let response = completion( - &RequestContext::new_for_testing(Some(&testdata("workspaces/completion"))), - CompletionParams { - text_document_position: TextDocumentPositionParams { - text_document: TextDocumentIdentifier { - uri: Url::from_file_path(testdata("workspaces/completion/BUILD.gn")) - .unwrap(), - }, - // assert(true) - // ^ - position: Position::new(36, 11), - }, - work_done_progress_params: Default::default(), - partial_result_params: Default::default(), - context: Default::default(), - }, + async fn test_smoke_expression_context() { + let names: HashSet<_> = run_completion( + &testdata("workspaces/completion/BUILD.gn"), + Position::new(36, 11), ) .await - .unwrap() - .unwrap(); - - let CompletionResponse::Array(items) = response else { - panic!(); - }; - - // Don't return duplicates. - let duplicates: Vec<_> = items - .iter() - .filter(|item| item.label != "cflags" && item.label != "pool") - .map(|item| item.label.as_str()) - .duplicates() - .collect(); - assert!( - duplicates.is_empty(), - "Duplicates in completion items: {}", - duplicates.iter().sorted().join(", ") - ); - - // Check items. - let names: HashSet<_> = items.iter().map(|item| item.label.as_str()).collect(); + .collect(); let expectation = [ + ("assert", true), + ("source_set", false), + ("current_cpu", true), + ("sources", true), ("config_variable", true), ("_config_variable", false), ("config_template", false),