diff --git a/crates/ty_ide/src/completion.rs b/crates/ty_ide/src/completion.rs index 2470308c41..dc7072dfd7 100644 --- a/crates/ty_ide/src/completion.rs +++ b/crates/ty_ide/src/completion.rs @@ -23,7 +23,18 @@ pub fn completion(db: &dyn Db, file: File, offset: TextSize) -> Vec model.attribute_completions(expr), - CompletionTargetAst::ImportFrom { import, name } => model.import_completions(import, name), + CompletionTargetAst::ObjectDotInImport { import, name } => { + model.import_submodule_completions(import, name) + } + CompletionTargetAst::ObjectDotInImportFrom { import } => { + model.from_import_submodule_completions(import) + } + CompletionTargetAst::ImportFrom { import, name } => { + model.from_import_completions(import, name) + } + CompletionTargetAst::Import { .. } | CompletionTargetAst::ImportViaFrom { .. } => { + model.import_completions() + } CompletionTargetAst::Scoped { node } => model.scoped_completions(node), }; completions.sort_by(compare_suggestions); @@ -50,11 +61,11 @@ enum CompletionTargetTokens<'t> { object: &'t Token, /// The token, if non-empty, following the dot. /// - /// This is currently unused, but we should use this - /// eventually to remove completions that aren't a - /// prefix of what has already been typed. (We are - /// currently relying on the LSP client to do this.) - #[expect(dead_code)] + /// For right now, this is only used to determine which + /// module in an `import` statement to return submodule + /// completions for. But we could use it for other things, + /// like only returning completions that start with a prefix + /// corresponding to this token. attribute: Option<&'t Token>, }, /// A `from module import attribute` token form was found, where @@ -63,6 +74,20 @@ enum CompletionTargetTokens<'t> { /// The module being imported from. module: &'t Token, }, + /// A `import module` token form was found, where `module` may be + /// empty. + Import { + /// The token corresponding to the `import` keyword. + import: &'t Token, + /// The token closest to the cursor. + /// + /// This is currently unused, but we should use this + /// eventually to remove completions that aren't a + /// prefix of what has already been typed. (We are + /// currently relying on the LSP client to do this.) + #[expect(dead_code)] + module: &'t Token, + }, /// A token was found under the cursor, but it didn't /// match any of our anticipated token patterns. Generic { token: &'t Token }, @@ -105,6 +130,8 @@ impl<'t> CompletionTargetTokens<'t> { } } else if let Some(module) = import_from_tokens(before) { CompletionTargetTokens::ImportFrom { module } + } else if let Some((import, module)) = import_tokens(before) { + CompletionTargetTokens::Import { import, module } } else if let Some([_]) = token_suffix_by_kinds(before, [TokenKind::Float]) { // If we're writing a `float`, then we should // specifically not offer completions. This wouldn't @@ -140,19 +167,47 @@ impl<'t> CompletionTargetTokens<'t> { offset: TextSize, ) -> Option> { match *self { - CompletionTargetTokens::PossibleObjectDot { object, .. } => { + CompletionTargetTokens::PossibleObjectDot { object, attribute } => { let covering_node = covering_node(parsed.syntax().into(), object.range()) - // We require that the end of the node range not - // exceed the cursor offset. This avoids selecting - // a node "too high" in the AST in cases where - // completions are requested in the middle of an - // expression. e.g., `foo..bar`. - .find_last(|node| node.is_expr_attribute() && node.range().end() <= offset) + .find_last(|node| { + // We require that the end of the node range not + // exceed the cursor offset. This avoids selecting + // a node "too high" in the AST in cases where + // completions are requested in the middle of an + // expression. e.g., `foo..bar`. + if node.is_expr_attribute() { + return node.range().end() <= offset; + } + // For import statements though, they can't be + // nested, so we don't care as much about the + // cursor being strictly after the statement. + // And indeed, sometimes it won't be! e.g., + // + // import re, os.p, zlib + // + // So just return once we find an import. + node.is_stmt_import() || node.is_stmt_import_from() + }) .ok()?; match covering_node.node() { ast::AnyNodeRef::ExprAttribute(expr) => { Some(CompletionTargetAst::ObjectDot { expr }) } + ast::AnyNodeRef::StmtImport(import) => { + let range = attribute + .map(Ranged::range) + .unwrap_or_else(|| object.range()); + // Find the name that overlaps with the + // token we identified for the attribute. + let name = import + .names + .iter() + .position(|alias| alias.range().contains_range(range))?; + Some(CompletionTargetAst::ObjectDotInImport { import, name }) + } + ast::AnyNodeRef::StmtImportFrom(import) => { + Some(CompletionTargetAst::ObjectDotInImportFrom { import }) + } _ => None, } } @@ -165,6 +220,20 @@ impl<'t> CompletionTargetTokens<'t> { }; Some(CompletionTargetAst::ImportFrom { import, name: None }) } + CompletionTargetTokens::Import { import, .. } => { + let covering_node = covering_node(parsed.syntax().into(), import.range()) + .find_first(|node| node.is_stmt_import() || node.is_stmt_import_from()) + .ok()?; + match covering_node.node() { + ast::AnyNodeRef::StmtImport(import) => { + Some(CompletionTargetAst::Import { import, name: None }) + } + ast::AnyNodeRef::StmtImportFrom(import) => { + Some(CompletionTargetAst::ImportViaFrom { import }) + } + _ => None, + } + } CompletionTargetTokens::Generic { token } => { let covering_node = covering_node(parsed.syntax().into(), token.range()); Some(CompletionTargetAst::Scoped { @@ -188,6 +257,18 @@ enum CompletionTargetAst<'t> { /// A `object.attribute` scenario, where we want to /// list attributes on `object` for completions. ObjectDot { expr: &'t ast::ExprAttribute }, + /// A `import module.submodule` scenario, where we only want to + /// list submodules for completions. + ObjectDotInImport { + /// The import statement. + import: &'t ast::StmtImport, + /// An index into `import.names`. The index is guaranteed to be + /// valid. + name: usize, + }, + /// A `from module.submodule` scenario, where we only want to list + /// submodules for completions. + ObjectDotInImportFrom { import: &'t ast::StmtImportFrom }, /// A `from module import attribute` scenario, where we want to /// list attributes on `module` for completions. ImportFrom { @@ -197,6 +278,24 @@ enum CompletionTargetAst<'t> { /// set, the index is guaranteed to be valid. name: Option, }, + /// A `import module` scenario, where we want to + /// list available modules for completions. + Import { + /// The import statement. + #[expect(dead_code)] + import: &'t ast::StmtImport, + /// An index into `import.names` if relevant. When this is + /// set, the index is guaranteed to be valid. + #[expect(dead_code)] + name: Option, + }, + /// A `from module` scenario, where we want to + /// list available modules for completions. + ImportViaFrom { + /// The import statement. + #[expect(dead_code)] + import: &'t ast::StmtImportFrom, + }, /// A scoped scenario, where we want to list all items available in /// the most narrow scope containing the giving AST node. Scoped { node: ast::AnyNodeRef<'t> }, @@ -317,6 +416,52 @@ fn import_from_tokens(tokens: &[Token]) -> Option<&Token> { None } +/// Looks for the start of a `import ` statement. +/// +/// This also handles cases like `import foo, c, bar`. +/// +/// If found, a token corresponding to the `import` or `from` keyword +/// and the the closest point of the `` is returned. +/// +/// It is assumed that callers will call `from_import_tokens` first to +/// try and recognize a `from ... import ...` statement before using +/// this. +fn import_tokens(tokens: &[Token]) -> Option<(&Token, &Token)> { + use TokenKind as TK; + + /// A look-back limit, in order to bound work. + /// + /// See `LIMIT` in `import_from_tokens` for more context. + const LIMIT: usize = 1_000; + + /// A state used to "parse" the tokens preceding the user's cursor, + /// in reverse, to detect a `import` statement. + enum S { + Start, + Names, + } + + let mut state = S::Start; + let module_token = tokens.last()?; + // Move backward through the tokens until we get to + // the `import` token. + for token in tokens.iter().rev().take(LIMIT) { + state = match (state, token.kind()) { + // It's okay to pop off a newline token here initially, + // since it may occur when the name being imported is + // empty. + (S::Start, TK::Newline) => S::Names, + // Munch through tokens that can make up an alias. + (S::Start | S::Names, TK::Name | TK::Comma | TK::As | TK::Unknown) => S::Names, + (S::Start | S::Names, TK::Import | TK::From) => { + return Some((token, module_token)); + } + _ => return None, + }; + } + None +} + /// Order completions lexicographically, with these exceptions: /// /// 1) A `_[^_]` prefix sorts last and @@ -2709,6 +2854,143 @@ importlib. test.assert_completions_include("resources"); } + #[test] + fn import_with_leading_character() { + let test = cursor_test( + "\ +import c +", + ); + test.assert_completions_include("collections"); + } + + #[test] + fn import_without_leading_character() { + let test = cursor_test( + "\ +import +", + ); + test.assert_completions_include("collections"); + } + + #[test] + fn import_multiple() { + let test = cursor_test( + "\ +import re, c, sys +", + ); + test.assert_completions_include("collections"); + } + + #[test] + fn import_with_aliases() { + let test = cursor_test( + "\ +import re as regexp, c, sys as system +", + ); + test.assert_completions_include("collections"); + } + + #[test] + fn import_over_multiple_lines() { + let test = cursor_test( + "\ +import re as regexp, \\ + c, \\ + sys as system +", + ); + test.assert_completions_include("collections"); + } + + #[test] + fn import_unknown_in_module() { + let test = cursor_test( + "\ +import ?, +", + ); + test.assert_completions_include("collections"); + } + + #[test] + fn import_via_from_with_leading_character() { + let test = cursor_test( + "\ +from c +", + ); + test.assert_completions_include("collections"); + } + + #[test] + fn import_via_from_without_leading_character() { + let test = cursor_test( + "\ +from +", + ); + test.assert_completions_include("collections"); + } + + #[test] + fn import_statement_with_submodule_with_leading_character() { + let test = cursor_test( + "\ +import os.p +", + ); + test.assert_completions_include("path"); + test.assert_completions_do_not_include("abspath"); + } + + #[test] + fn import_statement_with_submodule_multiple() { + let test = cursor_test( + "\ +import re, os.p, zlib +", + ); + test.assert_completions_include("path"); + test.assert_completions_do_not_include("abspath"); + } + + #[test] + fn import_statement_with_submodule_without_leading_character() { + let test = cursor_test( + "\ +import os. +", + ); + test.assert_completions_include("path"); + test.assert_completions_do_not_include("abspath"); + } + + #[test] + fn import_via_from_with_submodule_with_leading_character() { + let test = cursor_test( + "\ +from os.p +", + ); + test.assert_completions_include("path"); + test.assert_completions_do_not_include("abspath"); + } + + #[test] + fn import_via_from_with_submodule_without_leading_character() { + let test = cursor_test( + "\ +from os. +", + ); + test.assert_completions_include("path"); + test.assert_completions_do_not_include("abspath"); + } + #[test] fn regression_test_issue_642() { // Regression test for https://github.com/astral-sh/ty/issues/642 diff --git a/crates/ty_python_semantic/src/semantic_model.rs b/crates/ty_python_semantic/src/semantic_model.rs index f038111c5f..47eacbb3a5 100644 --- a/crates/ty_python_semantic/src/semantic_model.rs +++ b/crates/ty_python_semantic/src/semantic_model.rs @@ -6,7 +6,7 @@ use ruff_source_file::LineIndex; use crate::Db; use crate::module_name::ModuleName; -use crate::module_resolver::{KnownModule, Module, resolve_module}; +use crate::module_resolver::{KnownModule, Module, list_modules, resolve_module}; use crate::semantic_index::definition::Definition; use crate::semantic_index::scope::FileScopeId; use crate::semantic_index::semantic_index; @@ -41,8 +41,24 @@ impl<'db> SemanticModel<'db> { resolve_module(self.db, module_name) } + /// Returns completions for symbols available in a `import ` context. + pub fn import_completions(&self) -> Vec> { + list_modules(self.db) + .into_iter() + .map(|module| { + let builtin = module.is_known(self.db, KnownModule::Builtins); + let ty = Type::module_literal(self.db, self.file, module); + Completion { + name: Name::new(module.name(self.db).as_str()), + ty, + builtin, + } + }) + .collect() + } + /// Returns completions for symbols available in a `from module import ` context. - pub fn import_completions( + pub fn from_import_completions( &self, import: &ast::StmtImportFrom, _name: Option, @@ -61,6 +77,79 @@ impl<'db> SemanticModel<'db> { self.module_completions(&module_name) } + /// Returns completions only for submodules for the module + /// identified by `name` in `import`. + /// + /// For example, `import re, os., zlib`. + pub fn import_submodule_completions( + &self, + import: &ast::StmtImport, + name: usize, + ) -> Vec> { + let module_ident = &import.names[name].name; + let Some((parent_ident, _)) = module_ident.rsplit_once('.') else { + return vec![]; + }; + let module_name = + match ModuleName::from_identifier_parts(self.db, self.file, Some(parent_ident), 0) { + Ok(module_name) => module_name, + Err(err) => { + tracing::debug!( + "Could not extract module name from `{module:?}`: {err:?}", + module = module_ident, + ); + return vec![]; + } + }; + self.import_submodule_completions_for_name(&module_name) + } + + /// Returns completions only for submodules for the module + /// used in a `from module import attribute` statement. + /// + /// For example, `from os.`. + pub fn from_import_submodule_completions( + &self, + import: &ast::StmtImportFrom, + ) -> Vec> { + let level = import.level; + let Some(module_ident) = import.module.as_deref() else { + return vec![]; + }; + let Some((parent_ident, _)) = module_ident.rsplit_once('.') else { + return vec![]; + }; + let module_name = match ModuleName::from_identifier_parts( + self.db, + self.file, + Some(parent_ident), + level, + ) { + Ok(module_name) => module_name, + Err(err) => { + tracing::debug!( + "Could not extract module name from `{module:?}` with level {level}: {err:?}", + module = import.module, + level = import.level, + ); + return vec![]; + } + }; + self.import_submodule_completions_for_name(&module_name) + } + + /// Returns submodule-only completions for the given module. + fn import_submodule_completions_for_name( + &self, + module_name: &ModuleName, + ) -> Vec> { + let Some(module) = resolve_module(self.db, module_name) else { + tracing::debug!("Could not resolve module from `{module_name:?}`"); + return vec![]; + }; + self.submodule_completions(&module) + } + /// Returns completions for symbols available in the given module as if /// it were imported by this model's `File`. fn module_completions(&self, module_name: &ModuleName) -> Vec> { @@ -75,11 +164,20 @@ impl<'db> SemanticModel<'db> { for crate::types::Member { name, ty } in crate::types::all_members(self.db, ty) { completions.push(Completion { name, ty, builtin }); } + completions.extend(self.submodule_completions(&module)); + completions + } + + /// Returns completions for submodules of the given module. + fn submodule_completions(&self, module: &Module<'db>) -> Vec> { + let builtin = module.is_known(self.db, KnownModule::Builtins); + + let mut completions = vec![]; for submodule_basename in module.all_submodules(self.db) { let Some(basename) = ModuleName::new(submodule_basename.as_str()) else { continue; }; - let mut submodule_name = module_name.clone(); + let mut submodule_name = module.name(self.db).clone(); submodule_name.extend(&basename); let Some(submodule) = resolve_module(self.db, &submodule_name) else {