[red-knot] Support assert_type (#15194)

## Summary

See #15103.

## Test Plan

Markdown tests and unit tests.
This commit is contained in:
InSync 2025-01-10 23:45:02 +07:00 committed by GitHub
parent c87463842a
commit 6b98a26452
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 422 additions and 38 deletions

View file

@ -0,0 +1,145 @@
# `assert_type`
## Basic
```py
from typing_extensions import assert_type
def _(x: int):
assert_type(x, int) # fine
assert_type(x, str) # error: [type-assertion-failure]
```
## Narrowing
The asserted type is checked against the inferred type, not the declared type.
```toml
[environment]
python-version = "3.10"
```
```py
from typing_extensions import assert_type
def _(x: int | str):
if isinstance(x, int):
reveal_type(x) # revealed: int
assert_type(x, int) # fine
```
## Equivalence
The actual type must match the asserted type precisely.
```py
from typing import Any, Type, Union
from typing_extensions import assert_type
# Subtype does not count
def _(x: bool):
assert_type(x, int) # error: [type-assertion-failure]
def _(a: type[int], b: type[Any]):
assert_type(a, type[Any]) # error: [type-assertion-failure]
assert_type(b, type[int]) # error: [type-assertion-failure]
# The expression constructing the type is not taken into account
def _(a: type[int]):
assert_type(a, Type[int]) # fine
```
## Gradual types
```py
from typing import Any
from typing_extensions import Literal, assert_type
from knot_extensions import Unknown
# Any and Unknown are considered equivalent
def _(a: Unknown, b: Any):
reveal_type(a) # revealed: Unknown
assert_type(a, Any) # fine
reveal_type(b) # revealed: Any
assert_type(b, Unknown) # fine
def _(a: type[Unknown], b: type[Any]):
# TODO: Should be `type[Unknown]`
reveal_type(a) # revealed: @Todo(unsupported type[X] special form)
# TODO: Should be fine
assert_type(a, type[Any]) # error: [type-assertion-failure]
reveal_type(b) # revealed: type[Any]
# TODO: Should be fine
assert_type(b, type[Unknown]) # error: [type-assertion-failure]
```
## Tuples
Tuple types with the same elements are the same.
```py
from typing_extensions import assert_type
from knot_extensions import Unknown
def _(a: tuple[int, str, bytes]):
assert_type(a, tuple[int, str, bytes]) # fine
assert_type(a, tuple[int, str]) # error: [type-assertion-failure]
assert_type(a, tuple[int, str, bytes, None]) # error: [type-assertion-failure]
assert_type(a, tuple[int, bytes, str]) # error: [type-assertion-failure]
def _(a: tuple[Any, ...], b: tuple[Unknown, ...]):
assert_type(a, tuple[Any, ...]) # fine
assert_type(a, tuple[Unknown, ...]) # fine
assert_type(b, tuple[Unknown, ...]) # fine
assert_type(b, tuple[Any, ...]) # fine
```
## Unions
Unions with the same elements are the same, regardless of order.
```toml
[environment]
python-version = "3.10"
```
```py
from typing_extensions import assert_type
def _(a: str | int):
assert_type(a, str | int) # fine
# TODO: Order-independent union handling in type equivalence
assert_type(a, int | str) # error: [type-assertion-failure]
```
## Intersections
Intersections are the same when their positive and negative parts are respectively the same,
regardless of order.
```py
from typing_extensions import assert_type
from knot_extensions import Intersection, Not
class A: ...
class B: ...
class C: ...
class D: ...
def _(a: A):
if isinstance(a, B) and not isinstance(a, C) and not isinstance(a, D):
reveal_type(a) # revealed: A & B & ~C & ~D
assert_type(a, Intersection[A, B, Not[C], Not[D]]) # fine
# TODO: Order-independent intersection handling in type equivalence
assert_type(a, Intersection[B, A, Not[D], Not[C]]) # error: [type-assertion-failure]
```

View file

@ -1,4 +1,5 @@
use std::hash::Hash; use std::hash::Hash;
use std::iter;
use context::InferContext; use context::InferContext;
use diagnostic::{report_not_iterable, report_not_iterable_possibly_unbound}; use diagnostic::{report_not_iterable, report_not_iterable_possibly_unbound};
@ -1095,6 +1096,87 @@ impl<'db> Type<'db> {
) )
} }
/// Returns true if this type and `other` are gradual equivalent.
///
/// > Two gradual types `A` and `B` are equivalent
/// > (that is, the same gradual type, not merely consistent with one another)
/// > if and only if all materializations of `A` are also materializations of `B`,
/// > and all materializations of `B` are also materializations of `A`.
/// >
/// > &mdash; [Summary of type relations]
///
/// This powers the `assert_type()` directive.
///
/// [Summary of type relations]: https://typing.readthedocs.io/en/latest/spec/concepts.html#summary-of-type-relations
pub(crate) fn is_gradual_equivalent_to(self, db: &'db dyn Db, other: Type<'db>) -> bool {
let equivalent =
|(first, second): (&Type<'db>, &Type<'db>)| first.is_gradual_equivalent_to(db, *second);
match (self, other) {
(_, _) if self == other => true,
(Type::Dynamic(_), Type::Dynamic(_)) => true,
(Type::Instance(instance), Type::SubclassOf(subclass))
| (Type::SubclassOf(subclass), Type::Instance(instance)) => {
let Some(base_class) = subclass.subclass_of().into_class() else {
return false;
};
instance.class.is_known(db, KnownClass::Type)
&& base_class.is_known(db, KnownClass::Object)
}
(Type::SubclassOf(first), Type::SubclassOf(second)) => {
match (first.subclass_of(), second.subclass_of()) {
(first, second) if first == second => true,
(ClassBase::Dynamic(_), ClassBase::Dynamic(_)) => true,
_ => false,
}
}
(Type::Tuple(first), Type::Tuple(second)) => {
let first_elements = first.elements(db);
let second_elements = second.elements(db);
first_elements.len() == second_elements.len()
&& iter::zip(first_elements, second_elements).all(equivalent)
}
// TODO: Handle equivalent unions with items in different order
(Type::Union(first), Type::Union(second)) => {
let first_elements = first.elements(db);
let second_elements = second.elements(db);
if first_elements.len() != second_elements.len() {
return false;
}
iter::zip(first_elements, second_elements).all(equivalent)
}
// TODO: Handle equivalent intersections with items in different order
(Type::Intersection(first), Type::Intersection(second)) => {
let first_positive = first.positive(db);
let first_negative = first.negative(db);
let second_positive = second.positive(db);
let second_negative = second.negative(db);
if first_positive.len() != second_positive.len()
|| first_negative.len() != second_negative.len()
{
return false;
}
iter::zip(first_positive, second_positive).all(equivalent)
&& iter::zip(first_negative, second_negative).all(equivalent)
}
_ => false,
}
}
/// Return true if this type and `other` have no common elements. /// Return true if this type and `other` have no common elements.
/// ///
/// Note: This function aims to have no false positives, but might return /// Note: This function aims to have no false positives, but might return
@ -1924,6 +2006,14 @@ impl<'db> Type<'db> {
CallOutcome::callable(binding) CallOutcome::callable(binding)
} }
Some(KnownFunction::AssertType) => {
let Some((_, asserted_ty)) = binding.two_parameter_tys() else {
return CallOutcome::callable(binding);
};
CallOutcome::asserted(binding, asserted_ty)
}
_ => CallOutcome::callable(binding), _ => CallOutcome::callable(binding),
} }
} }
@ -3261,6 +3351,9 @@ pub enum KnownFunction {
/// [`typing(_extensions).no_type_check`](https://typing.readthedocs.io/en/latest/spec/directives.html#no-type-check) /// [`typing(_extensions).no_type_check`](https://typing.readthedocs.io/en/latest/spec/directives.html#no-type-check)
NoTypeCheck, NoTypeCheck,
/// `typing(_extensions).assert_type`
AssertType,
/// `knot_extensions.static_assert` /// `knot_extensions.static_assert`
StaticAssert, StaticAssert,
/// `knot_extensions.is_equivalent_to` /// `knot_extensions.is_equivalent_to`
@ -3283,18 +3376,7 @@ impl KnownFunction {
pub fn constraint_function(self) -> Option<KnownConstraintFunction> { pub fn constraint_function(self) -> Option<KnownConstraintFunction> {
match self { match self {
Self::ConstraintFunction(f) => Some(f), Self::ConstraintFunction(f) => Some(f),
Self::RevealType _ => None,
| Self::Len
| Self::Final
| Self::NoTypeCheck
| Self::StaticAssert
| Self::IsEquivalentTo
| Self::IsSubtypeOf
| Self::IsAssignableTo
| Self::IsDisjointFrom
| Self::IsFullyStatic
| Self::IsSingleton
| Self::IsSingleValued => None,
} }
} }
@ -3316,6 +3398,7 @@ impl KnownFunction {
"no_type_check" if definition.is_typing_definition(db) => { "no_type_check" if definition.is_typing_definition(db) => {
Some(KnownFunction::NoTypeCheck) Some(KnownFunction::NoTypeCheck)
} }
"assert_type" if definition.is_typing_definition(db) => Some(KnownFunction::AssertType),
"static_assert" if definition.is_knot_extensions_definition(db) => { "static_assert" if definition.is_knot_extensions_definition(db) => {
Some(KnownFunction::StaticAssert) Some(KnownFunction::StaticAssert)
} }
@ -3345,20 +3428,34 @@ impl KnownFunction {
} }
} }
/// Whether or not a particular function takes type expression as arguments, i.e. should /// Returns a `u32` bitmask specifying whether or not
/// the argument of a call like `f(int)` be interpreted as the type int (true) or as the /// arguments given to a particular function
/// type of the expression `int`, i.e. `Literal[int]` (false). /// should be interpreted as type expressions or value expressions.
const fn takes_type_expression_arguments(self) -> bool { ///
matches!( /// The argument is treated as a type expression
self, /// when the corresponding bit is `1`.
KnownFunction::IsEquivalentTo /// The least-significant (right-most) bit corresponds to
| KnownFunction::IsSubtypeOf /// the argument at the index 0 and so on.
| KnownFunction::IsAssignableTo ///
| KnownFunction::IsDisjointFrom /// For example, `assert_type()` has the bitmask value of `0b10`.
| KnownFunction::IsFullyStatic /// This means the second argument is a type expression and the first a value expression.
| KnownFunction::IsSingleton const fn takes_type_expression_arguments(self) -> u32 {
| KnownFunction::IsSingleValued const ALL_VALUES: u32 = 0b0;
) const SINGLE_TYPE: u32 = 0b1;
const TYPE_TYPE: u32 = 0b11;
const VALUE_TYPE: u32 = 0b10;
match self {
KnownFunction::IsEquivalentTo => TYPE_TYPE,
KnownFunction::IsSubtypeOf => TYPE_TYPE,
KnownFunction::IsAssignableTo => TYPE_TYPE,
KnownFunction::IsDisjointFrom => TYPE_TYPE,
KnownFunction::IsFullyStatic => SINGLE_TYPE,
KnownFunction::IsSingleton => SINGLE_TYPE,
KnownFunction::IsSingleValued => SINGLE_TYPE,
KnownFunction::AssertType => VALUE_TYPE,
_ => ALL_VALUES,
}
} }
} }
@ -3681,7 +3778,8 @@ impl<'db> Class<'db> {
// does not accept the right arguments // does not accept the right arguments
CallOutcome::Callable { binding } CallOutcome::Callable { binding }
| CallOutcome::RevealType { binding, .. } | CallOutcome::RevealType { binding, .. }
| CallOutcome::StaticAssertionError { binding, .. } => Ok(binding.return_ty()), | CallOutcome::StaticAssertionError { binding, .. }
| CallOutcome::AssertType { binding, .. } => Ok(binding.return_ty()),
}; };
return return_ty_result.map(|ty| ty.to_meta_type(db)); return return_ty_result.map(|ty| ty.to_meta_type(db));
@ -4636,6 +4734,82 @@ pub(crate) mod tests {
assert!(!from.into_type(&db).is_fully_static(&db)); assert!(!from.into_type(&db).is_fully_static(&db));
} }
#[test_case(Ty::Todo, Ty::Todo)]
#[test_case(Ty::Any, Ty::Any)]
#[test_case(Ty::Unknown, Ty::Unknown)]
#[test_case(Ty::Any, Ty::Unknown)]
#[test_case(Ty::Todo, Ty::Unknown)]
#[test_case(Ty::Todo, Ty::Any)]
#[test_case(Ty::Never, Ty::Never)]
#[test_case(Ty::AlwaysTruthy, Ty::AlwaysTruthy)]
#[test_case(Ty::AlwaysFalsy, Ty::AlwaysFalsy)]
#[test_case(Ty::LiteralString, Ty::LiteralString)]
#[test_case(Ty::BooleanLiteral(true), Ty::BooleanLiteral(true))]
#[test_case(Ty::BooleanLiteral(false), Ty::BooleanLiteral(false))]
#[test_case(Ty::SliceLiteral(0, 1, 2), Ty::SliceLiteral(0, 1, 2))]
#[test_case(Ty::BuiltinClassLiteral("str"), Ty::BuiltinClassLiteral("str"))]
#[test_case(Ty::BuiltinInstance("type"), Ty::SubclassOfBuiltinClass("object"))]
// TODO: Compare unions/intersections with different orders
// #[test_case(
// Ty::Union(vec![Ty::BuiltinInstance("str"), Ty::BuiltinInstance("int")]),
// Ty::Union(vec![Ty::BuiltinInstance("int"), Ty::BuiltinInstance("str")])
// )]
// #[test_case(
// Ty::Intersection {
// pos: vec![Ty::BuiltinInstance("str"), Ty::BuiltinInstance("int")],
// neg: vec![Ty::BuiltinInstance("bytes"), Ty::None]
// },
// Ty::Intersection {
// pos: vec![Ty::BuiltinInstance("int"), Ty::BuiltinInstance("str")],
// neg: vec![Ty::None, Ty::BuiltinInstance("bytes")]
// }
// )]
// #[test_case(
// Ty::Intersection {
// pos: vec![Ty::Union(vec![Ty::BuiltinInstance("str"), Ty::BuiltinInstance("int")])],
// neg: vec![Ty::SubclassOfAny]
// },
// Ty::Intersection {
// pos: vec![Ty::Union(vec![Ty::BuiltinInstance("int"), Ty::BuiltinInstance("str")])],
// neg: vec![Ty::SubclassOfUnknown]
// }
// )]
fn is_gradual_equivalent_to(a: Ty, b: Ty) {
let db = setup_db();
let a = a.into_type(&db);
let b = b.into_type(&db);
assert!(a.is_gradual_equivalent_to(&db, b));
assert!(b.is_gradual_equivalent_to(&db, a));
}
#[test_case(Ty::BuiltinInstance("type"), Ty::SubclassOfAny)]
#[test_case(Ty::SubclassOfBuiltinClass("object"), Ty::SubclassOfAny)]
#[test_case(
Ty::Union(vec![Ty::BuiltinInstance("str"), Ty::BuiltinInstance("int")]),
Ty::Union(vec![Ty::BuiltinInstance("int"), Ty::BuiltinInstance("str"), Ty::BuiltinInstance("bytes")])
)]
#[test_case(
Ty::Union(vec![Ty::BuiltinInstance("str"), Ty::BuiltinInstance("int"), Ty::BuiltinInstance("bytes")]),
Ty::Union(vec![Ty::BuiltinInstance("int"), Ty::BuiltinInstance("str"), Ty::BuiltinInstance("dict")])
)]
#[test_case(
Ty::Tuple(vec![Ty::BuiltinInstance("str"), Ty::BuiltinInstance("int")]),
Ty::Tuple(vec![Ty::BuiltinInstance("str"), Ty::BuiltinInstance("int"), Ty::BuiltinInstance("bytes")])
)]
#[test_case(
Ty::Tuple(vec![Ty::BuiltinInstance("str"), Ty::BuiltinInstance("int")]),
Ty::Tuple(vec![Ty::BuiltinInstance("int"), Ty::BuiltinInstance("str")])
)]
fn is_not_gradual_equivalent_to(a: Ty, b: Ty) {
let db = setup_db();
let a = a.into_type(&db);
let b = b.into_type(&db);
assert!(!a.is_gradual_equivalent_to(&db, b));
assert!(!b.is_gradual_equivalent_to(&db, a));
}
#[test_case(Ty::IntLiteral(1); "is_int_literal_truthy")] #[test_case(Ty::IntLiteral(1); "is_int_literal_truthy")]
#[test_case(Ty::IntLiteral(-1))] #[test_case(Ty::IntLiteral(-1))]
#[test_case(Ty::StringLiteral("foo"))] #[test_case(Ty::StringLiteral("foo"))]

