[red-knot] knot_extensions Python API (#15103)

## Summary

Adds a type-check-time Python API that allows us to create and
manipulate types and to test various of their properties. For example,
this can be used to write a Markdown test to make sure that `A & B` is a
subtype of `A` and `B`, but not of an unrelated class `C` (something
that requires quite a bit more code to do in Rust):
```py
from knot_extensions import Intersection, is_subtype_of, static_assert

class A: ...
class B: ...

type AB = Intersection[A, B]

static_assert(is_subtype_of(AB, A))
static_assert(is_subtype_of(AB, B))

class C: ...
static_assert(not is_subtype_of(AB, C))
```

I think this functionality is also helpful for interactive debugging
sessions, in order to query various properties of Red Knot's type
system. Which is something that otherwise requires a custom Rust unit
test, some boilerplate code and constant re-compilation.

## Test Plan

- New Markdown tests
- Tested the modified typeshed_sync workflow locally
This commit is contained in:
David Peter 2025-01-08 12:52:07 +01:00 committed by GitHub
parent 03ff883626
commit 235fdfc57a
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
15 changed files with 826 additions and 17 deletions

View file

@ -109,6 +109,7 @@ pub enum KnownModule {
#[allow(dead_code)]
Abc, // currently only used in tests
Collections,
KnotExtensions,
}
impl KnownModule {
@ -122,6 +123,7 @@ impl KnownModule {
Self::Sys => "sys",
Self::Abc => "abc",
Self::Collections => "collections",
Self::KnotExtensions => "knot_extensions",
}
}
@ -147,6 +149,7 @@ impl KnownModule {
"sys" => Some(Self::Sys),
"abc" => Some(Self::Abc),
"collections" => Some(Self::Collections),
"knot_extensions" => Some(Self::KnotExtensions),
_ => None,
}
}
@ -154,4 +157,8 @@ impl KnownModule {
pub const fn is_typing(self) -> bool {
matches!(self, Self::Typing)
}
pub const fn is_knot_extensions(self) -> bool {
matches!(self, Self::KnotExtensions)
}
}

View file

@ -74,6 +74,11 @@ impl<'db> Definition<'db> {
Some(KnownModule::Typing | KnownModule::TypingExtensions)
)
}
pub(crate) fn is_knot_extensions_definition(self, db: &'db dyn Db) -> bool {
file_to_module(db, self.file(db))
.is_some_and(|module| module.is_known(KnownModule::KnotExtensions))
}
}
#[derive(Copy, Clone, Debug)]

View file

