diff --git a/crates/red_knot_python_semantic/src/types.rs b/crates/red_knot_python_semantic/src/types.rs index 607fa80ac1..e9d7dde222 100644 --- a/crates/red_knot_python_semantic/src/types.rs +++ b/crates/red_knot_python_semantic/src/types.rs @@ -762,12 +762,25 @@ impl<'db> CallOutcome<'db> { } => { let mut not_callable = vec![]; let mut union_builder = UnionBuilder::new(db); + let mut revealed = false; for outcome in &**outcomes { - let return_ty = if let Self::NotCallable { not_callable_ty } = outcome { - not_callable.push(*not_callable_ty); - Type::Unknown - } else { - outcome.unwrap_with_diagnostic(db, node, builder) + let return_ty = match outcome { + Self::NotCallable { not_callable_ty } => { + not_callable.push(*not_callable_ty); + Type::Unknown + } + Self::RevealType { + return_ty, + revealed_ty: _, + } => { + if revealed { + *return_ty + } else { + revealed = true; + outcome.unwrap_with_diagnostic(db, node, builder) + } + } + _ => outcome.unwrap_with_diagnostic(db, node, builder), }; union_builder = union_builder.add(return_ty); } @@ -841,6 +854,15 @@ impl<'db> FunctionType<'db> { }) } + /// Return true if this is a symbol with given name from `typing` or `typing_extensions`. + pub(crate) fn is_typing_symbol(self, db: &'db dyn Db, name: &str) -> bool { + name == self.name(db) + && file_to_module(db, self.definition(db).file(db)).is_some_and(|module| { + module.search_path().is_standard_library() + && matches!(&**module.name(), "typing" | "typing_extensions") + }) + } + pub fn has_decorator(self, db: &dyn Db, decorator: Type<'_>) -> bool { self.decorators(db).contains(&decorator) } diff --git a/crates/red_knot_python_semantic/src/types/infer.rs b/crates/red_knot_python_semantic/src/types/infer.rs index 84c2303d1a..4a83800d9b 100644 --- a/crates/red_knot_python_semantic/src/types/infer.rs +++ b/crates/red_knot_python_semantic/src/types/infer.rs @@ -705,7 +705,7 @@ impl<'db> TypeInferenceBuilder<'db> { } let function_type = FunctionType::new(self.db, name.id.clone(), definition, decorator_tys); - let function_ty = if function_type.is_stdlib_symbol(self.db, "typing", "reveal_type") { + let function_ty = if function_type.is_typing_symbol(self.db, "reveal_type") { Type::RevealTypeFunction(function_type) } else { Type::Function(function_type) @@ -2761,6 +2761,44 @@ mod tests { Ok(()) } + #[test] + fn reveal_type_aliased() -> anyhow::Result<()> { + let mut db = setup_db(); + + db.write_dedented( + "/src/a.py", + " + from typing import reveal_type as rt + + x = 1 + rt(x) + ", + )?; + + assert_file_diagnostics(&db, "/src/a.py", &["Revealed type is 'Literal[1]'."]); + + Ok(()) + } + + #[test] + fn reveal_type_typing_extensions() -> anyhow::Result<()> { + let mut db = setup_db(); + + db.write_dedented( + "/src/a.py", + " + import typing_extensions + + x = 1 + typing_extensions.reveal_type(x) + ", + )?; + + assert_file_diagnostics(&db, "/src/a.py", &["Revealed type is 'Literal[1]'."]); + + Ok(()) + } + #[test] fn follow_import_to_class() -> anyhow::Result<()> { let mut db = setup_db();