View file

@ -1,5 +1,5 @@
use super::context::InferContext; use super::context::InferContext;
use super::diagnostic::CALL_NON_CALLABLE; use super::diagnostic::{CALL_NON_CALLABLE, TYPE_ASSERTION_FAILURE};
use super::{Severity, Signature, Type, TypeArrayDisplay, UnionBuilder}; use super::{Severity, Signature, Type, TypeArrayDisplay, UnionBuilder};
use crate::types::diagnostic::STATIC_ASSERT_ERROR; use crate::types::diagnostic::STATIC_ASSERT_ERROR;
use crate::Db; use crate::Db;
@ -44,6 +44,10 @@ pub(super) enum CallOutcome<'db> {
binding: CallBinding<'db>, binding: CallBinding<'db>,
error_kind: StaticAssertionErrorKind<'db>, error_kind: StaticAssertionErrorKind<'db>,
}, },
AssertType {
binding: CallBinding<'db>,
asserted_ty: Type<'db>,
},
} }
impl<'db> CallOutcome<'db> { impl<'db> CallOutcome<'db> {
@ -76,6 +80,14 @@ impl<'db> CallOutcome<'db> {
} }
} }
/// Create a new `CallOutcome::AssertType` with given revealed and return types.
pub(super) fn asserted(binding: CallBinding<'db>, asserted_ty: Type<'db>) -> CallOutcome<'db> {
CallOutcome::AssertType {
binding,
asserted_ty,
}
}
/// Get the return type of the call, or `None` if not callable. /// Get the return type of the call, or `None` if not callable.
pub(super) fn return_ty(&self, db: &'db dyn Db) -> Option<Type<'db>> { pub(super) fn return_ty(&self, db: &'db dyn Db) -> Option<Type<'db>> {
match self { match self {
@ -103,6 +115,10 @@ impl<'db> CallOutcome<'db> {
.map(UnionBuilder::build), .map(UnionBuilder::build),
Self::PossiblyUnboundDunderCall { call_outcome, .. } => call_outcome.return_ty(db), Self::PossiblyUnboundDunderCall { call_outcome, .. } => call_outcome.return_ty(db),
Self::StaticAssertionError { .. } => Some(Type::none(db)), Self::StaticAssertionError { .. } => Some(Type::none(db)),
Self::AssertType {
binding,
asserted_ty: _,
} => Some(binding.return_ty()),
} }
} }
@ -309,6 +325,28 @@ impl<'db> CallOutcome<'db> {
Ok(Type::unknown()) Ok(Type::unknown())
} }
CallOutcome::AssertType {
binding,
asserted_ty,
} => {
let [actual_ty, _asserted] = binding.parameter_tys() else {
return Ok(binding.return_ty());
};
if !actual_ty.is_gradual_equivalent_to(context.db(), *asserted_ty) {
context.report_lint(
&TYPE_ASSERTION_FAILURE,
node,
format_args!(
"Actual type `{}` is not the same as asserted type `{}`",
actual_ty.display(context.db()),
asserted_ty.display(context.db()),
),
);
}
Ok(binding.return_ty())
}
} }
} }
} }