@ -31,7 +31,9 @@ use crate::semantic_index::{
use crate::stdlib::{builtins_symbol, known_module_symbol, typing_extensions_symbol};
use crate::suppression::check_suppressions;
use crate::symbol::{Boundness, Symbol};
use crate::types::call::{bind_call, CallArguments, CallBinding, CallDunderResult, CallOutcome};
use crate::types::call::{
bind_call, CallArguments, CallBinding, CallDunderResult, CallOutcome, StaticAssertionErrorKind,
};
use crate::types::class_base::ClassBase;
use crate::types::diagnostic::INVALID_TYPE_FORM;
use crate::types::mro::{Mro, MroError, MroIterator};
@ -657,6 +659,13 @@ impl<'db> Type<'db> {
}
}
pub fn into_string_literal(self) -> Option<StringLiteralType<'db>> {
match self {
Type::StringLiteral(string_literal) => Some(string_literal),
_ => None,
}
}
#[track_caller]
pub fn expect_int_literal(self) -> i64 {
self.into_int_literal()
@ -1824,12 +1833,88 @@ impl<'db> Type<'db> {
let mut binding = bind_call(db, arguments, function_type.signature(db), Some(self));
match function_type.known(db) {
Some(KnownFunction::RevealType) => {
let revealed_ty = binding.first_parameter().unwrap_or(Type::Unknown);
let revealed_ty = binding.one_parameter_ty().unwrap_or(Type::Unknown);
CallOutcome::revealed(binding, revealed_ty)
}
Some(KnownFunction::StaticAssert) => {
if let Some((parameter_ty, message)) = binding.two_parameter_tys() {
let truthiness = parameter_ty.bool(db);
if truthiness.is_always_true() {
CallOutcome::callable(binding)
} else {
let error_kind = if let Some(message) =
message.into_string_literal().map(|s| &**s.value(db))
{
StaticAssertionErrorKind::CustomError(message)
} else if parameter_ty == Type::BooleanLiteral(false) {
StaticAssertionErrorKind::ArgumentIsFalse
} else if truthiness.is_always_false() {
StaticAssertionErrorKind::ArgumentIsFalsy(parameter_ty)
} else {
StaticAssertionErrorKind::ArgumentTruthinessIsAmbiguous(
parameter_ty,
)
};
CallOutcome::StaticAssertionError {
binding,
error_kind,
}
}
} else {
CallOutcome::callable(binding)
}
}
Some(KnownFunction::IsEquivalentTo) => {
let (ty_a, ty_b) = binding
.two_parameter_tys()
.unwrap_or((Type::Unknown, Type::Unknown));
binding
.set_return_ty(Type::BooleanLiteral(ty_a.is_equivalent_to(db, ty_b)));
CallOutcome::callable(binding)
}
Some(KnownFunction::IsSubtypeOf) => {
let (ty_a, ty_b) = binding
.two_parameter_tys()
.unwrap_or((Type::Unknown, Type::Unknown));
binding.set_return_ty(Type::BooleanLiteral(ty_a.is_subtype_of(db, ty_b)));
CallOutcome::callable(binding)
}
Some(KnownFunction::IsAssignableTo) => {
let (ty_a, ty_b) = binding
.two_parameter_tys()
.unwrap_or((Type::Unknown, Type::Unknown));
binding
.set_return_ty(Type::BooleanLiteral(ty_a.is_assignable_to(db, ty_b)));
CallOutcome::callable(binding)
}
Some(KnownFunction::IsDisjointFrom) => {
let (ty_a, ty_b) = binding
.two_parameter_tys()
.unwrap_or((Type::Unknown, Type::Unknown));
binding
.set_return_ty(Type::BooleanLiteral(ty_a.is_disjoint_from(db, ty_b)));
CallOutcome::callable(binding)
}
Some(KnownFunction::IsFullyStatic) => {
let ty = binding.one_parameter_ty().unwrap_or(Type::Unknown);
binding.set_return_ty(Type::BooleanLiteral(ty.is_fully_static(db)));
CallOutcome::callable(binding)
}
Some(KnownFunction::IsSingleton) => {
let ty = binding.one_parameter_ty().unwrap_or(Type::Unknown);
binding.set_return_ty(Type::BooleanLiteral(ty.is_singleton(db)));
CallOutcome::callable(binding)
}
Some(KnownFunction::IsSingleValued) => {
let ty = binding.one_parameter_ty().unwrap_or(Type::Unknown);
binding.set_return_ty(Type::BooleanLiteral(ty.is_single_valued(db)));
CallOutcome::callable(binding)
}
Some(KnownFunction::Len) => {
if let Some(first_arg) = binding.first_parameter() {
if let Some(first_arg) = binding.one_parameter_ty() {
if let Some(len_ty) = first_arg.len(db) {
binding.set_return_ty(len_ty);
}
@ -2107,6 +2192,7 @@ impl<'db> Type<'db> {
invalid_expressions: smallvec::smallvec![InvalidTypeExpression::BareLiteral],
fallback_type: Type::Unknown,
}),
Type::KnownInstance(KnownInstanceType::Unknown) => Ok(Type::Unknown),
Type::Todo(_) => Ok(*self),
_ => Ok(todo_type!(
"Unsupported or invalid type in a type expression"
@ -2613,6 +2699,14 @@ pub enum KnownInstanceType<'db> {
TypeVar(TypeVarInstance<'db>),
/// A single instance of `typing.TypeAliasType` (PEP 695 type alias)
TypeAliasType(TypeAliasType<'db>),
/// The symbol `knot_extensions.Unknown`
Unknown,
/// The symbol `knot_extensions.Not`
Not,
/// The symbol `knot_extensions.Intersection`
Intersection,
/// The symbol `knot_extensions.TypeOf`
TypeOf,
// Various special forms, special aliases and type qualifiers that we don't yet understand
// (all currently inferred as TODO in most contexts):
@ -2667,6 +2761,10 @@ impl<'db> KnownInstanceType<'db> {
Self::ChainMap => "ChainMap",
Self::OrderedDict => "OrderedDict",
Self::ReadOnly => "ReadOnly",
Self::Unknown => "Unknown",
Self::Not => "Not",
Self::Intersection => "Intersection",
Self::TypeOf => "TypeOf",
}
}
@ -2705,7 +2803,11 @@ impl<'db> KnownInstanceType<'db> {
| Self::ChainMap
| Self::OrderedDict
| Self::ReadOnly
| Self::TypeAliasType(_) => Truthiness::AlwaysTrue,
| Self::TypeAliasType(_)
| Self::Unknown
| Self::Not
| Self::Intersection
| Self::TypeOf => Truthiness::AlwaysTrue,
}
}
@ -2745,6 +2847,10 @@ impl<'db> KnownInstanceType<'db> {
Self::ReadOnly => "typing.ReadOnly",
Self::TypeVar(typevar) => typevar.name(db),
Self::TypeAliasType(_) => "typing.TypeAliasType",
Self::Unknown => "knot_extensions.Unknown",
Self::Not => "knot_extensions.Not",
Self::Intersection => "knot_extensions.Intersection",
Self::TypeOf => "knot_extensions.TypeOf",
}
}
@ -2784,6 +2890,10 @@ impl<'db> KnownInstanceType<'db> {
Self::OrderedDict => KnownClass::StdlibAlias,
Self::TypeVar(_) => KnownClass::TypeVar,
Self::TypeAliasType(_) => KnownClass::TypeAliasType,
Self::TypeOf => KnownClass::SpecialForm,
Self::Not => KnownClass::SpecialForm,
Self::Intersection => KnownClass::SpecialForm,
Self::Unknown => KnownClass::Object,
}
}
@ -2834,6 +2944,10 @@ impl<'db> KnownInstanceType<'db> {
"Concatenate" => Self::Concatenate,
"NotRequired" => Self::NotRequired,
"LiteralString" => Self::LiteralString,
"Unknown" => Self::Unknown,
"Not" => Self::Not,
"Intersection" => Self::Intersection,
"TypeOf" => Self::TypeOf,
_ => return None,
};
@ -2883,6 +2997,9 @@ impl<'db> KnownInstanceType<'db> {
| Self::TypeVar(_) => {
matches!(module, KnownModule::Typing | KnownModule::TypingExtensions)
}
Self::Unknown | Self::Not | Self::Intersection | Self::TypeOf => {
module.is_knot_extensions()
}
}
}
@ -3121,13 +3238,41 @@ pub enum KnownFunction {
/// [`typing(_extensions).no_type_check`](https://typing.readthedocs.io/en/latest/spec/directives.html#no-type-check)
NoTypeCheck,
/// `knot_extensions.static_assert`
StaticAssert,
/// `knot_extensions.is_equivalent_to`
IsEquivalentTo,
/// `knot_extensions.is_subtype_of`
IsSubtypeOf,
/// `knot_extensions.is_assignable_to`
IsAssignableTo,
/// `knot_extensions.is_disjoint_from`
IsDisjointFrom,
/// `knot_extensions.is_fully_static`
IsFullyStatic,
/// `knot_extensions.is_singleton`
IsSingleton,
/// `knot_extensions.is_single_valued`
IsSingleValued,
}
impl KnownFunction {
pub fn constraint_function(self) -> Option<KnownConstraintFunction> {
match self {
Self::ConstraintFunction(f) => Some(f),
Self::RevealType | Self::Len | Self::Final | Self::NoTypeCheck => None,
Self::RevealType
| Self::Len
| Self::Final
| Self::NoTypeCheck
| Self::StaticAssert
| Self::IsEquivalentTo
| Self::IsSubtypeOf
| Self::IsAssignableTo
| Self::IsDisjointFrom
| Self::IsFullyStatic
| Self::IsSingleton
| Self::IsSingleValued => None,
}
}
@ -3149,9 +3294,50 @@ impl KnownFunction {
"no_type_check" if definition.is_typing_definition(db) => {
Some(KnownFunction::NoTypeCheck)
}
"static_assert" if definition.is_knot_extensions_definition(db) => {
Some(KnownFunction::StaticAssert)
}
"is_subtype_of" if definition.is_knot_extensions_definition(db) => {
Some(KnownFunction::IsSubtypeOf)
}
"is_disjoint_from" if definition.is_knot_extensions_definition(db) => {
Some(KnownFunction::IsDisjointFrom)
}
"is_equivalent_to" if definition.is_knot_extensions_definition(db) => {
Some(KnownFunction::IsEquivalentTo)
}
"is_assignable_to" if definition.is_knot_extensions_definition(db) => {
Some(KnownFunction::IsAssignableTo)
}
"is_fully_static" if definition.is_knot_extensions_definition(db) => {
Some(KnownFunction::IsFullyStatic)
}
"is_singleton" if definition.is_knot_extensions_definition(db) => {
Some(KnownFunction::IsSingleton)
}
"is_single_valued" if definition.is_knot_extensions_definition(db) => {
Some(KnownFunction::IsSingleValued)
}
_ => None,
}
}
/// Whether or not a particular function takes type expression as arguments, i.e. should
/// the argument of a call like `f(int)` be interpreted as the type int (true) or as the
/// type of the expression `int`, i.e. `Literal[int]` (false).
const fn takes_type_expression_arguments(self) -> bool {
matches!(
self,
KnownFunction::IsEquivalentTo
| KnownFunction::IsSubtypeOf
| KnownFunction::IsAssignableTo
| KnownFunction::IsDisjointFrom
| KnownFunction::IsFullyStatic
| KnownFunction::IsSingleton
| KnownFunction::IsSingleValued
)
}
}
#[salsa::interned]

View file

@ -1,6 +1,7 @@
use super::context::InferContext;
use super::diagnostic::CALL_NON_CALLABLE;
use super::{Severity, Signature, Type, TypeArrayDisplay, UnionBuilder};
use crate::types::diagnostic::STATIC_ASSERT_ERROR;
use crate::Db;
use ruff_db::diagnostic::DiagnosticId;
use ruff_python_ast as ast;
@ -11,6 +12,14 @@ mod bind;
pub(super) use arguments::{Argument, CallArguments};
pub(super) use bind::{bind_call, CallBinding};
#[derive(Debug, Clone, PartialEq, Eq)]
pub(super) enum StaticAssertionErrorKind<'db> {
ArgumentIsFalse,
ArgumentIsFalsy(Type<'db>),
ArgumentTruthinessIsAmbiguous(Type<'db>),
CustomError(&'db str),
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub(super) enum CallOutcome<'db> {
Callable {
@ -31,6 +40,10 @@ pub(super) enum CallOutcome<'db> {
called_ty: Type<'db>,
call_outcome: Box<CallOutcome<'db>>,
},
StaticAssertionError {
binding: CallBinding<'db>,
error_kind: StaticAssertionErrorKind<'db>,
},
}
impl<'db> CallOutcome<'db> {
@ -89,6 +102,7 @@ impl<'db> CallOutcome<'db> {
})
.map(UnionBuilder::build),
Self::PossiblyUnboundDunderCall { call_outcome, .. } => call_outcome.return_ty(db),
Self::StaticAssertionError { .. } => Some(Type::none(db)),
}
}
@ -181,6 +195,7 @@ impl<'db> CallOutcome<'db> {
binding,
revealed_ty,
} => {
binding.report_diagnostics(context, node);
context.report_diagnostic(
node,
DiagnosticId::RevealedType,
@ -249,6 +264,51 @@ impl<'db> CallOutcome<'db> {
}),
}
}
CallOutcome::StaticAssertionError {
binding,
error_kind,
} => {
binding.report_diagnostics(context, node);
match error_kind {
StaticAssertionErrorKind::ArgumentIsFalse => {
context.report_lint(
&STATIC_ASSERT_ERROR,
node,
format_args!("Static assertion error: argument evaluates to `False`"),
);
}
StaticAssertionErrorKind::ArgumentIsFalsy(parameter_ty) => {
context.report_lint(
&STATIC_ASSERT_ERROR,
node,
format_args!(
"Static assertion error: argument of type `{parameter_ty}` is statically known to be falsy",
parameter_ty=parameter_ty.display(context.db())
),
);
}
StaticAssertionErrorKind::ArgumentTruthinessIsAmbiguous(parameter_ty) => {
context.report_lint(
&STATIC_ASSERT_ERROR,
node,
format_args!(
"Static assertion error: argument of type `{parameter_ty}` has an ambiguous static truthiness",
parameter_ty=parameter_ty.display(context.db())
),
);
}
StaticAssertionErrorKind::CustomError(message) => {
context.report_lint(
&STATIC_ASSERT_ERROR,
node,
format_args!("Static assertion error: {message}"),
);
}
}
Ok(Type::Unknown)
}
}
}
}

