diff --git a/crates/ty_ide/src/importer.rs b/crates/ty_ide/src/importer.rs index bbc5281ddf..646f12e52a 100644 --- a/crates/ty_ide/src/importer.rs +++ b/crates/ty_ide/src/importer.rs @@ -145,8 +145,12 @@ impl<'a> Importer<'a> { let request = request.avoid_conflicts(self.db, self.file, members); let mut symbol_text: Box = request.member.into(); let Some(response) = self.find(&request, members.at) else { - let import = Insertion::start_of_file(self.parsed.suite(), self.source, self.stylist) - .into_edit(&request.to_string()); + let insertion = if let Some(future) = self.find_last_future_import() { + Insertion::end_of_statement(future.stmt, self.source, self.stylist) + } else { + Insertion::start_of_file(self.parsed.suite(), self.source, self.stylist) + }; + let import = insertion.into_edit(&request.to_string()); if matches!(request.style, ImportStyle::Import) { symbol_text = format!("{}.{}", request.module, request.member).into(); } @@ -241,6 +245,19 @@ impl<'a> Importer<'a> { } choice } + + /// Find the last `from __future__` import statement in the AST. + fn find_last_future_import(&self) -> Option<&'a AstImport> { + self.imports + .iter() + .take_while(|import| { + import + .stmt + .as_import_from_stmt() + .is_some_and(|import_from| import_from.module.as_deref() == Some("__future__")) + }) + .last() + } } /// A map of symbols in scope at a particular location in a module. @@ -1293,6 +1310,44 @@ def foo(): "); } + #[test] + fn existing_future_import() { + let test = cursor_test( + "\ +from __future__ import annotations + + + ", + ); + assert_snapshot!( + test.import("typing", "TypeVar"), @r" + from __future__ import annotations + import typing + + typing.TypeVar + "); + } + + #[test] + fn existing_future_import_after_docstring() { + let test = cursor_test( + r#" +"This is a module level docstring" +from __future__ import annotations + + + "#, + ); + assert_snapshot!( + test.import("typing", "TypeVar"), @r#" + "This is a module level docstring" + from __future__ import annotations + import typing + + typing.TypeVar + "#); + } + #[test] fn qualify_symbol_to_avoid_overwriting_other_symbol_in_scope() { let test = cursor_test(