View file

@ -49,6 +49,7 @@ pub(crate) fn register_lints(registry: &mut LintRegistryBuilder) {
registry.register_lint(&POSSIBLY_UNBOUND_IMPORT); registry.register_lint(&POSSIBLY_UNBOUND_IMPORT);
registry.register_lint(&POSSIBLY_UNRESOLVED_REFERENCE); registry.register_lint(&POSSIBLY_UNRESOLVED_REFERENCE);
registry.register_lint(&SUBCLASS_OF_FINAL_CLASS); registry.register_lint(&SUBCLASS_OF_FINAL_CLASS);
registry.register_lint(&TYPE_ASSERTION_FAILURE);
registry.register_lint(&TOO_MANY_POSITIONAL_ARGUMENTS); registry.register_lint(&TOO_MANY_POSITIONAL_ARGUMENTS);
registry.register_lint(&UNDEFINED_REVEAL); registry.register_lint(&UNDEFINED_REVEAL);
registry.register_lint(&UNKNOWN_ARGUMENT); registry.register_lint(&UNKNOWN_ARGUMENT);
@ -575,6 +576,28 @@ declare_lint! {
} }
} }
declare_lint! {
/// ## What it does
/// Checks for `assert_type()` calls where the actual type
/// is not the same as the asserted type.
///
/// ## Why is this bad?
/// `assert_type()` allows confirming the inferred type of a certain value.
///
/// ## Example
///
/// ```python
/// def _(x: int):
/// assert_type(x, int) # fine
/// assert_type(x, str) # error: Actual type does not match asserted type
/// ```
pub(crate) static TYPE_ASSERTION_FAILURE = {
summary: "detects failed type assertions",
status: LintStatus::preview("1.0.0"),
default_level: Level::Error,
}
}
declare_lint! { declare_lint! {
/// ## What it does /// ## What it does
/// Checks for calls that pass more positional arguments than the callable can accept. /// Checks for calls that pass more positional arguments than the callable can accept.

View file

@ -956,7 +956,7 @@ impl<'db> TypeInferenceBuilder<'db> {
self.infer_type_parameters(type_params); self.infer_type_parameters(type_params);
if let Some(arguments) = class.arguments.as_deref() { if let Some(arguments) = class.arguments.as_deref() {
self.infer_arguments(arguments, false); self.infer_arguments(arguments, 0b0);
} }
} }
@ -2601,17 +2601,20 @@ impl<'db> TypeInferenceBuilder<'db> {
fn infer_arguments<'a>( fn infer_arguments<'a>(
&mut self, &mut self,
arguments: &'a ast::Arguments, arguments: &'a ast::Arguments,
infer_as_type_expressions: bool, infer_as_type_expressions: u32,
) -> CallArguments<'a, 'db> { ) -> CallArguments<'a, 'db> {
let infer_argument_type = if infer_as_type_expressions {
Self::infer_type_expression
} else {
Self::infer_expression
};
arguments arguments
.arguments_source_order() .arguments_source_order()
.map(|arg_or_keyword| { .enumerate()
.map(|(index, arg_or_keyword)| {
let infer_argument_type = if index < u32::BITS as usize
&& infer_as_type_expressions & (1 << index) != 0
{
Self::infer_type_expression
} else {
Self::infer_expression
};
match arg_or_keyword { match arg_or_keyword {
ast::ArgOrKeyword::Arg(arg) => match arg { ast::ArgOrKeyword::Arg(arg) => match arg {
ast::Expr::Starred(ast::ExprStarred { ast::Expr::Starred(ast::ExprStarred {
@ -3157,7 +3160,8 @@ impl<'db> TypeInferenceBuilder<'db> {
let infer_arguments_as_type_expressions = function_type let infer_arguments_as_type_expressions = function_type
.into_function_literal() .into_function_literal()
.and_then(|f| f.known(self.db())) .and_then(|f| f.known(self.db()))
.is_some_and(KnownFunction::takes_type_expression_arguments); .map(KnownFunction::takes_type_expression_arguments)
.unwrap_or(0b0);
let call_arguments = self.infer_arguments(arguments, infer_arguments_as_type_expressions); let call_arguments = self.infer_arguments(arguments, infer_arguments_as_type_expressions);
function_type function_type