diff --git a/crates/red_knot_python_semantic/src/types.rs b/crates/red_knot_python_semantic/src/types.rs index e9d7dde222..2d263310c9 100644 --- a/crates/red_knot_python_semantic/src/types.rs +++ b/crates/red_knot_python_semantic/src/types.rs @@ -51,9 +51,21 @@ fn symbol_ty_by_id<'db>(db: &'db dyn Db, scope: ScopeId<'db>, symbol: ScopedSymb // on inference from bindings. if use_def.has_public_declarations(symbol) { let declarations = use_def.public_declarations(symbol); + // If the symbol is undeclared in some paths, include the inferred type in the public type. + let undeclared_ty = if declarations.may_be_undeclared() { + Some(bindings_ty( + db, + use_def.public_bindings(symbol), + use_def + .public_may_be_unbound(symbol) + .then_some(Type::Unknown), + )) + } else { + None + }; // Intentionally ignore conflicting declared types; that's not our problem, it's the // problem of the module we are importing from. - declarations_ty(db, declarations).unwrap_or_else(|(ty, _)| ty) + declarations_ty(db, declarations, undeclared_ty).unwrap_or_else(|(ty, _)| ty) } else { bindings_ty( db, @@ -173,26 +185,21 @@ type DeclaredTypeResult<'db> = Result, (Type<'db>, Box<[Type<'db>]>)>; /// `Ok(declared_type)`. If there are conflicting declarations, returns /// `Err((union_of_declared_types, conflicting_declared_types))`. /// -/// If undeclared is a possibility, `Unknown` type will be part of the return type (and may +/// If undeclared is a possibility, `undeclared_ty` type will be part of the return type (and may /// conflict with other declarations.) /// /// # Panics -/// Will panic if there are no declarations and no possibility of undeclared. This is a logic -/// error, as any symbol with zero live declarations clearly must be undeclared. +/// Will panic if there are no declarations and no `undeclared_ty` is provided. This is a logic +/// error, as any symbol with zero live declarations clearly must be undeclared, and the caller +/// should provide an `undeclared_ty`. fn declarations_ty<'db>( db: &'db dyn Db, declarations: DeclarationsIterator<'_, 'db>, + undeclared_ty: Option>, ) -> DeclaredTypeResult<'db> { - let may_be_undeclared = declarations.may_be_undeclared(); let decl_types = declarations.map(|declaration| declaration_ty(db, declaration)); - let mut all_types = (if may_be_undeclared { - Some(Type::Unknown) - } else { - None - }) - .into_iter() - .chain(decl_types); + let mut all_types = undeclared_ty.into_iter().chain(decl_types); let first = all_types.next().expect( "declarations_ty must not be called with zero declarations and no may-be-undeclared.", diff --git a/crates/red_knot_python_semantic/src/types/infer.rs b/crates/red_knot_python_semantic/src/types/infer.rs index 4a83800d9b..5cf8be35ef 100644 --- a/crates/red_knot_python_semantic/src/types/infer.rs +++ b/crates/red_knot_python_semantic/src/types/infer.rs @@ -506,9 +506,14 @@ impl<'db> TypeInferenceBuilder<'db> { debug_assert!(binding.is_binding(self.db)); let use_def = self.index.use_def_map(binding.file_scope(self.db)); let declarations = use_def.declarations_at_binding(binding); + let undeclared_ty = if declarations.may_be_undeclared() { + Some(Type::Unknown) + } else { + None + }; let mut bound_ty = ty; - let declared_ty = - declarations_ty(self.db, declarations).unwrap_or_else(|(ty, conflicting)| { + let declared_ty = declarations_ty(self.db, declarations, undeclared_ty).unwrap_or_else( + |(ty, conflicting)| { // TODO point out the conflicting declarations in the diagnostic? let symbol_table = self.index.symbol_table(binding.file_scope(self.db)); let symbol_name = symbol_table.symbol(binding.symbol(self.db)).name(); @@ -521,7 +526,8 @@ impl<'db> TypeInferenceBuilder<'db> { ), ); ty - }); + }, + ); if !bound_ty.is_assignable_to(self.db, declared_ty) { self.invalid_assignment_diagnostic(node, declared_ty, bound_ty); // allow declarations to override inference in case of invalid assignment @@ -5777,6 +5783,27 @@ mod tests { assert_public_ty(&db, "/src/a.py", "f", "Literal[f, f]"); } + #[test] + fn import_from_conditional_reimport_vs_non_declaration() { + let mut db = setup_db(); + + db.write_file("/src/a.py", "from b import x").unwrap(); + db.write_dedented( + "/src/b.py", + " + if flag: + from c import x + else: + x = 1 + ", + ) + .unwrap(); + db.write_file("/src/c.pyi", "x: int").unwrap(); + + // TODO this should simplify to just 'int' + assert_public_ty(&db, "/src/a.py", "x", "int | Literal[1]"); + } + // Incremental inference tests fn first_public_binding<'db>(db: &'db TestDb, file: File, name: &str) -> Definition<'db> {