create generic context lazily

This commit is contained in:
Douglas Creager 2025-04-18 17:05:03 -04:00
parent 0a4dec0323
commit b44fb47f25
8 changed files with 172 additions and 136 deletions

View file

@ -4623,8 +4623,8 @@ impl<'db> Type<'db> {
match self {
Type::TypeVar(typevar) => specialization.get(db, typevar).unwrap_or(self),
Type::FunctionLiteral(function) => {
Type::FunctionLiteral(function.apply_specialization(db, specialization))
Type::FunctionLiteral(function) =>{
Type::FunctionLiteral(FunctionType::Specialized(SpecializedFunction::new(db, function, specialization)))
}
// Note that we don't need to apply the specialization to `self_instance`, since it
@ -4637,19 +4637,19 @@ impl<'db> Type<'db> {
// specialized.)
Type::BoundMethod(method) => Type::BoundMethod(BoundMethodType::new(
db,
method.function(db).apply_specialization(db, specialization),
FunctionType::Specialized(SpecializedFunction::new(db, method.function(db), specialization)),
method.self_instance(db),
)),
Type::MethodWrapper(MethodWrapperKind::FunctionTypeDunderGet(function)) => {
Type::MethodWrapper(MethodWrapperKind::FunctionTypeDunderGet(
function.apply_specialization(db, specialization),
FunctionType::Specialized(SpecializedFunction::new(db, function, specialization))
))
}
Type::MethodWrapper(MethodWrapperKind::FunctionTypeDunderCall(function)) => {
Type::MethodWrapper(MethodWrapperKind::FunctionTypeDunderCall(
function.apply_specialization(db, specialization),
FunctionType::Specialized(SpecializedFunction::new(db, function, specialization))
))
}
@ -5852,6 +5852,34 @@ impl<'db> FunctionSignature<'db> {
pub(crate) fn iter(&self) -> Iter<Signature<'db>> {
self.as_slice().iter()
}
fn apply_specialization(&mut self, db: &'db dyn Db, specialization: Specialization<'db>) {
match self {
Self::Single(signature) => signature.apply_specialization(db, specialization),
Self::Overloaded(signatures, implementation) => {
signatures
.iter_mut()
.for_each(|signature| signature.apply_specialization(db, specialization));
implementation
.as_mut()
.map(|signature| signature.apply_specialization(db, specialization));
}
}
}
fn set_generic_context(&mut self, generic_context: GenericContext<'db>) {
match self {
Self::Single(signature) => signature.set_generic_context(generic_context),
Self::Overloaded(signatures, implementation) => {
signatures
.iter_mut()
.for_each(|signature| signature.set_generic_context(generic_context));
implementation
.as_mut()
.map(|signature| signature.set_generic_context(generic_context));
}
}
}
}
impl<'db> IntoIterator for &'db FunctionSignature<'db> {
@ -5864,7 +5892,9 @@ impl<'db> IntoIterator for &'db FunctionSignature<'db> {
}
/// A callable type that represents a single Python function.
#[derive(Clone, Copy, Debug, Eq, Hash, Ord, PartialEq, PartialOrd, salsa::Update)]
#[derive(
Clone, Copy, Debug, Eq, Hash, Ord, PartialEq, PartialOrd, salsa::Supertype, salsa::Update,
)]
pub enum FunctionType<'db> {
/// A function literal in the Python AST
FunctionLiteral(FunctionLiteral<'db>),
@ -5888,6 +5918,7 @@ impl<'db> FunctionType<'db> {
fn function_literal(self, db: &'db dyn Db) -> FunctionLiteral<'db> {
match self {
FunctionType::FunctionLiteral(literal) => literal,
FunctionType::Specialized(specialized) => specialized.function(db).function_literal(db),
FunctionType::InheritedGenericContext(inherited) => inherited.function(db),
}
}
@ -5922,9 +5953,7 @@ impl<'db> FunctionType<'db> {
}
pub(crate) fn definition(self, db: &'db dyn Db) -> Definition<'db> {
let body_scope = self.function_literal(db).body_scope(db);
let index = semantic_index(db, body_scope.file(db));
index.expect_single_definition(body_scope.node(db).expect_function())
self.function_literal(db).definition(db)
}
/// Typed externally-visible signature for this function.
@ -5942,8 +5971,18 @@ impl<'db> FunctionType<'db> {
pub(crate) fn signature(self, db: &'db dyn Db) -> FunctionSignature<'db> {
match self {
FunctionType::FunctionLiteral(literal) => literal.signature(db),
FunctionType::Specialized(specialized) => specialized.signature(db),
FunctionType::InheritedGenericContext(inherited) => inherited.signature(db),
}
}
pub(crate) fn known(self, db: &'db dyn Db) -> Option<KnownFunction> {
self.function_literal(db).known(db)
}
pub(crate) fn is_known(self, db: &'db dyn Db, known_function: KnownFunction) -> bool {
self.known(db) == Some(known_function)
}
}
#[salsa::interned(debug)]
@ -5958,16 +5997,11 @@ pub struct FunctionLiteral<'db> {
/// The scope that's created by the function, in which the function body is evaluated.
body_scope: ScopeId<'db>,
/// The scope containing the PEP 695 type parameters in the function definition, if any.
type_params_scope: Option<ScopeId<'db>>,
/// A set of special decorators that were applied to this function
decorators: FunctionDecorators,
/// The generic context of a generic function.
generic_context: Option<GenericContext<'db>>,
/// A specialization that should be applied to the function's parameter and return types,
/// either because the function is itself generic, or because it appears in the body of a
/// generic class.
specialization: Option<Specialization<'db>>,
}
#[salsa::tracked]
@ -5976,6 +6010,12 @@ impl<'db> FunctionLiteral<'db> {
self.decorators(db).contains(decorator)
}
fn definition(self, db: &'db dyn Db) -> Definition<'db> {
let body_scope = self.body_scope(db);
let index = semantic_index(db, body_scope.file(db));
index.expect_single_definition(body_scope.node(db).expect_function())
}
/// Typed externally-visible signature for this function.
///
/// This is the signature as seen by external callers, possibly modified by decorators and/or
@ -5990,11 +6030,7 @@ impl<'db> FunctionLiteral<'db> {
/// would depend on the function's AST and rerun for every change in that file.
#[salsa::tracked]
fn signature(self, db: &'db dyn Db) -> FunctionSignature<'db> {
let mut internal_signature = self.internal_signature(db);
if let Some(specialization) = self.specialization(db) {
internal_signature = internal_signature.apply_specialization(db, specialization);
}
let internal_signature = self.internal_signature(db);
// The semantic model records a use for each function on the name node. This is used here
// to get the previous function definition with the same name.
@ -6054,11 +6090,11 @@ impl<'db> FunctionLiteral<'db> {
let scope = self.body_scope(db);
let function_stmt_node = scope.node(db).expect_function();
let definition = self.definition(db);
Signature::from_function(db, self.generic_context(db), definition, function_stmt_node)
}
pub(crate) fn is_known(self, db: &'db dyn Db, known_function: KnownFunction) -> bool {
self.known(db) == Some(known_function)
let generic_context = function_stmt_node.type_params.as_ref().map(|type_params| {
let index = semantic_index(db, scope.file(db));
GenericContext::from_type_params(db, index, type_params)
});
Signature::from_function(db, generic_context, definition, function_stmt_node)
}
fn with_generic_context(
@ -6072,21 +6108,27 @@ impl<'db> FunctionLiteral<'db> {
generic_context,
))
}
}
fn apply_specialization(self, db: &'db dyn Db, specialization: Specialization<'db>) -> Self {
let specialization = match self.specialization(db) {
Some(existing) => existing.apply_specialization(db, specialization),
None => specialization,
};
Self::new(
db,
self.name(db).clone(),
self.known(db),
self.body_scope(db),
self.decorators(db),
self.generic_context(db),
Some(specialization),
)
impl<'db> From<FunctionLiteral<'db>> for Type<'db> {
fn from(literal: FunctionLiteral<'db>) -> Type<'db> {
Type::FunctionLiteral(FunctionType::FunctionLiteral(literal))
}
}
#[salsa::interned(debug)]
pub struct SpecializedFunction<'db> {
function: FunctionType<'db>,
specialization: Specialization<'db>,
}
#[salsa::tracked]
impl<'db> SpecializedFunction<'db> {
#[salsa::tracked]
fn signature(self, db: &'db dyn Db) -> FunctionSignature<'db> {
let mut signature = self.function(db).signature(db);
signature.apply_specialization(db, self.specialization(db));
signature
}
}
@ -6096,6 +6138,16 @@ pub struct FunctionWithInheritedGenericContext<'db> {
generic_context: GenericContext<'db>,
}
#[salsa::tracked]
impl<'db> FunctionWithInheritedGenericContext<'db> {
#[salsa::tracked]
fn signature(self, db: &'db dyn Db) -> FunctionSignature<'db> {
let mut signature = self.function(db).signature(db);
signature.set_generic_context(self.generic_context(db));
signature
}
}
/// Non-exhaustive enumeration of known functions (e.g. `builtins.reveal_type`, ...) that might
/// have special behavior.
#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash, strum_macros::EnumString)]
@ -6308,9 +6360,11 @@ impl<'db> CallableType<'db> {
fn apply_specialization(self, db: &'db dyn Db, specialization: Specialization<'db>) -> Self {
CallableType::from_overloads(
db,
self.signatures(db)
.iter()
.map(|signature| signature.apply_specialization(db, specialization)),
self.signatures(db).iter().map(|signature| {
let mut signature = signature.clone();
signature.apply_specialization(db, specialization);
signature
}),
)
}

View file

@ -219,7 +219,7 @@ impl<'db> Bindings<'db> {
match binding_type {
Type::MethodWrapper(MethodWrapperKind::FunctionTypeDunderGet(function)) => {
let function_literal = function.function_literal();
let function_literal = function.function_literal(db);
if function_literal.has_known_decorator(db, FunctionDecorators::CLASSMETHOD) {
match overload.parameter_types() {
[_, Some(owner)] => {
@ -251,7 +251,7 @@ impl<'db> Bindings<'db> {
if let [Some(function_ty @ Type::FunctionLiteral(function)), ..] =
overload.parameter_types()
{
let function_literal = function.function_literal();
let function_literal = function.function_literal(db);
if function_literal.has_known_decorator(db, FunctionDecorators::CLASSMETHOD)
{
match overload.parameter_types() {
@ -301,7 +301,7 @@ impl<'db> Bindings<'db> {
if property.getter(db).is_some_and(|getter| {
getter
.into_function_literal()
.is_some_and(|f| f.function_literal().name(db) == "__name__")
.is_some_and(|f| f.function_literal(db).name(db) == "__name__")
}) =>
{
overload.set_return_type(Type::string_literal(db, type_alias.name(db)));
@ -310,7 +310,7 @@ impl<'db> Bindings<'db> {
if property.getter(db).is_some_and(|getter| {
getter
.into_function_literal()
.is_some_and(|f| f.function_literal().name(db) == "__name__")
.is_some_and(|f| f.function_literal(db).name(db) == "__name__")
}) =>
{
overload.set_return_type(Type::string_literal(db, type_var.name(db)));
@ -421,7 +421,7 @@ impl<'db> Bindings<'db> {
{
match bound_method
.function(db)
.function_literal()
.function_literal(db)
.name(db)
.as_str()
{
@ -465,7 +465,7 @@ impl<'db> Bindings<'db> {
}
Type::FunctionLiteral(function_type) => match function_type
.function_literal()
.function_literal(db)
.known(db)
{
Some(KnownFunction::IsEquivalentTo) => {
@ -1177,7 +1177,7 @@ impl<'db> CallableDescription<'db> {
match callable_type {
Type::FunctionLiteral(function) => Some(CallableDescription {
kind: "function",
name: function.function_literal().name(db),
name: function.function_literal(db).name(db),
}),
Type::ClassLiteral(class_type) => Some(CallableDescription {
kind: "class",
@ -1185,12 +1185,12 @@ impl<'db> CallableDescription<'db> {
}),
Type::BoundMethod(bound_method) => Some(CallableDescription {
kind: "bound method",
name: bound_method.function(db).function_literal().name(db),
name: bound_method.function(db).function_literal(db).name(db),
}),
Type::MethodWrapper(MethodWrapperKind::FunctionTypeDunderGet(function)) => {
Some(CallableDescription {
kind: "method wrapper `__get__` of function",
name: function.function_literal().name(db),
name: function.function_literal(db).name(db),
})
}
Type::MethodWrapper(MethodWrapperKind::PropertyDunderGet(_)) => {
@ -1315,7 +1315,7 @@ impl<'db> BindingError<'db> {
) -> Option<(Span, Span)> {
match callable_ty {
Type::FunctionLiteral(function) => {
let function_scope = function.function_literal().body_scope(db);
let function_scope = function.function_literal(db).body_scope(db);
let span = Span::from(function_scope.file(db));
let node = function_scope.node(db);
if let Some(func_def) = node.as_function() {

View file

@ -610,11 +610,7 @@ impl<'db> ClassLiteralType<'db> {
self.decorators(db)
.iter()
.filter_map(|deco| deco.into_function_literal())
.any(|decorator| {
decorator
.function_literal()
.is_known(db, KnownFunction::Final)
})
.any(|decorator| decorator.is_known(db, KnownFunction::Final))
}
/// Attempt to resolve the [method resolution order] ("MRO") for this class.
@ -956,7 +952,7 @@ impl<'db> ClassLiteralType<'db> {
"__new__" | "__init__",
) => Type::FunctionLiteral(
function
.function_literal()
.function_literal(db)
.with_generic_context(db, origin.generic_context(db)),
),
_ => ty,

View file

@ -170,7 +170,7 @@ impl<'db> InferContext<'db> {
// Iterate over all functions and test if any is decorated with `@no_type_check`.
function_scope_tys.any(|function_ty| {
function_ty
.function_literal()
.function_literal(self.db)
.has_known_decorator(self.db, FunctionDecorators::NO_TYPE_CHECK)
})
}

View file

@ -1096,7 +1096,7 @@ fn report_invalid_assignment_with_message(
Type::FunctionLiteral(function) => {
context.report_lint_old(&INVALID_ASSIGNMENT, node, format_args!(
"Implicit shadowing of function `{}`; annotate to make it explicit if this is intentional",
function.name(context.db())));
function.function_literal(context.db()).name(context.db())));
}
_ => {
context.report_lint_old(&INVALID_ASSIGNMENT, node, message);

View file

@ -10,7 +10,7 @@ use crate::types::class::{ClassType, GenericAlias, GenericClass};
use crate::types::generics::{GenericContext, Specialization};
use crate::types::signatures::{Parameter, Parameters, Signature};
use crate::types::{
FunctionSignature, InstanceType, IntersectionType, KnownClass, MethodWrapperKind,
FunctionSignature, FunctionType, InstanceType, IntersectionType, KnownClass, MethodWrapperKind,
StringLiteralType, SubclassOfInner, Type, TypeVarBoundOrConstraints, TypeVarInstance,
UnionType, WrapperDescriptorKind,
};
@ -108,7 +108,7 @@ impl Display for DisplayRepresentation<'_> {
f,
// "def {name}{specialization}{signature}",
"def {name}{signature}",
name = function.name(self.db),
name = function.function_literal(self.db).name(self.db),
signature = signature.display(self.db)
)
}
@ -135,7 +135,7 @@ impl Display for DisplayRepresentation<'_> {
write!(
f,
"bound method {instance}.{method}{signature}",
method = function.name(self.db),
method = function.function_literal(self.db).name(self.db),
instance = bound_method.self_instance(self.db).display(self.db),
signature = signature.bind_self().display(self.db)
)
@ -155,10 +155,12 @@ impl Display for DisplayRepresentation<'_> {
write!(
f,
"<method-wrapper `__get__` of `{function}{specialization}`>",
function = function.name(self.db),
specialization = if let Some(specialization) = function.specialization(self.db)
{
specialization.display_short(self.db).to_string()
function = function.function_literal(self.db).name(self.db),
specialization = if let FunctionType::Specialized(specialized) = function {
specialized
.specialization(self.db)
.display_short(self.db)
.to_string()
} else {
String::new()
},
@ -168,10 +170,12 @@ impl Display for DisplayRepresentation<'_> {
write!(
f,
"<method-wrapper `__call__` of `{function}{specialization}`>",
function = function.name(self.db),
specialization = if let Some(specialization) = function.specialization(self.db)
{
specialization.display_short(self.db).to_string()
function = function.function_literal(self.db).name(self.db),
specialization = if let FunctionType::Specialized(specialized) = function {
specialized
.specialization(self.db)
.display_short(self.db)
.to_string()
} else {
String::new()
},

View file

@ -82,7 +82,7 @@ use crate::types::mro::MroErrorKind;
use crate::types::unpacker::{UnpackResult, Unpacker};
use crate::types::{
todo_type, CallDunderError, CallableSignature, CallableType, Class, ClassLiteralType,
ClassType, DataclassMetadata, DynamicType, FunctionDecorators, FunctionType, GenericAlias,
ClassType, DataclassMetadata, DynamicType, FunctionDecorators, FunctionLiteral, GenericAlias,
GenericClass, IntersectionBuilder, IntersectionType, KnownClass, KnownFunction,
KnownInstanceType, MemberLookupPolicy, MetaclassCandidate, NonGenericClass, Parameter,
ParameterForm, Parameters, Signature, Signatures, SliceLiteralType, StringLiteralType,
@ -1491,10 +1491,6 @@ impl<'db> TypeInferenceBuilder<'db> {
}
}
let generic_context = type_params.as_ref().map(|type_params| {
GenericContext::from_type_params(self.db(), self.index, type_params)
});
let function_kind =
KnownFunction::try_from_definition_and_name(self.db(), definition, name);
@ -1503,16 +1499,19 @@ impl<'db> TypeInferenceBuilder<'db> {
.node_scope(NodeWithScopeRef::Function(function))
.to_scope_id(self.db(), self.file());
let specialization = None;
let type_params_scope = type_params.as_ref().map(|_| {
self.index
.node_scope(NodeWithScopeRef::FunctionTypeParameters(function))
.to_scope_id(self.db(), self.file())
});
let mut inferred_ty = Type::FunctionLiteral(FunctionType::new(
let mut inferred_ty = Type::from(FunctionLiteral::new(
self.db(),
&name.id,
function_kind,
body_scope,
type_params_scope,
function_decorators,
generic_context,
specialization,
));
for (decorator_ty, decorator_node) in decorator_types_and_nodes.iter().rev() {

View file

@ -289,17 +289,18 @@ impl<'db> Signature<'db> {
}
pub(crate) fn apply_specialization(
&self,
&mut self,
db: &'db dyn Db,
specialization: Specialization<'db>,
) -> Self {
Self {
generic_context: self.generic_context,
parameters: self.parameters.apply_specialization(db, specialization),
return_ty: self
.return_ty
.map(|ty| ty.apply_specialization(db, specialization)),
}
) {
self.parameters.apply_specialization(db, specialization);
self.return_ty
.as_mut()
.map(|ty| *ty = ty.apply_specialization(db, specialization));
}
pub(crate) fn set_generic_context(&mut self, generic_context: GenericContext<'db>) {
self.generic_context = Some(generic_context);
}
/// Return the parameters in this signature.
@ -1000,15 +1001,10 @@ impl<'db> Parameters<'db> {
)
}
fn apply_specialization(&self, db: &'db dyn Db, specialization: Specialization<'db>) -> Self {
Self {
value: self
.value
.iter()
.map(|param| param.apply_specialization(db, specialization))
.collect(),
is_gradual: self.is_gradual,
}
fn apply_specialization(&mut self, db: &'db dyn Db, specialization: Specialization<'db>) {
self.value
.iter_mut()
.for_each(|param| param.apply_specialization(db, specialization));
}
pub(crate) fn len(&self) -> usize {
@ -1172,14 +1168,11 @@ impl<'db> Parameter<'db> {
self
}
fn apply_specialization(&self, db: &'db dyn Db, specialization: Specialization<'db>) -> Self {
Self {
annotated_type: self
.annotated_type
.map(|ty| ty.apply_specialization(db, specialization)),
kind: self.kind.apply_specialization(db, specialization),
form: self.form,
}
fn apply_specialization(&mut self, db: &'db dyn Db, specialization: Specialization<'db>) {
self.annotated_type
.as_mut()
.map(|ty| *ty = ty.apply_specialization(db, specialization));
self.kind.apply_specialization(db, specialization);
}
/// Strip information from the parameter so that two equivalent parameters compare equal.
@ -1369,27 +1362,16 @@ pub(crate) enum ParameterKind<'db> {
}
impl<'db> ParameterKind<'db> {
fn apply_specialization(&self, db: &'db dyn Db, specialization: Specialization<'db>) -> Self {
fn apply_specialization(&mut self, db: &'db dyn Db, specialization: Specialization<'db>) {
match self {
Self::PositionalOnly { default_type, name } => Self::PositionalOnly {
default_type: default_type
.as_ref()
.map(|ty| ty.apply_specialization(db, specialization)),
name: name.clone(),
},
Self::PositionalOrKeyword { default_type, name } => Self::PositionalOrKeyword {
default_type: default_type
.as_ref()
.map(|ty| ty.apply_specialization(db, specialization)),
name: name.clone(),
},
Self::KeywordOnly { default_type, name } => Self::KeywordOnly {
default_type: default_type
.as_ref()
.map(|ty| ty.apply_specialization(db, specialization)),
name: name.clone(),
},
Self::Variadic { .. } | Self::KeywordVariadic { .. } => self.clone(),
Self::PositionalOnly { default_type, .. }
| Self::PositionalOrKeyword { default_type, .. }
| Self::KeywordOnly { default_type, .. } => {
default_type
.as_mut()
.map(|ty| *ty = ty.apply_specialization(db, specialization));
}
Self::Variadic { .. } | Self::KeywordVariadic { .. } => {}
}
}
}
@ -1406,16 +1388,20 @@ mod tests {
use super::*;
use crate::db::tests::{setup_db, TestDb};
use crate::symbol::global_symbol;
use crate::types::{FunctionSignature, FunctionType, KnownClass};
use crate::types::{FunctionLiteral, FunctionSignature, FunctionType, KnownClass};
use ruff_db::system::DbWithWritableSystem as _;
#[track_caller]
fn get_function_f<'db>(db: &'db TestDb, file: &'static str) -> FunctionType<'db> {
fn get_function_f<'db>(db: &'db TestDb, file: &'static str) -> FunctionLiteral<'db> {
let module = ruff_db::files::system_path_to_file(db, file).unwrap();
global_symbol(db, module, "f")
let function = global_symbol(db, module, "f")
.symbol
.expect_type()
.expect_function_literal()
.expect_function_literal();
let FunctionType::FunctionLiteral(literal) = function else {
panic!("function should be a function literal");
};
literal
}
#[track_caller]
@ -1653,9 +1639,6 @@ mod tests {
let expected_sig = func.internal_signature(&db);
// With no decorators, internal and external signature are the same
assert_eq!(
func.signature(&db),
&FunctionSignature::Single(expected_sig)
);
assert_eq!(func.signature(&db), FunctionSignature::Single(expected_sig));
}
}