fix(auto-import): Prefer imports of matching types for argument lists

This commit is contained in:
Lukas Wirth 2025-04-08 08:59:57 +02:00
parent 588948f267
commit 7255ef1375
9 changed files with 291 additions and 109 deletions

View file

@ -1,14 +1,16 @@
use std::cmp::Reverse;
use hir::{Module, db::HirDatabase};
use either::Either;
use hir::{Module, Type, db::HirDatabase};
use ide_db::{
active_parameter::ActiveParameter,
helpers::mod_path_to_ast,
imports::{
import_assets::{ImportAssets, ImportCandidate, LocatedImport},
insert_use::{ImportScope, insert_use, insert_use_as_alias},
},
};
use syntax::{AstNode, Edition, NodeOrToken, SyntaxElement, ast};
use syntax::{AstNode, Edition, SyntaxNode, ast, match_ast};
use crate::{AssistContext, AssistId, Assists, GroupLabel};
@ -92,7 +94,7 @@ use crate::{AssistContext, AssistId, Assists, GroupLabel};
pub(crate) fn auto_import(acc: &mut Assists, ctx: &AssistContext<'_>) -> Option<()> {
let cfg = ctx.config.import_path_config();
let (import_assets, syntax_under_caret) = find_importable_node(ctx)?;
let (import_assets, syntax_under_caret, expected) = find_importable_node(ctx)?;
let mut proposed_imports: Vec<_> = import_assets
.search_for_imports(&ctx.sema, cfg, ctx.config.insert_use.prefix_kind)
.collect();
@ -100,17 +102,8 @@ pub(crate) fn auto_import(acc: &mut Assists, ctx: &AssistContext<'_>) -> Option<
return None;
}
let range = match &syntax_under_caret {
NodeOrToken::Node(node) => ctx.sema.original_range(node).range,
NodeOrToken::Token(token) => token.text_range(),
};
let scope = ImportScope::find_insert_use_container(
&match syntax_under_caret {
NodeOrToken::Node(it) => it,
NodeOrToken::Token(it) => it.parent()?,
},
&ctx.sema,
)?;
let range = ctx.sema.original_range(&syntax_under_caret).range;
let scope = ImportScope::find_insert_use_container(&syntax_under_caret, &ctx.sema)?;
// we aren't interested in different namespaces
proposed_imports.sort_by(|a, b| a.import_path.cmp(&b.import_path));
@ -118,8 +111,9 @@ pub(crate) fn auto_import(acc: &mut Assists, ctx: &AssistContext<'_>) -> Option<
let current_module = ctx.sema.scope(scope.as_syntax_node()).map(|scope| scope.module());
// prioritize more relevant imports
proposed_imports
.sort_by_key(|import| Reverse(relevance_score(ctx, import, current_module.as_ref())));
proposed_imports.sort_by_key(|import| {
Reverse(relevance_score(ctx, import, expected.as_ref(), current_module.as_ref()))
});
let edition = current_module.map(|it| it.krate().edition(ctx.db())).unwrap_or(Edition::CURRENT);
let group_label = group_label(import_assets.import_candidate());
@ -180,22 +174,61 @@ pub(crate) fn auto_import(acc: &mut Assists, ctx: &AssistContext<'_>) -> Option<
pub(super) fn find_importable_node(
ctx: &AssistContext<'_>,
) -> Option<(ImportAssets, SyntaxElement)> {
) -> Option<(ImportAssets, SyntaxNode, Option<Type>)> {
// Deduplicate this with the `expected_type_and_name` logic for completions
let expected = |expr_or_pat: Either<ast::Expr, ast::Pat>| match expr_or_pat {
Either::Left(expr) => {
let parent = expr.syntax().parent()?;
// FIXME: Expand this
match_ast! {
match parent {
ast::ArgList(list) => {
ActiveParameter::at_arg(
&ctx.sema,
list,
expr.syntax().text_range().start(),
).map(|ap| ap.ty)
},
ast::LetStmt(stmt) => {
ctx.sema.type_of_pat(&stmt.pat()?).map(|t| t.original)
},
_ => None,
}
}
}
Either::Right(pat) => {
let parent = pat.syntax().parent()?;
// FIXME: Expand this
match_ast! {
match parent {
ast::LetStmt(stmt) => {
ctx.sema.type_of_expr(&stmt.initializer()?).map(|t| t.original)
},
_ => None,
}
}
}
};
if let Some(path_under_caret) = ctx.find_node_at_offset_with_descend::<ast::Path>() {
let expected =
path_under_caret.top_path().syntax().parent().and_then(Either::cast).and_then(expected);
ImportAssets::for_exact_path(&path_under_caret, &ctx.sema)
.zip(Some(path_under_caret.syntax().clone().into()))
.map(|it| (it, path_under_caret.syntax().clone(), expected))
} else if let Some(method_under_caret) =
ctx.find_node_at_offset_with_descend::<ast::MethodCallExpr>()
{
let expected = expected(Either::Left(method_under_caret.clone().into()));
ImportAssets::for_method_call(&method_under_caret, &ctx.sema)
.zip(Some(method_under_caret.syntax().clone().into()))
.map(|it| (it, method_under_caret.syntax().clone(), expected))
} else if ctx.find_node_at_offset_with_descend::<ast::Param>().is_some() {
None
} else if let Some(pat) = ctx
.find_node_at_offset_with_descend::<ast::IdentPat>()
.filter(ast::IdentPat::is_simple_ident)
{
ImportAssets::for_ident_pat(&ctx.sema, &pat).zip(Some(pat.syntax().clone().into()))
let expected = expected(Either::Right(pat.clone().into()));
ImportAssets::for_ident_pat(&ctx.sema, &pat).map(|it| (it, pat.syntax().clone(), expected))
} else {
None
}
@ -219,6 +252,7 @@ fn group_label(import_candidate: &ImportCandidate) -> GroupLabel {
pub(crate) fn relevance_score(
ctx: &AssistContext<'_>,
import: &LocatedImport,
expected: Option<&Type>,
current_module: Option<&Module>,
) -> i32 {
let mut score = 0;
@ -230,6 +264,35 @@ pub(crate) fn relevance_score(
hir::ItemInNs::Macros(makro) => Some(makro.module(db)),
};
if let Some(expected) = expected {
let ty = match import.item_to_import {
hir::ItemInNs::Types(module_def) | hir::ItemInNs::Values(module_def) => {
match module_def {
hir::ModuleDef::Function(function) => Some(function.ret_type(ctx.db())),
hir::ModuleDef::Adt(adt) => Some(match adt {
hir::Adt::Struct(it) => it.ty(ctx.db()),
hir::Adt::Union(it) => it.ty(ctx.db()),
hir::Adt::Enum(it) => it.ty(ctx.db()),
}),
hir::ModuleDef::Variant(variant) => Some(variant.constructor_ty(ctx.db())),
hir::ModuleDef::Const(it) => Some(it.ty(ctx.db())),
hir::ModuleDef::Static(it) => Some(it.ty(ctx.db())),
hir::ModuleDef::TypeAlias(it) => Some(it.ty(ctx.db())),
hir::ModuleDef::BuiltinType(it) => Some(it.ty(ctx.db())),
_ => None,
}
}
hir::ItemInNs::Macros(_) => None,
};
if let Some(ty) = ty {
if ty == *expected {
score = 100000;
} else if ty.could_unify_with(ctx.db(), expected) {
score = 10000;
}
}
}
match item_module.zip(current_module) {
// get the distance between the imported path and the current module
// (prefer items that are more local)
@ -554,7 +617,7 @@ mod baz {
}
",
r"
use PubMod3::PubStruct;
use PubMod1::PubStruct;
PubStruct
@ -1722,4 +1785,96 @@ mod foo {
",
);
}
#[test]
fn prefers_type_match() {
check_assist(
auto_import,
r"
mod sync { pub mod atomic { pub enum Ordering { V } } }
mod cmp { pub enum Ordering { V } }
fn takes_ordering(_: sync::atomic::Ordering) {}
fn main() {
takes_ordering(Ordering$0);
}
",
r"
use sync::atomic::Ordering;
mod sync { pub mod atomic { pub enum Ordering { V } } }
mod cmp { pub enum Ordering { V } }
fn takes_ordering(_: sync::atomic::Ordering) {}
fn main() {
takes_ordering(Ordering);
}
",
);
check_assist(
auto_import,
r"
mod sync { pub mod atomic { pub enum Ordering { V } } }
mod cmp { pub enum Ordering { V } }
fn takes_ordering(_: cmp::Ordering) {}
fn main() {
takes_ordering(Ordering$0);
}
",
r"
use cmp::Ordering;
mod sync { pub mod atomic { pub enum Ordering { V } } }
mod cmp { pub enum Ordering { V } }
fn takes_ordering(_: cmp::Ordering) {}
fn main() {
takes_ordering(Ordering);
}
",
);
}
#[test]
fn prefers_type_match2() {
check_assist(
auto_import,
r"
mod sync { pub mod atomic { pub enum Ordering { V } } }
mod cmp { pub enum Ordering { V } }
fn takes_ordering(_: sync::atomic::Ordering) {}
fn main() {
takes_ordering(Ordering$0::V);
}
",
r"
use sync::atomic::Ordering;
mod sync { pub mod atomic { pub enum Ordering { V } } }
mod cmp { pub enum Ordering { V } }
fn takes_ordering(_: sync::atomic::Ordering) {}
fn main() {
takes_ordering(Ordering::V);
}
",
);
check_assist(
auto_import,
r"
mod sync { pub mod atomic { pub enum Ordering { V } } }
mod cmp { pub enum Ordering { V } }
fn takes_ordering(_: cmp::Ordering) {}
fn main() {
takes_ordering(Ordering$0::V);
}
",
r"
use cmp::Ordering;
mod sync { pub mod atomic { pub enum Ordering { V } } }
mod cmp { pub enum Ordering { V } }
fn takes_ordering(_: cmp::Ordering) {}
fn main() {
takes_ordering(Ordering::V);
}
",
);
}
}

View file

@ -10,7 +10,7 @@ use ide_db::{
use syntax::Edition;
use syntax::ast::HasGenericArgs;
use syntax::{
AstNode, NodeOrToken, ast,
AstNode, ast,
ast::{HasArgList, make},
};
@ -38,7 +38,7 @@ use crate::{
// # pub mod std { pub mod collections { pub struct HashMap { } } }
// ```
pub(crate) fn qualify_path(acc: &mut Assists, ctx: &AssistContext<'_>) -> Option<()> {
let (import_assets, syntax_under_caret) = find_importable_node(ctx)?;
let (import_assets, syntax_under_caret, expected) = find_importable_node(ctx)?;
let cfg = ctx.config.import_path_config();
let mut proposed_imports: Vec<_> =
@ -47,57 +47,50 @@ pub(crate) fn qualify_path(acc: &mut Assists, ctx: &AssistContext<'_>) -> Option
return None;
}
let range = ctx.sema.original_range(&syntax_under_caret).range;
let current_module = ctx.sema.scope(&syntax_under_caret).map(|scope| scope.module());
let candidate = import_assets.import_candidate();
let qualify_candidate = match syntax_under_caret.clone() {
NodeOrToken::Node(syntax_under_caret) => match candidate {
ImportCandidate::Path(candidate) if !candidate.qualifier.is_empty() => {
cov_mark::hit!(qualify_path_qualifier_start);
let path = ast::Path::cast(syntax_under_caret)?;
let (prev_segment, segment) = (path.qualifier()?.segment()?, path.segment()?);
QualifyCandidate::QualifierStart(segment, prev_segment.generic_arg_list())
}
ImportCandidate::Path(_) => {
cov_mark::hit!(qualify_path_unqualified_name);
let path = ast::Path::cast(syntax_under_caret)?;
let generics = path.segment()?.generic_arg_list();
QualifyCandidate::UnqualifiedName(generics)
}
ImportCandidate::TraitAssocItem(_) => {
cov_mark::hit!(qualify_path_trait_assoc_item);
let path = ast::Path::cast(syntax_under_caret)?;
let (qualifier, segment) = (path.qualifier()?, path.segment()?);
QualifyCandidate::TraitAssocItem(qualifier, segment)
}
ImportCandidate::TraitMethod(_) => {
cov_mark::hit!(qualify_path_trait_method);
let mcall_expr = ast::MethodCallExpr::cast(syntax_under_caret)?;
QualifyCandidate::TraitMethod(ctx.sema.db, mcall_expr)
}
},
// derive attribute path
NodeOrToken::Token(_) => QualifyCandidate::UnqualifiedName(None),
let qualify_candidate = match candidate {
ImportCandidate::Path(candidate) if !candidate.qualifier.is_empty() => {
cov_mark::hit!(qualify_path_qualifier_start);
let path = ast::Path::cast(syntax_under_caret)?;
let (prev_segment, segment) = (path.qualifier()?.segment()?, path.segment()?);
QualifyCandidate::QualifierStart(segment, prev_segment.generic_arg_list())
}
ImportCandidate::Path(_) => {
cov_mark::hit!(qualify_path_unqualified_name);
let path = ast::Path::cast(syntax_under_caret)?;
let generics = path.segment()?.generic_arg_list();
QualifyCandidate::UnqualifiedName(generics)
}
ImportCandidate::TraitAssocItem(_) => {
cov_mark::hit!(qualify_path_trait_assoc_item);
let path = ast::Path::cast(syntax_under_caret)?;
let (qualifier, segment) = (path.qualifier()?, path.segment()?);
QualifyCandidate::TraitAssocItem(qualifier, segment)
}
ImportCandidate::TraitMethod(_) => {
cov_mark::hit!(qualify_path_trait_method);
let mcall_expr = ast::MethodCallExpr::cast(syntax_under_caret)?;
QualifyCandidate::TraitMethod(ctx.sema.db, mcall_expr)
}
};
// we aren't interested in different namespaces
proposed_imports.sort_by(|a, b| a.import_path.cmp(&b.import_path));
proposed_imports.dedup_by(|a, b| a.import_path == b.import_path);
let range = match &syntax_under_caret {
NodeOrToken::Node(node) => ctx.sema.original_range(node).range,
NodeOrToken::Token(token) => token.text_range(),
};
let current_module = ctx
.sema
.scope(&match syntax_under_caret {
NodeOrToken::Node(node) => node.clone(),
NodeOrToken::Token(t) => t.parent()?,
})
.map(|scope| scope.module());
let current_edition =
current_module.map(|it| it.krate().edition(ctx.db())).unwrap_or(Edition::CURRENT);
// prioritize more relevant imports
proposed_imports.sort_by_key(|import| {
Reverse(super::auto_import::relevance_score(ctx, import, current_module.as_ref()))
Reverse(super::auto_import::relevance_score(
ctx,
import,
expected.as_ref(),
current_module.as_ref(),
))
});
let group_label = group_label(candidate);
@ -353,7 +346,7 @@ pub mod PubMod3 {
}
"#,
r#"
PubMod3::PubStruct
PubMod1::PubStruct
pub mod PubMod1 {
pub struct PubStruct;

View file

@ -144,7 +144,7 @@ fn f() { let a = A { x: 1, y: true }; let b: i32 = a.x; }"#,
term_search,
r#"//- minicore: todo, unimplemented, option
fn f() { let a: i32 = 1; let b: Option<i32> = todo$0!(); }"#,
r#"fn f() { let a: i32 = 1; let b: Option<i32> = Some(a); }"#,
r#"fn f() { let a: i32 = 1; let b: Option<i32> = None; }"#,
)
}
@ -156,7 +156,7 @@ fn f() { let a: i32 = 1; let b: Option<i32> = todo$0!(); }"#,
enum Option<T> { None, Some(T) }
fn f() { let a: i32 = 1; let b: Option<i32> = todo$0!(); }"#,
r#"enum Option<T> { None, Some(T) }
fn f() { let a: i32 = 1; let b: Option<i32> = Option::Some(a); }"#,
fn f() { let a: i32 = 1; let b: Option<i32> = Option::None; }"#,
)
}
@ -168,7 +168,7 @@ fn f() { let a: i32 = 1; let b: Option<i32> = Option::Some(a); }"#,
enum Option<T> { None, Some(T) }
fn f() { let a: Option<i32> = Option::None; let b: Option<Option<i32>> = todo$0!(); }"#,
r#"enum Option<T> { None, Some(T) }
fn f() { let a: Option<i32> = Option::None; let b: Option<Option<i32>> = Option::Some(a); }"#,
fn f() { let a: Option<i32> = Option::None; let b: Option<Option<i32>> = Option::None; }"#,
)
}
@ -221,7 +221,7 @@ fn f() { let a: i32 = 1; let b: i32 = 2; let a: u32 = 0; let c: i32 = todo$0!();
term_search,
r#"//- minicore: todo, unimplemented
fn f() { let a: bool = todo$0!(); }"#,
r#"fn f() { let a: bool = false; }"#,
r#"fn f() { let a: bool = true; }"#,
)
}

View file

@ -297,7 +297,9 @@ fn check_with_config(
let assist = match assist_label {
Some(label) => res.into_iter().find(|resolved| resolved.label == label),
None => res.pop(),
None if res.is_empty() => None,
// Pick the first as that is the one with the highest priority
None => Some(res.swap_remove(0)),
};
match (assist, expected) {

View file

@ -7,7 +7,10 @@ use itertools::Either;
use syntax::{
AstNode, AstToken, Direction, NodeOrToken, SyntaxElement, SyntaxKind, SyntaxNode, SyntaxToken,
T, TextRange, TextSize,
algo::{self, ancestors_at_offset, find_node_at_offset, non_trivia_sibling},
algo::{
self, ancestors_at_offset, find_node_at_offset, non_trivia_sibling,
previous_non_trivia_token,
},
ast::{
self, AttrKind, HasArgList, HasGenericArgs, HasGenericParams, HasLoopBody, HasName,
NameOrNameRef,
@ -1813,22 +1816,6 @@ fn is_in_block(node: &SyntaxNode) -> bool {
.unwrap_or(false)
}
fn previous_non_trivia_token(e: impl Into<SyntaxElement>) -> Option<SyntaxToken> {
let mut token = match e.into() {
SyntaxElement::Node(n) => n.first_token()?,
SyntaxElement::Token(t) => t,
}
.prev_token();
while let Some(inner) = token {
if !inner.kind().is_trivia() {
return Some(inner);
} else {
token = inner.prev_token();
}
}
None
}
fn next_non_trivia_token(e: impl Into<SyntaxElement>) -> Option<SyntaxToken> {
let mut token = match e.into() {
SyntaxElement::Node(n) => n.last_token()?,

View file

@ -371,6 +371,17 @@ fn foo() {
"#,
expect![[r#"ty: Foo, name: ?"#]],
);
check_expected_type_and_name(
r#"
struct Foo { field: u32 }
fn foo() {
Foo {
..self::$0
}
}
"#,
expect!["ty: ?, name: ?"],
);
}
#[test]

View file

@ -3,6 +3,7 @@
use either::Either;
use hir::{InFile, Semantics, Type};
use parser::T;
use span::TextSize;
use syntax::{
AstNode, NodeOrToken, SyntaxToken,
ast::{self, AstChildren, HasArgList, HasAttrs, HasName},
@ -21,7 +22,24 @@ impl ActiveParameter {
/// Returns information about the call argument this token is part of.
pub fn at_token(sema: &Semantics<'_, RootDatabase>, token: SyntaxToken) -> Option<Self> {
let (signature, active_parameter) = callable_for_token(sema, token)?;
Self::from_signature_and_active_parameter(sema, signature, active_parameter)
}
/// Returns information about the call argument this token is part of.
pub fn at_arg(
sema: &Semantics<'_, RootDatabase>,
list: ast::ArgList,
at: TextSize,
) -> Option<Self> {
let (signature, active_parameter) = callable_for_arg_list(sema, list, at)?;
Self::from_signature_and_active_parameter(sema, signature, active_parameter)
}
fn from_signature_and_active_parameter(
sema: &Semantics<'_, RootDatabase>,
signature: hir::Callable,
active_parameter: Option<usize>,
) -> Option<Self> {
let idx = active_parameter?;
let mut params = signature.params();
if idx >= params.len() {
@ -49,20 +67,32 @@ pub fn callable_for_token(
sema: &Semantics<'_, RootDatabase>,
token: SyntaxToken,
) -> Option<(hir::Callable, Option<usize>)> {
let offset = token.text_range().start();
// Find the calling expression and its NameRef
let parent = token.parent()?;
let calling_node = parent.ancestors().filter_map(ast::CallableExpr::cast).find(|it| {
it.arg_list()
.is_some_and(|it| it.syntax().text_range().contains(token.text_range().start()))
})?;
let calling_node = parent
.ancestors()
.filter_map(ast::CallableExpr::cast)
.find(|it| it.arg_list().is_some_and(|it| it.syntax().text_range().contains(offset)))?;
callable_for_node(sema, &calling_node, &token)
callable_for_node(sema, &calling_node, offset)
}
/// Returns a [`hir::Callable`] this token is a part of and its argument index of said callable.
pub fn callable_for_arg_list(
sema: &Semantics<'_, RootDatabase>,
arg_list: ast::ArgList,
at: TextSize,
) -> Option<(hir::Callable, Option<usize>)> {
debug_assert!(arg_list.syntax().text_range().contains(at));
let callable = arg_list.syntax().parent().and_then(ast::CallableExpr::cast)?;
callable_for_node(sema, &callable, at)
}
pub fn callable_for_node(
sema: &Semantics<'_, RootDatabase>,
calling_node: &ast::CallableExpr,
token: &SyntaxToken,
offset: TextSize,
) -> Option<(hir::Callable, Option<usize>)> {
let callable = match calling_node {
ast::CallableExpr::Call(call) => sema.resolve_expr_as_callable(&call.expr()?),
@ -74,7 +104,7 @@ pub fn callable_for_node(
.children_with_tokens()
.filter_map(NodeOrToken::into_token)
.filter(|t| t.kind() == T![,])
.take_while(|t| t.text_range().start() <= token.text_range().start())
.take_while(|t| t.text_range().start() <= offset)
.count()
});
Some((callable, active_param))

View file

@ -9,7 +9,7 @@ use hir::{
};
use ide_db::{
FilePosition, FxIndexMap,
active_parameter::{callable_for_node, generic_def_for_node},
active_parameter::{callable_for_arg_list, generic_def_for_node},
documentation::{Documentation, HasDocs},
};
use span::Edition;
@ -17,7 +17,7 @@ use stdx::format_to;
use syntax::{
AstNode, Direction, NodeOrToken, SyntaxElementChildren, SyntaxNode, SyntaxToken, T, TextRange,
TextSize, ToSmolStr, algo,
ast::{self, AstChildren, HasArgList},
ast::{self, AstChildren},
match_ast,
};
@ -163,20 +163,8 @@ fn signature_help_for_call(
edition: Edition,
display_target: DisplayTarget,
) -> Option<SignatureHelp> {
// Find the calling expression and its NameRef
let mut nodes = arg_list.syntax().ancestors().skip(1);
let calling_node = loop {
if let Some(callable) = ast::CallableExpr::cast(nodes.next()?) {
let inside_callable = callable
.arg_list()
.is_some_and(|it| it.syntax().text_range().contains(token.text_range().start()));
if inside_callable {
break callable;
}
}
};
let (callable, active_parameter) = callable_for_node(sema, &calling_node, &token)?;
let (callable, active_parameter) =
callable_for_arg_list(sema, arg_list, token.text_range().start())?;
let mut res =
SignatureHelp { doc: None, signature: String::new(), parameters: vec![], active_parameter };

View file

@ -116,3 +116,19 @@ pub fn neighbor<T: AstNode>(me: &T, direction: Direction) -> Option<T> {
pub fn has_errors(node: &SyntaxNode) -> bool {
node.children().any(|it| it.kind() == SyntaxKind::ERROR)
}
pub fn previous_non_trivia_token(e: impl Into<SyntaxElement>) -> Option<SyntaxToken> {
let mut token = match e.into() {
SyntaxElement::Node(n) => n.first_token()?,
SyntaxElement::Token(t) => t,
}
.prev_token();
while let Some(inner) = token {
if !inner.kind().is_trivia() {
return Some(inner);
} else {
token = inner.prev_token();
}
}
None
}