[red-knot] support for typing.reveal_type (#13384)

Add support for the `typing.reveal_type` function, emitting a diagnostic
revealing the type of its single argument. This is a necessary piece for
the planned testing framework.

This puts the cart slightly in front of the horse, in that we don't yet
have proper support for validating call signatures / argument types. But
it's easy to do just enough to make `reveal_type` work.

This PR includes support for calling union types (this is necessary
because we don't yet support `sys.version_info` checks, so
`typing.reveal_type` itself is a union type), plus some nice
consolidated error messages for calls to unions where some elements are
not callable. This is mostly to demonstrate the flexibility in
diagnostics that we get from the `CallOutcome` enum.
This commit is contained in:
Carl Meyer 2024-09-18 09:59:51 -07:00 committed by GitHub
parent 44d916fb4e
commit c173ec5bc7
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 384 additions and 51 deletions

View file

@ -238,6 +238,8 @@ pub enum Type<'db> {
None,
/// a specific function object
Function(FunctionType<'db>),
/// The `typing.reveal_type` function, which has special `__call__` behavior.
RevealTypeFunction(FunctionType<'db>),
/// a specific module object
Module(File),
/// a specific class object
@ -324,14 +326,16 @@ impl<'db> Type<'db> {
pub const fn into_function_type(self) -> Option<FunctionType<'db>> {
match self {
Type::Function(function_type) => Some(function_type),
Type::Function(function_type) | Type::RevealTypeFunction(function_type) => {
Some(function_type)
}
_ => None,
}
}
pub fn expect_function(self) -> FunctionType<'db> {
self.into_function_type()
.expect("Expected a Type::Function variant")
.expect("Expected a variant wrapping a FunctionType")
}
pub const fn into_int_literal_type(self) -> Option<i64> {
@ -367,6 +371,16 @@ impl<'db> Type<'db> {
}
}
pub fn is_stdlib_symbol(&self, db: &'db dyn Db, module_name: &str, name: &str) -> bool {
match self {
Type::Class(class) => class.is_stdlib_symbol(db, module_name, name),
Type::Function(function) | Type::RevealTypeFunction(function) => {
function.is_stdlib_symbol(db, module_name, name)
}
_ => false,
}
}
/// Return true if this type is [assignable to] type `target`.
///
/// [assignable to]: https://typing.readthedocs.io/en/latest/spec/concepts.html#the-assignable-to-or-consistent-subtyping-relation
@ -436,7 +450,7 @@ impl<'db> Type<'db> {
// TODO: attribute lookup on None type
Type::Unknown
}
Type::Function(_) => {
Type::Function(_) | Type::RevealTypeFunction(_) => {
// TODO: attribute lookup on function type
Type::Unknown
}
@ -482,26 +496,39 @@ impl<'db> Type<'db> {
///
/// Returns `None` if `self` is not a callable type.
#[must_use]
pub fn call(&self, db: &'db dyn Db) -> Option<Type<'db>> {
fn call(self, db: &'db dyn Db, arg_types: &[Type<'db>]) -> CallOutcome<'db> {
match self {
Type::Function(function_type) => Some(function_type.return_type(db)),
// TODO validate typed call arguments vs callable signature
Type::Function(function_type) => CallOutcome::callable(function_type.return_type(db)),
Type::RevealTypeFunction(function_type) => CallOutcome::revealed(
function_type.return_type(db),
*arg_types.first().unwrap_or(&Type::Unknown),
),
// TODO annotated return type on `__new__` or metaclass `__call__`
Type::Class(class) => Some(Type::Instance(*class)),
Type::Class(class) => CallOutcome::callable(Type::Instance(class)),
// TODO: handle classes which implement the Callable protocol
Type::Instance(_instance_ty) => Some(Type::Unknown),
// TODO: handle classes which implement the `__call__` protocol
Type::Instance(_instance_ty) => CallOutcome::callable(Type::Unknown),
// `Any` is callable, and its return type is also `Any`.
Type::Any => Some(Type::Any),
Type::Any => CallOutcome::callable(Type::Any),
Type::Unknown => Some(Type::Unknown),
Type::Unknown => CallOutcome::callable(Type::Unknown),
// TODO: union and intersection types, if they reduce to `Callable`
Type::Union(_) => Some(Type::Unknown),
Type::Intersection(_) => Some(Type::Unknown),
Type::Union(union) => CallOutcome::union(
self,
union
.elements(db)
.iter()
.map(|elem| elem.call(db, arg_types))
.collect::<Box<[CallOutcome<'db>]>>(),
),
_ => None,
// TODO: intersection types
Type::Intersection(_) => CallOutcome::callable(Type::Unknown),
_ => CallOutcome::not_callable(self),
}
}
@ -513,7 +540,7 @@ impl<'db> Type<'db> {
/// for y in x:
/// pass
/// ```
fn iterate(&self, db: &'db dyn Db) -> IterationOutcome<'db> {
fn iterate(self, db: &'db dyn Db) -> IterationOutcome<'db> {
if let Type::Tuple(tuple_type) = self {
return IterationOutcome::Iterable {
element_ty: UnionType::from_elements(db, &**tuple_type.elements(db)),
@ -526,18 +553,22 @@ impl<'db> Type<'db> {
let dunder_iter_method = iterable_meta_type.member(db, "__iter__");
if !dunder_iter_method.is_unbound() {
let Some(iterator_ty) = dunder_iter_method.call(db) else {
let CallOutcome::Callable {
return_ty: iterator_ty,
} = dunder_iter_method.call(db, &[])
else {
return IterationOutcome::NotIterable {
not_iterable_ty: *self,
not_iterable_ty: self,
};
};
let dunder_next_method = iterator_ty.to_meta_type(db).member(db, "__next__");
return dunder_next_method
.call(db)
.call(db, &[])
.return_ty(db)
.map(|element_ty| IterationOutcome::Iterable { element_ty })
.unwrap_or(IterationOutcome::NotIterable {
not_iterable_ty: *self,
not_iterable_ty: self,
});
}
@ -550,10 +581,11 @@ impl<'db> Type<'db> {
let dunder_get_item_method = iterable_meta_type.member(db, "__getitem__");
dunder_get_item_method
.call(db)
.call(db, &[])
.return_ty(db)
.map(|element_ty| IterationOutcome::Iterable { element_ty })
.unwrap_or(IterationOutcome::NotIterable {
not_iterable_ty: *self,
not_iterable_ty: self,
})
}
@ -573,6 +605,7 @@ impl<'db> Type<'db> {
Type::BooleanLiteral(_)
| Type::BytesLiteral(_)
| Type::Function(_)
| Type::RevealTypeFunction(_)
| Type::Instance(_)
| Type::Module(_)
| Type::IntLiteral(_)
@ -595,7 +628,7 @@ impl<'db> Type<'db> {
Type::BooleanLiteral(_) => builtins_symbol_ty(db, "bool"),
Type::BytesLiteral(_) => builtins_symbol_ty(db, "bytes"),
Type::IntLiteral(_) => builtins_symbol_ty(db, "int"),
Type::Function(_) => types_symbol_ty(db, "FunctionType"),
Type::Function(_) | Type::RevealTypeFunction(_) => types_symbol_ty(db, "FunctionType"),
Type::Module(_) => types_symbol_ty(db, "ModuleType"),
Type::None => typeshed_symbol_ty(db, "NoneType"),
// TODO not accurate if there's a custom metaclass...
@ -619,6 +652,152 @@ impl<'db> From<&Type<'db>> for Type<'db> {
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
enum CallOutcome<'db> {
Callable {
return_ty: Type<'db>,
},
RevealType {
return_ty: Type<'db>,
revealed_ty: Type<'db>,
},
NotCallable {
not_callable_ty: Type<'db>,
},
Union {
called_ty: Type<'db>,
outcomes: Box<[CallOutcome<'db>]>,
},
}
impl<'db> CallOutcome<'db> {
/// Create a new `CallOutcome::Callable` with given return type.
fn callable(return_ty: Type<'db>) -> CallOutcome {
CallOutcome::Callable { return_ty }
}
/// Create a new `CallOutcome::NotCallable` with given not-callable type.
fn not_callable(not_callable_ty: Type<'db>) -> CallOutcome {
CallOutcome::NotCallable { not_callable_ty }
}
/// Create a new `CallOutcome::RevealType` with given revealed and return types.
fn revealed(return_ty: Type<'db>, revealed_ty: Type<'db>) -> CallOutcome<'db> {
CallOutcome::RevealType {
return_ty,
revealed_ty,
}
}
/// Create a new `CallOutcome::Union` with given wrapped outcomes.
fn union(called_ty: Type<'db>, outcomes: impl Into<Box<[CallOutcome<'db>]>>) -> CallOutcome {
CallOutcome::Union {
called_ty,
outcomes: outcomes.into(),
}
}
/// Get the return type of the call, or `None` if not callable.
fn return_ty(&self, db: &'db dyn Db) -> Option<Type<'db>> {
match self {
Self::Callable { return_ty } => Some(*return_ty),
Self::RevealType {
return_ty,
revealed_ty: _,
} => Some(*return_ty),
Self::NotCallable { not_callable_ty: _ } => None,
Self::Union {
outcomes,
called_ty: _,
} => outcomes
.iter()
// If all outcomes are NotCallable, we return None; if some outcomes are callable
// and some are not, we return a union including Unknown.
.fold(None, |acc, outcome| {
let ty = outcome.return_ty(db);
match (acc, ty) {
(None, None) => None,
(None, Some(ty)) => Some(UnionBuilder::new(db).add(ty)),
(Some(builder), ty) => Some(builder.add(ty.unwrap_or(Type::Unknown))),
}
})
.map(UnionBuilder::build),
}
}
/// Get the return type of the call, emitting diagnostics if needed.
fn unwrap_with_diagnostic<'a>(
&self,
db: &'db dyn Db,
node: ast::AnyNodeRef,
builder: &'a mut TypeInferenceBuilder<'db>,
) -> Type<'db> {
match self {
Self::Callable { return_ty } => *return_ty,
Self::RevealType {
return_ty,
revealed_ty,
} => {
builder.add_diagnostic(
node,
"revealed-type",
format_args!("Revealed type is '{}'.", revealed_ty.display(db)),
);
*return_ty
}
Self::NotCallable { not_callable_ty } => {
builder.add_diagnostic(
node,
"call-non-callable",
format_args!(
"Object of type '{}' is not callable.",
not_callable_ty.display(db)
),
);
Type::Unknown
}
Self::Union {
outcomes,
called_ty,
} => {
let mut not_callable = vec![];
let mut union_builder = UnionBuilder::new(db);
for outcome in &**outcomes {
let return_ty = if let Self::NotCallable { not_callable_ty } = outcome {
not_callable.push(*not_callable_ty);
Type::Unknown
} else {
outcome.unwrap_with_diagnostic(db, node, builder)
};
union_builder = union_builder.add(return_ty);
}
match not_callable[..] {
[] => {}
[elem] => builder.add_diagnostic(
node,
"call-non-callable",
format_args!(
"Union element '{}' of type '{}' is not callable.",
elem.display(db),
called_ty.display(db)
),
),
_ => builder.add_diagnostic(
node,
"call-non-callable",
format_args!(
"Union elements {} of type '{}' are not callable.",
not_callable.display(db),
called_ty.display(db)
),
),
}
union_builder.build()
}
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
enum IterationOutcome<'db> {
Iterable { element_ty: Type<'db> },
@ -654,6 +833,14 @@ pub struct FunctionType<'db> {
}
impl<'db> FunctionType<'db> {
/// Return true if this is a standard library function with given module name and name.
pub(crate) fn is_stdlib_symbol(self, db: &'db dyn Db, module_name: &str, name: &str) -> bool {
name == self.name(db)
&& file_to_module(db, self.definition(db).file(db)).is_some_and(|module| {
module.search_path().is_standard_library() && module.name() == module_name
})
}
pub fn has_decorator(self, db: &dyn Db, decorator: Type<'_>) -> bool {
self.decorators(db).contains(&decorator)
}

View file

@ -36,6 +36,7 @@ impl Display for DisplayType<'_> {
| Type::BytesLiteral(_)
| Type::Class(_)
| Type::Function(_)
| Type::RevealTypeFunction(_)
) {
write!(f, "Literal[{representation}]",)
} else {
@ -72,7 +73,9 @@ impl Display for DisplayRepresentation<'_> {
// TODO functions and classes should display using a fully qualified name
Type::Class(class) => f.write_str(class.name(self.db)),
Type::Instance(class) => f.write_str(class.name(self.db)),
Type::Function(function) => f.write_str(function.name(self.db)),
Type::Function(function) | Type::RevealTypeFunction(function) => {
f.write_str(function.name(self.db))
}
Type::Union(union) => union.display(self.db).fmt(f),
Type::Intersection(intersection) => intersection.display(self.db).fmt(f),
Type::IntLiteral(n) => n.fmt(f),
@ -191,7 +194,7 @@ impl TryFrom<Type<'_>> for LiteralTypeKind {
fn try_from(value: Type<'_>) -> Result<Self, Self::Error> {
match value {
Type::Class(_) => Ok(Self::Class),
Type::Function(_) => Ok(Self::Function),
Type::Function(_) | Type::RevealTypeFunction(_) => Ok(Self::Function),
Type::IntLiteral(_) => Ok(Self::IntLiteral),
Type::StringLiteral(_) => Ok(Self::StringLiteral),
Type::BytesLiteral(_) => Ok(Self::BytesLiteral),

View file

@ -704,12 +704,12 @@ impl<'db> TypeInferenceBuilder<'db> {
}
}
let function_ty = Type::Function(FunctionType::new(
self.db,
name.id.clone(),
definition,
decorator_tys,
));
let function_type = FunctionType::new(self.db, name.id.clone(), definition, decorator_tys);
let function_ty = if function_type.is_stdlib_symbol(self.db, "typing", "reveal_type") {
Type::RevealTypeFunction(function_type)
} else {
Type::Function(function_type)
};
self.add_declaration_with_binding(function.into(), definition, function_ty, function_ty);
}
@ -1241,7 +1241,7 @@ impl<'db> TypeInferenceBuilder<'db> {
node,
"not-iterable",
format_args!(
"Object of type '{}' is not iterable",
"Object of type '{}' is not iterable.",
not_iterable_ty.display(self.db)
),
);
@ -2023,19 +2023,12 @@ impl<'db> TypeInferenceBuilder<'db> {
arguments,
} = call_expression;
self.infer_arguments(arguments);
// TODO: proper typed call signature, representing keyword args etc
let arg_types = self.infer_arguments(arguments);
let function_type = self.infer_expression(func);
function_type.call(self.db).unwrap_or_else(|| {
self.add_diagnostic(
func.as_ref().into(),
"call-non-callable",
format_args!(
"Object of type '{}' is not callable",
function_type.display(self.db)
),
);
Type::Unknown
})
function_type
.call(self.db, arg_types.as_slice())
.unwrap_with_diagnostic(self.db, func.as_ref().into(), self)
}
fn infer_starred_expression(&mut self, starred: &ast::ExprStarred) -> Type<'db> {
@ -2410,7 +2403,12 @@ impl<'db> TypeInferenceBuilder<'db> {
/// Adds a new diagnostic.
///
/// The diagnostic does not get added if the rule isn't enabled for this file.
fn add_diagnostic(&mut self, node: AnyNodeRef, rule: &str, message: std::fmt::Arguments) {
pub(super) fn add_diagnostic(
&mut self,
node: AnyNodeRef,
rule: &str,
message: std::fmt::Arguments,
) {
if !self.db.is_file_open(self.file) {
return;
}
@ -2746,6 +2744,25 @@ mod tests {
assert_diagnostic_messages(&diagnostics, expected);
}
#[test]
fn reveal_type() -> anyhow::Result<()> {
let mut db = setup_db();
db.write_dedented(
"/src/a.py",
"
from typing import reveal_type
x = 1
reveal_type(x)
",
)?;
assert_file_diagnostics(&db, "/src/a.py", &["Revealed type is 'Literal[1]'."]);
Ok(())
}
#[test]
fn follow_import_to_class() -> anyhow::Result<()> {
let mut db = setup_db();
@ -3333,6 +3350,104 @@ mod tests {
Ok(())
}
#[test]
fn call_union() -> anyhow::Result<()> {
let mut db = setup_db();
db.write_dedented(
"src/a.py",
"
if flag:
def f() -> int:
return 1
else:
def f() -> str:
return 'foo'
x = f()
",
)?;
assert_public_ty(&db, "src/a.py", "x", "int | str");
Ok(())
}
#[test]
fn call_union_with_unknown() -> anyhow::Result<()> {
let mut db = setup_db();
db.write_dedented(
"src/a.py",
"
from nonexistent import f
if flag:
def f() -> int:
return 1
x = f()
",
)?;
assert_public_ty(&db, "src/a.py", "x", "Unknown | int");
Ok(())
}
#[test]
fn call_union_with_not_callable() -> anyhow::Result<()> {
let mut db = setup_db();
db.write_dedented(
"src/a.py",
"
if flag:
f = 1
else:
def f() -> int:
return 1
x = f()
",
)?;
assert_file_diagnostics(
&db,
"src/a.py",
&["Union element 'Literal[1]' of type 'Literal[1] | Literal[f]' is not callable."],
);
assert_public_ty(&db, "src/a.py", "x", "Unknown | int");
Ok(())
}
#[test]
fn call_union_with_multiple_not_callable() -> anyhow::Result<()> {
let mut db = setup_db();
db.write_dedented(
"src/a.py",
"
if flag:
f = 1
elif flag2:
f = 'foo'
else:
def f() -> int:
return 1
x = f()
",
)?;
assert_file_diagnostics(
&db,
"src/a.py",
&[
r#"Union elements Literal[1], Literal["foo"] of type 'Literal[1] | Literal["foo"] | Literal[f]' are not callable."#,
],
);
assert_public_ty(&db, "src/a.py", "x", "Unknown | int");
Ok(())
}
#[test]
fn invalid_callable() {
let mut db = setup_db();
@ -3349,7 +3464,7 @@ mod tests {
assert_file_diagnostics(
&db,
"/src/a.py",
&["Object of type 'Literal[123]' is not callable"],
&["Object of type 'Literal[123]' is not callable."],
);
}
@ -4666,6 +4781,34 @@ mod tests {
Ok(())
}
#[test]
fn for_loop_non_callable_iter() -> anyhow::Result<()> {
let mut db = setup_db();
db.write_dedented(
"src/a.py",
"
class NotIterable:
if flag:
__iter__ = 1
else:
__iter__ = None
for x in NotIterable():
pass
",
)?;
assert_file_diagnostics(
&db,
"src/a.py",
&["Object of type 'NotIterable' is not iterable."],
);
assert_public_ty(&db, "src/a.py", "x", "Unbound | Unknown");
Ok(())
}
#[test]
fn except_handler_single_exception() -> anyhow::Result<()> {
let mut db = setup_db();
@ -4970,7 +5113,7 @@ mod tests {
assert_file_diagnostics(
&db,
"src/a.py",
&["Object of type 'Unbound' is not iterable"],
&["Object of type 'Unbound' is not iterable."],
);
Ok(())
@ -4998,7 +5141,7 @@ mod tests {
assert_scope_ty(&db, "src/a.py", &["foo", "<listcomp>"], "x", "int");
assert_scope_ty(&db, "src/a.py", &["foo", "<listcomp>"], "z", "Unknown");
assert_file_diagnostics(&db, "src/a.py", &["Object of type 'int' is not iterable"]);
assert_file_diagnostics(&db, "src/a.py", &["Object of type 'int' is not iterable."]);
Ok(())
}
@ -5192,7 +5335,7 @@ mod tests {
assert_file_diagnostics(
&db,
"/src/a.py",
&["Object of type 'Literal[123]' is not iterable"],
&["Object of type 'Literal[123]' is not iterable."],
);
}
@ -5218,7 +5361,7 @@ mod tests {
assert_file_diagnostics(
&db,
"/src/a.py",
&["Object of type 'NotIterable' is not iterable"],
&["Object of type 'NotIterable' is not iterable."],
);
}
@ -5247,7 +5390,7 @@ mod tests {
assert_file_diagnostics(
&db,
"/src/a.py",
&["Object of type 'NotIterable' is not iterable"],
&["Object of type 'NotIterable' is not iterable."],
);
}
@ -5277,7 +5420,7 @@ mod tests {
assert_file_diagnostics(
&db,
"/src/a.py",
&["Object of type 'NotIterable' is not iterable"],
&["Object of type 'NotIterable' is not iterable."],
);
}