View file

@ -154,8 +154,18 @@ impl<'db> CallBinding<'db> {
&self.parameter_tys
}
pub(crate) fn first_parameter(&self) -> Option<Type<'db>> {
self.parameter_tys().first().copied()
pub(crate) fn one_parameter_ty(&self) -> Option<Type<'db>> {
match self.parameter_tys() {
[ty] => Some(*ty),
_ => None,
}
}
pub(crate) fn two_parameter_tys(&self) -> Option<(Type<'db>, Type<'db>)> {
match self.parameter_tys() {
[first, second] => Some((*first, *second)),
_ => None,
}
}
fn callable_name(&self, db: &'db dyn Db) -> Option<&ast::name::Name> {

View file

@ -100,7 +100,11 @@ impl<'db> ClassBase<'db> {
| KnownInstanceType::Required
| KnownInstanceType::TypeAlias
| KnownInstanceType::ReadOnly
| KnownInstanceType::Optional => None,
| KnownInstanceType::Optional
| KnownInstanceType::Not
| KnownInstanceType::Intersection
| KnownInstanceType::TypeOf => None,
KnownInstanceType::Unknown => Some(Self::Unknown),
KnownInstanceType::Any => Some(Self::Any),
// TODO: Classes inheriting from `typing.Type` et al. also have `Generic` in their MRO
KnownInstanceType::Dict => {

View file

@ -56,6 +56,7 @@ pub(crate) fn register_lints(registry: &mut LintRegistryBuilder) {
registry.register_lint(&UNRESOLVED_REFERENCE);
registry.register_lint(&UNSUPPORTED_OPERATOR);
registry.register_lint(&ZERO_STEPSIZE_IN_SLICE);
registry.register_lint(&STATIC_ASSERT_ERROR);
// String annotations
registry.register_lint(&BYTE_STRING_TYPE_ANNOTATION);
@ -678,6 +679,25 @@ declare_lint! {
}
}
declare_lint! {
/// ## What it does
/// Makes sure that the argument of `static_assert` is statically known to be true.
///
/// ## Examples
/// ```python
/// from knot_extensions import static_assert
///
/// static_assert(1 + 1 == 3) # error: evaluates to `False`
///
/// static_assert(int(2.0 * 3.0) == 6) # error: does not have a statically known truthiness
/// ```
pub(crate) static STATIC_ASSERT_ERROR = {
summary: "Failed static assertion",
status: LintStatus::preview("1.0.0"),
default_level: Level::Error,
}
}
#[derive(Debug, Eq, PartialEq, Clone)]
pub struct TypeCheckDiagnostic {
pub(crate) id: DiagnosticId,

View file

@ -28,7 +28,7 @@
//! definitions once the rest of the types in the scope have been inferred.
use std::num::NonZeroU32;
use itertools::Itertools;
use itertools::{Either, Itertools};
use ruff_db::files::File;
use ruff_db::parsed::parsed_module;
use ruff_python_ast::{self as ast, AnyNodeRef, ExprContext};
@ -919,7 +919,7 @@ impl<'db> TypeInferenceBuilder<'db> {
self.infer_type_parameters(type_params);
if let Some(arguments) = class.arguments.as_deref() {
self.infer_arguments(arguments);
self.infer_arguments(arguments, false);
}
}
@ -2523,7 +2523,17 @@ impl<'db> TypeInferenceBuilder<'db> {
self.infer_expression(expression)
}
fn infer_arguments(&mut self, arguments: &ast::Arguments) -> CallArguments<'db> {
fn infer_arguments(
&mut self,
arguments: &ast::Arguments,
infer_as_type_expressions: bool,
) -> CallArguments<'db> {
let infer_argument_type = if infer_as_type_expressions {
Self::infer_type_expression
} else {
Self::infer_expression
};
arguments
.arguments_source_order()
.map(|arg_or_keyword| {
@ -2534,19 +2544,19 @@ impl<'db> TypeInferenceBuilder<'db> {
range: _,
ctx: _,
}) => {
let ty = self.infer_expression(value);
let ty = infer_argument_type(self, value);
self.store_expression_type(arg, ty);
Argument::Variadic(ty)
}
// TODO diagnostic if after a keyword argument
_ => Argument::Positional(self.infer_expression(arg)),
_ => Argument::Positional(infer_argument_type(self, arg)),
},
ast::ArgOrKeyword::Keyword(ast::Keyword {
arg,
value,
range: _,
}) => {
let ty = self.infer_expression(value);
let ty = infer_argument_type(self, value);
if let Some(arg) = arg {
Argument::Keyword {
name: arg.id.clone(),
@ -3070,8 +3080,14 @@ impl<'db> TypeInferenceBuilder<'db> {
arguments,
} = call_expression;
let call_arguments = self.infer_arguments(arguments);
let function_type = self.infer_expression(func);
let infer_arguments_as_type_expressions = function_type
.into_function_literal()
.and_then(|f| f.known(self.db()))
.is_some_and(KnownFunction::takes_type_expression_arguments);
let call_arguments = self.infer_arguments(arguments, infer_arguments_as_type_expressions);
function_type
.call(self.db(), &call_arguments)
.unwrap_with_diagnostic(&self.context, call_expression.into())
@ -4448,7 +4464,7 @@ impl<'db> TypeInferenceBuilder<'db> {
return dunder_getitem_method
.call(self.db(), &CallArguments::positional([value_ty, slice_ty]))
.return_ty_result(&self.context, value_node.into())
.return_ty_result( &self.context, value_node.into())
.unwrap_or_else(|err| {
self.context.report_lint(
&CALL_NON_CALLABLE,
@ -5156,6 +5172,55 @@ impl<'db> TypeInferenceBuilder<'db> {
todo_type!("Callable types")
}
// Type API special forms
KnownInstanceType::Not => match arguments_slice {
ast::Expr::Tuple(_) => {
self.context.report_lint(
&INVALID_TYPE_FORM,
subscript.into(),
format_args!(
"Special form `{}` expected exactly one type parameter",
known_instance.repr(self.db())
),
);
Type::Unknown
}
_ => {
let argument_type = self.infer_type_expression(arguments_slice);
argument_type.negate(self.db())
}
},
KnownInstanceType::Intersection => {
let elements = match arguments_slice {
ast::Expr::Tuple(tuple) => Either::Left(tuple.iter()),
element => Either::Right(std::iter::once(element)),
};
elements
.fold(IntersectionBuilder::new(self.db()), |builder, element| {
builder.add_positive(self.infer_type_expression(element))
})
.build()
}
KnownInstanceType::TypeOf => match arguments_slice {
ast::Expr::Tuple(_) => {
self.context.report_lint(
&INVALID_TYPE_FORM,
subscript.into(),
format_args!(
"Special form `{}` expected exactly one type parameter",
known_instance.repr(self.db())
),
);
Type::Unknown
}
_ => {
// NB: This calls `infer_expression` instead of `infer_type_expression`.
let argument_type = self.infer_expression(arguments_slice);
argument_type
}
},
// TODO: Generics
KnownInstanceType::ChainMap => {
self.infer_type_expression(arguments_slice);
@ -5241,7 +5306,9 @@ impl<'db> TypeInferenceBuilder<'db> {
);
Type::Unknown
}
KnownInstanceType::TypingSelf | KnownInstanceType::TypeAlias => {
KnownInstanceType::TypingSelf
| KnownInstanceType::TypeAlias
| KnownInstanceType::Unknown => {
self.context.report_lint(
&INVALID_TYPE_FORM,
subscript.into(),