Rename Red Knot (#17820)

This commit is contained in:
Micha Reiser 2025-05-03 19:49:15 +02:00 committed by GitHub
parent e6a798b962
commit b51c4f82ea
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
1564 changed files with 1598 additions and 1578 deletions

View file

@ -0,0 +1,202 @@
use std::hash::Hash;
use std::ops::Deref;
use ruff_db::parsed::ParsedModule;
/// Ref-counted owned reference to an AST node.
///
/// The type holds an owned reference to the node's ref-counted [`ParsedModule`].
/// Holding on to the node's [`ParsedModule`] guarantees that the reference to the
/// node must still be valid.
///
/// Holding on to any [`AstNodeRef`] prevents the [`ParsedModule`] from being released.
///
/// ## Equality
/// Two `AstNodeRef` are considered equal if their pointer addresses are equal.
///
/// ## Usage in salsa tracked structs
/// It's important that [`AstNodeRef`] fields in salsa tracked structs are tracked fields
/// (attributed with `#[tracked`]). It prevents that the tracked struct gets a new ID
/// every time the AST changes, which in turn, invalidates the result of any query
/// that takes said tracked struct as a query argument or returns the tracked struct as part of its result.
///
/// For example, marking the [`AstNodeRef`] as tracked on `Expression`
/// has the effect that salsa will consider the expression as "unchanged" for as long as it:
///
/// * belongs to the same file
/// * belongs to the same scope
/// * has the same kind
/// * was created in the same order
///
/// This means that changes to expressions in other scopes don't invalidate the expression's id, giving
/// us some form of scope-stable identity for expressions. Only queries accessing the node field
/// run on every AST change. All other queries only run when the expression's identity changes.
#[derive(Clone)]
pub struct AstNodeRef<T> {
/// Owned reference to the node's [`ParsedModule`].
///
/// The node's reference is guaranteed to remain valid as long as it's enclosing
/// [`ParsedModule`] is alive.
parsed: ParsedModule,
/// Pointer to the referenced node.
node: std::ptr::NonNull<T>,
}
#[allow(unsafe_code)]
impl<T> AstNodeRef<T> {
/// Creates a new `AstNodeRef` that references `node`. The `parsed` is the [`ParsedModule`] to
/// which the `AstNodeRef` belongs.
///
/// ## Safety
///
/// Dereferencing the `node` can result in undefined behavior if `parsed` isn't the
/// [`ParsedModule`] to which `node` belongs. It's the caller's responsibility to ensure that
/// the invariant `node belongs to parsed` is upheld.
pub(super) unsafe fn new(parsed: ParsedModule, node: &T) -> Self {
Self {
parsed,
node: std::ptr::NonNull::from(node),
}
}
/// Returns a reference to the wrapped node.
pub const fn node(&self) -> &T {
// SAFETY: Holding on to `parsed` ensures that the AST to which `node` belongs is still
// alive and not moved.
unsafe { self.node.as_ref() }
}
}
impl<T> Deref for AstNodeRef<T> {
type Target = T;
fn deref(&self) -> &Self::Target {
self.node()
}
}
impl<T> std::fmt::Debug for AstNodeRef<T>
where
T: std::fmt::Debug,
{
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_tuple("AstNodeRef").field(&self.node()).finish()
}
}
impl<T> PartialEq for AstNodeRef<T>
where
T: PartialEq,
{
fn eq(&self, other: &Self) -> bool {
if self.parsed == other.parsed {
// Comparing the pointer addresses is sufficient to determine equality
// if the parsed are the same.
self.node.eq(&other.node)
} else {
// Otherwise perform a deep comparison.
self.node().eq(other.node())
}
}
}
impl<T> Eq for AstNodeRef<T> where T: Eq {}
impl<T> Hash for AstNodeRef<T>
where
T: Hash,
{
fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
self.node().hash(state);
}
}
#[allow(unsafe_code)]
unsafe impl<T> salsa::Update for AstNodeRef<T> {
unsafe fn maybe_update(old_pointer: *mut Self, new_value: Self) -> bool {
let old_ref = &mut (*old_pointer);
if old_ref.parsed == new_value.parsed && old_ref.node.eq(&new_value.node) {
false
} else {
*old_ref = new_value;
true
}
}
}
#[allow(unsafe_code)]
unsafe impl<T> Send for AstNodeRef<T> where T: Send {}
#[allow(unsafe_code)]
unsafe impl<T> Sync for AstNodeRef<T> where T: Sync {}
#[cfg(test)]
mod tests {
use crate::ast_node_ref::AstNodeRef;
use ruff_db::parsed::ParsedModule;
use ruff_python_ast::PySourceType;
use ruff_python_parser::parse_unchecked_source;
#[test]
#[allow(unsafe_code)]
fn equality() {
let parsed_raw = parse_unchecked_source("1 + 2", PySourceType::Python);
let parsed = ParsedModule::new(parsed_raw.clone());
let stmt = &parsed.syntax().body[0];
let node1 = unsafe { AstNodeRef::new(parsed.clone(), stmt) };
let node2 = unsafe { AstNodeRef::new(parsed.clone(), stmt) };
assert_eq!(node1, node2);
// Compare from different trees
let cloned = ParsedModule::new(parsed_raw);
let stmt_cloned = &cloned.syntax().body[0];
let cloned_node = unsafe { AstNodeRef::new(cloned.clone(), stmt_cloned) };
assert_eq!(node1, cloned_node);
let other_raw = parse_unchecked_source("2 + 2", PySourceType::Python);
let other = ParsedModule::new(other_raw);
let other_stmt = &other.syntax().body[0];
let other_node = unsafe { AstNodeRef::new(other.clone(), other_stmt) };
assert_ne!(node1, other_node);
}
#[allow(unsafe_code)]
#[test]
fn inequality() {
let parsed_raw = parse_unchecked_source("1 + 2", PySourceType::Python);
let parsed = ParsedModule::new(parsed_raw);
let stmt = &parsed.syntax().body[0];
let node = unsafe { AstNodeRef::new(parsed.clone(), stmt) };
let other_raw = parse_unchecked_source("2 + 2", PySourceType::Python);
let other = ParsedModule::new(other_raw);
let other_stmt = &other.syntax().body[0];
let other_node = unsafe { AstNodeRef::new(other.clone(), other_stmt) };
assert_ne!(node, other_node);
}
#[test]
#[allow(unsafe_code)]
fn debug() {
let parsed_raw = parse_unchecked_source("1 + 2", PySourceType::Python);
let parsed = ParsedModule::new(parsed_raw);
let stmt = &parsed.syntax().body[0];
let stmt_node = unsafe { AstNodeRef::new(parsed.clone(), stmt) };
let debug = format!("{stmt_node:?}");
assert_eq!(debug, format!("AstNodeRef({stmt:?})"));
}
}

View file

@ -0,0 +1,199 @@
use std::sync::Arc;
use crate::lint::{LintRegistry, RuleSelection};
use ruff_db::files::File;
use ruff_db::{Db as SourceDb, Upcast};
/// Database giving access to semantic information about a Python program.
#[salsa::db]
pub trait Db: SourceDb + Upcast<dyn SourceDb> {
fn is_file_open(&self, file: File) -> bool;
fn rule_selection(&self) -> Arc<RuleSelection>;
fn lint_registry(&self) -> &LintRegistry;
}
#[cfg(test)]
pub(crate) mod tests {
use std::sync::Arc;
use crate::program::{Program, SearchPathSettings};
use crate::{default_lint_registry, ProgramSettings, PythonPlatform};
use super::Db;
use crate::lint::{LintRegistry, RuleSelection};
use anyhow::Context;
use ruff_db::files::{File, Files};
use ruff_db::system::{
DbWithTestSystem, DbWithWritableSystem as _, System, SystemPath, SystemPathBuf, TestSystem,
};
use ruff_db::vendored::VendoredFileSystem;
use ruff_db::{Db as SourceDb, Upcast};
use ruff_python_ast::PythonVersion;
#[salsa::db]
#[derive(Clone)]
pub(crate) struct TestDb {
storage: salsa::Storage<Self>,
files: Files,
system: TestSystem,
vendored: VendoredFileSystem,
events: Arc<std::sync::Mutex<Vec<salsa::Event>>>,
rule_selection: Arc<RuleSelection>,
}
impl TestDb {
pub(crate) fn new() -> Self {
Self {
storage: salsa::Storage::default(),
system: TestSystem::default(),
vendored: ty_vendored::file_system().clone(),
events: Arc::default(),
files: Files::default(),
rule_selection: Arc::new(RuleSelection::from_registry(default_lint_registry())),
}
}
/// Takes the salsa events.
///
/// ## Panics
/// If there are any pending salsa snapshots.
pub(crate) fn take_salsa_events(&mut self) -> Vec<salsa::Event> {
let inner = Arc::get_mut(&mut self.events).expect("no pending salsa snapshots");
let events = inner.get_mut().unwrap();
std::mem::take(&mut *events)
}
/// Clears the salsa events.
///
/// ## Panics
/// If there are any pending salsa snapshots.
pub(crate) fn clear_salsa_events(&mut self) {
self.take_salsa_events();
}
}
impl DbWithTestSystem for TestDb {
fn test_system(&self) -> &TestSystem {
&self.system
}
fn test_system_mut(&mut self) -> &mut TestSystem {
&mut self.system
}
}
#[salsa::db]
impl SourceDb for TestDb {
fn vendored(&self) -> &VendoredFileSystem {
&self.vendored
}
fn system(&self) -> &dyn System {
&self.system
}
fn files(&self) -> &Files {
&self.files
}
fn python_version(&self) -> PythonVersion {
Program::get(self).python_version(self)
}
}
impl Upcast<dyn SourceDb> for TestDb {
fn upcast(&self) -> &(dyn SourceDb + 'static) {
self
}
fn upcast_mut(&mut self) -> &mut (dyn SourceDb + 'static) {
self
}
}
#[salsa::db]
impl Db for TestDb {
fn is_file_open(&self, file: File) -> bool {
!file.path(self).is_vendored_path()
}
fn rule_selection(&self) -> Arc<RuleSelection> {
self.rule_selection.clone()
}
fn lint_registry(&self) -> &LintRegistry {
default_lint_registry()
}
}
#[salsa::db]
impl salsa::Database for TestDb {
fn salsa_event(&self, event: &dyn Fn() -> salsa::Event) {
let event = event();
tracing::trace!("event: {event:?}");
let mut events = self.events.lock().unwrap();
events.push(event);
}
}
pub(crate) struct TestDbBuilder<'a> {
/// Target Python version
python_version: PythonVersion,
/// Target Python platform
python_platform: PythonPlatform,
/// Path and content pairs for files that should be present
files: Vec<(&'a str, &'a str)>,
}
impl<'a> TestDbBuilder<'a> {
pub(crate) fn new() -> Self {
Self {
python_version: PythonVersion::default(),
python_platform: PythonPlatform::default(),
files: vec![],
}
}
pub(crate) fn with_python_version(mut self, version: PythonVersion) -> Self {
self.python_version = version;
self
}
pub(crate) fn with_file(
mut self,
path: &'a (impl AsRef<SystemPath> + ?Sized),
content: &'a str,
) -> Self {
self.files.push((path.as_ref().as_str(), content));
self
}
pub(crate) fn build(self) -> anyhow::Result<TestDb> {
let mut db = TestDb::new();
let src_root = SystemPathBuf::from("/src");
db.memory_file_system().create_directory_all(&src_root)?;
db.write_files(self.files)
.context("Failed to write test files")?;
Program::from_settings(
&db,
ProgramSettings {
python_version: self.python_version,
python_platform: self.python_platform,
search_paths: SearchPathSettings::new(vec![src_root]),
},
)
.context("Failed to configure Program settings")?;
Ok(db)
}
}
pub(crate) fn setup_db() -> TestDb {
TestDbBuilder::new().build().expect("valid TestDb setup")
}
}

View file

@ -0,0 +1,52 @@
use std::hash::BuildHasherDefault;
use rustc_hash::FxHasher;
use crate::lint::{LintRegistry, LintRegistryBuilder};
use crate::suppression::{INVALID_IGNORE_COMMENT, UNKNOWN_RULE, UNUSED_IGNORE_COMMENT};
pub use db::Db;
pub use module_name::ModuleName;
pub use module_resolver::{resolve_module, system_module_search_paths, KnownModule, Module};
pub use program::{Program, ProgramSettings, PythonPath, SearchPathSettings};
pub use python_platform::PythonPlatform;
pub use semantic_model::{HasType, SemanticModel};
pub use site_packages::SysPrefixPathOrigin;
pub mod ast_node_ref;
mod db;
pub mod lint;
pub(crate) mod list;
mod module_name;
mod module_resolver;
mod node_key;
mod program;
mod python_platform;
pub mod semantic_index;
mod semantic_model;
pub(crate) mod site_packages;
mod suppression;
pub(crate) mod symbol;
pub mod types;
mod unpack;
mod util;
type FxOrderSet<V> = ordermap::set::OrderSet<V, BuildHasherDefault<FxHasher>>;
/// Returns the default registry with all known semantic lints.
pub fn default_lint_registry() -> &'static LintRegistry {
static REGISTRY: std::sync::LazyLock<LintRegistry> = std::sync::LazyLock::new(|| {
let mut registry = LintRegistryBuilder::default();
register_lints(&mut registry);
registry.build()
});
&REGISTRY
}
/// Register all known semantic lints.
pub fn register_lints(registry: &mut LintRegistryBuilder) {
types::register_lints(registry);
registry.register_lint(&UNUSED_IGNORE_COMMENT);
registry.register_lint(&UNKNOWN_RULE);
registry.register_lint(&INVALID_IGNORE_COMMENT);
}

View file

@ -0,0 +1,537 @@
use core::fmt;
use itertools::Itertools;
use ruff_db::diagnostic::{DiagnosticId, LintName, Severity};
use rustc_hash::FxHashMap;
use std::fmt::Formatter;
use std::hash::Hasher;
use thiserror::Error;
#[derive(Debug, Clone)]
pub struct LintMetadata {
/// The unique identifier for the lint.
pub name: LintName,
/// A one-sentence summary of what the lint catches.
pub summary: &'static str,
/// An in depth explanation of the lint in markdown. Covers what the lint does, why it's bad and possible fixes.
///
/// The documentation may require post-processing to be rendered correctly. For example, lines
/// might have leading or trailing whitespace that should be removed.
pub raw_documentation: &'static str,
/// The default level of the lint if the user doesn't specify one.
pub default_level: Level,
pub status: LintStatus,
/// The source file in which the lint is declared.
pub file: &'static str,
/// The 1-based line number in the source `file` where the lint is declared.
pub line: u32,
}
#[derive(Copy, Clone, Debug, Eq, PartialEq, Ord, PartialOrd)]
#[cfg_attr(
feature = "serde",
derive(serde::Serialize, serde::Deserialize),
serde(rename_all = "kebab-case")
)]
#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
pub enum Level {
/// # Ignore
///
/// The lint is disabled and should not run.
Ignore,
/// # Warn
///
/// The lint is enabled and diagnostic should have a warning severity.
Warn,
/// # Error
///
/// The lint is enabled and diagnostics have an error severity.
Error,
}
impl Level {
pub const fn is_error(self) -> bool {
matches!(self, Level::Error)
}
pub const fn is_warn(self) -> bool {
matches!(self, Level::Warn)
}
pub const fn is_ignore(self) -> bool {
matches!(self, Level::Ignore)
}
}
impl fmt::Display for Level {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
match self {
Level::Ignore => f.write_str("ignore"),
Level::Warn => f.write_str("warn"),
Level::Error => f.write_str("error"),
}
}
}
impl TryFrom<Level> for Severity {
type Error = ();
fn try_from(level: Level) -> Result<Self, ()> {
match level {
Level::Ignore => Err(()),
Level::Warn => Ok(Severity::Warning),
Level::Error => Ok(Severity::Error),
}
}
}
impl LintMetadata {
pub fn name(&self) -> LintName {
self.name
}
pub fn summary(&self) -> &str {
self.summary
}
/// Returns the documentation line by line with one leading space and all trailing whitespace removed.
pub fn documentation_lines(&self) -> impl Iterator<Item = &str> {
self.raw_documentation.lines().map(|line| {
line.strip_prefix(char::is_whitespace)
.unwrap_or(line)
.trim_end()
})
}
/// Returns the documentation as a single string.
pub fn documentation(&self) -> String {
self.documentation_lines().join("\n")
}
pub fn default_level(&self) -> Level {
self.default_level
}
pub fn status(&self) -> &LintStatus {
&self.status
}
pub fn file(&self) -> &str {
self.file
}
pub fn line(&self) -> u32 {
self.line
}
}
#[doc(hidden)]
pub const fn lint_metadata_defaults() -> LintMetadata {
LintMetadata {
name: LintName::of(""),
summary: "",
raw_documentation: "",
default_level: Level::Error,
status: LintStatus::preview("0.0.0"),
file: "",
line: 1,
}
}
#[derive(Copy, Clone, Debug)]
pub enum LintStatus {
/// The lint has been added to the linter, but is not yet stable.
Preview {
/// The version in which the lint was added.
since: &'static str,
},
/// The lint is stable.
Stable {
/// The version in which the lint was stabilized.
since: &'static str,
},
/// The lint is deprecated and no longer recommended for use.
Deprecated {
/// The version in which the lint was deprecated.
since: &'static str,
/// The reason why the lint has been deprecated.
///
/// This should explain why the lint has been deprecated and if there's a replacement lint that users
/// can use instead.
reason: &'static str,
},
/// The lint has been removed and can no longer be used.
Removed {
/// The version in which the lint was removed.
since: &'static str,
/// The reason why the lint has been removed.
reason: &'static str,
},
}
impl LintStatus {
pub const fn preview(since: &'static str) -> Self {
LintStatus::Preview { since }
}
pub const fn stable(since: &'static str) -> Self {
LintStatus::Stable { since }
}
pub const fn deprecated(since: &'static str, reason: &'static str) -> Self {
LintStatus::Deprecated { since, reason }
}
pub const fn removed(since: &'static str, reason: &'static str) -> Self {
LintStatus::Removed { since, reason }
}
pub const fn is_removed(&self) -> bool {
matches!(self, LintStatus::Removed { .. })
}
pub const fn is_deprecated(&self) -> bool {
matches!(self, LintStatus::Deprecated { .. })
}
}
/// Declares a lint rule with the given metadata.
///
/// ```rust
/// use ty_python_semantic::declare_lint;
/// use ty_python_semantic::lint::{LintStatus, Level};
///
/// declare_lint! {
/// /// ## What it does
/// /// Checks for references to names that are not defined.
/// ///
/// /// ## Why is this bad?
/// /// Using an undefined variable will raise a `NameError` at runtime.
/// ///
/// /// ## Example
/// ///
/// /// ```python
/// /// print(x) # NameError: name 'x' is not defined
/// /// ```
/// pub(crate) static UNRESOLVED_REFERENCE = {
/// summary: "detects references to names that are not defined",
/// status: LintStatus::preview("1.0.0"),
/// default_level: Level::Warn,
/// }
/// }
/// ```
#[macro_export]
macro_rules! declare_lint {
(
$(#[doc = $doc:literal])+
$vis: vis static $name: ident = {
summary: $summary: literal,
status: $status: expr,
// Optional properties
$( $key:ident: $value:expr, )*
}
) => {
$( #[doc = $doc] )+
#[allow(clippy::needless_update)]
$vis static $name: $crate::lint::LintMetadata = $crate::lint::LintMetadata {
name: ruff_db::diagnostic::LintName::of(ruff_macros::kebab_case!($name)),
summary: $summary,
raw_documentation: concat!($($doc, '\n',)+),
status: $status,
file: file!(),
line: line!(),
$( $key: $value, )*
..$crate::lint::lint_metadata_defaults()
};
};
}
/// A unique identifier for a lint rule.
///
/// Implements `PartialEq`, `Eq`, and `Hash` based on the `LintMetadata` pointer
/// for fast comparison and lookup.
#[derive(Debug, Clone, Copy)]
pub struct LintId {
definition: &'static LintMetadata,
}
impl LintId {
pub const fn of(definition: &'static LintMetadata) -> Self {
LintId { definition }
}
}
impl PartialEq for LintId {
fn eq(&self, other: &Self) -> bool {
std::ptr::eq(self.definition, other.definition)
}
}
impl Eq for LintId {}
impl std::hash::Hash for LintId {
fn hash<H: Hasher>(&self, state: &mut H) {
std::ptr::hash(self.definition, state);
}
}
impl std::ops::Deref for LintId {
type Target = LintMetadata;
fn deref(&self) -> &Self::Target {
self.definition
}
}
#[derive(Default, Debug)]
pub struct LintRegistryBuilder {
/// Registered lints that haven't been removed.
lints: Vec<LintId>,
/// Lints indexed by name, including aliases and removed rules.
by_name: FxHashMap<&'static str, LintEntry>,
}
impl LintRegistryBuilder {
#[track_caller]
pub fn register_lint(&mut self, lint: &'static LintMetadata) {
assert_eq!(
self.by_name.insert(&*lint.name, lint.into()),
None,
"duplicate lint registration for '{name}'",
name = lint.name
);
if !lint.status.is_removed() {
self.lints.push(LintId::of(lint));
}
}
#[track_caller]
pub fn register_alias(&mut self, from: LintName, to: &'static LintMetadata) {
let target = match self.by_name.get(to.name.as_str()) {
Some(LintEntry::Lint(target) | LintEntry::Removed(target)) => target,
Some(LintEntry::Alias(target)) => {
panic!(
"lint alias {from} -> {to:?} points to another alias {target:?}",
target = target.name()
)
}
None => panic!(
"lint alias {from} -> {to} points to non-registered lint",
to = to.name
),
};
assert_eq!(
self.by_name
.insert(from.as_str(), LintEntry::Alias(*target)),
None,
"duplicate lint registration for '{from}'",
);
}
pub fn build(self) -> LintRegistry {
LintRegistry {
lints: self.lints,
by_name: self.by_name,
}
}
}
#[derive(Default, Debug, Clone)]
pub struct LintRegistry {
lints: Vec<LintId>,
by_name: FxHashMap<&'static str, LintEntry>,
}
impl LintRegistry {
/// Looks up a lint by its name.
pub fn get(&self, code: &str) -> Result<LintId, GetLintError> {
match self.by_name.get(code) {
Some(LintEntry::Lint(metadata)) => Ok(*metadata),
Some(LintEntry::Alias(lint)) => {
if lint.status.is_removed() {
Err(GetLintError::Removed(lint.name()))
} else {
Ok(*lint)
}
}
Some(LintEntry::Removed(lint)) => Err(GetLintError::Removed(lint.name())),
None => {
if let Some(without_prefix) = DiagnosticId::strip_category(code) {
if let Some(entry) = self.by_name.get(without_prefix) {
return Err(GetLintError::PrefixedWithCategory {
prefixed: code.to_string(),
suggestion: entry.id().name.to_string(),
});
}
}
Err(GetLintError::Unknown(code.to_string()))
}
}
}
/// Returns all registered, non-removed lints.
pub fn lints(&self) -> &[LintId] {
&self.lints
}
/// Returns an iterator over all known aliases and to their target lints.
///
/// This iterator includes aliases that point to removed lints.
pub fn aliases(&self) -> impl Iterator<Item = (LintName, LintId)> + '_ {
self.by_name.iter().filter_map(|(key, value)| {
if let LintEntry::Alias(alias) = value {
Some((LintName::of(key), *alias))
} else {
None
}
})
}
/// Iterates over all removed lints.
pub fn removed(&self) -> impl Iterator<Item = LintId> + '_ {
self.by_name.iter().filter_map(|(_, value)| {
if let LintEntry::Removed(metadata) = value {
Some(*metadata)
} else {
None
}
})
}
}
#[derive(Error, Debug, Clone, PartialEq, Eq)]
pub enum GetLintError {
/// The name maps to this removed lint.
#[error("lint `{0}` has been removed")]
Removed(LintName),
/// No lint with the given name is known.
#[error("unknown lint `{0}`")]
Unknown(String),
/// The name uses the full qualified diagnostic id `lint:<rule>` instead of just `rule`.
/// The String is the name without the `lint:` category prefix.
#[error("unknown lint `{prefixed}`. Did you mean `{suggestion}`?")]
PrefixedWithCategory {
prefixed: String,
suggestion: String,
},
}
#[derive(Debug, PartialEq, Eq, Clone, Copy)]
pub enum LintEntry {
/// An existing lint rule. Can be in preview, stable or deprecated.
Lint(LintId),
/// A lint rule that has been removed.
Removed(LintId),
Alias(LintId),
}
impl LintEntry {
fn id(self) -> LintId {
match self {
LintEntry::Lint(id) => id,
LintEntry::Removed(id) => id,
LintEntry::Alias(id) => id,
}
}
}
impl From<&'static LintMetadata> for LintEntry {
fn from(metadata: &'static LintMetadata) -> Self {
if metadata.status.is_removed() {
LintEntry::Removed(LintId::of(metadata))
} else {
LintEntry::Lint(LintId::of(metadata))
}
}
}
#[derive(Debug, Clone, Default, PartialEq, Eq)]
pub struct RuleSelection {
/// Map with the severity for each enabled lint rule.
///
/// If a rule isn't present in this map, then it should be considered disabled.
lints: FxHashMap<LintId, (Severity, LintSource)>,
}
impl RuleSelection {
/// Creates a new rule selection from all known lints in the registry that are enabled
/// according to their default severity.
pub fn from_registry(registry: &LintRegistry) -> Self {
let lints = registry
.lints()
.iter()
.filter_map(|lint| {
Severity::try_from(lint.default_level())
.ok()
.map(|severity| (*lint, (severity, LintSource::Default)))
})
.collect();
RuleSelection { lints }
}
/// Returns an iterator over all enabled lints.
pub fn enabled(&self) -> impl Iterator<Item = LintId> + '_ {
self.lints.keys().copied()
}
/// Returns an iterator over all enabled lints and their severity.
pub fn iter(&self) -> impl ExactSizeIterator<Item = (LintId, Severity)> + '_ {
self.lints
.iter()
.map(|(&lint, &(severity, _))| (lint, severity))
}
/// Returns the configured severity for the lint with the given id or `None` if the lint is disabled.
pub fn severity(&self, lint: LintId) -> Option<Severity> {
self.lints.get(&lint).map(|(severity, _)| *severity)
}
/// Returns `true` if the `lint` is enabled.
pub fn is_enabled(&self, lint: LintId) -> bool {
self.severity(lint).is_some()
}
/// Enables `lint` and configures with the given `severity`.
///
/// Overrides any previous configuration for the lint.
pub fn enable(&mut self, lint: LintId, severity: Severity, source: LintSource) {
self.lints.insert(lint, (severity, source));
}
/// Disables `lint` if it was previously enabled.
pub fn disable(&mut self, lint: LintId) {
self.lints.remove(&lint);
}
}
#[derive(Default, Copy, Clone, Debug, PartialEq, Eq)]
pub enum LintSource {
/// The user didn't enable the rule explicitly, instead it's enabled by default.
#[default]
Default,
/// The rule was enabled by using a CLI argument
Cli,
/// The rule was enabled in a configuration file.
File,
}

View file

@ -0,0 +1,745 @@
//! Sorted, arena-allocated association lists
//!
//! An [_association list_][alist], which is a linked list of key/value pairs. We additionally
//! guarantee that the elements of an association list are sorted (by their keys), and that they do
//! not contain any entries with duplicate keys.
//!
//! Association lists have fallen out of favor in recent decades, since you often need operations
//! that are inefficient on them. In particular, looking up a random element by index is O(n), just
//! like a linked list; and looking up an element by key is also O(n), since you must do a linear
//! scan of the list to find the matching element. The typical implementation also suffers from
//! poor cache locality and high memory allocation overhead, since individual list cells are
//! typically allocated separately from the heap. We solve that last problem by storing the cells
//! of an association list in an [`IndexVec`] arena.
//!
//! We exploit structural sharing where possible, reusing cells across multiple lists when we can.
//! That said, we don't guarantee that lists are canonical — it's entirely possible for two lists
//! with identical contents to use different list cells and have different identifiers.
//!
//! Given all of this, association lists have the following benefits:
//!
//! - Lists can be represented by a single 32-bit integer (the index into the arena of the head of
//! the list).
//! - Lists can be cloned in constant time, since the underlying cells are immutable.
//! - Lists can be combined quickly (for both intersection and union), especially when you already
//! have to zip through both input lists to combine each key's values in some way.
//!
//! There is one remaining caveat:
//!
//! - You should construct lists in key order; doing this lets you insert each value in constant time.
//! Inserting entries in reverse order results in _quadratic_ overall time to construct the list.
//!
//! Lists are created using a [`ListBuilder`], and once created are accessed via a [`ListStorage`].
//!
//! ## Tests
//!
//! This module contains quickcheck-based property tests.
//!
//! These tests are disabled by default, as they are non-deterministic and slow. You can run them
//! explicitly using:
//!
//! ```sh
//! cargo test -p ruff_index -- --ignored list::property_tests
//! ```
//!
//! The number of tests (default: 100) can be controlled by setting the `QUICKCHECK_TESTS`
//! environment variable. For example:
//!
//! ```sh
//! QUICKCHECK_TESTS=10000 cargo test …
//! ```
//!
//! If you want to run these tests for a longer period of time, it's advisable to run them in
//! release mode. As some tests are slower than others, it's advisable to run them in a loop until
//! they fail:
//!
//! ```sh
//! export QUICKCHECK_TESTS=100000
//! while cargo test --release -p ruff_index -- \
//! --ignored list::property_tests; do :; done
//! ```
//!
//! [alist]: https://en.wikipedia.org/wiki/Association_list
use std::cmp::Ordering;
use std::marker::PhantomData;
use std::ops::Deref;
use ruff_index::{newtype_index, IndexVec};
/// A handle to an association list. Use [`ListStorage`] to access its elements, and
/// [`ListBuilder`] to construct other lists based on this one.
#[derive(Clone, Copy, Debug, Eq, Ord, PartialEq, PartialOrd)]
pub(crate) struct List<K, V = ()> {
last: Option<ListCellId>,
_phantom: PhantomData<(K, V)>,
}
impl<K, V> List<K, V> {
pub(crate) const fn empty() -> List<K, V> {
List::new(None)
}
const fn new(last: Option<ListCellId>) -> List<K, V> {
List {
last,
_phantom: PhantomData,
}
}
}
impl<K, V> Default for List<K, V> {
fn default() -> Self {
List::empty()
}
}
#[newtype_index]
#[derive(PartialOrd, Ord)]
struct ListCellId;
/// Stores one or more association lists. This type provides read-only access to the lists. Use a
/// [`ListBuilder`] to create lists.
#[derive(Debug, Eq, PartialEq)]
pub(crate) struct ListStorage<K, V = ()> {
cells: IndexVec<ListCellId, ListCell<K, V>>,
}
/// Each association list is represented by a sequence of snoc cells. A snoc cell is like the more
/// familiar cons cell `(a : (b : (c : nil)))`, but in reverse `(((nil : a) : b) : c)`.
///
/// **Terminology**: The elements of a cons cell are usually called `head` and `tail` (assuming
/// you're not in Lisp-land, where they're called `car` and `cdr`). The elements of a snoc cell
/// are usually called `rest` and `last`.
#[derive(Debug, Eq, PartialEq)]
struct ListCell<K, V> {
rest: Option<ListCellId>,
key: K,
value: V,
}
/// Constructs one or more association lists.
#[derive(Debug, Eq, PartialEq)]
pub(crate) struct ListBuilder<K, V = ()> {
storage: ListStorage<K, V>,
/// Scratch space that lets us implement our list operations iteratively instead of
/// recursively.
///
/// The snoc-list representation that we use for alists is very common in functional
/// programming, and the simplest implementations of most of the operations are defined
/// recursively on that data structure. However, they are not _tail_ recursive, which means
/// that the call stack grows linearly with the size of the input, which can be a problem for
/// large lists.
///
/// You can often rework those recursive implementations into iterative ones using an
/// _accumulator_, but that comes at the cost of reversing the list. If we didn't care about
/// ordering, that wouldn't be a problem. Since we want our lists to be sorted, we can't rely
/// on that on its own.
///
/// The next standard trick is to use an accumulator, and use a fix-up step at the end to
/// reverse the (reversed) result in the accumulator, restoring the correct order.
///
/// So, that's what we do! However, as one last optimization, we don't build up alist cells in
/// our accumulator, since that would add wasteful cruft to our list storage. Instead, we use a
/// normal Vec as our accumulator, holding the key/value pairs that should be stitched onto the
/// end of whatever result list we are creating. For our fix-up step, we can consume a Vec in
/// reverse order by `pop`ping the elements off one by one.
scratch: Vec<(K, V)>,
}
impl<K, V> Default for ListBuilder<K, V> {
fn default() -> Self {
ListBuilder {
storage: ListStorage {
cells: IndexVec::default(),
},
scratch: Vec::default(),
}
}
}
impl<K, V> Deref for ListBuilder<K, V> {
type Target = ListStorage<K, V>;
fn deref(&self) -> &ListStorage<K, V> {
&self.storage
}
}
impl<K, V> ListBuilder<K, V> {
/// Finalizes a `ListBuilder`. After calling this, you cannot create any new lists managed by
/// this storage.
pub(crate) fn build(mut self) -> ListStorage<K, V> {
self.storage.cells.shrink_to_fit();
self.storage
}
/// Adds a new cell to the list.
///
/// Adding an element always returns a non-empty list, which means we could technically use `I`
/// as our return type, since we never return `None`. However, for consistency with our other
/// methods, we always use `Option<I>` as the return type for any method that can return a
/// list.
#[allow(clippy::unnecessary_wraps)]
fn add_cell(&mut self, rest: Option<ListCellId>, key: K, value: V) -> Option<ListCellId> {
Some(self.storage.cells.push(ListCell { rest, key, value }))
}
/// Returns an entry pointing at where `key` would be inserted into a list.
///
/// Note that when we add a new element to a list, we might have to clone the keys and values
/// of some existing elements. This is because list cells are immutable once created, since
/// they might be shared across multiple lists. We must therefore create new cells for every
/// element that appears after the new element.
///
/// That means that you should construct lists in key order, since that means that there are no
/// entries to duplicate for each insertion. If you construct the list in reverse order, we
/// will have to duplicate O(n) entries for each insertion, making it _quadratic_ to construct
/// the entire list.
pub(crate) fn entry(&mut self, list: List<K, V>, key: K) -> ListEntry<K, V>
where
K: Clone + Ord,
V: Clone,
{
self.scratch.clear();
// Iterate through the input list, looking for the position where the key should be
// inserted. We will need to create new list cells for any elements that appear after the
// new key. Stash those away in our scratch accumulator as we step through the input. The
// result of the loop is that "rest" of the result list, which we will stitch the new key
// (and any succeeding keys) onto.
let mut curr = list.last;
while let Some(curr_id) = curr {
let cell = &self.storage.cells[curr_id];
match key.cmp(&cell.key) {
// We found an existing entry in the input list with the desired key.
Ordering::Equal => {
return ListEntry {
builder: self,
list,
key,
rest: ListTail::Occupied(curr_id),
};
}
// The input list does not already contain this key, and this is where we should
// add it.
Ordering::Greater => {
return ListEntry {
builder: self,
list,
key,
rest: ListTail::Vacant(curr_id),
};
}
// If this key is in the list, it's further along. We'll need to create a new cell
// for this entry in the result list, so add its contents to the scratch
// accumulator.
Ordering::Less => {
let new_key = cell.key.clone();
let new_value = cell.value.clone();
self.scratch.push((new_key, new_value));
curr = cell.rest;
}
}
}
// We made it all the way through the list without finding the desired key, so it belongs
// at the beginning. (And we will unfortunately have to duplicate every existing cell if
// the caller proceeds with inserting the new key!)
ListEntry {
builder: self,
list,
key,
rest: ListTail::Beginning,
}
}
}
/// A view into a list, indicating where a key would be inserted.
pub(crate) struct ListEntry<'a, K, V = ()> {
builder: &'a mut ListBuilder<K, V>,
list: List<K, V>,
key: K,
/// Points at the element that already contains `key`, if there is one, or the element
/// immediately before where it would go, if not.
rest: ListTail<ListCellId>,
}
enum ListTail<I> {
/// The list does not already contain `key`, and it would go at the beginning of the list.
Beginning,
/// The list already contains `key`
Occupied(I),
/// The list does not already contain key, and it would go immediately after the given element
Vacant(I),
}
impl<K, V> ListEntry<'_, K, V>
where
K: Clone,
V: Clone,
{
fn stitch_up(self, rest: Option<ListCellId>, value: V) -> List<K, V> {
let mut last = rest;
last = self.builder.add_cell(last, self.key, value);
while let Some((key, value)) = self.builder.scratch.pop() {
last = self.builder.add_cell(last, key, value);
}
List::new(last)
}
/// Inserts a new key/value into the list if the key is not already present. If the list
/// already contains `key`, we return the original list as-is, and do not invoke your closure.
pub(crate) fn or_insert_with<F>(self, f: F) -> List<K, V>
where
F: FnOnce() -> V,
{
let rest = match self.rest {
// If the list already contains `key`, we don't need to replace anything, and can
// return the original list unmodified.
ListTail::Occupied(_) => return self.list,
// Otherwise we have to create a new entry and stitch it onto the list.
ListTail::Beginning => None,
ListTail::Vacant(index) => Some(index),
};
self.stitch_up(rest, f())
}
/// Inserts a new key and the default value into the list if the key is not already present. If
/// the list already contains `key`, we return the original list as-is.
pub(crate) fn or_insert_default(self) -> List<K, V>
where
V: Default,
{
self.or_insert_with(V::default)
}
}
impl<K, V> ListBuilder<K, V> {
/// Returns the intersection of two lists. The result will contain an entry for any key that
/// appears in both lists. The corresponding values will be combined using the `combine`
/// function that you provide.
#[allow(clippy::needless_pass_by_value)]
pub(crate) fn intersect_with<F>(
&mut self,
a: List<K, V>,
b: List<K, V>,
mut combine: F,
) -> List<K, V>
where
K: Clone + Ord,
V: Clone,
F: FnMut(&V, &V) -> V,
{
self.scratch.clear();
// Zip through the lists, building up the keys/values of the new entries into our scratch
// vector. Continue until we run out of elements in either list. (Any remaining elements in
// the other list cannot possibly be in the intersection.)
let mut a = a.last;
let mut b = b.last;
while let (Some(a_id), Some(b_id)) = (a, b) {
let a_cell = &self.storage.cells[a_id];
let b_cell = &self.storage.cells[b_id];
match a_cell.key.cmp(&b_cell.key) {
// Both lists contain this key; combine their values
Ordering::Equal => {
let new_key = a_cell.key.clone();
let new_value = combine(&a_cell.value, &b_cell.value);
self.scratch.push((new_key, new_value));
a = a_cell.rest;
b = b_cell.rest;
}
// a's key is only present in a, so it's not included in the result.
Ordering::Greater => a = a_cell.rest,
// b's key is only present in b, so it's not included in the result.
Ordering::Less => b = b_cell.rest,
}
}
// Once the iteration loop terminates, we stitch the new entries back together into proper
// alist cells.
let mut last = None;
while let Some((key, value)) = self.scratch.pop() {
last = self.add_cell(last, key, value);
}
List::new(last)
}
}
// ----
// Sets
impl<K> ListStorage<K, ()> {
/// Iterates through the elements in a set _in reverse order_.
#[allow(clippy::needless_pass_by_value)]
pub(crate) fn iter_set_reverse(&self, set: List<K, ()>) -> ListSetReverseIterator<K> {
ListSetReverseIterator {
storage: self,
curr: set.last,
}
}
}
pub(crate) struct ListSetReverseIterator<'a, K> {
storage: &'a ListStorage<K, ()>,
curr: Option<ListCellId>,
}
impl<'a, K> Iterator for ListSetReverseIterator<'a, K> {
type Item = &'a K;
fn next(&mut self) -> Option<Self::Item> {
let cell = &self.storage.cells[self.curr?];
self.curr = cell.rest;
Some(&cell.key)
}
}
impl<K> ListBuilder<K, ()> {
/// Adds an element to a set.
pub(crate) fn insert(&mut self, set: List<K, ()>, element: K) -> List<K, ()>
where
K: Clone + Ord,
{
self.entry(set, element).or_insert_default()
}
/// Returns the intersection of two sets. The result will contain any value that appears in
/// both sets.
pub(crate) fn intersect(&mut self, a: List<K, ()>, b: List<K, ()>) -> List<K, ()>
where
K: Clone + Ord,
{
self.intersect_with(a, b, |(), ()| ())
}
}
// -----
// Tests
#[cfg(test)]
mod tests {
use super::*;
use std::fmt::Display;
use std::fmt::Write;
// ----
// Sets
impl<K> ListStorage<K>
where
K: Display,
{
fn display_set(&self, list: List<K, ()>) -> String {
let elements: Vec<_> = self.iter_set_reverse(list).collect();
let mut result = String::new();
result.push('[');
for element in elements.into_iter().rev() {
if result.len() > 1 {
result.push_str(", ");
}
write!(&mut result, "{element}").unwrap();
}
result.push(']');
result
}
}
#[test]
fn can_insert_into_set() {
let mut builder = ListBuilder::<u16>::default();
// Build up the set in order
let empty = List::empty();
let set1 = builder.insert(empty, 1);
let set12 = builder.insert(set1, 2);
let set123 = builder.insert(set12, 3);
let set1232 = builder.insert(set123, 2);
assert_eq!(builder.display_set(empty), "[]");
assert_eq!(builder.display_set(set1), "[1]");
assert_eq!(builder.display_set(set12), "[1, 2]");
assert_eq!(builder.display_set(set123), "[1, 2, 3]");
assert_eq!(builder.display_set(set1232), "[1, 2, 3]");
// And in reverse order
let set3 = builder.insert(empty, 3);
let set32 = builder.insert(set3, 2);
let set321 = builder.insert(set32, 1);
let set3212 = builder.insert(set321, 2);
assert_eq!(builder.display_set(empty), "[]");
assert_eq!(builder.display_set(set3), "[3]");
assert_eq!(builder.display_set(set32), "[2, 3]");
assert_eq!(builder.display_set(set321), "[1, 2, 3]");
assert_eq!(builder.display_set(set3212), "[1, 2, 3]");
}
#[test]
fn can_intersect_sets() {
let mut builder = ListBuilder::<u16>::default();
let empty = List::empty();
let set1 = builder.insert(empty, 1);
let set12 = builder.insert(set1, 2);
let set123 = builder.insert(set12, 3);
let set1234 = builder.insert(set123, 4);
let set2 = builder.insert(empty, 2);
let set24 = builder.insert(set2, 4);
let set245 = builder.insert(set24, 5);
let set2457 = builder.insert(set245, 7);
let intersection = builder.intersect(empty, empty);
assert_eq!(builder.display_set(intersection), "[]");
let intersection = builder.intersect(empty, set1234);
assert_eq!(builder.display_set(intersection), "[]");
let intersection = builder.intersect(empty, set2457);
assert_eq!(builder.display_set(intersection), "[]");
let intersection = builder.intersect(set1, set1234);
assert_eq!(builder.display_set(intersection), "[1]");
let intersection = builder.intersect(set1, set2457);
assert_eq!(builder.display_set(intersection), "[]");
let intersection = builder.intersect(set2, set1234);
assert_eq!(builder.display_set(intersection), "[2]");
let intersection = builder.intersect(set2, set2457);
assert_eq!(builder.display_set(intersection), "[2]");
let intersection = builder.intersect(set1234, set2457);
assert_eq!(builder.display_set(intersection), "[2, 4]");
}
// ----
// Maps
impl<K, V> ListStorage<K, V> {
/// Iterates through the entries in a list _in reverse order by key_.
#[allow(clippy::needless_pass_by_value)]
pub(crate) fn iter_reverse(&self, list: List<K, V>) -> ListReverseIterator<'_, K, V> {
ListReverseIterator {
storage: self,
curr: list.last,
}
}
}
pub(crate) struct ListReverseIterator<'a, K, V> {
storage: &'a ListStorage<K, V>,
curr: Option<ListCellId>,
}
impl<'a, K, V> Iterator for ListReverseIterator<'a, K, V> {
type Item = (&'a K, &'a V);
fn next(&mut self) -> Option<Self::Item> {
let cell = &self.storage.cells[self.curr?];
self.curr = cell.rest;
Some((&cell.key, &cell.value))
}
}
impl<K, V> ListStorage<K, V>
where
K: Display,
V: Display,
{
fn display(&self, list: List<K, V>) -> String {
let entries: Vec<_> = self.iter_reverse(list).collect();
let mut result = String::new();
result.push('[');
for (key, value) in entries.into_iter().rev() {
if result.len() > 1 {
result.push_str(", ");
}
write!(&mut result, "{key}:{value}").unwrap();
}
result.push(']');
result
}
}
#[test]
fn can_insert_into_map() {
let mut builder = ListBuilder::<u16, u16>::default();
// Build up the map in order
let empty = List::empty();
let map1 = builder.entry(empty, 1).or_insert_with(|| 1);
let map12 = builder.entry(map1, 2).or_insert_with(|| 2);
let map123 = builder.entry(map12, 3).or_insert_with(|| 3);
let map1232 = builder.entry(map123, 2).or_insert_with(|| 4);
assert_eq!(builder.display(empty), "[]");
assert_eq!(builder.display(map1), "[1:1]");
assert_eq!(builder.display(map12), "[1:1, 2:2]");
assert_eq!(builder.display(map123), "[1:1, 2:2, 3:3]");
assert_eq!(builder.display(map1232), "[1:1, 2:2, 3:3]");
// And in reverse order
let map3 = builder.entry(empty, 3).or_insert_with(|| 3);
let map32 = builder.entry(map3, 2).or_insert_with(|| 2);
let map321 = builder.entry(map32, 1).or_insert_with(|| 1);
let map3212 = builder.entry(map321, 2).or_insert_with(|| 4);
assert_eq!(builder.display(empty), "[]");
assert_eq!(builder.display(map3), "[3:3]");
assert_eq!(builder.display(map32), "[2:2, 3:3]");
assert_eq!(builder.display(map321), "[1:1, 2:2, 3:3]");
assert_eq!(builder.display(map3212), "[1:1, 2:2, 3:3]");
}
#[test]
fn can_intersect_maps() {
let mut builder = ListBuilder::<u16, u16>::default();
let empty = List::empty();
let map1 = builder.entry(empty, 1).or_insert_with(|| 1);
let map12 = builder.entry(map1, 2).or_insert_with(|| 2);
let map123 = builder.entry(map12, 3).or_insert_with(|| 3);
let map1234 = builder.entry(map123, 4).or_insert_with(|| 4);
let map2 = builder.entry(empty, 2).or_insert_with(|| 20);
let map24 = builder.entry(map2, 4).or_insert_with(|| 40);
let map245 = builder.entry(map24, 5).or_insert_with(|| 50);
let map2457 = builder.entry(map245, 7).or_insert_with(|| 70);
let intersection = builder.intersect_with(empty, empty, |a, b| a + b);
assert_eq!(builder.display(intersection), "[]");
let intersection = builder.intersect_with(empty, map1234, |a, b| a + b);
assert_eq!(builder.display(intersection), "[]");
let intersection = builder.intersect_with(empty, map2457, |a, b| a + b);
assert_eq!(builder.display(intersection), "[]");
let intersection = builder.intersect_with(map1, map1234, |a, b| a + b);
assert_eq!(builder.display(intersection), "[1:2]");
let intersection = builder.intersect_with(map1, map2457, |a, b| a + b);
assert_eq!(builder.display(intersection), "[]");
let intersection = builder.intersect_with(map2, map1234, |a, b| a + b);
assert_eq!(builder.display(intersection), "[2:22]");
let intersection = builder.intersect_with(map2, map2457, |a, b| a + b);
assert_eq!(builder.display(intersection), "[2:40]");
let intersection = builder.intersect_with(map1234, map2457, |a, b| a + b);
assert_eq!(builder.display(intersection), "[2:22, 4:44]");
}
}
// --------------
// Property tests
#[cfg(test)]
mod property_tests {
use super::*;
use std::collections::{BTreeMap, BTreeSet};
impl<K> ListBuilder<K>
where
K: Clone + Ord,
{
fn set_from_elements<'a>(&mut self, elements: impl IntoIterator<Item = &'a K>) -> List<K>
where
K: 'a,
{
let mut set = List::empty();
for element in elements {
set = self.insert(set, element.clone());
}
set
}
}
// For most of the tests below, we use a vec as our input, instead of a HashSet or BTreeSet,
// since we want to test the behavior of adding duplicate elements to the set.
#[quickcheck_macros::quickcheck]
#[ignore]
#[allow(clippy::needless_pass_by_value)]
fn roundtrip_set_from_vec(elements: Vec<u16>) -> bool {
let mut builder = ListBuilder::default();
let set = builder.set_from_elements(&elements);
let expected: BTreeSet<_> = elements.iter().copied().collect();
let actual = builder.iter_set_reverse(set).copied();
actual.eq(expected.into_iter().rev())
}
#[quickcheck_macros::quickcheck]
#[ignore]
#[allow(clippy::needless_pass_by_value)]
fn roundtrip_set_intersection(a_elements: Vec<u16>, b_elements: Vec<u16>) -> bool {
let mut builder = ListBuilder::default();
let a = builder.set_from_elements(&a_elements);
let b = builder.set_from_elements(&b_elements);
let intersection = builder.intersect(a, b);
let a_set: BTreeSet<_> = a_elements.iter().copied().collect();
let b_set: BTreeSet<_> = b_elements.iter().copied().collect();
let expected: Vec<_> = a_set.intersection(&b_set).copied().collect();
let actual = builder.iter_set_reverse(intersection).copied();
actual.eq(expected.into_iter().rev())
}
impl<K, V> ListBuilder<K, V>
where
K: Clone + Ord,
V: Clone + Eq,
{
fn set_from_pairs<'a, I>(&mut self, pairs: I) -> List<K, V>
where
K: 'a,
V: 'a,
I: IntoIterator<Item = &'a (K, V)>,
I::IntoIter: DoubleEndedIterator,
{
let mut list = List::empty();
for (key, value) in pairs.into_iter().rev() {
list = self
.entry(list, key.clone())
.or_insert_with(|| value.clone());
}
list
}
}
fn join<K, V>(a: &BTreeMap<K, V>, b: &BTreeMap<K, V>) -> BTreeMap<K, (Option<V>, Option<V>)>
where
K: Clone + Ord,
V: Clone + Ord,
{
let mut joined: BTreeMap<K, (Option<V>, Option<V>)> = BTreeMap::new();
for (k, v) in a {
joined.entry(k.clone()).or_default().0 = Some(v.clone());
}
for (k, v) in b {
joined.entry(k.clone()).or_default().1 = Some(v.clone());
}
joined
}
#[quickcheck_macros::quickcheck]
#[ignore]
#[allow(clippy::needless_pass_by_value)]
fn roundtrip_list_from_vec(pairs: Vec<(u16, u16)>) -> bool {
let mut builder = ListBuilder::default();
let list = builder.set_from_pairs(&pairs);
let expected: BTreeMap<_, _> = pairs.iter().copied().collect();
let actual = builder.iter_reverse(list).map(|(k, v)| (*k, *v));
actual.eq(expected.into_iter().rev())
}
#[quickcheck_macros::quickcheck]
#[ignore]
#[allow(clippy::needless_pass_by_value)]
fn roundtrip_list_intersection(
a_elements: Vec<(u16, u16)>,
b_elements: Vec<(u16, u16)>,
) -> bool {
let mut builder = ListBuilder::default();
let a = builder.set_from_pairs(&a_elements);
let b = builder.set_from_pairs(&b_elements);
let intersection = builder.intersect_with(a, b, |a, b| a + b);
let a_map: BTreeMap<_, _> = a_elements.iter().copied().collect();
let b_map: BTreeMap<_, _> = b_elements.iter().copied().collect();
let intersection_map = join(&a_map, &b_map);
let expected: Vec<_> = intersection_map
.into_iter()
.filter_map(|(k, (v1, v2))| Some((k, v1? + v2?)))
.collect();
let actual = builder.iter_reverse(intersection).map(|(k, v)| (*k, *v));
actual.eq(expected.into_iter().rev())
}
}

View file

@ -0,0 +1,319 @@
use std::fmt;
use std::num::NonZeroU32;
use std::ops::Deref;
use compact_str::{CompactString, ToCompactString};
use ruff_db::files::File;
use ruff_python_ast as ast;
use ruff_python_stdlib::identifiers::is_identifier;
use crate::{db::Db, module_resolver::file_to_module};
/// A module name, e.g. `foo.bar`.
///
/// Always normalized to the absolute form (never a relative module name, i.e., never `.foo`).
#[derive(Clone, Debug, Eq, PartialEq, Hash, PartialOrd, Ord)]
pub struct ModuleName(compact_str::CompactString);
impl ModuleName {
/// Creates a new module name for `name`. Returns `Some` if `name` is a valid, absolute
/// module name and `None` otherwise.
///
/// The module name is invalid if:
///
/// * The name is empty
/// * The name is relative
/// * The name ends with a `.`
/// * The name contains a sequence of multiple dots
/// * A component of a name (the part between two dots) isn't a valid python identifier.
#[inline]
#[must_use]
pub fn new(name: &str) -> Option<Self> {
Self::is_valid_name(name).then(|| Self(CompactString::from(name)))
}
/// Creates a new module name for `name` where `name` is a static string.
/// Returns `Some` if `name` is a valid, absolute module name and `None` otherwise.
///
/// The module name is invalid if:
///
/// * The name is empty
/// * The name is relative
/// * The name ends with a `.`
/// * The name contains a sequence of multiple dots
/// * A component of a name (the part between two dots) isn't a valid python identifier.
///
/// ## Examples
///
/// ```
/// use ty_python_semantic::ModuleName;
///
/// assert_eq!(ModuleName::new_static("foo.bar").as_deref(), Some("foo.bar"));
/// assert_eq!(ModuleName::new_static(""), None);
/// assert_eq!(ModuleName::new_static("..foo"), None);
/// assert_eq!(ModuleName::new_static(".foo"), None);
/// assert_eq!(ModuleName::new_static("foo."), None);
/// assert_eq!(ModuleName::new_static("foo..bar"), None);
/// assert_eq!(ModuleName::new_static("2000"), None);
/// ```
#[inline]
#[must_use]
pub fn new_static(name: &'static str) -> Option<Self> {
Self::is_valid_name(name).then(|| Self(CompactString::const_new(name)))
}
#[must_use]
fn is_valid_name(name: &str) -> bool {
!name.is_empty() && name.split('.').all(is_identifier)
}
/// An iterator over the components of the module name:
///
/// # Examples
///
/// ```
/// use ty_python_semantic::ModuleName;
///
/// assert_eq!(ModuleName::new_static("foo.bar.baz").unwrap().components().collect::<Vec<_>>(), vec!["foo", "bar", "baz"]);
/// ```
#[must_use]
pub fn components(&self) -> impl DoubleEndedIterator<Item = &str> {
self.0.split('.')
}
/// The name of this module's immediate parent, if it has a parent.
///
/// # Examples
///
/// ```
/// use ty_python_semantic::ModuleName;
///
/// assert_eq!(ModuleName::new_static("foo.bar").unwrap().parent(), Some(ModuleName::new_static("foo").unwrap()));
/// assert_eq!(ModuleName::new_static("foo.bar.baz").unwrap().parent(), Some(ModuleName::new_static("foo.bar").unwrap()));
/// assert_eq!(ModuleName::new_static("root").unwrap().parent(), None);
/// ```
#[must_use]
pub fn parent(&self) -> Option<ModuleName> {
let (parent, _) = self.0.rsplit_once('.')?;
Some(Self(parent.to_compact_string()))
}
/// Returns `true` if the name starts with `other`.
///
/// This is equivalent to checking if `self` is a sub-module of `other`.
///
/// # Examples
///
/// ```
/// use ty_python_semantic::ModuleName;
///
/// assert!(ModuleName::new_static("foo.bar").unwrap().starts_with(&ModuleName::new_static("foo").unwrap()));
///
/// assert!(!ModuleName::new_static("foo.bar").unwrap().starts_with(&ModuleName::new_static("bar").unwrap()));
/// assert!(!ModuleName::new_static("foo_bar").unwrap().starts_with(&ModuleName::new_static("foo").unwrap()));
/// ```
#[must_use]
pub fn starts_with(&self, other: &ModuleName) -> bool {
let mut self_components = self.components();
let other_components = other.components();
for other_component in other_components {
if self_components.next() != Some(other_component) {
return false;
}
}
true
}
#[must_use]
#[inline]
pub fn as_str(&self) -> &str {
&self.0
}
/// Construct a [`ModuleName`] from a sequence of parts.
///
/// # Examples
///
/// ```
/// use ty_python_semantic::ModuleName;
///
/// assert_eq!(&*ModuleName::from_components(["a"]).unwrap(), "a");
/// assert_eq!(&*ModuleName::from_components(["a", "b"]).unwrap(), "a.b");
/// assert_eq!(&*ModuleName::from_components(["a", "b", "c"]).unwrap(), "a.b.c");
///
/// assert_eq!(ModuleName::from_components(["a-b"]), None);
/// assert_eq!(ModuleName::from_components(["a", "a-b"]), None);
/// assert_eq!(ModuleName::from_components(["a", "b", "a-b-c"]), None);
/// ```
#[must_use]
pub fn from_components<'a>(components: impl IntoIterator<Item = &'a str>) -> Option<Self> {
let mut components = components.into_iter();
let first_part = components.next()?;
if !is_identifier(first_part) {
return None;
}
let name = if let Some(second_part) = components.next() {
if !is_identifier(second_part) {
return None;
}
let mut name = format!("{first_part}.{second_part}");
for part in components {
if !is_identifier(part) {
return None;
}
name.push('.');
name.push_str(part);
}
CompactString::from(&name)
} else {
CompactString::from(first_part)
};
Some(Self(name))
}
/// Extend `self` with the components of `other`
///
/// # Examples
///
/// ```
/// use ty_python_semantic::ModuleName;
///
/// let mut module_name = ModuleName::new_static("foo").unwrap();
/// module_name.extend(&ModuleName::new_static("bar").unwrap());
/// assert_eq!(&module_name, "foo.bar");
/// module_name.extend(&ModuleName::new_static("baz.eggs.ham").unwrap());
/// assert_eq!(&module_name, "foo.bar.baz.eggs.ham");
/// ```
pub fn extend(&mut self, other: &ModuleName) {
self.0.push('.');
self.0.push_str(other);
}
/// Returns an iterator of this module name and all of its parent modules.
///
/// # Examples
///
/// ```
/// use ty_python_semantic::ModuleName;
///
/// assert_eq!(
/// ModuleName::new_static("foo.bar.baz").unwrap().ancestors().collect::<Vec<_>>(),
/// vec![
/// ModuleName::new_static("foo.bar.baz").unwrap(),
/// ModuleName::new_static("foo.bar").unwrap(),
/// ModuleName::new_static("foo").unwrap(),
/// ],
/// );
/// ```
pub fn ancestors(&self) -> impl Iterator<Item = Self> {
std::iter::successors(Some(self.clone()), Self::parent)
}
pub(crate) fn from_import_statement<'db>(
db: &'db dyn Db,
importing_file: File,
node: &'db ast::StmtImportFrom,
) -> Result<Self, ModuleNameResolutionError> {
let ast::StmtImportFrom {
module,
level,
names: _,
range: _,
} = node;
let module = module.as_deref();
if let Some(level) = NonZeroU32::new(*level) {
relative_module_name(db, importing_file, module, level)
} else {
module
.and_then(Self::new)
.ok_or(ModuleNameResolutionError::InvalidSyntax)
}
}
}
impl Deref for ModuleName {
type Target = str;
#[inline]
fn deref(&self) -> &Self::Target {
self.as_str()
}
}
impl PartialEq<str> for ModuleName {
fn eq(&self, other: &str) -> bool {
self.as_str() == other
}
}
impl PartialEq<ModuleName> for str {
fn eq(&self, other: &ModuleName) -> bool {
self == other.as_str()
}
}
impl std::fmt::Display for ModuleName {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> std::fmt::Result {
f.write_str(&self.0)
}
}
/// Given a `from .foo import bar` relative import, resolve the relative module
/// we're importing `bar` from into an absolute [`ModuleName`]
/// using the name of the module we're currently analyzing.
///
/// - `level` is the number of dots at the beginning of the relative module name:
/// - `from .foo.bar import baz` => `level == 1`
/// - `from ...foo.bar import baz` => `level == 3`
/// - `tail` is the relative module name stripped of all leading dots:
/// - `from .foo import bar` => `tail == "foo"`
/// - `from ..foo.bar import baz` => `tail == "foo.bar"`
fn relative_module_name(
db: &dyn Db,
importing_file: File,
tail: Option<&str>,
level: NonZeroU32,
) -> Result<ModuleName, ModuleNameResolutionError> {
let module = file_to_module(db, importing_file)
.ok_or(ModuleNameResolutionError::UnknownCurrentModule)?;
let mut level = level.get();
if module.kind().is_package() {
level = level.saturating_sub(1);
}
let mut module_name = module
.name()
.ancestors()
.nth(level as usize)
.ok_or(ModuleNameResolutionError::TooManyDots)?;
if let Some(tail) = tail {
let tail = ModuleName::new(tail).ok_or(ModuleNameResolutionError::InvalidSyntax)?;
module_name.extend(&tail);
}
Ok(module_name)
}
/// Various ways in which resolving a [`ModuleName`]
/// from an [`ast::StmtImport`] or [`ast::StmtImportFrom`] node might fail
#[derive(Debug, Copy, Clone, PartialEq, Eq)]
pub(crate) enum ModuleNameResolutionError {
/// The import statement has invalid syntax
InvalidSyntax,
/// We couldn't resolve the file we're currently analyzing back to a module
/// (Only necessary for relative import statements)
UnknownCurrentModule,
/// The relative import statement seems to take us outside of the module search path
/// (e.g. our current module is `foo.bar`, and the relative import statement in `foo.bar`
/// is `from ....baz import spam`)
TooManyDots,
}

View file

@ -0,0 +1,45 @@
use std::iter::FusedIterator;
pub use module::{KnownModule, Module};
pub use resolver::resolve_module;
pub(crate) use resolver::{file_to_module, SearchPaths};
use ruff_db::system::SystemPath;
use crate::module_resolver::resolver::search_paths;
use crate::Db;
use resolver::SearchPathIterator;
mod module;
mod path;
mod resolver;
mod typeshed;
#[cfg(test)]
mod testing;
/// Returns an iterator over all search paths pointing to a system path
pub fn system_module_search_paths(db: &dyn Db) -> SystemModuleSearchPathsIter {
SystemModuleSearchPathsIter {
inner: search_paths(db),
}
}
pub struct SystemModuleSearchPathsIter<'db> {
inner: SearchPathIterator<'db>,
}
impl<'db> Iterator for SystemModuleSearchPathsIter<'db> {
type Item = &'db SystemPath;
fn next(&mut self) -> Option<Self::Item> {
loop {
let next = self.inner.next()?;
if let Some(system_path) = next.as_system_path() {
return Some(system_path);
}
}
}
}
impl FusedIterator for SystemModuleSearchPathsIter<'_> {}

View file

@ -0,0 +1,205 @@
use std::fmt::Formatter;
use std::str::FromStr;
use std::sync::Arc;
use ruff_db::files::File;
use super::path::SearchPath;
use crate::module_name::ModuleName;
/// Representation of a Python module.
#[derive(Clone, PartialEq, Eq, Hash)]
pub struct Module {
inner: Arc<ModuleInner>,
}
impl Module {
pub(crate) fn new(
name: ModuleName,
kind: ModuleKind,
search_path: SearchPath,
file: File,
) -> Self {
let known = KnownModule::try_from_search_path_and_name(&search_path, &name);
Self {
inner: Arc::new(ModuleInner {
name,
kind,
search_path,
file,
known,
}),
}
}
/// The absolute name of the module (e.g. `foo.bar`)
pub fn name(&self) -> &ModuleName {
&self.inner.name
}
/// The file to the source code that defines this module
pub fn file(&self) -> File {
self.inner.file
}
/// Is this a module that we special-case somehow? If so, which one?
pub fn known(&self) -> Option<KnownModule> {
self.inner.known
}
/// Does this module represent the given known module?
pub fn is_known(&self, known_module: KnownModule) -> bool {
self.known() == Some(known_module)
}
/// The search path from which the module was resolved.
pub(crate) fn search_path(&self) -> &SearchPath {
&self.inner.search_path
}
/// Determine whether this module is a single-file module or a package
pub fn kind(&self) -> ModuleKind {
self.inner.kind
}
}
impl std::fmt::Debug for Module {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Module")
.field("name", &self.name())
.field("kind", &self.kind())
.field("file", &self.file())
.field("search_path", &self.search_path())
.finish()
}
}
#[derive(PartialEq, Eq, Hash)]
struct ModuleInner {
name: ModuleName,
kind: ModuleKind,
search_path: SearchPath,
file: File,
known: Option<KnownModule>,
}
#[derive(Copy, Clone, Debug, Eq, PartialEq, Hash)]
pub enum ModuleKind {
/// A single-file module (e.g. `foo.py` or `foo.pyi`)
Module,
/// A python package (`foo/__init__.py` or `foo/__init__.pyi`)
Package,
}
impl ModuleKind {
pub const fn is_package(self) -> bool {
matches!(self, ModuleKind::Package)
}
pub const fn is_module(self) -> bool {
matches!(self, ModuleKind::Module)
}
}
/// Enumeration of various core stdlib modules in which important types are located
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, strum_macros::EnumString)]
#[cfg_attr(test, derive(strum_macros::EnumIter))]
#[strum(serialize_all = "snake_case")]
pub enum KnownModule {
Builtins,
Enum,
Types,
#[strum(serialize = "_typeshed")]
Typeshed,
TypingExtensions,
Typing,
Sys,
#[allow(dead_code)]
Abc, // currently only used in tests
Dataclasses,
Collections,
Inspect,
TyExtensions,
}
impl KnownModule {
pub const fn as_str(self) -> &'static str {
match self {
Self::Builtins => "builtins",
Self::Enum => "enum",
Self::Types => "types",
Self::Typing => "typing",
Self::Typeshed => "_typeshed",
Self::TypingExtensions => "typing_extensions",
Self::Sys => "sys",
Self::Abc => "abc",
Self::Dataclasses => "dataclasses",
Self::Collections => "collections",
Self::Inspect => "inspect",
Self::TyExtensions => "ty_extensions",
}
}
pub fn name(self) -> ModuleName {
ModuleName::new_static(self.as_str())
.unwrap_or_else(|| panic!("{self} should be a valid module name!"))
}
pub(crate) fn try_from_search_path_and_name(
search_path: &SearchPath,
name: &ModuleName,
) -> Option<Self> {
if search_path.is_standard_library() {
Self::from_str(name.as_str()).ok()
} else {
None
}
}
pub const fn is_builtins(self) -> bool {
matches!(self, Self::Builtins)
}
pub const fn is_typing(self) -> bool {
matches!(self, Self::Typing)
}
pub const fn is_ty_extensions(self) -> bool {
matches!(self, Self::TyExtensions)
}
pub const fn is_inspect(self) -> bool {
matches!(self, Self::Inspect)
}
pub const fn is_enum(self) -> bool {
matches!(self, Self::Enum)
}
}
impl std::fmt::Display for KnownModule {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
f.write_str(self.as_str())
}
}
#[cfg(test)]
mod tests {
use super::*;
use strum::IntoEnumIterator;
#[test]
fn known_module_roundtrip_from_str() {
let stdlib_search_path = SearchPath::vendored_stdlib();
for module in KnownModule::iter() {
let module_name = module.name();
assert_eq!(
KnownModule::try_from_search_path_and_name(&stdlib_search_path, &module_name),
Some(module),
"The strum `EnumString` implementation appears to be incorrect for `{module_name}`"
);
}
}
}

File diff suppressed because it is too large Load diff

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,314 @@
use ruff_db::system::{
DbWithTestSystem as _, DbWithWritableSystem as _, SystemPath, SystemPathBuf,
};
use ruff_db::vendored::VendoredPathBuf;
use ruff_python_ast::PythonVersion;
use crate::db::tests::TestDb;
use crate::program::{Program, SearchPathSettings};
use crate::{ProgramSettings, PythonPath, PythonPlatform};
/// A test case for the module resolver.
///
/// You generally shouldn't construct instances of this struct directly;
/// instead, use the [`TestCaseBuilder`].
pub(crate) struct TestCase<T> {
pub(crate) db: TestDb,
pub(crate) src: SystemPathBuf,
pub(crate) stdlib: T,
// Most test cases only ever need a single `site-packages` directory,
// so this is a single directory instead of a `Vec` of directories,
// like it is in `ruff_db::Program`.
pub(crate) site_packages: SystemPathBuf,
pub(crate) python_version: PythonVersion,
}
/// A `(file_name, file_contents)` tuple
pub(crate) type FileSpec = (&'static str, &'static str);
/// Specification for a typeshed mock to be created as part of a test
#[derive(Debug, Clone, Copy, Default)]
pub(crate) struct MockedTypeshed {
/// The stdlib files to be created in the typeshed mock
pub(crate) stdlib_files: &'static [FileSpec],
/// The contents of the `stdlib/VERSIONS` file
/// to be created in the typeshed mock
pub(crate) versions: &'static str,
}
#[derive(Debug)]
pub(crate) struct VendoredTypeshed;
#[derive(Debug)]
pub(crate) struct UnspecifiedTypeshed;
/// A builder for a module-resolver test case.
///
/// The builder takes care of creating a [`TestDb`]
/// instance, applying the module resolver settings,
/// and creating mock directories for the stdlib, `site-packages`,
/// first-party code, etc.
///
/// For simple tests that do not involve typeshed,
/// test cases can be created as follows:
///
/// ```rs
/// let test_case = TestCaseBuilder::new()
/// .with_src_files(...)
/// .build();
///
/// let test_case2 = TestCaseBuilder::new()
/// .with_site_packages_files(...)
/// .build();
/// ```
///
/// Any tests can specify the target Python version that should be used
/// in the module resolver settings:
///
/// ```rs
/// let test_case = TestCaseBuilder::new()
/// .with_src_files(...)
/// .with_python_version(...)
/// .build();
/// ```
///
/// For tests checking that standard-library module resolution is working
/// correctly, you should usually create a [`MockedTypeshed`] instance
/// and pass it to the [`TestCaseBuilder::with_mocked_typeshed`] method.
/// If you need to check something that involves the vendored typeshed stubs
/// we include as part of the binary, you can instead use the
/// [`TestCaseBuilder::with_vendored_typeshed`] method.
/// For either of these, you should almost always try to be explicit
/// about the Python version you want to be specified in the module-resolver
/// settings for the test:
///
/// ```rs
/// const TYPESHED = MockedTypeshed { ... };
///
/// let test_case = resolver_test_case()
/// .with_mocked_typeshed(TYPESHED)
/// .with_python_version(...)
/// .build();
///
/// let test_case2 = resolver_test_case()
/// .with_vendored_typeshed()
/// .with_python_version(...)
/// .build();
/// ```
///
/// If you have not called one of those options, the `stdlib` field
/// on the [`TestCase`] instance created from `.build()` will be set
/// to `()`.
pub(crate) struct TestCaseBuilder<T> {
typeshed_option: T,
python_version: PythonVersion,
python_platform: PythonPlatform,
first_party_files: Vec<FileSpec>,
site_packages_files: Vec<FileSpec>,
}
impl<T> TestCaseBuilder<T> {
/// Specify files to be created in the `src` mock directory
pub(crate) fn with_src_files(mut self, files: &[FileSpec]) -> Self {
self.first_party_files.extend(files.iter().copied());
self
}
/// Specify files to be created in the `site-packages` mock directory
pub(crate) fn with_site_packages_files(mut self, files: &[FileSpec]) -> Self {
self.site_packages_files.extend(files.iter().copied());
self
}
/// Specify the Python version the module resolver should assume
pub(crate) fn with_python_version(mut self, python_version: PythonVersion) -> Self {
self.python_version = python_version;
self
}
fn write_mock_directory(
db: &mut TestDb,
location: impl AsRef<SystemPath>,
files: impl IntoIterator<Item = FileSpec>,
) -> SystemPathBuf {
let root = location.as_ref().to_path_buf();
// Make sure to create the directory even if the list of files is empty:
db.memory_file_system().create_directory_all(&root).unwrap();
db.write_files(
files
.into_iter()
.map(|(relative_path, contents)| (root.join(relative_path), contents)),
)
.unwrap();
root
}
}
impl TestCaseBuilder<UnspecifiedTypeshed> {
pub(crate) fn new() -> TestCaseBuilder<UnspecifiedTypeshed> {
Self {
typeshed_option: UnspecifiedTypeshed,
python_version: PythonVersion::default(),
python_platform: PythonPlatform::default(),
first_party_files: vec![],
site_packages_files: vec![],
}
}
/// Use the vendored stdlib stubs included in the Ruff binary for this test case
pub(crate) fn with_vendored_typeshed(self) -> TestCaseBuilder<VendoredTypeshed> {
let TestCaseBuilder {
typeshed_option: _,
python_version,
python_platform,
first_party_files,
site_packages_files,
} = self;
TestCaseBuilder {
typeshed_option: VendoredTypeshed,
python_version,
python_platform,
first_party_files,
site_packages_files,
}
}
/// Use a mock typeshed directory for this test case
pub(crate) fn with_mocked_typeshed(
self,
typeshed: MockedTypeshed,
) -> TestCaseBuilder<MockedTypeshed> {
let TestCaseBuilder {
typeshed_option: _,
python_version,
python_platform,
first_party_files,
site_packages_files,
} = self;
TestCaseBuilder {
typeshed_option: typeshed,
python_version,
python_platform,
first_party_files,
site_packages_files,
}
}
pub(crate) fn build(self) -> TestCase<()> {
let TestCase {
db,
src,
stdlib: _,
site_packages,
python_version,
} = self.with_mocked_typeshed(MockedTypeshed::default()).build();
TestCase {
db,
src,
stdlib: (),
site_packages,
python_version,
}
}
}
impl TestCaseBuilder<MockedTypeshed> {
pub(crate) fn build(self) -> TestCase<SystemPathBuf> {
let TestCaseBuilder {
typeshed_option,
python_version,
python_platform,
first_party_files,
site_packages_files,
} = self;
let mut db = TestDb::new();
let site_packages =
Self::write_mock_directory(&mut db, "/site-packages", site_packages_files);
let src = Self::write_mock_directory(&mut db, "/src", first_party_files);
let typeshed = Self::build_typeshed_mock(&mut db, &typeshed_option);
Program::from_settings(
&db,
ProgramSettings {
python_version,
python_platform,
search_paths: SearchPathSettings {
extra_paths: vec![],
src_roots: vec![src.clone()],
custom_typeshed: Some(typeshed.clone()),
python_path: PythonPath::KnownSitePackages(vec![site_packages.clone()]),
},
},
)
.expect("Valid program settings");
TestCase {
db,
src,
stdlib: typeshed.join("stdlib"),
site_packages,
python_version,
}
}
fn build_typeshed_mock(db: &mut TestDb, typeshed_to_build: &MockedTypeshed) -> SystemPathBuf {
let typeshed = SystemPathBuf::from("/typeshed");
let MockedTypeshed {
stdlib_files,
versions,
} = typeshed_to_build;
Self::write_mock_directory(
db,
typeshed.join("stdlib"),
stdlib_files
.iter()
.copied()
.chain(std::iter::once(("VERSIONS", *versions))),
);
typeshed
}
}
impl TestCaseBuilder<VendoredTypeshed> {
pub(crate) fn build(self) -> TestCase<VendoredPathBuf> {
let TestCaseBuilder {
typeshed_option: VendoredTypeshed,
python_version,
python_platform,
first_party_files,
site_packages_files,
} = self;
let mut db = TestDb::new();
let site_packages =
Self::write_mock_directory(&mut db, "/site-packages", site_packages_files);
let src = Self::write_mock_directory(&mut db, "/src", first_party_files);
Program::from_settings(
&db,
ProgramSettings {
python_version,
python_platform,
search_paths: SearchPathSettings {
python_path: PythonPath::KnownSitePackages(vec![site_packages.clone()]),
..SearchPathSettings::new(vec![src.clone()])
},
},
)
.expect("Valid search path settings");
TestCase {
db,
src,
stdlib: VendoredPathBuf::from("stdlib"),
site_packages,
python_version,
}
}
}

View file

@ -0,0 +1,702 @@
use std::collections::BTreeMap;
use std::fmt;
use std::num::{NonZeroU16, NonZeroUsize};
use std::ops::{RangeFrom, RangeInclusive};
use std::str::FromStr;
use ruff_python_ast::PythonVersion;
use rustc_hash::FxHashMap;
use crate::db::Db;
use crate::module_name::ModuleName;
use crate::Program;
pub(in crate::module_resolver) fn vendored_typeshed_versions(db: &dyn Db) -> TypeshedVersions {
TypeshedVersions::from_str(
&db.vendored()
.read_to_string("stdlib/VERSIONS")
.expect("The vendored typeshed stubs should contain a VERSIONS file"),
)
.expect("The VERSIONS file in the vendored typeshed stubs should be well-formed")
}
pub(crate) fn typeshed_versions(db: &dyn Db) -> &TypeshedVersions {
Program::get(db).search_paths(db).typeshed_versions()
}
#[derive(Debug, PartialEq, Eq, Clone)]
pub(crate) struct TypeshedVersionsParseError {
line_number: Option<NonZeroU16>,
reason: TypeshedVersionsParseErrorKind,
}
impl fmt::Display for TypeshedVersionsParseError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let TypeshedVersionsParseError {
line_number,
reason,
} = self;
if let Some(line_number) = line_number {
write!(
f,
"Error while parsing line {line_number} of typeshed's VERSIONS file: {reason}"
)
} else {
write!(f, "Error while parsing typeshed's VERSIONS file: {reason}")
}
}
}
impl std::error::Error for TypeshedVersionsParseError {
fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
if let TypeshedVersionsParseErrorKind::IntegerParsingFailure { err, .. } = &self.reason {
Some(err)
} else {
None
}
}
}
#[derive(Debug, PartialEq, Eq, Clone)]
pub(super) enum TypeshedVersionsParseErrorKind {
TooManyLines(NonZeroUsize),
UnexpectedNumberOfColons,
InvalidModuleName(String),
UnexpectedNumberOfHyphens,
UnexpectedNumberOfPeriods(String),
IntegerParsingFailure {
version: String,
err: std::num::ParseIntError,
},
}
impl fmt::Display for TypeshedVersionsParseErrorKind {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::TooManyLines(num_lines) => write!(
f,
"File has too many lines ({num_lines}); maximum allowed is {}",
NonZeroU16::MAX
),
Self::UnexpectedNumberOfColons => {
f.write_str("Expected every non-comment line to have exactly one colon")
}
Self::InvalidModuleName(name) => write!(
f,
"Expected all components of '{name}' to be valid Python identifiers"
),
Self::UnexpectedNumberOfHyphens => {
f.write_str("Expected every non-comment line to have exactly one '-' character")
}
Self::UnexpectedNumberOfPeriods(format) => write!(
f,
"Expected all versions to be in the form {{MAJOR}}.{{MINOR}}; got '{format}'"
),
Self::IntegerParsingFailure { version, err } => write!(
f,
"Failed to convert '{version}' to a pair of integers due to {err}",
),
}
}
}
#[derive(Debug, PartialEq, Eq)]
pub(crate) struct TypeshedVersions(FxHashMap<ModuleName, PyVersionRange>);
impl TypeshedVersions {
#[must_use]
fn exact(&self, module_name: &ModuleName) -> Option<&PyVersionRange> {
self.0.get(module_name)
}
#[must_use]
pub(in crate::module_resolver) fn query_module(
&self,
module: &ModuleName,
python_version: PythonVersion,
) -> TypeshedVersionsQueryResult {
if let Some(range) = self.exact(module) {
if range.contains(python_version) {
TypeshedVersionsQueryResult::Exists
} else {
TypeshedVersionsQueryResult::DoesNotExist
}
} else {
let mut module = module.parent();
while let Some(module_to_try) = module {
if let Some(range) = self.exact(&module_to_try) {
return {
if range.contains(python_version) {
TypeshedVersionsQueryResult::MaybeExists
} else {
TypeshedVersionsQueryResult::DoesNotExist
}
};
}
module = module_to_try.parent();
}
TypeshedVersionsQueryResult::DoesNotExist
}
}
}
/// Possible answers [`TypeshedVersions::query_module()`] could give to the question:
/// "Does this module exist in the stdlib at runtime on a certain target version?"
#[derive(Debug, Copy, PartialEq, Eq, Clone, Hash)]
pub(crate) enum TypeshedVersionsQueryResult {
/// The module definitely exists in the stdlib at runtime on the user-specified target version.
///
/// For example:
/// - The target version is Python 3.8
/// - We're querying whether the `asyncio.tasks` module exists in the stdlib
/// - The VERSIONS file contains the line `asyncio.tasks: 3.8-`
Exists,
/// The module definitely does not exist in the stdlib on the user-specified target version.
///
/// For example:
/// - We're querying whether the `foo` module exists in the stdlib
/// - There is no top-level `foo` module in VERSIONS
///
/// OR:
/// - The target version is Python 3.8
/// - We're querying whether the module `importlib.abc` exists in the stdlib
/// - The VERSIONS file contains the line `importlib.abc: 3.10-`,
/// indicating that the module was added in 3.10
///
/// OR:
/// - The target version is Python 3.8
/// - We're querying whether the module `collections.abc` exists in the stdlib
/// - The VERSIONS file does not contain any information about the `collections.abc` submodule,
/// but *does* contain the line `collections: 3.10-`,
/// indicating that the entire `collections` package was added in Python 3.10.
DoesNotExist,
/// The module potentially exists in the stdlib and, if it does,
/// it definitely exists on the user-specified target version.
///
/// This variant is only relevant for submodules,
/// for which the typeshed VERSIONS file does not provide comprehensive information.
/// (The VERSIONS file is guaranteed to provide information about all top-level stdlib modules and packages,
/// but not necessarily about all submodules within each top-level package.)
///
/// For example:
/// - The target version is Python 3.8
/// - We're querying whether the `asyncio.staggered` module exists in the stdlib
/// - The typeshed VERSIONS file contains the line `asyncio: 3.8`,
/// indicating that the `asyncio` package was added in Python 3.8,
/// but does not contain any explicit information about the `asyncio.staggered` submodule.
MaybeExists,
}
impl FromStr for TypeshedVersions {
type Err = TypeshedVersionsParseError;
fn from_str(s: &str) -> Result<Self, Self::Err> {
let mut map = FxHashMap::default();
for (line_index, line) in s.lines().enumerate() {
// humans expect line numbers to be 1-indexed
let line_number = NonZeroUsize::new(line_index.saturating_add(1)).unwrap();
let Ok(line_number) = NonZeroU16::try_from(line_number) else {
return Err(TypeshedVersionsParseError {
line_number: None,
reason: TypeshedVersionsParseErrorKind::TooManyLines(line_number),
});
};
let Some(content) = line.split('#').map(str::trim).next() else {
continue;
};
if content.is_empty() {
continue;
}
let mut parts = content.split(':').map(str::trim);
let (Some(module_name), Some(rest), None) = (parts.next(), parts.next(), parts.next())
else {
return Err(TypeshedVersionsParseError {
line_number: Some(line_number),
reason: TypeshedVersionsParseErrorKind::UnexpectedNumberOfColons,
});
};
let Some(module_name) = ModuleName::new(module_name) else {
return Err(TypeshedVersionsParseError {
line_number: Some(line_number),
reason: TypeshedVersionsParseErrorKind::InvalidModuleName(
module_name.to_string(),
),
});
};
match PyVersionRange::from_str(rest) {
Ok(version) => map.insert(module_name, version),
Err(reason) => {
return Err(TypeshedVersionsParseError {
line_number: Some(line_number),
reason,
})
}
};
}
Ok(Self(map))
}
}
impl fmt::Display for TypeshedVersions {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let sorted_items: BTreeMap<&ModuleName, &PyVersionRange> = self.0.iter().collect();
for (module_name, range) in sorted_items {
writeln!(f, "{module_name}: {range}")?;
}
Ok(())
}
}
#[derive(Debug, Clone, Eq, PartialEq, Hash)]
enum PyVersionRange {
AvailableFrom(RangeFrom<PythonVersion>),
AvailableWithin(RangeInclusive<PythonVersion>),
}
impl PyVersionRange {
#[must_use]
fn contains(&self, version: PythonVersion) -> bool {
match self {
Self::AvailableFrom(inner) => inner.contains(&version),
Self::AvailableWithin(inner) => inner.contains(&version),
}
}
}
impl FromStr for PyVersionRange {
type Err = TypeshedVersionsParseErrorKind;
fn from_str(s: &str) -> Result<Self, Self::Err> {
let mut parts = s.split('-').map(str::trim);
match (parts.next(), parts.next(), parts.next()) {
(Some(lower), Some(""), None) => {
let lower = python_version_from_versions_file_string(lower)?;
Ok(Self::AvailableFrom(lower..))
}
(Some(lower), Some(upper), None) => {
let lower = python_version_from_versions_file_string(lower)?;
let upper = python_version_from_versions_file_string(upper)?;
Ok(Self::AvailableWithin(lower..=upper))
}
_ => Err(TypeshedVersionsParseErrorKind::UnexpectedNumberOfHyphens),
}
}
}
impl fmt::Display for PyVersionRange {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::AvailableFrom(range_from) => write!(f, "{}-", range_from.start),
Self::AvailableWithin(range_inclusive) => {
write!(f, "{}-{}", range_inclusive.start(), range_inclusive.end())
}
}
}
}
fn python_version_from_versions_file_string(
s: &str,
) -> Result<PythonVersion, TypeshedVersionsParseErrorKind> {
let mut parts = s.split('.').map(str::trim);
let (Some(major), Some(minor), None) = (parts.next(), parts.next(), parts.next()) else {
return Err(TypeshedVersionsParseErrorKind::UnexpectedNumberOfPeriods(
s.to_string(),
));
};
PythonVersion::try_from((major, minor)).map_err(|int_parse_error| {
TypeshedVersionsParseErrorKind::IntegerParsingFailure {
version: s.to_string(),
err: int_parse_error,
}
})
}
#[cfg(test)]
mod tests {
use std::fmt::Write as _;
use std::num::{IntErrorKind, NonZeroU16};
use std::path::Path;
use insta::assert_snapshot;
use crate::db::tests::TestDb;
use super::*;
const TYPESHED_STDLIB_DIR: &str = "stdlib";
const ONE: Option<NonZeroU16> = Some(NonZeroU16::new(1).unwrap());
impl TypeshedVersions {
#[must_use]
fn contains_exact(&self, module: &ModuleName) -> bool {
self.exact(module).is_some()
}
#[must_use]
fn len(&self) -> usize {
self.0.len()
}
}
#[test]
fn can_parse_vendored_versions_file() {
let db = TestDb::new();
let versions = vendored_typeshed_versions(&db);
assert!(versions.len() > 100);
assert!(versions.len() < 1000);
let asyncio = ModuleName::new_static("asyncio").unwrap();
let asyncio_staggered = ModuleName::new_static("asyncio.staggered").unwrap();
let audioop = ModuleName::new_static("audioop").unwrap();
assert!(versions.contains_exact(&asyncio));
assert_eq!(
versions.query_module(&asyncio, PythonVersion::PY310),
TypeshedVersionsQueryResult::Exists
);
assert!(versions.contains_exact(&asyncio_staggered));
assert_eq!(
versions.query_module(&asyncio_staggered, PythonVersion::PY38),
TypeshedVersionsQueryResult::Exists
);
assert_eq!(
versions.query_module(&asyncio_staggered, PythonVersion::PY37),
TypeshedVersionsQueryResult::DoesNotExist
);
assert!(versions.contains_exact(&audioop));
assert_eq!(
versions.query_module(&audioop, PythonVersion::PY312),
TypeshedVersionsQueryResult::Exists
);
assert_eq!(
versions.query_module(&audioop, PythonVersion::PY313),
TypeshedVersionsQueryResult::DoesNotExist
);
}
#[test]
fn typeshed_versions_consistent_with_vendored_stubs() {
let db = TestDb::new();
let vendored_typeshed_versions = vendored_typeshed_versions(&db);
let vendored_typeshed_dir =
Path::new(env!("CARGO_MANIFEST_DIR")).join("../ty_vendored/vendor/typeshed");
let mut empty_iterator = true;
let stdlib_stubs_path = vendored_typeshed_dir.join(TYPESHED_STDLIB_DIR);
for entry in std::fs::read_dir(&stdlib_stubs_path).unwrap() {
empty_iterator = false;
let entry = entry.unwrap();
let absolute_path = entry.path();
let relative_path = absolute_path
.strip_prefix(&stdlib_stubs_path)
.unwrap_or_else(|_| panic!("Expected path to be a child of {stdlib_stubs_path:?} but found {absolute_path:?}"));
let relative_path_str = relative_path.as_os_str().to_str().unwrap_or_else(|| {
panic!("Expected all typeshed paths to be valid UTF-8; got {relative_path:?}")
});
if relative_path_str == "VERSIONS" {
continue;
}
let top_level_module = if let Some(extension) = relative_path.extension() {
// It was a file; strip off the file extension to get the module name:
let extension = extension
.to_str()
.unwrap_or_else(||panic!("Expected all file extensions to be UTF-8; was not true for {relative_path:?}"));
relative_path_str
.strip_suffix(extension)
.and_then(|string| string.strip_suffix('.')).unwrap_or_else(|| {
panic!("Expected path {relative_path_str:?} to end with computed extension {extension:?}")
})
} else {
// It was a directory; no need to do anything to get the module name
relative_path_str
};
let top_level_module = ModuleName::new(top_level_module)
.unwrap_or_else(|| panic!("{top_level_module:?} was not a valid module name!"));
assert!(vendored_typeshed_versions.contains_exact(&top_level_module));
}
assert!(
!empty_iterator,
"Expected there to be at least one file or directory in the vendored typeshed stubs"
);
}
#[test]
fn can_parse_mock_versions_file() {
const VERSIONS: &str = "\
# a comment
# some more comment
# yet more comment
# and some more comment
bar: 2.7-3.10
# more comment
bar.baz: 3.1-3.9
foo: 3.8- # trailing comment
";
let parsed_versions = TypeshedVersions::from_str(VERSIONS).unwrap();
assert_eq!(parsed_versions.len(), 3);
assert_snapshot!(parsed_versions.to_string(), @r"
bar: 2.7-3.10
bar.baz: 3.1-3.9
foo: 3.8-
"
);
}
#[test]
fn version_within_range_parsed_correctly() {
let parsed_versions = TypeshedVersions::from_str("bar: 2.7-3.10").unwrap();
let bar = ModuleName::new_static("bar").unwrap();
assert!(parsed_versions.contains_exact(&bar));
assert_eq!(
parsed_versions.query_module(&bar, PythonVersion::PY37),
TypeshedVersionsQueryResult::Exists
);
assert_eq!(
parsed_versions.query_module(&bar, PythonVersion::PY310),
TypeshedVersionsQueryResult::Exists
);
assert_eq!(
parsed_versions.query_module(&bar, PythonVersion::PY311),
TypeshedVersionsQueryResult::DoesNotExist
);
}
#[test]
fn version_from_range_parsed_correctly() {
let parsed_versions = TypeshedVersions::from_str("foo: 3.8-").unwrap();
let foo = ModuleName::new_static("foo").unwrap();
assert!(parsed_versions.contains_exact(&foo));
assert_eq!(
parsed_versions.query_module(&foo, PythonVersion::PY37),
TypeshedVersionsQueryResult::DoesNotExist
);
assert_eq!(
parsed_versions.query_module(&foo, PythonVersion::PY38),
TypeshedVersionsQueryResult::Exists
);
assert_eq!(
parsed_versions.query_module(&foo, PythonVersion::PY311),
TypeshedVersionsQueryResult::Exists
);
}
#[test]
fn explicit_submodule_parsed_correctly() {
let parsed_versions = TypeshedVersions::from_str("bar.baz: 3.1-3.9").unwrap();
let bar_baz = ModuleName::new_static("bar.baz").unwrap();
assert!(parsed_versions.contains_exact(&bar_baz));
assert_eq!(
parsed_versions.query_module(&bar_baz, PythonVersion::PY37),
TypeshedVersionsQueryResult::Exists
);
assert_eq!(
parsed_versions.query_module(&bar_baz, PythonVersion::PY39),
TypeshedVersionsQueryResult::Exists
);
assert_eq!(
parsed_versions.query_module(&bar_baz, PythonVersion::PY310),
TypeshedVersionsQueryResult::DoesNotExist
);
}
#[test]
fn implicit_submodule_queried_correctly() {
let parsed_versions = TypeshedVersions::from_str("bar: 2.7-3.10").unwrap();
let bar_eggs = ModuleName::new_static("bar.eggs").unwrap();
assert!(!parsed_versions.contains_exact(&bar_eggs));
assert_eq!(
parsed_versions.query_module(&bar_eggs, PythonVersion::PY37),
TypeshedVersionsQueryResult::MaybeExists
);
assert_eq!(
parsed_versions.query_module(&bar_eggs, PythonVersion::PY310),
TypeshedVersionsQueryResult::MaybeExists
);
assert_eq!(
parsed_versions.query_module(&bar_eggs, PythonVersion::PY311),
TypeshedVersionsQueryResult::DoesNotExist
);
}
#[test]
fn nonexistent_module_queried_correctly() {
let parsed_versions = TypeshedVersions::from_str("eggs: 3.8-").unwrap();
let spam = ModuleName::new_static("spam").unwrap();
assert!(!parsed_versions.contains_exact(&spam));
assert_eq!(
parsed_versions.query_module(&spam, PythonVersion::PY37),
TypeshedVersionsQueryResult::DoesNotExist
);
assert_eq!(
parsed_versions.query_module(&spam, PythonVersion::PY313),
TypeshedVersionsQueryResult::DoesNotExist
);
}
#[test]
fn invalid_huge_versions_file() {
let offset = 100;
let too_many = u16::MAX as usize + offset;
let mut massive_versions_file = String::new();
for i in 0..too_many {
let _ = writeln!(&mut massive_versions_file, "x{i}: 3.8-");
}
assert_eq!(
TypeshedVersions::from_str(&massive_versions_file),
Err(TypeshedVersionsParseError {
line_number: None,
reason: TypeshedVersionsParseErrorKind::TooManyLines(
NonZeroUsize::new(too_many + 1 - offset).unwrap()
)
})
);
}
#[test]
fn invalid_typeshed_versions_bad_colon_number() {
assert_eq!(
TypeshedVersions::from_str("foo 3.7"),
Err(TypeshedVersionsParseError {
line_number: ONE,
reason: TypeshedVersionsParseErrorKind::UnexpectedNumberOfColons
})
);
assert_eq!(
TypeshedVersions::from_str("foo:: 3.7"),
Err(TypeshedVersionsParseError {
line_number: ONE,
reason: TypeshedVersionsParseErrorKind::UnexpectedNumberOfColons
})
);
}
#[test]
fn invalid_typeshed_versions_non_identifier_modules() {
assert_eq!(
TypeshedVersions::from_str("not!an!identifier!: 3.7"),
Err(TypeshedVersionsParseError {
line_number: ONE,
reason: TypeshedVersionsParseErrorKind::InvalidModuleName(
"not!an!identifier!".to_string()
)
})
);
assert_eq!(
TypeshedVersions::from_str("(also_not).(an_identifier): 3.7"),
Err(TypeshedVersionsParseError {
line_number: ONE,
reason: TypeshedVersionsParseErrorKind::InvalidModuleName(
"(also_not).(an_identifier)".to_string()
)
})
);
}
#[test]
fn invalid_typeshed_versions_bad_hyphen_number() {
assert_eq!(
TypeshedVersions::from_str("foo: 3.8"),
Err(TypeshedVersionsParseError {
line_number: ONE,
reason: TypeshedVersionsParseErrorKind::UnexpectedNumberOfHyphens
})
);
assert_eq!(
TypeshedVersions::from_str("foo: 3.8--"),
Err(TypeshedVersionsParseError {
line_number: ONE,
reason: TypeshedVersionsParseErrorKind::UnexpectedNumberOfHyphens
})
);
assert_eq!(
TypeshedVersions::from_str("foo: 3.8--3.9"),
Err(TypeshedVersionsParseError {
line_number: ONE,
reason: TypeshedVersionsParseErrorKind::UnexpectedNumberOfHyphens
})
);
}
#[test]
fn invalid_typeshed_versions_bad_period_number() {
assert_eq!(
TypeshedVersions::from_str("foo: 38-"),
Err(TypeshedVersionsParseError {
line_number: ONE,
reason: TypeshedVersionsParseErrorKind::UnexpectedNumberOfPeriods("38".to_string())
})
);
assert_eq!(
TypeshedVersions::from_str("foo: 3..8-"),
Err(TypeshedVersionsParseError {
line_number: ONE,
reason: TypeshedVersionsParseErrorKind::UnexpectedNumberOfPeriods(
"3..8".to_string()
)
})
);
assert_eq!(
TypeshedVersions::from_str("foo: 3.8-3..11"),
Err(TypeshedVersionsParseError {
line_number: ONE,
reason: TypeshedVersionsParseErrorKind::UnexpectedNumberOfPeriods(
"3..11".to_string()
)
})
);
}
#[test]
fn invalid_typeshed_versions_non_digits() {
let err = TypeshedVersions::from_str("foo: 1.two-").unwrap_err();
assert_eq!(err.line_number, ONE);
let TypeshedVersionsParseErrorKind::IntegerParsingFailure { version, err } = err.reason
else {
panic!()
};
assert_eq!(version, "1.two".to_string());
assert_eq!(*err.kind(), IntErrorKind::InvalidDigit);
let err = TypeshedVersions::from_str("foo: 3.8-four.9").unwrap_err();
assert_eq!(err.line_number, ONE);
let TypeshedVersionsParseErrorKind::IntegerParsingFailure { version, err } = err.reason
else {
panic!()
};
assert_eq!(version, "four.9".to_string());
assert_eq!(*err.kind(), IntErrorKind::InvalidDigit);
}
}

View file

@ -0,0 +1,21 @@
use ruff_python_ast::AnyNodeRef;
/// Compact key for a node for use in a hash map.
///
/// Stores the memory address of the node, because using the range and the kind
/// of the node is not enough to uniquely identify them in ASTs resulting from
/// invalid syntax. For example, parsing the input `for` results in a `StmtFor`
/// AST node where both the `target` and the `iter` field are `ExprName` nodes
/// with the same (empty) range `3..3`.
#[derive(Copy, Clone, Debug, Eq, PartialEq, Hash)]
pub(super) struct NodeKey(usize);
impl NodeKey {
pub(super) fn from_node<'a, N>(node: N) -> Self
where
N: Into<AnyNodeRef<'a>>,
{
let node = node.into();
NodeKey(node.as_ptr().as_ptr() as usize)
}
}

View file

@ -0,0 +1,166 @@
use crate::module_resolver::SearchPaths;
use crate::python_platform::PythonPlatform;
use crate::site_packages::SysPrefixPathOrigin;
use crate::Db;
use anyhow::Context;
use ruff_db::system::{SystemPath, SystemPathBuf};
use ruff_python_ast::PythonVersion;
use salsa::Durability;
use salsa::Setter;
#[salsa::input(singleton)]
pub struct Program {
pub python_version: PythonVersion,
#[return_ref]
pub python_platform: PythonPlatform,
#[return_ref]
pub(crate) search_paths: SearchPaths,
}
impl Program {
pub fn from_settings(db: &dyn Db, settings: ProgramSettings) -> anyhow::Result<Self> {
let ProgramSettings {
python_version,
python_platform,
search_paths,
} = settings;
tracing::info!("Python version: Python {python_version}, platform: {python_platform}");
let search_paths = SearchPaths::from_settings(db, &search_paths)
.with_context(|| "Invalid search path settings")?;
Ok(
Program::builder(python_version, python_platform, search_paths)
.durability(Durability::HIGH)
.new(db),
)
}
pub fn update_from_settings(
self,
db: &mut dyn Db,
settings: ProgramSettings,
) -> anyhow::Result<()> {
let ProgramSettings {
python_version,
python_platform,
search_paths,
} = settings;
if &python_platform != self.python_platform(db) {
tracing::debug!("Updating python platform: `{python_platform:?}`");
self.set_python_platform(db).to(python_platform);
}
if python_version != self.python_version(db) {
tracing::debug!("Updating python version: Python {python_version}");
self.set_python_version(db).to(python_version);
}
self.update_search_paths(db, &search_paths)?;
Ok(())
}
pub fn update_search_paths(
self,
db: &mut dyn Db,
search_path_settings: &SearchPathSettings,
) -> anyhow::Result<()> {
let search_paths = SearchPaths::from_settings(db, search_path_settings)?;
if self.search_paths(db) != &search_paths {
tracing::debug!("Update search paths");
self.set_search_paths(db).to(search_paths);
}
Ok(())
}
pub fn custom_stdlib_search_path(self, db: &dyn Db) -> Option<&SystemPath> {
self.search_paths(db).custom_stdlib()
}
}
#[derive(Clone, Debug, Eq, PartialEq)]
#[cfg_attr(feature = "serde", derive(serde::Serialize))]
pub struct ProgramSettings {
pub python_version: PythonVersion,
pub python_platform: PythonPlatform,
pub search_paths: SearchPathSettings,
}
/// Configures the search paths for module resolution.
#[derive(Eq, PartialEq, Debug, Clone)]
#[cfg_attr(feature = "serde", derive(serde::Serialize))]
pub struct SearchPathSettings {
/// List of user-provided paths that should take first priority in the module resolution.
/// Examples in other type checkers are mypy's MYPYPATH environment variable,
/// or pyright's stubPath configuration setting.
pub extra_paths: Vec<SystemPathBuf>,
/// The root of the project, used for finding first-party modules.
pub src_roots: Vec<SystemPathBuf>,
/// Optional path to a "custom typeshed" directory on disk for us to use for standard-library types.
/// If this is not provided, we will fallback to our vendored typeshed stubs for the stdlib,
/// bundled as a zip file in the binary
pub custom_typeshed: Option<SystemPathBuf>,
/// Path to the Python installation from which ty resolves third party dependencies
/// and their type information.
pub python_path: PythonPath,
}
impl SearchPathSettings {
pub fn new(src_roots: Vec<SystemPathBuf>) -> Self {
Self {
src_roots,
extra_paths: vec![],
custom_typeshed: None,
python_path: PythonPath::KnownSitePackages(vec![]),
}
}
}
#[derive(Debug, Clone, Eq, PartialEq)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub enum PythonPath {
/// A path that represents the value of [`sys.prefix`] at runtime in Python
/// for a given Python executable.
///
/// For the case of a virtual environment, where a
/// Python binary is at `/.venv/bin/python`, `sys.prefix` is the path to
/// the virtual environment the Python binary lies inside, i.e. `/.venv`,
/// and `site-packages` will be at `.venv/lib/python3.X/site-packages`.
/// System Python installations generally work the same way: if a system
/// Python installation lies at `/opt/homebrew/bin/python`, `sys.prefix`
/// will be `/opt/homebrew`, and `site-packages` will be at
/// `/opt/homebrew/lib/python3.X/site-packages`.
///
/// [`sys.prefix`]: https://docs.python.org/3/library/sys.html#sys.prefix
SysPrefix(SystemPathBuf, SysPrefixPathOrigin),
/// Tries to discover a virtual environment in the given path.
Discover(SystemPathBuf),
/// Resolved site packages paths.
///
/// This variant is mainly intended for testing where we want to skip resolving `site-packages`
/// because it would unnecessarily complicate the test setup.
KnownSitePackages(Vec<SystemPathBuf>),
}
impl PythonPath {
pub fn from_virtual_env_var(path: impl Into<SystemPathBuf>) -> Self {
Self::SysPrefix(path.into(), SysPrefixPathOrigin::VirtualEnvVar)
}
pub fn from_cli_flag(path: SystemPathBuf) -> Self {
Self::SysPrefix(path, SysPrefixPathOrigin::PythonCliFlag)
}
}

View file

@ -0,0 +1,129 @@
use std::fmt::{Display, Formatter};
/// The target platform to assume when resolving types.
#[derive(Debug, Clone, PartialEq, Eq)]
#[cfg_attr(
feature = "serde",
derive(serde::Serialize, serde::Deserialize),
serde(rename_all = "kebab-case")
)]
pub enum PythonPlatform {
/// Do not make any assumptions about the target platform.
All,
/// Assume a specific target platform like `linux`, `darwin` or `win32`.
///
/// We use a string (instead of individual enum variants), as the set of possible platforms
/// may change over time. See <https://docs.python.org/3/library/sys.html#sys.platform> for
/// some known platform identifiers.
#[cfg_attr(feature = "serde", serde(untagged))]
Identifier(String),
}
impl From<String> for PythonPlatform {
fn from(platform: String) -> Self {
match platform.as_str() {
"all" => PythonPlatform::All,
_ => PythonPlatform::Identifier(platform.to_string()),
}
}
}
impl Display for PythonPlatform {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
match self {
PythonPlatform::All => f.write_str("all"),
PythonPlatform::Identifier(name) => f.write_str(name),
}
}
}
impl Default for PythonPlatform {
fn default() -> Self {
if cfg!(target_os = "windows") {
PythonPlatform::Identifier("win32".to_string())
} else if cfg!(target_os = "macos") {
PythonPlatform::Identifier("darwin".to_string())
} else if cfg!(target_os = "android") {
PythonPlatform::Identifier("android".to_string())
} else if cfg!(target_os = "ios") {
PythonPlatform::Identifier("ios".to_string())
} else {
PythonPlatform::Identifier("linux".to_string())
}
}
}
#[cfg(feature = "schemars")]
mod schema {
use crate::PythonPlatform;
use schemars::_serde_json::Value;
use schemars::gen::SchemaGenerator;
use schemars::schema::{Metadata, Schema, SchemaObject, SubschemaValidation};
use schemars::JsonSchema;
impl JsonSchema for PythonPlatform {
fn schema_name() -> String {
"PythonPlatform".to_string()
}
fn json_schema(_gen: &mut SchemaGenerator) -> Schema {
Schema::Object(SchemaObject {
// Hard code some well known values, but allow any other string as well.
subschemas: Some(Box::new(SubschemaValidation {
any_of: Some(vec![
Schema::Object(SchemaObject {
instance_type: Some(schemars::schema::InstanceType::String.into()),
..SchemaObject::default()
}),
// Promote well-known values for better auto-completion.
// Using `const` over `enumValues` as recommended [here](https://github.com/SchemaStore/schemastore/blob/master/CONTRIBUTING.md#documenting-enums).
Schema::Object(SchemaObject {
const_value: Some(Value::String("all".to_string())),
metadata: Some(Box::new(Metadata {
description: Some(
"Do not make any assumptions about the target platform."
.to_string(),
),
..Metadata::default()
})),
..SchemaObject::default()
}),
Schema::Object(SchemaObject {
const_value: Some(Value::String("darwin".to_string())),
metadata: Some(Box::new(Metadata {
description: Some("Darwin".to_string()),
..Metadata::default()
})),
..SchemaObject::default()
}),
Schema::Object(SchemaObject {
const_value: Some(Value::String("linux".to_string())),
metadata: Some(Box::new(Metadata {
description: Some("Linux".to_string()),
..Metadata::default()
})),
..SchemaObject::default()
}),
Schema::Object(SchemaObject {
const_value: Some(Value::String("win32".to_string())),
metadata: Some(Box::new(Metadata {
description: Some("Windows".to_string()),
..Metadata::default()
})),
..SchemaObject::default()
}),
]),
..SubschemaValidation::default()
})),
..SchemaObject::default()
})
}
}
}

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,212 @@
use rustc_hash::FxHashMap;
use ruff_index::newtype_index;
use ruff_python_ast as ast;
use ruff_python_ast::ExprRef;
use crate::semantic_index::ast_ids::node_key::ExpressionNodeKey;
use crate::semantic_index::semantic_index;
use crate::semantic_index::symbol::ScopeId;
use crate::Db;
/// AST ids for a single scope.
///
/// The motivation for building the AST ids per scope isn't about reducing invalidation because
/// the struct changes whenever the parsed AST changes. Instead, it's mainly that we can
/// build the AST ids struct when building the symbol table and also keep the property that
/// IDs of outer scopes are unaffected by changes in inner scopes.
///
/// For example, we don't want that adding new statements to `foo` changes the statement id of `x = foo()` in:
///
/// ```python
/// def foo():
/// return 5
///
/// x = foo()
/// ```
#[derive(Debug, salsa::Update)]
pub(crate) struct AstIds {
/// Maps expressions to their expression id.
expressions_map: FxHashMap<ExpressionNodeKey, ScopedExpressionId>,
/// Maps expressions which "use" a symbol (that is, [`ast::ExprName`]) to a use id.
uses_map: FxHashMap<ExpressionNodeKey, ScopedUseId>,
}
impl AstIds {
fn expression_id(&self, key: impl Into<ExpressionNodeKey>) -> ScopedExpressionId {
let key = &key.into();
*self.expressions_map.get(key).unwrap_or_else(|| {
panic!("Could not find expression ID for {key:?}");
})
}
fn use_id(&self, key: impl Into<ExpressionNodeKey>) -> ScopedUseId {
self.uses_map[&key.into()]
}
}
fn ast_ids<'db>(db: &'db dyn Db, scope: ScopeId) -> &'db AstIds {
semantic_index(db, scope.file(db)).ast_ids(scope.file_scope_id(db))
}
/// Uniquely identifies a use of a name in a [`crate::semantic_index::symbol::FileScopeId`].
#[newtype_index]
pub struct ScopedUseId;
pub trait HasScopedUseId {
/// Returns the ID that uniquely identifies the use in `scope`.
fn scoped_use_id(&self, db: &dyn Db, scope: ScopeId) -> ScopedUseId;
}
impl HasScopedUseId for ast::Identifier {
fn scoped_use_id(&self, db: &dyn Db, scope: ScopeId) -> ScopedUseId {
let ast_ids = ast_ids(db, scope);
ast_ids.use_id(self)
}
}
impl HasScopedUseId for ast::ExprName {
fn scoped_use_id(&self, db: &dyn Db, scope: ScopeId) -> ScopedUseId {
let expression_ref = ExprRef::from(self);
expression_ref.scoped_use_id(db, scope)
}
}
impl HasScopedUseId for ast::ExprRef<'_> {
fn scoped_use_id(&self, db: &dyn Db, scope: ScopeId) -> ScopedUseId {
let ast_ids = ast_ids(db, scope);
ast_ids.use_id(*self)
}
}
/// Uniquely identifies an [`ast::Expr`] in a [`crate::semantic_index::symbol::FileScopeId`].
#[newtype_index]
#[derive(salsa::Update)]
pub struct ScopedExpressionId;
pub trait HasScopedExpressionId {
/// Returns the ID that uniquely identifies the node in `scope`.
fn scoped_expression_id(&self, db: &dyn Db, scope: ScopeId) -> ScopedExpressionId;
}
impl<T: HasScopedExpressionId> HasScopedExpressionId for Box<T> {
fn scoped_expression_id(&self, db: &dyn Db, scope: ScopeId) -> ScopedExpressionId {
self.as_ref().scoped_expression_id(db, scope)
}
}
macro_rules! impl_has_scoped_expression_id {
($ty: ty) => {
impl HasScopedExpressionId for $ty {
fn scoped_expression_id(&self, db: &dyn Db, scope: ScopeId) -> ScopedExpressionId {
let expression_ref = ExprRef::from(self);
expression_ref.scoped_expression_id(db, scope)
}
}
};
}
impl_has_scoped_expression_id!(ast::ExprBoolOp);
impl_has_scoped_expression_id!(ast::ExprName);
impl_has_scoped_expression_id!(ast::ExprBinOp);
impl_has_scoped_expression_id!(ast::ExprUnaryOp);
impl_has_scoped_expression_id!(ast::ExprLambda);
impl_has_scoped_expression_id!(ast::ExprIf);
impl_has_scoped_expression_id!(ast::ExprDict);
impl_has_scoped_expression_id!(ast::ExprSet);
impl_has_scoped_expression_id!(ast::ExprListComp);
impl_has_scoped_expression_id!(ast::ExprSetComp);
impl_has_scoped_expression_id!(ast::ExprDictComp);
impl_has_scoped_expression_id!(ast::ExprGenerator);
impl_has_scoped_expression_id!(ast::ExprAwait);
impl_has_scoped_expression_id!(ast::ExprYield);
impl_has_scoped_expression_id!(ast::ExprYieldFrom);
impl_has_scoped_expression_id!(ast::ExprCompare);
impl_has_scoped_expression_id!(ast::ExprCall);
impl_has_scoped_expression_id!(ast::ExprFString);
impl_has_scoped_expression_id!(ast::ExprStringLiteral);
impl_has_scoped_expression_id!(ast::ExprBytesLiteral);
impl_has_scoped_expression_id!(ast::ExprNumberLiteral);
impl_has_scoped_expression_id!(ast::ExprBooleanLiteral);
impl_has_scoped_expression_id!(ast::ExprNoneLiteral);
impl_has_scoped_expression_id!(ast::ExprEllipsisLiteral);
impl_has_scoped_expression_id!(ast::ExprAttribute);
impl_has_scoped_expression_id!(ast::ExprSubscript);
impl_has_scoped_expression_id!(ast::ExprStarred);
impl_has_scoped_expression_id!(ast::ExprNamed);
impl_has_scoped_expression_id!(ast::ExprList);
impl_has_scoped_expression_id!(ast::ExprTuple);
impl_has_scoped_expression_id!(ast::ExprSlice);
impl_has_scoped_expression_id!(ast::ExprIpyEscapeCommand);
impl_has_scoped_expression_id!(ast::Expr);
impl HasScopedExpressionId for ast::ExprRef<'_> {
fn scoped_expression_id(&self, db: &dyn Db, scope: ScopeId) -> ScopedExpressionId {
let ast_ids = ast_ids(db, scope);
ast_ids.expression_id(*self)
}
}
#[derive(Debug, Default)]
pub(super) struct AstIdsBuilder {
expressions_map: FxHashMap<ExpressionNodeKey, ScopedExpressionId>,
uses_map: FxHashMap<ExpressionNodeKey, ScopedUseId>,
}
impl AstIdsBuilder {
/// Adds `expr` to the expression ids map and returns its id.
pub(super) fn record_expression(&mut self, expr: &ast::Expr) -> ScopedExpressionId {
let expression_id = self.expressions_map.len().into();
self.expressions_map.insert(expr.into(), expression_id);
expression_id
}
/// Adds `expr` to the use ids map and returns its id.
pub(super) fn record_use(&mut self, expr: impl Into<ExpressionNodeKey>) -> ScopedUseId {
let use_id = self.uses_map.len().into();
self.uses_map.insert(expr.into(), use_id);
use_id
}
pub(super) fn finish(mut self) -> AstIds {
self.expressions_map.shrink_to_fit();
self.uses_map.shrink_to_fit();
AstIds {
expressions_map: self.expressions_map,
uses_map: self.uses_map,
}
}
}
/// Node key that can only be constructed for expressions.
pub(crate) mod node_key {
use ruff_python_ast as ast;
use crate::node_key::NodeKey;
#[derive(Copy, Clone, Eq, PartialEq, Hash, Debug, salsa::Update)]
pub(crate) struct ExpressionNodeKey(NodeKey);
impl From<ast::ExprRef<'_>> for ExpressionNodeKey {
fn from(value: ast::ExprRef<'_>) -> Self {
Self(NodeKey::from_node(value))
}
}
impl From<&ast::Expr> for ExpressionNodeKey {
fn from(value: &ast::Expr) -> Self {
Self(NodeKey::from_node(value))
}
}
impl From<&ast::Identifier> for ExpressionNodeKey {
fn from(value: &ast::Identifier) -> Self {
Self(NodeKey::from_node(value))
}
}
}

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,102 @@
use crate::semantic_index::use_def::FlowSnapshot;
use super::SemanticIndexBuilder;
/// An abstraction over the fact that each scope should have its own [`TryNodeContextStack`]
#[derive(Debug, Default)]
pub(super) struct TryNodeContextStackManager(Vec<TryNodeContextStack>);
impl TryNodeContextStackManager {
/// Push a new [`TryNodeContextStack`] onto the stack of stacks.
///
/// Each [`TryNodeContextStack`] is only valid for a single scope
pub(super) fn enter_nested_scope(&mut self) {
self.0.push(TryNodeContextStack::default());
}
/// Pop a new [`TryNodeContextStack`] off the stack of stacks.
///
/// Each [`TryNodeContextStack`] is only valid for a single scope
pub(super) fn exit_scope(&mut self) {
let popped_context = self.0.pop();
debug_assert!(
popped_context.is_some(),
"exit_scope() should never be called on an empty stack \
(this indicates an unbalanced `enter_nested_scope()`/`exit_scope()` pair of calls)"
);
}
/// Push a [`TryNodeContext`] onto the [`TryNodeContextStack`]
/// at the top of our stack of stacks
pub(super) fn push_context(&mut self) {
self.current_try_context_stack().push_context();
}
/// Pop a [`TryNodeContext`] off the [`TryNodeContextStack`]
/// at the top of our stack of stacks. Return the Vec of [`FlowSnapshot`]s
/// recorded while we were visiting the `try` suite.
pub(super) fn pop_context(&mut self) -> Vec<FlowSnapshot> {
self.current_try_context_stack().pop_context()
}
/// Retrieve the stack that is at the top of our stack of stacks.
/// For each `try` block on that stack, push the snapshot onto the `try` block
pub(super) fn record_definition(&mut self, builder: &SemanticIndexBuilder) {
self.current_try_context_stack().record_definition(builder);
}
/// Retrieve the [`TryNodeContextStack`] that is relevant for the current scope.
fn current_try_context_stack(&mut self) -> &mut TryNodeContextStack {
self.0
.last_mut()
.expect("There should always be at least one `TryBlockContexts` on the stack")
}
}
/// The contexts of nested `try`/`except` blocks for a single scope
#[derive(Debug, Default)]
struct TryNodeContextStack(Vec<TryNodeContext>);
impl TryNodeContextStack {
/// Push a new [`TryNodeContext`] for recording intermediate states
/// while visiting a [`ruff_python_ast::StmtTry`] node that has a `finally` branch.
fn push_context(&mut self) {
self.0.push(TryNodeContext::default());
}
/// Pop a [`TryNodeContext`] off the stack. Return the Vec of [`FlowSnapshot`]s
/// recorded while we were visiting the `try` suite.
fn pop_context(&mut self) -> Vec<FlowSnapshot> {
let TryNodeContext {
try_suite_snapshots,
} = self
.0
.pop()
.expect("Cannot pop a `try` block off an empty `TryBlockContexts` stack");
try_suite_snapshots
}
/// For each `try` block on the stack, push the snapshot onto the `try` block
fn record_definition(&mut self, builder: &SemanticIndexBuilder) {
for context in &mut self.0 {
context.record_definition(builder.flow_snapshot());
}
}
}
/// Context for tracking definitions over the course of a single
/// [`ruff_python_ast::StmtTry`] node
///
/// It will likely be necessary to add more fields to this struct in the future
/// when we add more advanced handling of `finally` branches.
#[derive(Debug, Default)]
struct TryNodeContext {
try_suite_snapshots: Vec<FlowSnapshot>,
}
impl TryNodeContext {
/// Take a record of what the internal state looked like after a definition
fn record_definition(&mut self, snapshot: FlowSnapshot) {
self.try_suite_snapshots.push(snapshot);
}
}

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,68 @@
use crate::ast_node_ref::AstNodeRef;
use crate::db::Db;
use crate::semantic_index::symbol::{FileScopeId, ScopeId};
use ruff_db::files::File;
use ruff_python_ast as ast;
use salsa;
/// Whether or not this expression should be inferred as a normal expression or
/// a type expression. For example, in `self.x: <annotation> = <value>`, the
/// `<annotation>` is inferred as a type expression, while `<value>` is inferred
/// as a normal expression.
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub(crate) enum ExpressionKind {
Normal,
TypeExpression,
}
/// An independently type-inferable expression.
///
/// Includes constraint expressions (e.g. if tests) and the RHS of an unpacking assignment.
///
/// ## Module-local type
/// This type should not be used as part of any cross-module API because
/// it holds a reference to the AST node. Range-offset changes
/// then propagate through all usages, and deserialization requires
/// reparsing the entire module.
///
/// E.g. don't use this type in:
///
/// * a return type of a cross-module query
/// * a field of a type that is a return type of a cross-module query
/// * an argument of a cross-module query
#[salsa::tracked(debug)]
pub(crate) struct Expression<'db> {
/// The file in which the expression occurs.
pub(crate) file: File,
/// The scope in which the expression occurs.
pub(crate) file_scope: FileScopeId,
/// The expression node.
#[no_eq]
#[tracked]
#[return_ref]
pub(crate) node_ref: AstNodeRef<ast::Expr>,
/// An assignment statement, if this expression is immediately used as the rhs of that
/// assignment.
///
/// (Note that this is the _immediately_ containing assignment — if a complex expression is
/// assigned to some target, only the outermost expression node has this set. The inner
/// expressions are used to build up the assignment result, and are not "immediately assigned"
/// to the target, and so have `None` for this field.)
#[no_eq]
#[tracked]
pub(crate) assigned_to: Option<AstNodeRef<ast::StmtAssign>>,
/// Should this expression be inferred as a normal expression or a type expression?
pub(crate) kind: ExpressionKind,
count: countme::Count<Expression<'static>>,
}
impl<'db> Expression<'db> {
pub(crate) fn scope(self, db: &'db dyn Db) -> ScopeId<'db> {
self.file_scope(db).to_scope_id(db, self.file(db))
}
}

View file

@ -0,0 +1,144 @@
//! # Narrowing constraints
//!
//! When building a semantic index for a file, we associate each binding with a _narrowing
//! constraint_, which constrains the type of the binding's symbol. Note that a binding can be
//! associated with a different narrowing constraint at different points in a file. See the
//! [`use_def`][crate::semantic_index::use_def] module for more details.
//!
//! This module defines how narrowing constraints are stored internally.
//!
//! A _narrowing constraint_ consists of a list of _predicates_, each of which corresponds with an
//! expression in the source file (represented by a [`Predicate`]). We need to support the
//! following operations on narrowing constraints:
//!
//! - Adding a new predicate to an existing constraint
//! - Merging two constraints together, which produces the _intersection_ of their predicates
//! - Iterating through the predicates in a constraint
//!
//! In particular, note that we do not need random access to the predicates in a constraint. That
//! means that we can use a simple [_sorted association list_][crate::list] as our data structure.
//! That lets us use a single 32-bit integer to store each narrowing constraint, no matter how many
//! predicates it contains. It also makes merging two narrowing constraints fast, since alists
//! support fast intersection.
//!
//! Because we visit the contents of each scope in source-file order, and assign scoped IDs in
//! source-file order, that means that we will tend to visit narrowing constraints in order by
//! their predicate IDs. This is exactly how to get the best performance from our alist
//! implementation.
//!
//! [`Predicate`]: crate::semantic_index::predicate::Predicate
use crate::list::{List, ListBuilder, ListSetReverseIterator, ListStorage};
use crate::semantic_index::predicate::ScopedPredicateId;
/// A narrowing constraint associated with a live binding.
///
/// A constraint is a list of [`Predicate`]s that each constrain the type of the binding's symbol.
///
/// [`Predicate`]: crate::semantic_index::predicate::Predicate
pub(crate) type ScopedNarrowingConstraint = List<ScopedNarrowingConstraintPredicate>;
/// One of the [`Predicate`]s in a narrowing constraint, which constraints the type of the
/// binding's symbol.
///
/// Note that those [`Predicate`]s are stored in [their own per-scope
/// arena][crate::semantic_index::predicate::Predicates], so internally we use a
/// [`ScopedPredicateId`] to refer to the underlying predicate.
///
/// [`Predicate`]: crate::semantic_index::predicate::Predicate
#[derive(Clone, Copy, Debug, Eq, Ord, PartialEq, PartialOrd)]
pub(crate) struct ScopedNarrowingConstraintPredicate(ScopedPredicateId);
impl ScopedNarrowingConstraintPredicate {
/// Returns (the ID of) the `Predicate`
pub(crate) fn predicate(self) -> ScopedPredicateId {
self.0
}
}
impl From<ScopedPredicateId> for ScopedNarrowingConstraintPredicate {
fn from(predicate: ScopedPredicateId) -> ScopedNarrowingConstraintPredicate {
ScopedNarrowingConstraintPredicate(predicate)
}
}
/// A collection of narrowing constraints for a given scope.
#[derive(Debug, Eq, PartialEq)]
pub(crate) struct NarrowingConstraints {
lists: ListStorage<ScopedNarrowingConstraintPredicate>,
}
// Building constraints
// --------------------
/// A builder for creating narrowing constraints.
#[derive(Debug, Default, Eq, PartialEq)]
pub(crate) struct NarrowingConstraintsBuilder {
lists: ListBuilder<ScopedNarrowingConstraintPredicate>,
}
impl NarrowingConstraintsBuilder {
pub(crate) fn build(self) -> NarrowingConstraints {
NarrowingConstraints {
lists: self.lists.build(),
}
}
/// Adds a predicate to an existing narrowing constraint.
pub(crate) fn add_predicate_to_constraint(
&mut self,
constraint: ScopedNarrowingConstraint,
predicate: ScopedNarrowingConstraintPredicate,
) -> ScopedNarrowingConstraint {
self.lists.insert(constraint, predicate)
}
/// Returns the intersection of two narrowing constraints. The result contains the predicates
/// that appear in both inputs.
pub(crate) fn intersect_constraints(
&mut self,
a: ScopedNarrowingConstraint,
b: ScopedNarrowingConstraint,
) -> ScopedNarrowingConstraint {
self.lists.intersect(a, b)
}
}
// Iteration
// ---------
pub(crate) type NarrowingConstraintsIterator<'a> =
std::iter::Copied<ListSetReverseIterator<'a, ScopedNarrowingConstraintPredicate>>;
impl NarrowingConstraints {
/// Iterates over the predicates in a narrowing constraint.
pub(crate) fn iter_predicates(
&self,
set: ScopedNarrowingConstraint,
) -> NarrowingConstraintsIterator<'_> {
self.lists.iter_set_reverse(set).copied()
}
}
// Test support
// ------------
#[cfg(test)]
mod tests {
use super::*;
impl ScopedNarrowingConstraintPredicate {
pub(crate) fn as_u32(self) -> u32 {
self.0.as_u32()
}
}
impl NarrowingConstraintsBuilder {
pub(crate) fn iter_predicates(
&self,
set: ScopedNarrowingConstraint,
) -> NarrowingConstraintsIterator<'_> {
self.lists.iter_set_reverse(set).copied()
}
}
}

View file

@ -0,0 +1,173 @@
//! _Predicates_ are Python expressions whose runtime values can affect type inference.
//!
//! We currently use predicates in two places:
//!
//! - [_Narrowing constraints_][crate::semantic_index::narrowing_constraints] constrain the type of
//! a binding that is visible at a particular use.
//! - [_Visibility constraints_][crate::semantic_index::visibility_constraints] determine the
//! static visibility of a binding, and the reachability of a statement.
use ruff_db::files::File;
use ruff_index::{newtype_index, IndexVec};
use ruff_python_ast::Singleton;
use crate::db::Db;
use crate::semantic_index::expression::Expression;
use crate::semantic_index::global_scope;
use crate::semantic_index::symbol::{FileScopeId, ScopeId, ScopedSymbolId};
// A scoped identifier for each `Predicate` in a scope.
#[newtype_index]
#[derive(Ord, PartialOrd)]
pub(crate) struct ScopedPredicateId;
// A collection of predicates for a given scope.
pub(crate) type Predicates<'db> = IndexVec<ScopedPredicateId, Predicate<'db>>;
#[derive(Debug, Default)]
pub(crate) struct PredicatesBuilder<'db> {
predicates: IndexVec<ScopedPredicateId, Predicate<'db>>,
}
impl<'db> PredicatesBuilder<'db> {
/// Adds a predicate. Note that we do not deduplicate predicates. If you add a `Predicate`
/// more than once, you will get distinct `ScopedPredicateId`s for each one. (This lets you
/// model predicates that might evaluate to different values at different points of execution.)
pub(crate) fn add_predicate(&mut self, predicate: Predicate<'db>) -> ScopedPredicateId {
self.predicates.push(predicate)
}
pub(crate) fn build(mut self) -> Predicates<'db> {
self.predicates.shrink_to_fit();
self.predicates
}
}
#[derive(Clone, Copy, Debug, Hash, PartialEq, Eq, salsa::Update)]
pub(crate) struct Predicate<'db> {
pub(crate) node: PredicateNode<'db>,
pub(crate) is_positive: bool,
}
impl Predicate<'_> {
pub(crate) fn negated(self) -> Self {
Self {
node: self.node,
is_positive: !self.is_positive,
}
}
}
#[derive(Clone, Copy, Debug, Hash, PartialEq, Eq, salsa::Update)]
pub(crate) enum PredicateNode<'db> {
Expression(Expression<'db>),
Pattern(PatternPredicate<'db>),
StarImportPlaceholder(StarImportPlaceholderPredicate<'db>),
}
/// Pattern kinds for which we support type narrowing and/or static visibility analysis.
#[derive(Debug, Clone, Hash, PartialEq, salsa::Update)]
pub(crate) enum PatternPredicateKind<'db> {
Singleton(Singleton),
Value(Expression<'db>),
Or(Vec<PatternPredicateKind<'db>>),
Class(Expression<'db>),
Unsupported,
}
#[salsa::tracked(debug)]
pub(crate) struct PatternPredicate<'db> {
pub(crate) file: File,
pub(crate) file_scope: FileScopeId,
pub(crate) subject: Expression<'db>,
#[return_ref]
pub(crate) kind: PatternPredicateKind<'db>,
pub(crate) guard: Option<Expression<'db>>,
count: countme::Count<PatternPredicate<'static>>,
}
impl<'db> PatternPredicate<'db> {
pub(crate) fn scope(self, db: &'db dyn Db) -> ScopeId<'db> {
self.file_scope(db).to_scope_id(db, self.file(db))
}
}
/// A "placeholder predicate" that is used to model the fact that the boundness of a
/// (possible) definition or declaration caused by a `*` import cannot be fully determined
/// until type-inference time. This is essentially the same as a standard visibility constraint,
/// so we reuse the [`Predicate`] infrastructure to model it.
///
/// To illustrate, say we have a module `exporter.py` like so:
///
/// ```py
/// if <condition>:
/// class A: ...
/// ```
///
/// and we have a module `importer.py` like so:
///
/// ```py
/// A = 1
///
/// from importer import *
/// ```
///
/// Since we cannot know whether or not <condition> is true at semantic-index time,
/// we record a definition for `A` in `b.py` as a result of the `from a import *`
/// statement, but place a predicate on it to record the fact that we don't yet
/// know whether this definition will be visible from all control-flow paths or not.
/// Essentially, we model `b.py` as something similar to this:
///
/// ```py
/// A = 1
///
/// if <star_import_placeholder_predicate>:
/// from a import A
/// ```
///
/// At type-check time, the placeholder predicate for the `A` definition is evaluated by
/// attempting to resolve the `A` symbol in `a.py`'s global namespace:
/// - If it resolves to a definitely bound symbol, then the predicate resolves to [`Truthiness::AlwaysTrue`]
/// - If it resolves to an unbound symbol, then the predicate resolves to [`Truthiness::AlwaysFalse`]
/// - If it resolves to a possibly bound symbol, then the predicate resolves to [`Truthiness::Ambiguous`]
///
/// [Truthiness]: [crate::types::Truthiness]
#[salsa::tracked(debug)]
pub(crate) struct StarImportPlaceholderPredicate<'db> {
pub(crate) importing_file: File,
/// Each symbol imported by a `*` import has a separate predicate associated with it:
/// this field identifies which symbol that is.
///
/// Note that a [`ScopedSymbolId`] is only meaningful if you also know the scope
/// it is relative to. For this specific struct, however, there's no need to store a
/// separate field to hold the ID of the scope. `StarImportPredicate`s are only created
/// for valid `*`-import definitions, and valid `*`-import definitions can only ever
/// exist in the global scope; thus, we know that the `symbol_id` here will be relative
/// to the global scope of the importing file.
pub(crate) symbol_id: ScopedSymbolId,
pub(crate) referenced_file: File,
}
impl<'db> StarImportPlaceholderPredicate<'db> {
pub(crate) fn scope(self, db: &'db dyn Db) -> ScopeId<'db> {
// See doc-comment above [`StarImportPlaceholderPredicate::symbol_id`]:
// valid `*`-import definitions can only take place in the global scope.
global_scope(db, self.importing_file(db))
}
}
impl<'db> From<StarImportPlaceholderPredicate<'db>> for Predicate<'db> {
fn from(predicate: StarImportPlaceholderPredicate<'db>) -> Self {
Predicate {
node: PredicateNode::StarImportPlaceholder(predicate),
is_positive: true,
}
}
}

View file

@ -0,0 +1,408 @@
//! A visitor and query to find all global-scope symbols that are exported from a module
//! when a wildcard import is used.
//!
//! For example, if a module `foo` contains `from bar import *`, which symbols from the global
//! scope of `bar` are imported into the global namespace of `foo`?
//!
//! ## Why is this a separate query rather than a part of semantic indexing?
//!
//! This query is called by the [`super::SemanticIndexBuilder`] in order to add the correct
//! [`super::Definition`]s to the semantic index of a module `foo` if `foo` has a
//! `from bar import *` statement in its global namespace. Adding the correct `Definition`s to
//! `foo`'s [`super::SemanticIndex`] requires knowing which symbols are exported from `bar`.
//!
//! If we determined the set of exported names during semantic indexing rather than as a
//! separate query, we would need to complete semantic indexing on `bar` in order to
//! complete analysis of the global namespace of `foo`. Since semantic indexing is somewhat
//! expensive, this would be undesirable. A separate query allows us to avoid this issue.
//!
//! An additional concern is that the recursive nature of this query means that it must be able
//! to handle cycles. We do this using fixpoint iteration; adding fixpoint iteration to the
//! whole [`super::semantic_index()`] query would probably be prohibitively expensive.
use ruff_db::{files::File, parsed::parsed_module};
use ruff_python_ast::{
self as ast,
name::Name,
visitor::{walk_expr, walk_pattern, walk_stmt, Visitor},
};
use rustc_hash::FxHashMap;
use crate::{module_name::ModuleName, resolve_module, Db};
fn exports_cycle_recover(
_db: &dyn Db,
_value: &[Name],
_count: u32,
_file: File,
) -> salsa::CycleRecoveryAction<Box<[Name]>> {
salsa::CycleRecoveryAction::Iterate
}
fn exports_cycle_initial(_db: &dyn Db, _file: File) -> Box<[Name]> {
Box::default()
}
#[salsa::tracked(return_ref, cycle_fn=exports_cycle_recover, cycle_initial=exports_cycle_initial)]
pub(super) fn exported_names(db: &dyn Db, file: File) -> Box<[Name]> {
let module = parsed_module(db.upcast(), file);
let mut finder = ExportFinder::new(db, file);
finder.visit_body(module.suite());
finder.resolve_exports()
}
struct ExportFinder<'db> {
db: &'db dyn Db,
file: File,
visiting_stub_file: bool,
exports: FxHashMap<&'db Name, PossibleExportKind>,
dunder_all: DunderAll,
}
impl<'db> ExportFinder<'db> {
fn new(db: &'db dyn Db, file: File) -> Self {
Self {
db,
file,
visiting_stub_file: file.is_stub(db.upcast()),
exports: FxHashMap::default(),
dunder_all: DunderAll::NotPresent,
}
}
fn possibly_add_export(&mut self, export: &'db Name, kind: PossibleExportKind) {
self.exports.insert(export, kind);
if export == "__all__" {
self.dunder_all = DunderAll::Present;
}
}
fn resolve_exports(self) -> Box<[Name]> {
match self.dunder_all {
DunderAll::NotPresent => self
.exports
.into_iter()
.filter_map(|(name, kind)| {
if kind == PossibleExportKind::StubImportWithoutRedundantAlias {
return None;
}
if name.starts_with('_') {
return None;
}
Some(name.clone())
})
.collect(),
DunderAll::Present => self.exports.into_keys().cloned().collect(),
}
}
}
impl<'db> Visitor<'db> for ExportFinder<'db> {
fn visit_alias(&mut self, alias: &'db ast::Alias) {
let ast::Alias {
name,
asname,
range: _,
} = alias;
let name = &name.id;
let asname = asname.as_ref().map(|asname| &asname.id);
// If the source is a stub, names defined by imports are only exported
// if they use the explicit `foo as foo` syntax:
let kind = if self.visiting_stub_file && asname.is_none_or(|asname| asname != name) {
PossibleExportKind::StubImportWithoutRedundantAlias
} else {
PossibleExportKind::Normal
};
self.possibly_add_export(asname.unwrap_or(name), kind);
}
fn visit_pattern(&mut self, pattern: &'db ast::Pattern) {
match pattern {
ast::Pattern::MatchAs(ast::PatternMatchAs {
pattern,
name,
range: _,
}) => {
if let Some(pattern) = pattern {
self.visit_pattern(pattern);
}
if let Some(name) = name {
// Wildcard patterns (`case _:`) do not bind names.
// Currently `self.possibly_add_export()` just ignores
// all names with leading underscores, but this will not always be the case
// (in the future we will want to support modules with `__all__ = ['_']`).
if name != "_" {
self.possibly_add_export(&name.id, PossibleExportKind::Normal);
}
}
}
ast::Pattern::MatchMapping(ast::PatternMatchMapping {
patterns,
rest,
keys: _,
range: _,
}) => {
for pattern in patterns {
self.visit_pattern(pattern);
}
if let Some(rest) = rest {
self.possibly_add_export(&rest.id, PossibleExportKind::Normal);
}
}
ast::Pattern::MatchStar(ast::PatternMatchStar { name, range: _ }) => {
if let Some(name) = name {
self.possibly_add_export(&name.id, PossibleExportKind::Normal);
}
}
ast::Pattern::MatchSequence(_)
| ast::Pattern::MatchOr(_)
| ast::Pattern::MatchClass(_) => {
walk_pattern(self, pattern);
}
ast::Pattern::MatchSingleton(_) | ast::Pattern::MatchValue(_) => {}
}
}
fn visit_stmt(&mut self, stmt: &'db ast::Stmt) {
match stmt {
ast::Stmt::ClassDef(ast::StmtClassDef {
name,
decorator_list,
arguments,
type_params: _, // We don't want to visit the type params of the class
body: _, // We don't want to visit the body of the class
range: _,
}) => {
self.possibly_add_export(&name.id, PossibleExportKind::Normal);
for decorator in decorator_list {
self.visit_decorator(decorator);
}
if let Some(arguments) = arguments {
self.visit_arguments(arguments);
}
}
ast::Stmt::FunctionDef(ast::StmtFunctionDef {
name,
decorator_list,
parameters,
returns,
type_params: _, // We don't want to visit the type params of the function
body: _, // We don't want to visit the body of the function
range: _,
is_async: _,
}) => {
self.possibly_add_export(&name.id, PossibleExportKind::Normal);
for decorator in decorator_list {
self.visit_decorator(decorator);
}
self.visit_parameters(parameters);
if let Some(returns) = returns {
self.visit_expr(returns);
}
}
ast::Stmt::AnnAssign(ast::StmtAnnAssign {
target,
value,
annotation,
simple: _,
range: _,
}) => {
if value.is_some() || self.visiting_stub_file {
self.visit_expr(target);
}
self.visit_expr(annotation);
if let Some(value) = value {
self.visit_expr(value);
}
}
ast::Stmt::TypeAlias(ast::StmtTypeAlias {
name,
type_params: _,
value: _,
range: _,
}) => {
self.visit_expr(name);
// Neither walrus expressions nor statements cannot appear in type aliases;
// no need to recursively visit the `value` or `type_params`
}
ast::Stmt::ImportFrom(node) => {
let mut found_star = false;
for name in &node.names {
if &name.name.id == "*" {
if !found_star {
found_star = true;
for export in
ModuleName::from_import_statement(self.db, self.file, node)
.ok()
.and_then(|module_name| resolve_module(self.db, &module_name))
.iter()
.flat_map(|module| exported_names(self.db, module.file()))
{
self.possibly_add_export(export, PossibleExportKind::Normal);
}
}
} else {
self.visit_alias(name);
}
}
}
ast::Stmt::Import(_)
| ast::Stmt::AugAssign(_)
| ast::Stmt::While(_)
| ast::Stmt::If(_)
| ast::Stmt::With(_)
| ast::Stmt::Assert(_)
| ast::Stmt::Try(_)
| ast::Stmt::Expr(_)
| ast::Stmt::For(_)
| ast::Stmt::Assign(_)
| ast::Stmt::Match(_) => walk_stmt(self, stmt),
ast::Stmt::Global(_)
| ast::Stmt::Raise(_)
| ast::Stmt::Return(_)
| ast::Stmt::Break(_)
| ast::Stmt::Continue(_)
| ast::Stmt::IpyEscapeCommand(_)
| ast::Stmt::Delete(_)
| ast::Stmt::Nonlocal(_)
| ast::Stmt::Pass(_) => {}
}
}
fn visit_expr(&mut self, expr: &'db ast::Expr) {
match expr {
ast::Expr::Name(ast::ExprName { id, ctx, range: _ }) => {
if ctx.is_store() {
self.possibly_add_export(id, PossibleExportKind::Normal);
}
}
ast::Expr::Lambda(_)
| ast::Expr::BooleanLiteral(_)
| ast::Expr::NoneLiteral(_)
| ast::Expr::NumberLiteral(_)
| ast::Expr::BytesLiteral(_)
| ast::Expr::EllipsisLiteral(_)
| ast::Expr::StringLiteral(_) => {}
// Walrus definitions "leak" from comprehension scopes into the comprehension's
// enclosing scope; they thus need special handling
ast::Expr::SetComp(_)
| ast::Expr::ListComp(_)
| ast::Expr::Generator(_)
| ast::Expr::DictComp(_) => {
let mut walrus_finder = WalrusFinder {
export_finder: self,
};
walk_expr(&mut walrus_finder, expr);
}
ast::Expr::BoolOp(_)
| ast::Expr::Named(_)
| ast::Expr::BinOp(_)
| ast::Expr::UnaryOp(_)
| ast::Expr::If(_)
| ast::Expr::Attribute(_)
| ast::Expr::Subscript(_)
| ast::Expr::Starred(_)
| ast::Expr::Call(_)
| ast::Expr::Compare(_)
| ast::Expr::Yield(_)
| ast::Expr::YieldFrom(_)
| ast::Expr::FString(_)
| ast::Expr::Tuple(_)
| ast::Expr::List(_)
| ast::Expr::Slice(_)
| ast::Expr::IpyEscapeCommand(_)
| ast::Expr::Dict(_)
| ast::Expr::Set(_)
| ast::Expr::Await(_) => walk_expr(self, expr),
}
}
}
struct WalrusFinder<'a, 'db> {
export_finder: &'a mut ExportFinder<'db>,
}
impl<'db> Visitor<'db> for WalrusFinder<'_, 'db> {
fn visit_expr(&mut self, expr: &'db ast::Expr) {
match expr {
// It's important for us to short-circuit here for lambdas specifically,
// as walruses cannot leak out of the body of a lambda function.
ast::Expr::Lambda(_)
| ast::Expr::BooleanLiteral(_)
| ast::Expr::NoneLiteral(_)
| ast::Expr::NumberLiteral(_)
| ast::Expr::BytesLiteral(_)
| ast::Expr::EllipsisLiteral(_)
| ast::Expr::StringLiteral(_)
| ast::Expr::Name(_) => {}
ast::Expr::Named(ast::ExprNamed {
target,
value: _,
range: _,
}) => {
if let ast::Expr::Name(ast::ExprName {
id,
ctx: ast::ExprContext::Store,
range: _,
}) = &**target
{
self.export_finder
.possibly_add_export(id, PossibleExportKind::Normal);
}
}
// We must recurse inside nested comprehensions,
// as even a walrus inside a comprehension inside a comprehension in the global scope
// will leak out into the global scope
ast::Expr::DictComp(_)
| ast::Expr::SetComp(_)
| ast::Expr::ListComp(_)
| ast::Expr::Generator(_)
| ast::Expr::BoolOp(_)
| ast::Expr::BinOp(_)
| ast::Expr::UnaryOp(_)
| ast::Expr::If(_)
| ast::Expr::Attribute(_)
| ast::Expr::Subscript(_)
| ast::Expr::Starred(_)
| ast::Expr::Call(_)
| ast::Expr::Compare(_)
| ast::Expr::Yield(_)
| ast::Expr::YieldFrom(_)
| ast::Expr::FString(_)
| ast::Expr::Tuple(_)
| ast::Expr::List(_)
| ast::Expr::Slice(_)
| ast::Expr::IpyEscapeCommand(_)
| ast::Expr::Dict(_)
| ast::Expr::Set(_)
| ast::Expr::Await(_) => walk_expr(self, expr),
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
enum PossibleExportKind {
Normal,
StubImportWithoutRedundantAlias,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
enum DunderAll {
NotPresent,
Present,
}

View file

@ -0,0 +1,576 @@
use std::hash::{Hash, Hasher};
use std::ops::Range;
use bitflags::bitflags;
use hashbrown::hash_map::RawEntryMut;
use ruff_db::files::File;
use ruff_db::parsed::ParsedModule;
use ruff_index::{newtype_index, IndexVec};
use ruff_python_ast as ast;
use ruff_python_ast::name::Name;
use rustc_hash::FxHasher;
use crate::ast_node_ref::AstNodeRef;
use crate::node_key::NodeKey;
use crate::semantic_index::visibility_constraints::ScopedVisibilityConstraintId;
use crate::semantic_index::{semantic_index, SymbolMap};
use crate::Db;
#[derive(Eq, PartialEq, Debug)]
pub struct Symbol {
name: Name,
flags: SymbolFlags,
}
impl Symbol {
fn new(name: Name) -> Self {
Self {
name,
flags: SymbolFlags::empty(),
}
}
fn insert_flags(&mut self, flags: SymbolFlags) {
self.flags.insert(flags);
}
/// The symbol's name.
pub fn name(&self) -> &Name {
&self.name
}
/// Is the symbol used in its containing scope?
pub fn is_used(&self) -> bool {
self.flags.contains(SymbolFlags::IS_USED)
}
/// Is the symbol defined in its containing scope?
pub fn is_bound(&self) -> bool {
self.flags.contains(SymbolFlags::IS_BOUND)
}
/// Is the symbol declared in its containing scope?
pub fn is_declared(&self) -> bool {
self.flags.contains(SymbolFlags::IS_DECLARED)
}
}
bitflags! {
/// Flags that can be queried to obtain information about a symbol in a given scope.
///
/// See the doc-comment at the top of [`super::use_def`] for explanations of what it
/// means for a symbol to be *bound* as opposed to *declared*.
#[derive(Copy, Clone, Debug, Eq, PartialEq)]
struct SymbolFlags: u8 {
const IS_USED = 1 << 0;
const IS_BOUND = 1 << 1;
const IS_DECLARED = 1 << 2;
/// TODO: This flag is not yet set by anything
const MARKED_GLOBAL = 1 << 3;
/// TODO: This flag is not yet set by anything
const MARKED_NONLOCAL = 1 << 4;
}
}
/// ID that uniquely identifies a symbol in a file.
#[derive(Copy, Clone, Debug, Eq, PartialEq, Hash)]
pub struct FileSymbolId {
scope: FileScopeId,
scoped_symbol_id: ScopedSymbolId,
}
impl FileSymbolId {
pub fn scope(self) -> FileScopeId {
self.scope
}
pub(crate) fn scoped_symbol_id(self) -> ScopedSymbolId {
self.scoped_symbol_id
}
}
impl From<FileSymbolId> for ScopedSymbolId {
fn from(val: FileSymbolId) -> Self {
val.scoped_symbol_id()
}
}
/// Symbol ID that uniquely identifies a symbol inside a [`Scope`].
#[newtype_index]
#[derive(salsa::Update)]
pub struct ScopedSymbolId;
/// A cross-module identifier of a scope that can be used as a salsa query parameter.
#[salsa::tracked(debug)]
pub struct ScopeId<'db> {
pub file: File,
pub file_scope_id: FileScopeId,
count: countme::Count<ScopeId<'static>>,
}
impl<'db> ScopeId<'db> {
pub(crate) fn is_function_like(self, db: &'db dyn Db) -> bool {
self.node(db).scope_kind().is_function_like()
}
pub(crate) fn is_type_parameter(self, db: &'db dyn Db) -> bool {
self.node(db).scope_kind().is_type_parameter()
}
pub(crate) fn node(self, db: &dyn Db) -> &NodeWithScopeKind {
self.scope(db).node()
}
pub(crate) fn scope(self, db: &dyn Db) -> &Scope {
semantic_index(db, self.file(db)).scope(self.file_scope_id(db))
}
#[cfg(test)]
pub(crate) fn name(self, db: &'db dyn Db) -> &'db str {
match self.node(db) {
NodeWithScopeKind::Module => "<module>",
NodeWithScopeKind::Class(class) | NodeWithScopeKind::ClassTypeParameters(class) => {
class.name.as_str()
}
NodeWithScopeKind::Function(function)
| NodeWithScopeKind::FunctionTypeParameters(function) => function.name.as_str(),
NodeWithScopeKind::TypeAlias(type_alias)
| NodeWithScopeKind::TypeAliasTypeParameters(type_alias) => type_alias
.name
.as_name_expr()
.map(|name| name.id.as_str())
.unwrap_or("<type alias>"),
NodeWithScopeKind::Lambda(_) => "<lambda>",
NodeWithScopeKind::ListComprehension(_) => "<listcomp>",
NodeWithScopeKind::SetComprehension(_) => "<setcomp>",
NodeWithScopeKind::DictComprehension(_) => "<dictcomp>",
NodeWithScopeKind::GeneratorExpression(_) => "<generator>",
}
}
}
/// ID that uniquely identifies a scope inside of a module.
#[newtype_index]
#[derive(salsa::Update)]
pub struct FileScopeId;
impl FileScopeId {
/// Returns the scope id of the module-global scope.
pub fn global() -> Self {
FileScopeId::from_u32(0)
}
pub fn is_global(self) -> bool {
self == FileScopeId::global()
}
pub fn to_scope_id(self, db: &dyn Db, file: File) -> ScopeId<'_> {
let index = semantic_index(db, file);
index.scope_ids_by_scope[self]
}
}
#[derive(Debug, salsa::Update)]
pub struct Scope {
parent: Option<FileScopeId>,
node: NodeWithScopeKind,
descendants: Range<FileScopeId>,
reachability: ScopedVisibilityConstraintId,
}
impl Scope {
pub(super) fn new(
parent: Option<FileScopeId>,
node: NodeWithScopeKind,
descendants: Range<FileScopeId>,
reachability: ScopedVisibilityConstraintId,
) -> Self {
Scope {
parent,
node,
descendants,
reachability,
}
}
pub fn parent(&self) -> Option<FileScopeId> {
self.parent
}
pub fn node(&self) -> &NodeWithScopeKind {
&self.node
}
pub fn kind(&self) -> ScopeKind {
self.node().scope_kind()
}
pub fn descendants(&self) -> Range<FileScopeId> {
self.descendants.clone()
}
pub(super) fn extend_descendants(&mut self, children_end: FileScopeId) {
self.descendants = self.descendants.start..children_end;
}
pub(crate) fn is_eager(&self) -> bool {
self.kind().is_eager()
}
pub(crate) fn reachability(&self) -> ScopedVisibilityConstraintId {
self.reachability
}
}
#[derive(Copy, Clone, Debug, PartialEq, Eq)]
pub enum ScopeKind {
Module,
Annotation,
Class,
Function,
Lambda,
Comprehension,
TypeAlias,
}
impl ScopeKind {
pub(crate) fn is_eager(self) -> bool {
match self {
ScopeKind::Module | ScopeKind::Class | ScopeKind::Comprehension => true,
ScopeKind::Annotation
| ScopeKind::Function
| ScopeKind::Lambda
| ScopeKind::TypeAlias => false,
}
}
pub(crate) fn is_function_like(self) -> bool {
// Type parameter scopes behave like function scopes in terms of name resolution; CPython
// symbol table also uses the term "function-like" for these scopes.
matches!(
self,
ScopeKind::Annotation
| ScopeKind::Function
| ScopeKind::Lambda
| ScopeKind::TypeAlias
| ScopeKind::Comprehension
)
}
pub(crate) fn is_class(self) -> bool {
matches!(self, ScopeKind::Class)
}
pub(crate) fn is_type_parameter(self) -> bool {
matches!(self, ScopeKind::Annotation | ScopeKind::TypeAlias)
}
}
/// Symbol table for a specific [`Scope`].
#[derive(Default, salsa::Update)]
pub struct SymbolTable {
/// The symbols in this scope.
symbols: IndexVec<ScopedSymbolId, Symbol>,
/// The symbols indexed by name.
symbols_by_name: SymbolMap,
}
impl SymbolTable {
fn shrink_to_fit(&mut self) {
self.symbols.shrink_to_fit();
}
pub(crate) fn symbol(&self, symbol_id: impl Into<ScopedSymbolId>) -> &Symbol {
&self.symbols[symbol_id.into()]
}
#[allow(unused)]
pub(crate) fn symbol_ids(&self) -> impl Iterator<Item = ScopedSymbolId> {
self.symbols.indices()
}
pub fn symbols(&self) -> impl Iterator<Item = &Symbol> {
self.symbols.iter()
}
/// Returns the symbol named `name`.
pub(crate) fn symbol_by_name(&self, name: &str) -> Option<&Symbol> {
let id = self.symbol_id_by_name(name)?;
Some(self.symbol(id))
}
/// Returns the [`ScopedSymbolId`] of the symbol named `name`.
pub(crate) fn symbol_id_by_name(&self, name: &str) -> Option<ScopedSymbolId> {
let (id, ()) = self
.symbols_by_name
.raw_entry()
.from_hash(Self::hash_name(name), |id| {
self.symbol(*id).name().as_str() == name
})?;
Some(*id)
}
fn hash_name(name: &str) -> u64 {
let mut hasher = FxHasher::default();
name.hash(&mut hasher);
hasher.finish()
}
}
impl PartialEq for SymbolTable {
fn eq(&self, other: &Self) -> bool {
// We don't need to compare the symbols_by_name because the name is already captured in `Symbol`.
self.symbols == other.symbols
}
}
impl Eq for SymbolTable {}
impl std::fmt::Debug for SymbolTable {
/// Exclude the `symbols_by_name` field from the debug output.
/// It's very noisy and not useful for debugging.
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_tuple("SymbolTable")
.field(&self.symbols)
.finish_non_exhaustive()
}
}
#[derive(Debug, Default)]
pub(super) struct SymbolTableBuilder {
table: SymbolTable,
}
impl SymbolTableBuilder {
pub(super) fn add_symbol(&mut self, name: Name) -> (ScopedSymbolId, bool) {
let hash = SymbolTable::hash_name(&name);
let entry = self
.table
.symbols_by_name
.raw_entry_mut()
.from_hash(hash, |id| self.table.symbols[*id].name() == &name);
match entry {
RawEntryMut::Occupied(entry) => (*entry.key(), false),
RawEntryMut::Vacant(entry) => {
let symbol = Symbol::new(name);
let id = self.table.symbols.push(symbol);
entry.insert_with_hasher(hash, id, (), |id| {
SymbolTable::hash_name(self.table.symbols[*id].name().as_str())
});
(id, true)
}
}
}
pub(super) fn mark_symbol_bound(&mut self, id: ScopedSymbolId) {
self.table.symbols[id].insert_flags(SymbolFlags::IS_BOUND);
}
pub(super) fn mark_symbol_declared(&mut self, id: ScopedSymbolId) {
self.table.symbols[id].insert_flags(SymbolFlags::IS_DECLARED);
}
pub(super) fn mark_symbol_used(&mut self, id: ScopedSymbolId) {
self.table.symbols[id].insert_flags(SymbolFlags::IS_USED);
}
pub(super) fn symbols(&self) -> impl Iterator<Item = &Symbol> {
self.table.symbols()
}
pub(super) fn symbol_id_by_name(&self, name: &str) -> Option<ScopedSymbolId> {
self.table.symbol_id_by_name(name)
}
pub(super) fn symbol(&self, symbol_id: impl Into<ScopedSymbolId>) -> &Symbol {
self.table.symbol(symbol_id)
}
pub(super) fn finish(mut self) -> SymbolTable {
self.table.shrink_to_fit();
self.table
}
}
/// Reference to a node that introduces a new scope.
#[derive(Copy, Clone, Debug)]
pub(crate) enum NodeWithScopeRef<'a> {
Module,
Class(&'a ast::StmtClassDef),
Function(&'a ast::StmtFunctionDef),
Lambda(&'a ast::ExprLambda),
FunctionTypeParameters(&'a ast::StmtFunctionDef),
ClassTypeParameters(&'a ast::StmtClassDef),
TypeAlias(&'a ast::StmtTypeAlias),
TypeAliasTypeParameters(&'a ast::StmtTypeAlias),
ListComprehension(&'a ast::ExprListComp),
SetComprehension(&'a ast::ExprSetComp),
DictComprehension(&'a ast::ExprDictComp),
GeneratorExpression(&'a ast::ExprGenerator),
}
impl NodeWithScopeRef<'_> {
/// Converts the unowned reference to an owned [`NodeWithScopeKind`].
///
/// # Safety
/// The node wrapped by `self` must be a child of `module`.
#[allow(unsafe_code)]
pub(super) unsafe fn to_kind(self, module: ParsedModule) -> NodeWithScopeKind {
match self {
NodeWithScopeRef::Module => NodeWithScopeKind::Module,
NodeWithScopeRef::Class(class) => {
NodeWithScopeKind::Class(AstNodeRef::new(module, class))
}
NodeWithScopeRef::Function(function) => {
NodeWithScopeKind::Function(AstNodeRef::new(module, function))
}
NodeWithScopeRef::TypeAlias(type_alias) => {
NodeWithScopeKind::TypeAlias(AstNodeRef::new(module, type_alias))
}
NodeWithScopeRef::TypeAliasTypeParameters(type_alias) => {
NodeWithScopeKind::TypeAliasTypeParameters(AstNodeRef::new(module, type_alias))
}
NodeWithScopeRef::Lambda(lambda) => {
NodeWithScopeKind::Lambda(AstNodeRef::new(module, lambda))
}
NodeWithScopeRef::FunctionTypeParameters(function) => {
NodeWithScopeKind::FunctionTypeParameters(AstNodeRef::new(module, function))
}
NodeWithScopeRef::ClassTypeParameters(class) => {
NodeWithScopeKind::ClassTypeParameters(AstNodeRef::new(module, class))
}
NodeWithScopeRef::ListComprehension(comprehension) => {
NodeWithScopeKind::ListComprehension(AstNodeRef::new(module, comprehension))
}
NodeWithScopeRef::SetComprehension(comprehension) => {
NodeWithScopeKind::SetComprehension(AstNodeRef::new(module, comprehension))
}
NodeWithScopeRef::DictComprehension(comprehension) => {
NodeWithScopeKind::DictComprehension(AstNodeRef::new(module, comprehension))
}
NodeWithScopeRef::GeneratorExpression(generator) => {
NodeWithScopeKind::GeneratorExpression(AstNodeRef::new(module, generator))
}
}
}
pub(crate) fn node_key(self) -> NodeWithScopeKey {
match self {
NodeWithScopeRef::Module => NodeWithScopeKey::Module,
NodeWithScopeRef::Class(class) => NodeWithScopeKey::Class(NodeKey::from_node(class)),
NodeWithScopeRef::Function(function) => {
NodeWithScopeKey::Function(NodeKey::from_node(function))
}
NodeWithScopeRef::Lambda(lambda) => {
NodeWithScopeKey::Lambda(NodeKey::from_node(lambda))
}
NodeWithScopeRef::FunctionTypeParameters(function) => {
NodeWithScopeKey::FunctionTypeParameters(NodeKey::from_node(function))
}
NodeWithScopeRef::ClassTypeParameters(class) => {
NodeWithScopeKey::ClassTypeParameters(NodeKey::from_node(class))
}
NodeWithScopeRef::TypeAlias(type_alias) => {
NodeWithScopeKey::TypeAlias(NodeKey::from_node(type_alias))
}
NodeWithScopeRef::TypeAliasTypeParameters(type_alias) => {
NodeWithScopeKey::TypeAliasTypeParameters(NodeKey::from_node(type_alias))
}
NodeWithScopeRef::ListComprehension(comprehension) => {
NodeWithScopeKey::ListComprehension(NodeKey::from_node(comprehension))
}
NodeWithScopeRef::SetComprehension(comprehension) => {
NodeWithScopeKey::SetComprehension(NodeKey::from_node(comprehension))
}
NodeWithScopeRef::DictComprehension(comprehension) => {
NodeWithScopeKey::DictComprehension(NodeKey::from_node(comprehension))
}
NodeWithScopeRef::GeneratorExpression(generator) => {
NodeWithScopeKey::GeneratorExpression(NodeKey::from_node(generator))
}
}
}
}
/// Node that introduces a new scope.
#[derive(Clone, Debug, salsa::Update)]
pub enum NodeWithScopeKind {
Module,
Class(AstNodeRef<ast::StmtClassDef>),
ClassTypeParameters(AstNodeRef<ast::StmtClassDef>),
Function(AstNodeRef<ast::StmtFunctionDef>),
FunctionTypeParameters(AstNodeRef<ast::StmtFunctionDef>),
TypeAliasTypeParameters(AstNodeRef<ast::StmtTypeAlias>),
TypeAlias(AstNodeRef<ast::StmtTypeAlias>),
Lambda(AstNodeRef<ast::ExprLambda>),
ListComprehension(AstNodeRef<ast::ExprListComp>),
SetComprehension(AstNodeRef<ast::ExprSetComp>),
DictComprehension(AstNodeRef<ast::ExprDictComp>),
GeneratorExpression(AstNodeRef<ast::ExprGenerator>),
}
impl NodeWithScopeKind {
pub(crate) const fn scope_kind(&self) -> ScopeKind {
match self {
Self::Module => ScopeKind::Module,
Self::Class(_) => ScopeKind::Class,
Self::Function(_) => ScopeKind::Function,
Self::Lambda(_) => ScopeKind::Lambda,
Self::FunctionTypeParameters(_)
| Self::ClassTypeParameters(_)
| Self::TypeAliasTypeParameters(_) => ScopeKind::Annotation,
Self::TypeAlias(_) => ScopeKind::TypeAlias,
Self::ListComprehension(_)
| Self::SetComprehension(_)
| Self::DictComprehension(_)
| Self::GeneratorExpression(_) => ScopeKind::Comprehension,
}
}
pub fn expect_class(&self) -> &ast::StmtClassDef {
match self {
Self::Class(class) => class.node(),
_ => panic!("expected class"),
}
}
pub fn expect_function(&self) -> &ast::StmtFunctionDef {
self.as_function().expect("expected function")
}
pub fn expect_type_alias(&self) -> &ast::StmtTypeAlias {
match self {
Self::TypeAlias(type_alias) => type_alias.node(),
_ => panic!("expected type alias"),
}
}
pub const fn as_function(&self) -> Option<&ast::StmtFunctionDef> {
match self {
Self::Function(function) => Some(function.node()),
_ => None,
}
}
}
#[derive(Copy, Clone, Debug, Eq, PartialEq, Hash)]
pub(crate) enum NodeWithScopeKey {
Module,
Class(NodeKey),
ClassTypeParameters(NodeKey),
Function(NodeKey),
FunctionTypeParameters(NodeKey),
TypeAlias(NodeKey),
TypeAliasTypeParameters(NodeKey),
Lambda(NodeKey),
ListComprehension(NodeKey),
SetComprehension(NodeKey),
DictComprehension(NodeKey),
GeneratorExpression(NodeKey),
}

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,633 @@
//! Track live bindings per symbol, applicable constraints per binding, and live declarations.
//!
//! These data structures operate entirely on scope-local newtype-indices for definitions and
//! constraints, referring to their location in the `all_definitions` and `all_constraints`
//! indexvecs in [`super::UseDefMapBuilder`].
//!
//! We need to track arbitrary associations between bindings and constraints, not just a single set
//! of currently dominating constraints (where "dominating" means "control flow must have passed
//! through it to reach this point"), because we can have dominating constraints that apply to some
//! bindings but not others, as in this code:
//!
//! ```python
//! x = 1 if flag else None
//! if x is not None:
//! if flag2:
//! x = 2 if flag else None
//! x
//! ```
//!
//! The `x is not None` constraint dominates the final use of `x`, but it applies only to the first
//! binding of `x`, not the second, so `None` is a possible value for `x`.
//!
//! And we can't just track, for each binding, an index into a list of dominating constraints,
//! either, because we can have bindings which are still visible, but subject to constraints that
//! are no longer dominating, as in this code:
//!
//! ```python
//! x = 0
//! if flag1:
//! x = 1 if flag2 else None
//! assert x is not None
//! x
//! ```
//!
//! From the point of view of the final use of `x`, the `x is not None` constraint no longer
//! dominates, but it does dominate the `x = 1 if flag2 else None` binding, so we have to keep
//! track of that.
//!
//! The data structures use `IndexVec` arenas to store all data compactly and contiguously, while
//! supporting very cheap clones.
//!
//! Tracking live declarations is simpler, since constraints are not involved, but otherwise very
//! similar to tracking live bindings.
use itertools::{EitherOrBoth, Itertools};
use ruff_index::newtype_index;
use smallvec::{smallvec, SmallVec};
use crate::semantic_index::narrowing_constraints::{
NarrowingConstraintsBuilder, ScopedNarrowingConstraint, ScopedNarrowingConstraintPredicate,
};
use crate::semantic_index::visibility_constraints::{
ScopedVisibilityConstraintId, VisibilityConstraintsBuilder,
};
/// A newtype-index for a definition in a particular scope.
#[newtype_index]
#[derive(Ord, PartialOrd)]
pub(super) struct ScopedDefinitionId;
impl ScopedDefinitionId {
/// A special ID that is used to describe an implicit start-of-scope state. When
/// we see that this definition is live, we know that the symbol is (possibly)
/// unbound or undeclared at a given usage site.
/// When creating a use-def-map builder, we always add an empty `None` definition
/// at index 0, so this ID is always present.
pub(super) const UNBOUND: ScopedDefinitionId = ScopedDefinitionId::from_u32(0);
}
/// Can keep inline this many live bindings or declarations per symbol at a given time; more will
/// go to heap.
const INLINE_DEFINITIONS_PER_SYMBOL: usize = 4;
/// Live declarations for a single symbol at some point in control flow, with their
/// corresponding visibility constraints.
#[derive(Clone, Debug, Default, PartialEq, Eq, salsa::Update)]
pub(super) struct SymbolDeclarations {
/// A list of live declarations for this symbol, sorted by their `ScopedDefinitionId`
live_declarations: SmallVec<[LiveDeclaration; INLINE_DEFINITIONS_PER_SYMBOL]>,
}
/// One of the live declarations for a single symbol at some point in control flow.
#[derive(Clone, Debug, PartialEq, Eq)]
pub(super) struct LiveDeclaration {
pub(super) declaration: ScopedDefinitionId,
pub(super) visibility_constraint: ScopedVisibilityConstraintId,
}
pub(super) type LiveDeclarationsIterator<'a> = std::slice::Iter<'a, LiveDeclaration>;
impl SymbolDeclarations {
fn undeclared(scope_start_visibility: ScopedVisibilityConstraintId) -> Self {
let initial_declaration = LiveDeclaration {
declaration: ScopedDefinitionId::UNBOUND,
visibility_constraint: scope_start_visibility,
};
Self {
live_declarations: smallvec![initial_declaration],
}
}
/// Record a newly-encountered declaration for this symbol.
fn record_declaration(&mut self, declaration: ScopedDefinitionId) {
// The new declaration replaces all previous live declaration in this path.
self.live_declarations.clear();
self.live_declarations.push(LiveDeclaration {
declaration,
visibility_constraint: ScopedVisibilityConstraintId::ALWAYS_TRUE,
});
}
/// Add given visibility constraint to all live declarations.
pub(super) fn record_visibility_constraint(
&mut self,
visibility_constraints: &mut VisibilityConstraintsBuilder,
constraint: ScopedVisibilityConstraintId,
) {
for declaration in &mut self.live_declarations {
declaration.visibility_constraint = visibility_constraints
.add_and_constraint(declaration.visibility_constraint, constraint);
}
}
/// Return an iterator over live declarations for this symbol.
pub(super) fn iter(&self) -> LiveDeclarationsIterator<'_> {
self.live_declarations.iter()
}
/// Iterate over the IDs of each currently live declaration for this symbol
fn iter_declarations(&self) -> impl Iterator<Item = ScopedDefinitionId> + '_ {
self.iter().map(|lb| lb.declaration)
}
fn simplify_visibility_constraints(&mut self, other: SymbolDeclarations) {
// If the set of live declarations hasn't changed, don't simplify.
if self.live_declarations.len() != other.live_declarations.len()
|| !self.iter_declarations().eq(other.iter_declarations())
{
return;
}
for (declaration, other_declaration) in self
.live_declarations
.iter_mut()
.zip(other.live_declarations)
{
declaration.visibility_constraint = other_declaration.visibility_constraint;
}
}
fn merge(&mut self, b: Self, visibility_constraints: &mut VisibilityConstraintsBuilder) {
let a = std::mem::take(self);
// Invariant: merge_join_by consumes the two iterators in sorted order, which ensures that
// the merged `live_declarations` vec remains sorted. If a definition is found in both `a`
// and `b`, we compose the constraints from the two paths in an appropriate way
// (intersection for narrowing constraints; ternary OR for visibility constraints). If a
// definition is found in only one path, it is used as-is.
let a = a.live_declarations.into_iter();
let b = b.live_declarations.into_iter();
for zipped in a.merge_join_by(b, |a, b| a.declaration.cmp(&b.declaration)) {
match zipped {
EitherOrBoth::Both(a, b) => {
let visibility_constraint = visibility_constraints
.add_or_constraint(a.visibility_constraint, b.visibility_constraint);
self.live_declarations.push(LiveDeclaration {
declaration: a.declaration,
visibility_constraint,
});
}
EitherOrBoth::Left(declaration) | EitherOrBoth::Right(declaration) => {
self.live_declarations.push(declaration);
}
}
}
}
}
/// Live bindings for a single symbol at some point in control flow. Each live binding comes
/// with a set of narrowing constraints and a visibility constraint.
#[derive(Clone, Debug, Default, PartialEq, Eq, salsa::Update)]
pub(super) struct SymbolBindings {
/// A list of live bindings for this symbol, sorted by their `ScopedDefinitionId`
live_bindings: SmallVec<[LiveBinding; INLINE_DEFINITIONS_PER_SYMBOL]>,
}
/// One of the live bindings for a single symbol at some point in control flow.
#[derive(Clone, Debug, PartialEq, Eq)]
pub(super) struct LiveBinding {
pub(super) binding: ScopedDefinitionId,
pub(super) narrowing_constraint: ScopedNarrowingConstraint,
pub(super) visibility_constraint: ScopedVisibilityConstraintId,
}
pub(super) type LiveBindingsIterator<'a> = std::slice::Iter<'a, LiveBinding>;
impl SymbolBindings {
fn unbound(scope_start_visibility: ScopedVisibilityConstraintId) -> Self {
let initial_binding = LiveBinding {
binding: ScopedDefinitionId::UNBOUND,
narrowing_constraint: ScopedNarrowingConstraint::empty(),
visibility_constraint: scope_start_visibility,
};
Self {
live_bindings: smallvec![initial_binding],
}
}
/// Record a newly-encountered binding for this symbol.
pub(super) fn record_binding(
&mut self,
binding: ScopedDefinitionId,
visibility_constraint: ScopedVisibilityConstraintId,
) {
// The new binding replaces all previous live bindings in this path, and has no
// constraints.
self.live_bindings.clear();
self.live_bindings.push(LiveBinding {
binding,
narrowing_constraint: ScopedNarrowingConstraint::empty(),
visibility_constraint,
});
}
/// Add given constraint to all live bindings.
pub(super) fn record_narrowing_constraint(
&mut self,
narrowing_constraints: &mut NarrowingConstraintsBuilder,
predicate: ScopedNarrowingConstraintPredicate,
) {
for binding in &mut self.live_bindings {
binding.narrowing_constraint = narrowing_constraints
.add_predicate_to_constraint(binding.narrowing_constraint, predicate);
}
}
/// Add given visibility constraint to all live bindings.
pub(super) fn record_visibility_constraint(
&mut self,
visibility_constraints: &mut VisibilityConstraintsBuilder,
constraint: ScopedVisibilityConstraintId,
) {
for binding in &mut self.live_bindings {
binding.visibility_constraint = visibility_constraints
.add_and_constraint(binding.visibility_constraint, constraint);
}
}
/// Iterate over currently live bindings for this symbol
pub(super) fn iter(&self) -> LiveBindingsIterator<'_> {
self.live_bindings.iter()
}
/// Iterate over the IDs of each currently live binding for this symbol
fn iter_bindings(&self) -> impl Iterator<Item = ScopedDefinitionId> + '_ {
self.iter().map(|lb| lb.binding)
}
fn simplify_visibility_constraints(&mut self, other: SymbolBindings) {
// If the set of live bindings hasn't changed, don't simplify.
if self.live_bindings.len() != other.live_bindings.len()
|| !self.iter_bindings().eq(other.iter_bindings())
{
return;
}
for (binding, other_binding) in self.live_bindings.iter_mut().zip(other.live_bindings) {
binding.visibility_constraint = other_binding.visibility_constraint;
}
}
fn merge(
&mut self,
b: Self,
narrowing_constraints: &mut NarrowingConstraintsBuilder,
visibility_constraints: &mut VisibilityConstraintsBuilder,
) {
let a = std::mem::take(self);
// Invariant: merge_join_by consumes the two iterators in sorted order, which ensures that
// the merged `live_bindings` vec remains sorted. If a definition is found in both `a` and
// `b`, we compose the constraints from the two paths in an appropriate way (intersection
// for narrowing constraints; ternary OR for visibility constraints). If a definition is
// found in only one path, it is used as-is.
let a = a.live_bindings.into_iter();
let b = b.live_bindings.into_iter();
for zipped in a.merge_join_by(b, |a, b| a.binding.cmp(&b.binding)) {
match zipped {
EitherOrBoth::Both(a, b) => {
// If the same definition is visible through both paths, any constraint
// that applies on only one path is irrelevant to the resulting type from
// unioning the two paths, so we intersect the constraints.
let narrowing_constraint = narrowing_constraints
.intersect_constraints(a.narrowing_constraint, b.narrowing_constraint);
// For visibility constraints, we merge them using a ternary OR operation:
let visibility_constraint = visibility_constraints
.add_or_constraint(a.visibility_constraint, b.visibility_constraint);
self.live_bindings.push(LiveBinding {
binding: a.binding,
narrowing_constraint,
visibility_constraint,
});
}
EitherOrBoth::Left(binding) | EitherOrBoth::Right(binding) => {
self.live_bindings.push(binding);
}
}
}
}
}
#[derive(Clone, Debug, PartialEq, Eq)]
pub(in crate::semantic_index) struct SymbolState {
declarations: SymbolDeclarations,
bindings: SymbolBindings,
}
impl SymbolState {
/// Return a new [`SymbolState`] representing an unbound, undeclared symbol.
pub(super) fn undefined(scope_start_visibility: ScopedVisibilityConstraintId) -> Self {
Self {
declarations: SymbolDeclarations::undeclared(scope_start_visibility),
bindings: SymbolBindings::unbound(scope_start_visibility),
}
}
/// Record a newly-encountered binding for this symbol.
pub(super) fn record_binding(
&mut self,
binding_id: ScopedDefinitionId,
visibility_constraint: ScopedVisibilityConstraintId,
) {
debug_assert_ne!(binding_id, ScopedDefinitionId::UNBOUND);
self.bindings
.record_binding(binding_id, visibility_constraint);
}
/// Add given constraint to all live bindings.
pub(super) fn record_narrowing_constraint(
&mut self,
narrowing_constraints: &mut NarrowingConstraintsBuilder,
constraint: ScopedNarrowingConstraintPredicate,
) {
self.bindings
.record_narrowing_constraint(narrowing_constraints, constraint);
}
/// Add given visibility constraint to all live bindings.
pub(super) fn record_visibility_constraint(
&mut self,
visibility_constraints: &mut VisibilityConstraintsBuilder,
constraint: ScopedVisibilityConstraintId,
) {
self.bindings
.record_visibility_constraint(visibility_constraints, constraint);
self.declarations
.record_visibility_constraint(visibility_constraints, constraint);
}
/// Simplifies this snapshot to have the same visibility constraints as a previous point in the
/// control flow, but only if the set of live bindings or declarations for this symbol hasn't
/// changed.
pub(super) fn simplify_visibility_constraints(&mut self, snapshot_state: SymbolState) {
self.bindings
.simplify_visibility_constraints(snapshot_state.bindings);
self.declarations
.simplify_visibility_constraints(snapshot_state.declarations);
}
/// Record a newly-encountered declaration of this symbol.
pub(super) fn record_declaration(&mut self, declaration_id: ScopedDefinitionId) {
self.declarations.record_declaration(declaration_id);
}
/// Merge another [`SymbolState`] into this one.
pub(super) fn merge(
&mut self,
b: SymbolState,
narrowing_constraints: &mut NarrowingConstraintsBuilder,
visibility_constraints: &mut VisibilityConstraintsBuilder,
) {
self.bindings
.merge(b.bindings, narrowing_constraints, visibility_constraints);
self.declarations
.merge(b.declarations, visibility_constraints);
}
pub(super) fn bindings(&self) -> &SymbolBindings {
&self.bindings
}
pub(super) fn declarations(&self) -> &SymbolDeclarations {
&self.declarations
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::semantic_index::predicate::ScopedPredicateId;
#[track_caller]
fn assert_bindings(
narrowing_constraints: &NarrowingConstraintsBuilder,
symbol: &SymbolState,
expected: &[&str],
) {
let actual = symbol
.bindings()
.iter()
.map(|live_binding| {
let def_id = live_binding.binding;
let def = if def_id == ScopedDefinitionId::UNBOUND {
"unbound".into()
} else {
def_id.as_u32().to_string()
};
let predicates = narrowing_constraints
.iter_predicates(live_binding.narrowing_constraint)
.map(|idx| idx.as_u32().to_string())
.collect::<Vec<_>>()
.join(", ");
format!("{def}<{predicates}>")
})
.collect::<Vec<_>>();
assert_eq!(actual, expected);
}
#[track_caller]
pub(crate) fn assert_declarations(symbol: &SymbolState, expected: &[&str]) {
let actual = symbol
.declarations()
.iter()
.map(
|LiveDeclaration {
declaration,
visibility_constraint: _,
}| {
if *declaration == ScopedDefinitionId::UNBOUND {
"undeclared".into()
} else {
declaration.as_u32().to_string()
}
},
)
.collect::<Vec<_>>();
assert_eq!(actual, expected);
}
#[test]
fn unbound() {
let narrowing_constraints = NarrowingConstraintsBuilder::default();
let sym = SymbolState::undefined(ScopedVisibilityConstraintId::ALWAYS_TRUE);
assert_bindings(&narrowing_constraints, &sym, &["unbound<>"]);
}
#[test]
fn with() {
let narrowing_constraints = NarrowingConstraintsBuilder::default();
let mut sym = SymbolState::undefined(ScopedVisibilityConstraintId::ALWAYS_TRUE);
sym.record_binding(
ScopedDefinitionId::from_u32(1),
ScopedVisibilityConstraintId::ALWAYS_TRUE,
);
assert_bindings(&narrowing_constraints, &sym, &["1<>"]);
}
#[test]
fn record_constraint() {
let mut narrowing_constraints = NarrowingConstraintsBuilder::default();
let mut sym = SymbolState::undefined(ScopedVisibilityConstraintId::ALWAYS_TRUE);
sym.record_binding(
ScopedDefinitionId::from_u32(1),
ScopedVisibilityConstraintId::ALWAYS_TRUE,
);
let predicate = ScopedPredicateId::from_u32(0).into();
sym.record_narrowing_constraint(&mut narrowing_constraints, predicate);
assert_bindings(&narrowing_constraints, &sym, &["1<0>"]);
}
#[test]
fn merge() {
let mut narrowing_constraints = NarrowingConstraintsBuilder::default();
let mut visibility_constraints = VisibilityConstraintsBuilder::default();
// merging the same definition with the same constraint keeps the constraint
let mut sym1a = SymbolState::undefined(ScopedVisibilityConstraintId::ALWAYS_TRUE);
sym1a.record_binding(
ScopedDefinitionId::from_u32(1),
ScopedVisibilityConstraintId::ALWAYS_TRUE,
);
let predicate = ScopedPredicateId::from_u32(0).into();
sym1a.record_narrowing_constraint(&mut narrowing_constraints, predicate);
let mut sym1b = SymbolState::undefined(ScopedVisibilityConstraintId::ALWAYS_TRUE);
sym1b.record_binding(
ScopedDefinitionId::from_u32(1),
ScopedVisibilityConstraintId::ALWAYS_TRUE,
);
let predicate = ScopedPredicateId::from_u32(0).into();
sym1b.record_narrowing_constraint(&mut narrowing_constraints, predicate);
sym1a.merge(
sym1b,
&mut narrowing_constraints,
&mut visibility_constraints,
);
let mut sym1 = sym1a;
assert_bindings(&narrowing_constraints, &sym1, &["1<0>"]);
// merging the same definition with differing constraints drops all constraints
let mut sym2a = SymbolState::undefined(ScopedVisibilityConstraintId::ALWAYS_TRUE);
sym2a.record_binding(
ScopedDefinitionId::from_u32(2),
ScopedVisibilityConstraintId::ALWAYS_TRUE,
);
let predicate = ScopedPredicateId::from_u32(1).into();
sym2a.record_narrowing_constraint(&mut narrowing_constraints, predicate);
let mut sym1b = SymbolState::undefined(ScopedVisibilityConstraintId::ALWAYS_TRUE);
sym1b.record_binding(
ScopedDefinitionId::from_u32(2),
ScopedVisibilityConstraintId::ALWAYS_TRUE,
);
let predicate = ScopedPredicateId::from_u32(2).into();
sym1b.record_narrowing_constraint(&mut narrowing_constraints, predicate);
sym2a.merge(
sym1b,
&mut narrowing_constraints,
&mut visibility_constraints,
);
let sym2 = sym2a;
assert_bindings(&narrowing_constraints, &sym2, &["2<>"]);
// merging a constrained definition with unbound keeps both
let mut sym3a = SymbolState::undefined(ScopedVisibilityConstraintId::ALWAYS_TRUE);
sym3a.record_binding(
ScopedDefinitionId::from_u32(3),
ScopedVisibilityConstraintId::ALWAYS_TRUE,
);
let predicate = ScopedPredicateId::from_u32(3).into();
sym3a.record_narrowing_constraint(&mut narrowing_constraints, predicate);
let sym2b = SymbolState::undefined(ScopedVisibilityConstraintId::ALWAYS_TRUE);
sym3a.merge(
sym2b,
&mut narrowing_constraints,
&mut visibility_constraints,
);
let sym3 = sym3a;
assert_bindings(&narrowing_constraints, &sym3, &["unbound<>", "3<3>"]);
// merging different definitions keeps them each with their existing constraints
sym1.merge(
sym3,
&mut narrowing_constraints,
&mut visibility_constraints,
);
let sym = sym1;
assert_bindings(&narrowing_constraints, &sym, &["unbound<>", "1<0>", "3<3>"]);
}
#[test]
fn no_declaration() {
let sym = SymbolState::undefined(ScopedVisibilityConstraintId::ALWAYS_TRUE);
assert_declarations(&sym, &["undeclared"]);
}
#[test]
fn record_declaration() {
let mut sym = SymbolState::undefined(ScopedVisibilityConstraintId::ALWAYS_TRUE);
sym.record_declaration(ScopedDefinitionId::from_u32(1));
assert_declarations(&sym, &["1"]);
}
#[test]
fn record_declaration_override() {
let mut sym = SymbolState::undefined(ScopedVisibilityConstraintId::ALWAYS_TRUE);
sym.record_declaration(ScopedDefinitionId::from_u32(1));
sym.record_declaration(ScopedDefinitionId::from_u32(2));
assert_declarations(&sym, &["2"]);
}
#[test]
fn record_declaration_merge() {
let mut narrowing_constraints = NarrowingConstraintsBuilder::default();
let mut visibility_constraints = VisibilityConstraintsBuilder::default();
let mut sym = SymbolState::undefined(ScopedVisibilityConstraintId::ALWAYS_TRUE);
sym.record_declaration(ScopedDefinitionId::from_u32(1));
let mut sym2 = SymbolState::undefined(ScopedVisibilityConstraintId::ALWAYS_TRUE);
sym2.record_declaration(ScopedDefinitionId::from_u32(2));
sym.merge(
sym2,
&mut narrowing_constraints,
&mut visibility_constraints,
);
assert_declarations(&sym, &["1", "2"]);
}
#[test]
fn record_declaration_merge_partial_undeclared() {
let mut narrowing_constraints = NarrowingConstraintsBuilder::default();
let mut visibility_constraints = VisibilityConstraintsBuilder::default();
let mut sym = SymbolState::undefined(ScopedVisibilityConstraintId::ALWAYS_TRUE);
sym.record_declaration(ScopedDefinitionId::from_u32(1));
let sym2 = SymbolState::undefined(ScopedVisibilityConstraintId::ALWAYS_TRUE);
sym.merge(
sym2,
&mut narrowing_constraints,
&mut visibility_constraints,
);
assert_declarations(&sym, &["undeclared", "1"]);
}
}

View file

@ -0,0 +1,670 @@
//! # Visibility constraints
//!
//! During semantic index building, we collect visibility constraints for each binding and
//! declaration. These constraints are then used during type-checking to determine the static
//! visibility of a certain definition. This allows us to re-analyze control flow during type
//! checking, potentially "hiding" some branches that we can statically determine to never be
//! taken. Consider the following example first. We added implicit "unbound" definitions at the
//! start of the scope. Note how visibility constraints can apply to bindings outside of the
//! if-statement:
//! ```py
//! x = <unbound> # not a live binding for the use of x below, shadowed by `x = 1`
//! y = <unbound> # visibility constraint: ~test
//!
//! x = 1 # visibility constraint: ~test
//! if test:
//! x = 2 # visibility constraint: test
//!
//! y = 2 # visibility constraint: test
//!
//! use(x)
//! use(y)
//! ```
//! The static truthiness of the `test` condition can either be always-false, ambiguous, or
//! always-true. Similarly, we have the same three options when evaluating a visibility constraint.
//! This outcome determines the visibility of a definition: always-true means that the definition
//! is definitely visible for a given use, always-false means that the definition is definitely
//! not visible, and ambiguous means that we might see this definition or not. In the latter case,
//! we need to consider both options during type inference and boundness analysis. For the example
//! above, these are the possible type inference / boundness results for the uses of `x` and `y`:
//!
//! ```text
//! | `test` truthiness | `~test` truthiness | type of `x` | boundness of `y` |
//! |-------------------|--------------------|-----------------|------------------|
//! | always false | always true | `Literal[1]` | unbound |
//! | ambiguous | ambiguous | `Literal[1, 2]` | possibly unbound |
//! | always true | always false | `Literal[2]` | bound |
//! ```
//!
//! ### Sequential constraints (ternary AND)
//!
//! As we have seen above, visibility constraints can apply outside of a control flow element.
//! So we need to consider the possibility that multiple constraints apply to the same binding.
//! Here, we consider what happens if multiple `if`-statements lead to a sequence of constraints.
//! Consider the following example:
//! ```py
//! x = 0
//!
//! if test1:
//! x = 1
//!
//! if test2:
//! x = 2
//! ```
//! The binding `x = 2` is easy to analyze. Its visibility corresponds to the truthiness of `test2`.
//! For the `x = 1` binding, things are a bit more interesting. It is always visible if `test1` is
//! always-true *and* `test2` is always-false. It is never visible if `test1` is always-false *or*
//! `test2` is always-true. And it is ambiguous otherwise. This corresponds to a ternary *test1 AND
//! ~test2* operation in three-valued Kleene logic [Kleene]:
//!
//! ```text
//! | AND | always-false | ambiguous | always-true |
//! |--------------|--------------|--------------|--------------|
//! | always false | always-false | always-false | always-false |
//! | ambiguous | always-false | ambiguous | ambiguous |
//! | always true | always-false | ambiguous | always-true |
//! ```
//!
//! The `x = 0` binding can be handled similarly, with the difference that both `test1` and `test2`
//! are negated:
//! ```py
//! x = 0 # ~test1 AND ~test2
//!
//! if test1:
//! x = 1 # test1 AND ~test2
//!
//! if test2:
//! x = 2 # test2
//! ```
//!
//! ### Merged constraints (ternary OR)
//!
//! Finally, we consider what happens in "parallel" control flow. Consider the following example
//! where we have omitted the test condition for the outer `if` for clarity:
//! ```py
//! x = 0
//!
//! if <…>:
//! if test1:
//! x = 1
//! else:
//! if test2:
//! x = 2
//!
//! use(x)
//! ```
//! At the usage of `x`, i.e. after control flow has been merged again, the visibility of the `x =
//! 0` binding behaves as follows: the binding is always visible if `test1` is always-false *or*
//! `test2` is always-false; and it is never visible if `test1` is always-true *and* `test2` is
//! always-true. This corresponds to a ternary *OR* operation in Kleene logic:
//!
//! ```text
//! | OR | always-false | ambiguous | always-true |
//! |--------------|--------------|--------------|--------------|
//! | always false | always-false | ambiguous | always-true |
//! | ambiguous | ambiguous | ambiguous | always-true |
//! | always true | always-true | always-true | always-true |
//! ```
//!
//! Using this, we can annotate the visibility constraints for the example above:
//! ```py
//! x = 0 # ~test1 OR ~test2
//!
//! if <…>:
//! if test1:
//! x = 1 # test1
//! else:
//! if test2:
//! x = 2 # test2
//!
//! use(x)
//! ```
//!
//! ### Explicit ambiguity
//!
//! In some cases, we explicitly add an “ambiguous” constraint to all bindings
//! in a certain control flow path. We do this when branching on something that we can not (or
//! intentionally do not want to) analyze statically. `for` loops are one example:
//! ```py
//! x = <unbound>
//!
//! for _ in range(2):
//! x = 1
//! ```
//! Here, we report an ambiguous visibility constraint before branching off. If we don't do this,
//! the `x = <unbound>` binding would be considered unconditionally visible in the no-loop case.
//! And since the other branch does not have the live `x = <unbound>` binding, we would incorrectly
//! create a state where the `x = <unbound>` binding is always visible.
//!
//!
//! ### Representing formulas
//!
//! Given everything above, we can represent a visibility constraint as a _ternary formula_. This
//! is like a boolean formula (which maps several true/false variables to a single true/false
//! result), but which allows the third "ambiguous" value in addition to "true" and "false".
//!
//! [_Binary decision diagrams_][bdd] (BDDs) are a common way to represent boolean formulas when
//! doing program analysis. We extend this to a _ternary decision diagram_ (TDD) to support
//! ambiguous values.
//!
//! A TDD is a graph, and a ternary formula is represented by a node in this graph. There are three
//! possible leaf nodes representing the "true", "false", and "ambiguous" constant functions.
//! Interior nodes consist of a ternary variable to evaluate, and outgoing edges for whether the
//! variable evaluates to true, false, or ambiguous.
//!
//! Our TDDs are _reduced_ and _ordered_ (as is typical for BDDs).
//!
//! An ordered TDD means that variables appear in the same order in all paths within the graph.
//!
//! A reduced TDD means two things: First, we intern the graph nodes, so that we only keep a single
//! copy of interior nodes with the same contents. Second, we eliminate any nodes that are "noops",
//! where the "true" and "false" outgoing edges lead to the same node. (This implies that it
//! doesn't matter what value that variable has when evaluating the formula, and we can leave it
//! out of the evaluation chain completely.)
//!
//! Reduced and ordered decision diagrams are _normal forms_, which means that two equivalent
//! formulas (which have the same outputs for every combination of inputs) are represented by
//! exactly the same graph node. (Because of interning, this is not _equal_ nodes, but _identical_
//! ones.) That means that we can compare formulas for equivalence in constant time, and in
//! particular, can check whether a visibility constraint is statically always true or false,
//! regardless of any Python program state, by seeing if the constraint's formula is the "true" or
//! "false" leaf node.
//!
//! [Kleene]: <https://en.wikipedia.org/wiki/Three-valued_logic#Kleene_and_Priest_logics>
//! [bdd]: https://en.wikipedia.org/wiki/Binary_decision_diagram
use std::cmp::Ordering;
use ruff_index::{Idx, IndexVec};
use rustc_hash::FxHashMap;
use crate::semantic_index::expression::Expression;
use crate::semantic_index::predicate::{
PatternPredicate, PatternPredicateKind, Predicate, PredicateNode, Predicates, ScopedPredicateId,
};
use crate::semantic_index::symbol_table;
use crate::symbol::imported_symbol;
use crate::types::{infer_expression_type, Truthiness, Type};
use crate::Db;
/// A ternary formula that defines under what conditions a binding is visible. (A ternary formula
/// is just like a boolean formula, but with `Ambiguous` as a third potential result. See the
/// module documentation for more details.)
///
/// The primitive atoms of the formula are [`Predicate`]s, which express some property of the
/// runtime state of the code that we are analyzing.
///
/// We assume that each atom has a stable value each time that the formula is evaluated. An atom
/// that resolves to `Ambiguous` might be true or false, and we can't tell which — but within that
/// evaluation, we assume that the atom has the _same_ unknown value each time it appears. That
/// allows us to perform simplifications like `A !A → true` and `A ∧ !A → false`.
///
/// That means that when you are constructing a formula, you might need to create distinct atoms
/// for a particular [`Predicate`], if your formula needs to consider how a particular runtime
/// property might be different at different points in the execution of the program.
///
/// Visibility constraints are normalized, so equivalent constraints are guaranteed to have equal
/// IDs.
#[derive(Clone, Copy, Eq, Hash, PartialEq)]
pub(crate) struct ScopedVisibilityConstraintId(u32);
impl std::fmt::Debug for ScopedVisibilityConstraintId {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let mut f = f.debug_tuple("ScopedVisibilityConstraintId");
match *self {
// We use format_args instead of rendering the strings directly so that we don't get
// any quotes in the output: ScopedVisibilityConstraintId(AlwaysTrue) instead of
// ScopedVisibilityConstraintId("AlwaysTrue").
ALWAYS_TRUE => f.field(&format_args!("AlwaysTrue")),
AMBIGUOUS => f.field(&format_args!("Ambiguous")),
ALWAYS_FALSE => f.field(&format_args!("AlwaysFalse")),
_ => f.field(&self.0),
};
f.finish()
}
}
// Internal details:
//
// There are 3 terminals, with hard-coded constraint IDs: true, ambiguous, and false.
//
// _Atoms_ are the underlying Predicates, which are the variables that are evaluated by the
// ternary function.
//
// _Interior nodes_ provide the TDD structure for the formula. Interior nodes are stored in an
// arena Vec, with the constraint ID providing an index into the arena.
#[derive(Clone, Copy, Debug, Eq, Hash, PartialEq)]
struct InteriorNode {
/// A "variable" that is evaluated as part of a TDD ternary function. For visibility
/// constraints, this is a `Predicate` that represents some runtime property of the Python
/// code that we are evaluating.
atom: ScopedPredicateId,
if_true: ScopedVisibilityConstraintId,
if_ambiguous: ScopedVisibilityConstraintId,
if_false: ScopedVisibilityConstraintId,
}
impl ScopedVisibilityConstraintId {
/// A special ID that is used for an "always true" / "always visible" constraint.
pub(crate) const ALWAYS_TRUE: ScopedVisibilityConstraintId =
ScopedVisibilityConstraintId(0xffff_ffff);
/// A special ID that is used for an ambiguous constraint.
pub(crate) const AMBIGUOUS: ScopedVisibilityConstraintId =
ScopedVisibilityConstraintId(0xffff_fffe);
/// A special ID that is used for an "always false" / "never visible" constraint.
pub(crate) const ALWAYS_FALSE: ScopedVisibilityConstraintId =
ScopedVisibilityConstraintId(0xffff_fffd);
fn is_terminal(self) -> bool {
self.0 >= SMALLEST_TERMINAL.0
}
}
impl Idx for ScopedVisibilityConstraintId {
#[inline]
fn new(value: usize) -> Self {
assert!(value <= (SMALLEST_TERMINAL.0 as usize));
#[allow(clippy::cast_possible_truncation)]
Self(value as u32)
}
#[inline]
fn index(self) -> usize {
debug_assert!(!self.is_terminal());
self.0 as usize
}
}
// Rebind some constants locally so that we don't need as many qualifiers below.
const ALWAYS_TRUE: ScopedVisibilityConstraintId = ScopedVisibilityConstraintId::ALWAYS_TRUE;
const AMBIGUOUS: ScopedVisibilityConstraintId = ScopedVisibilityConstraintId::AMBIGUOUS;
const ALWAYS_FALSE: ScopedVisibilityConstraintId = ScopedVisibilityConstraintId::ALWAYS_FALSE;
const SMALLEST_TERMINAL: ScopedVisibilityConstraintId = ALWAYS_FALSE;
/// A collection of visibility constraints for a given scope.
#[derive(Debug, PartialEq, Eq, salsa::Update)]
pub(crate) struct VisibilityConstraints {
interiors: IndexVec<ScopedVisibilityConstraintId, InteriorNode>,
}
#[derive(Debug, Default, PartialEq, Eq)]
pub(crate) struct VisibilityConstraintsBuilder {
interiors: IndexVec<ScopedVisibilityConstraintId, InteriorNode>,
interior_cache: FxHashMap<InteriorNode, ScopedVisibilityConstraintId>,
not_cache: FxHashMap<ScopedVisibilityConstraintId, ScopedVisibilityConstraintId>,
and_cache: FxHashMap<
(ScopedVisibilityConstraintId, ScopedVisibilityConstraintId),
ScopedVisibilityConstraintId,
>,
or_cache: FxHashMap<
(ScopedVisibilityConstraintId, ScopedVisibilityConstraintId),
ScopedVisibilityConstraintId,
>,
}
impl VisibilityConstraintsBuilder {
pub(crate) fn build(self) -> VisibilityConstraints {
VisibilityConstraints {
interiors: self.interiors,
}
}
/// Returns whether `a` or `b` has a "larger" atom. TDDs are ordered such that interior nodes
/// can only have edges to "larger" nodes. Terminals are considered to have a larger atom than
/// any internal node, since they are leaf nodes.
fn cmp_atoms(
&self,
a: ScopedVisibilityConstraintId,
b: ScopedVisibilityConstraintId,
) -> Ordering {
if a == b || (a.is_terminal() && b.is_terminal()) {
Ordering::Equal
} else if a.is_terminal() {
Ordering::Greater
} else if b.is_terminal() {
Ordering::Less
} else {
self.interiors[a].atom.cmp(&self.interiors[b].atom)
}
}
/// Adds an interior node, ensuring that we always use the same visibility constraint ID for
/// equal nodes.
fn add_interior(&mut self, node: InteriorNode) -> ScopedVisibilityConstraintId {
// If the true and false branches lead to the same node, we can override the ambiguous
// branch to go there too. And this node is then redundant and can be reduced.
if node.if_true == node.if_false {
return node.if_true;
}
*self
.interior_cache
.entry(node)
.or_insert_with(|| self.interiors.push(node))
}
/// Adds a new visibility constraint that checks a single [`Predicate`].
///
/// [`ScopedPredicateId`]s are the “variables” that are evaluated by a TDD. A TDD variable has
/// the same value no matter how many times it appears in the ternary formula that the TDD
/// represents.
///
/// However, we sometimes have to model how a `Predicate` can have a different runtime
/// value at different points in the execution of the program. To handle this, you can take
/// advantage of the fact that the [`Predicates`] arena does not deduplicate `Predicate`s.
/// You can add a `Predicate` multiple times, yielding different `ScopedPredicateId`s, which
/// you can then create separate TDD atoms for.
pub(crate) fn add_atom(
&mut self,
predicate: ScopedPredicateId,
) -> ScopedVisibilityConstraintId {
self.add_interior(InteriorNode {
atom: predicate,
if_true: ALWAYS_TRUE,
if_ambiguous: AMBIGUOUS,
if_false: ALWAYS_FALSE,
})
}
/// Adds a new visibility constraint that is the ternary NOT of an existing one.
pub(crate) fn add_not_constraint(
&mut self,
a: ScopedVisibilityConstraintId,
) -> ScopedVisibilityConstraintId {
if a == ALWAYS_TRUE {
return ALWAYS_FALSE;
} else if a == AMBIGUOUS {
return AMBIGUOUS;
} else if a == ALWAYS_FALSE {
return ALWAYS_TRUE;
}
if let Some(cached) = self.not_cache.get(&a) {
return *cached;
}
let a_node = self.interiors[a];
let if_true = self.add_not_constraint(a_node.if_true);
let if_ambiguous = self.add_not_constraint(a_node.if_ambiguous);
let if_false = self.add_not_constraint(a_node.if_false);
let result = self.add_interior(InteriorNode {
atom: a_node.atom,
if_true,
if_ambiguous,
if_false,
});
self.not_cache.insert(a, result);
result
}
/// Adds a new visibility constraint that is the ternary OR of two existing ones.
pub(crate) fn add_or_constraint(
&mut self,
a: ScopedVisibilityConstraintId,
b: ScopedVisibilityConstraintId,
) -> ScopedVisibilityConstraintId {
match (a, b) {
(ALWAYS_TRUE, _) | (_, ALWAYS_TRUE) => return ALWAYS_TRUE,
(ALWAYS_FALSE, other) | (other, ALWAYS_FALSE) => return other,
(AMBIGUOUS, AMBIGUOUS) => return AMBIGUOUS,
_ => {}
}
// OR is commutative, which lets us halve the cache requirements
let (a, b) = if b.0 < a.0 { (b, a) } else { (a, b) };
if let Some(cached) = self.or_cache.get(&(a, b)) {
return *cached;
}
let (atom, if_true, if_ambiguous, if_false) = match self.cmp_atoms(a, b) {
Ordering::Equal => {
let a_node = self.interiors[a];
let b_node = self.interiors[b];
let if_true = self.add_or_constraint(a_node.if_true, b_node.if_true);
let if_false = self.add_or_constraint(a_node.if_false, b_node.if_false);
let if_ambiguous = if if_true == if_false {
if_true
} else {
self.add_or_constraint(a_node.if_ambiguous, b_node.if_ambiguous)
};
(a_node.atom, if_true, if_ambiguous, if_false)
}
Ordering::Less => {
let a_node = self.interiors[a];
let if_true = self.add_or_constraint(a_node.if_true, b);
let if_false = self.add_or_constraint(a_node.if_false, b);
let if_ambiguous = if if_true == if_false {
if_true
} else {
self.add_or_constraint(a_node.if_ambiguous, b)
};
(a_node.atom, if_true, if_ambiguous, if_false)
}
Ordering::Greater => {
let b_node = self.interiors[b];
let if_true = self.add_or_constraint(a, b_node.if_true);
let if_false = self.add_or_constraint(a, b_node.if_false);
let if_ambiguous = if if_true == if_false {
if_true
} else {
self.add_or_constraint(a, b_node.if_ambiguous)
};
(b_node.atom, if_true, if_ambiguous, if_false)
}
};
let result = self.add_interior(InteriorNode {
atom,
if_true,
if_ambiguous,
if_false,
});
self.or_cache.insert((a, b), result);
result
}
/// Adds a new visibility constraint that is the ternary AND of two existing ones.
pub(crate) fn add_and_constraint(
&mut self,
a: ScopedVisibilityConstraintId,
b: ScopedVisibilityConstraintId,
) -> ScopedVisibilityConstraintId {
match (a, b) {
(ALWAYS_FALSE, _) | (_, ALWAYS_FALSE) => return ALWAYS_FALSE,
(ALWAYS_TRUE, other) | (other, ALWAYS_TRUE) => return other,
(AMBIGUOUS, AMBIGUOUS) => return AMBIGUOUS,
_ => {}
}
// AND is commutative, which lets us halve the cache requirements
let (a, b) = if b.0 < a.0 { (b, a) } else { (a, b) };
if let Some(cached) = self.and_cache.get(&(a, b)) {
return *cached;
}
let (atom, if_true, if_ambiguous, if_false) = match self.cmp_atoms(a, b) {
Ordering::Equal => {
let a_node = self.interiors[a];
let b_node = self.interiors[b];
let if_true = self.add_and_constraint(a_node.if_true, b_node.if_true);
let if_false = self.add_and_constraint(a_node.if_false, b_node.if_false);
let if_ambiguous = if if_true == if_false {
if_true
} else {
self.add_and_constraint(a_node.if_ambiguous, b_node.if_ambiguous)
};
(a_node.atom, if_true, if_ambiguous, if_false)
}
Ordering::Less => {
let a_node = self.interiors[a];
let if_true = self.add_and_constraint(a_node.if_true, b);
let if_false = self.add_and_constraint(a_node.if_false, b);
let if_ambiguous = if if_true == if_false {
if_true
} else {
self.add_and_constraint(a_node.if_ambiguous, b)
};
(a_node.atom, if_true, if_ambiguous, if_false)
}
Ordering::Greater => {
let b_node = self.interiors[b];
let if_true = self.add_and_constraint(a, b_node.if_true);
let if_false = self.add_and_constraint(a, b_node.if_false);
let if_ambiguous = if if_true == if_false {
if_true
} else {
self.add_and_constraint(a, b_node.if_ambiguous)
};
(b_node.atom, if_true, if_ambiguous, if_false)
}
};
let result = self.add_interior(InteriorNode {
atom,
if_true,
if_ambiguous,
if_false,
});
self.and_cache.insert((a, b), result);
result
}
}
impl VisibilityConstraints {
/// Analyze the statically known visibility for a given visibility constraint.
pub(crate) fn evaluate<'db>(
&self,
db: &'db dyn Db,
predicates: &Predicates<'db>,
mut id: ScopedVisibilityConstraintId,
) -> Truthiness {
loop {
let node = match id {
ALWAYS_TRUE => return Truthiness::AlwaysTrue,
AMBIGUOUS => return Truthiness::Ambiguous,
ALWAYS_FALSE => return Truthiness::AlwaysFalse,
_ => self.interiors[id],
};
let predicate = &predicates[node.atom];
match Self::analyze_single(db, predicate) {
Truthiness::AlwaysTrue => id = node.if_true,
Truthiness::Ambiguous => id = node.if_ambiguous,
Truthiness::AlwaysFalse => id = node.if_false,
}
}
}
fn analyze_single_pattern_predicate_kind<'db>(
db: &'db dyn Db,
predicate_kind: &PatternPredicateKind<'db>,
subject: Expression<'db>,
) -> Truthiness {
match predicate_kind {
PatternPredicateKind::Value(value) => {
let subject_ty = infer_expression_type(db, subject);
let value_ty = infer_expression_type(db, *value);
if subject_ty.is_single_valued(db) {
Truthiness::from(subject_ty.is_equivalent_to(db, value_ty))
} else {
Truthiness::Ambiguous
}
}
PatternPredicateKind::Singleton(singleton) => {
let subject_ty = infer_expression_type(db, subject);
let singleton_ty = match singleton {
ruff_python_ast::Singleton::None => Type::none(db),
ruff_python_ast::Singleton::True => Type::BooleanLiteral(true),
ruff_python_ast::Singleton::False => Type::BooleanLiteral(false),
};
debug_assert!(singleton_ty.is_singleton(db));
if subject_ty.is_equivalent_to(db, singleton_ty) {
Truthiness::AlwaysTrue
} else if subject_ty.is_disjoint_from(db, singleton_ty) {
Truthiness::AlwaysFalse
} else {
Truthiness::Ambiguous
}
}
PatternPredicateKind::Or(predicates) => {
use std::ops::ControlFlow;
let (ControlFlow::Break(truthiness) | ControlFlow::Continue(truthiness)) =
predicates
.iter()
.map(|p| Self::analyze_single_pattern_predicate_kind(db, p, subject))
// this is just a "max", but with a slight optimization: `AlwaysTrue` is the "greatest" possible element, so we short-circuit if we get there
.try_fold(Truthiness::AlwaysFalse, |acc, next| match (acc, next) {
(Truthiness::AlwaysTrue, _) | (_, Truthiness::AlwaysTrue) => {
ControlFlow::Break(Truthiness::AlwaysTrue)
}
(Truthiness::Ambiguous, _) | (_, Truthiness::Ambiguous) => {
ControlFlow::Continue(Truthiness::Ambiguous)
}
(Truthiness::AlwaysFalse, Truthiness::AlwaysFalse) => {
ControlFlow::Continue(Truthiness::AlwaysFalse)
}
});
truthiness
}
PatternPredicateKind::Class(class_expr) => {
let subject_ty = infer_expression_type(db, subject);
let class_ty = infer_expression_type(db, *class_expr).to_instance(db);
class_ty.map_or(Truthiness::Ambiguous, |class_ty| {
if subject_ty.is_subtype_of(db, class_ty) {
Truthiness::AlwaysTrue
} else if subject_ty.is_disjoint_from(db, class_ty) {
Truthiness::AlwaysFalse
} else {
Truthiness::Ambiguous
}
})
}
PatternPredicateKind::Unsupported => Truthiness::Ambiguous,
}
}
fn analyze_single_pattern_predicate(db: &dyn Db, predicate: PatternPredicate) -> Truthiness {
let truthiness = Self::analyze_single_pattern_predicate_kind(
db,
predicate.kind(db),
predicate.subject(db),
);
if truthiness == Truthiness::AlwaysTrue && predicate.guard(db).is_some() {
// Fall back to ambiguous, the guard might change the result.
// TODO: actually analyze guard truthiness
Truthiness::Ambiguous
} else {
truthiness
}
}
fn analyze_single(db: &dyn Db, predicate: &Predicate) -> Truthiness {
match predicate.node {
PredicateNode::Expression(test_expr) => {
let ty = infer_expression_type(db, test_expr);
ty.bool(db).negate_if(!predicate.is_positive)
}
PredicateNode::Pattern(inner) => Self::analyze_single_pattern_predicate(db, inner),
PredicateNode::StarImportPlaceholder(star_import) => {
let symbol_table = symbol_table(db, star_import.scope(db));
let symbol_name = symbol_table.symbol(star_import.symbol_id(db)).name();
match imported_symbol(db, star_import.referenced_file(db), symbol_name).symbol {
crate::symbol::Symbol::Type(_, crate::symbol::Boundness::Bound) => {
Truthiness::AlwaysTrue
}
crate::symbol::Symbol::Type(_, crate::symbol::Boundness::PossiblyUnbound) => {
Truthiness::Ambiguous
}
crate::symbol::Symbol::Unbound => Truthiness::AlwaysFalse,
}
}
}
}
}

View file

@ -0,0 +1,241 @@
use ruff_db::files::{File, FilePath};
use ruff_db::source::line_index;
use ruff_python_ast as ast;
use ruff_python_ast::{Expr, ExprRef};
use ruff_source_file::LineIndex;
use crate::module_name::ModuleName;
use crate::module_resolver::{resolve_module, Module};
use crate::semantic_index::ast_ids::HasScopedExpressionId;
use crate::semantic_index::semantic_index;
use crate::types::{binding_type, infer_scope_types, Type};
use crate::Db;
pub struct SemanticModel<'db> {
db: &'db dyn Db,
file: File,
}
impl<'db> SemanticModel<'db> {
pub fn new(db: &'db dyn Db, file: File) -> Self {
Self { db, file }
}
// TODO we don't actually want to expose the Db directly to lint rules, but we need to find a
// solution for exposing information from types
pub fn db(&self) -> &dyn Db {
self.db
}
pub fn file_path(&self) -> &FilePath {
self.file.path(self.db)
}
pub fn line_index(&self) -> LineIndex {
line_index(self.db.upcast(), self.file)
}
pub fn resolve_module(&self, module_name: &ModuleName) -> Option<Module> {
resolve_module(self.db, module_name)
}
}
pub trait HasType {
/// Returns the inferred type of `self`.
///
/// ## Panics
/// May panic if `self` is from another file than `model`.
fn inferred_type<'db>(&self, model: &SemanticModel<'db>) -> Type<'db>;
}
impl HasType for ast::ExprRef<'_> {
fn inferred_type<'db>(&self, model: &SemanticModel<'db>) -> Type<'db> {
let index = semantic_index(model.db, model.file);
let file_scope = index.expression_scope_id(*self);
let scope = file_scope.to_scope_id(model.db, model.file);
let expression_id = self.scoped_expression_id(model.db, scope);
infer_scope_types(model.db, scope).expression_type(expression_id)
}
}
macro_rules! impl_expression_has_type {
($ty: ty) => {
impl HasType for $ty {
#[inline]
fn inferred_type<'db>(&self, model: &SemanticModel<'db>) -> Type<'db> {
let expression_ref = ExprRef::from(self);
expression_ref.inferred_type(model)
}
}
};
}
impl_expression_has_type!(ast::ExprBoolOp);
impl_expression_has_type!(ast::ExprNamed);
impl_expression_has_type!(ast::ExprBinOp);
impl_expression_has_type!(ast::ExprUnaryOp);
impl_expression_has_type!(ast::ExprLambda);
impl_expression_has_type!(ast::ExprIf);
impl_expression_has_type!(ast::ExprDict);
impl_expression_has_type!(ast::ExprSet);
impl_expression_has_type!(ast::ExprListComp);
impl_expression_has_type!(ast::ExprSetComp);
impl_expression_has_type!(ast::ExprDictComp);
impl_expression_has_type!(ast::ExprGenerator);
impl_expression_has_type!(ast::ExprAwait);
impl_expression_has_type!(ast::ExprYield);
impl_expression_has_type!(ast::ExprYieldFrom);
impl_expression_has_type!(ast::ExprCompare);
impl_expression_has_type!(ast::ExprCall);
impl_expression_has_type!(ast::ExprFString);
impl_expression_has_type!(ast::ExprStringLiteral);
impl_expression_has_type!(ast::ExprBytesLiteral);
impl_expression_has_type!(ast::ExprNumberLiteral);
impl_expression_has_type!(ast::ExprBooleanLiteral);
impl_expression_has_type!(ast::ExprNoneLiteral);
impl_expression_has_type!(ast::ExprEllipsisLiteral);
impl_expression_has_type!(ast::ExprAttribute);
impl_expression_has_type!(ast::ExprSubscript);
impl_expression_has_type!(ast::ExprStarred);
impl_expression_has_type!(ast::ExprName);
impl_expression_has_type!(ast::ExprList);
impl_expression_has_type!(ast::ExprTuple);
impl_expression_has_type!(ast::ExprSlice);
impl_expression_has_type!(ast::ExprIpyEscapeCommand);
impl HasType for ast::Expr {
fn inferred_type<'db>(&self, model: &SemanticModel<'db>) -> Type<'db> {
match self {
Expr::BoolOp(inner) => inner.inferred_type(model),
Expr::Named(inner) => inner.inferred_type(model),
Expr::BinOp(inner) => inner.inferred_type(model),
Expr::UnaryOp(inner) => inner.inferred_type(model),
Expr::Lambda(inner) => inner.inferred_type(model),
Expr::If(inner) => inner.inferred_type(model),
Expr::Dict(inner) => inner.inferred_type(model),
Expr::Set(inner) => inner.inferred_type(model),
Expr::ListComp(inner) => inner.inferred_type(model),
Expr::SetComp(inner) => inner.inferred_type(model),
Expr::DictComp(inner) => inner.inferred_type(model),
Expr::Generator(inner) => inner.inferred_type(model),
Expr::Await(inner) => inner.inferred_type(model),
Expr::Yield(inner) => inner.inferred_type(model),
Expr::YieldFrom(inner) => inner.inferred_type(model),
Expr::Compare(inner) => inner.inferred_type(model),
Expr::Call(inner) => inner.inferred_type(model),
Expr::FString(inner) => inner.inferred_type(model),
Expr::StringLiteral(inner) => inner.inferred_type(model),
Expr::BytesLiteral(inner) => inner.inferred_type(model),
Expr::NumberLiteral(inner) => inner.inferred_type(model),
Expr::BooleanLiteral(inner) => inner.inferred_type(model),
Expr::NoneLiteral(inner) => inner.inferred_type(model),
Expr::EllipsisLiteral(inner) => inner.inferred_type(model),
Expr::Attribute(inner) => inner.inferred_type(model),
Expr::Subscript(inner) => inner.inferred_type(model),
Expr::Starred(inner) => inner.inferred_type(model),
Expr::Name(inner) => inner.inferred_type(model),
Expr::List(inner) => inner.inferred_type(model),
Expr::Tuple(inner) => inner.inferred_type(model),
Expr::Slice(inner) => inner.inferred_type(model),
Expr::IpyEscapeCommand(inner) => inner.inferred_type(model),
}
}
}
macro_rules! impl_binding_has_ty {
($ty: ty) => {
impl HasType for $ty {
#[inline]
fn inferred_type<'db>(&self, model: &SemanticModel<'db>) -> Type<'db> {
let index = semantic_index(model.db, model.file);
let binding = index.expect_single_definition(self);
binding_type(model.db, binding)
}
}
};
}
impl_binding_has_ty!(ast::StmtFunctionDef);
impl_binding_has_ty!(ast::StmtClassDef);
impl_binding_has_ty!(ast::Parameter);
impl_binding_has_ty!(ast::ParameterWithDefault);
impl_binding_has_ty!(ast::ExceptHandlerExceptHandler);
impl HasType for ast::Alias {
fn inferred_type<'db>(&self, model: &SemanticModel<'db>) -> Type<'db> {
if &self.name == "*" {
return Type::Never;
}
let index = semantic_index(model.db, model.file);
binding_type(model.db, index.expect_single_definition(self))
}
}
#[cfg(test)]
mod tests {
use ruff_db::files::system_path_to_file;
use ruff_db::parsed::parsed_module;
use crate::db::tests::TestDbBuilder;
use crate::{HasType, SemanticModel};
#[test]
fn function_type() -> anyhow::Result<()> {
let db = TestDbBuilder::new()
.with_file("/src/foo.py", "def test(): pass")
.build()?;
let foo = system_path_to_file(&db, "/src/foo.py").unwrap();
let ast = parsed_module(&db, foo);
let function = ast.suite()[0].as_function_def_stmt().unwrap();
let model = SemanticModel::new(&db, foo);
let ty = function.inferred_type(&model);
assert!(ty.is_function_literal());
Ok(())
}
#[test]
fn class_type() -> anyhow::Result<()> {
let db = TestDbBuilder::new()
.with_file("/src/foo.py", "class Test: pass")
.build()?;
let foo = system_path_to_file(&db, "/src/foo.py").unwrap();
let ast = parsed_module(&db, foo);
let class = ast.suite()[0].as_class_def_stmt().unwrap();
let model = SemanticModel::new(&db, foo);
let ty = class.inferred_type(&model);
assert!(ty.is_class_literal());
Ok(())
}
#[test]
fn alias_type() -> anyhow::Result<()> {
let db = TestDbBuilder::new()
.with_file("/src/foo.py", "class Test: pass")
.with_file("/src/bar.py", "from foo import Test")
.build()?;
let bar = system_path_to_file(&db, "/src/bar.py").unwrap();
let ast = parsed_module(&db, bar);
let import = ast.suite()[0].as_import_from_stmt().unwrap();
let alias = &import.names[0];
let model = SemanticModel::new(&db, bar);
let ty = alias.inferred_type(&model);
assert!(ty.is_class_literal());
Ok(())
}
}

View file

@ -0,0 +1,916 @@
//! Utilities for finding the `site-packages` directory,
//! into which third-party packages are installed.
//!
//! The routines exposed by this module have different behaviour depending
//! on the platform of the *host machine*, which may be
//! different from the *target platform for type checking*. (A user
//! might be running ty on a Windows machine, but might
//! reasonably ask us to type-check code assuming that the code runs
//! on Linux.)
use std::fmt;
use std::fmt::Display;
use std::io;
use std::num::NonZeroUsize;
use std::ops::Deref;
use ruff_db::system::{System, SystemPath, SystemPathBuf};
use ruff_python_ast::PythonVersion;
type SitePackagesDiscoveryResult<T> = Result<T, SitePackagesDiscoveryError>;
/// Abstraction for a Python virtual environment.
///
/// Most of this information is derived from the virtual environment's `pyvenv.cfg` file.
/// The format of this file is not defined anywhere, and exactly which keys are present
/// depends on the tool that was used to create the virtual environment.
#[derive(Debug)]
pub(crate) struct VirtualEnvironment {
venv_path: SysPrefixPath,
base_executable_home_path: PythonHomePath,
include_system_site_packages: bool,
/// The version of the Python executable that was used to create this virtual environment.
///
/// The Python version is encoded under different keys and in different formats
/// by different virtual-environment creation tools,
/// and the key is never read by the standard-library `site.py` module,
/// so it's possible that we might not be able to find this information
/// in an acceptable format under any of the keys we expect.
/// This field will be `None` if so.
version: Option<PythonVersion>,
}
impl VirtualEnvironment {
pub(crate) fn new(
path: impl AsRef<SystemPath>,
origin: SysPrefixPathOrigin,
system: &dyn System,
) -> SitePackagesDiscoveryResult<Self> {
Self::new_impl(path.as_ref(), origin, system)
}
fn new_impl(
path: &SystemPath,
origin: SysPrefixPathOrigin,
system: &dyn System,
) -> SitePackagesDiscoveryResult<Self> {
fn pyvenv_cfg_line_number(index: usize) -> NonZeroUsize {
index.checked_add(1).and_then(NonZeroUsize::new).unwrap()
}
let venv_path = SysPrefixPath::new(path, origin, system)?;
let pyvenv_cfg_path = venv_path.join("pyvenv.cfg");
tracing::debug!("Attempting to parse virtual environment metadata at '{pyvenv_cfg_path}'");
let pyvenv_cfg = system
.read_to_string(&pyvenv_cfg_path)
.map_err(|io_err| SitePackagesDiscoveryError::NoPyvenvCfgFile(origin, io_err))?;
let mut include_system_site_packages = false;
let mut base_executable_home_path = None;
let mut version_info_string = None;
// A `pyvenv.cfg` file *looks* like a `.ini` file, but actually isn't valid `.ini` syntax!
// The Python standard-library's `site` module parses these files by splitting each line on
// '=' characters, so that's what we should do as well.
//
// See also: https://snarky.ca/how-virtual-environments-work/
for (index, line) in pyvenv_cfg.lines().enumerate() {
if let Some((key, value)) = line.split_once('=') {
let key = key.trim();
if key.is_empty() {
return Err(SitePackagesDiscoveryError::PyvenvCfgParseError(
pyvenv_cfg_path,
PyvenvCfgParseErrorKind::MalformedKeyValuePair {
line_number: pyvenv_cfg_line_number(index),
},
));
}
let value = value.trim();
if value.is_empty() {
return Err(SitePackagesDiscoveryError::PyvenvCfgParseError(
pyvenv_cfg_path,
PyvenvCfgParseErrorKind::MalformedKeyValuePair {
line_number: pyvenv_cfg_line_number(index),
},
));
}
if value.contains('=') {
return Err(SitePackagesDiscoveryError::PyvenvCfgParseError(
pyvenv_cfg_path,
PyvenvCfgParseErrorKind::TooManyEquals {
line_number: pyvenv_cfg_line_number(index),
},
));
}
match key {
"include-system-site-packages" => {
include_system_site_packages = value.eq_ignore_ascii_case("true");
}
"home" => base_executable_home_path = Some(value),
// `virtualenv` and `uv` call this key `version_info`,
// but the stdlib venv module calls it `version`
"version" | "version_info" => version_info_string = Some(value),
_ => continue,
}
}
}
// The `home` key is read by the standard library's `site.py` module,
// so if it's missing from the `pyvenv.cfg` file
// (or the provided value is invalid),
// it's reasonable to consider the virtual environment irredeemably broken.
let Some(base_executable_home_path) = base_executable_home_path else {
return Err(SitePackagesDiscoveryError::PyvenvCfgParseError(
pyvenv_cfg_path,
PyvenvCfgParseErrorKind::NoHomeKey,
));
};
let base_executable_home_path = PythonHomePath::new(base_executable_home_path, system)
.map_err(|io_err| {
SitePackagesDiscoveryError::PyvenvCfgParseError(
pyvenv_cfg_path,
PyvenvCfgParseErrorKind::InvalidHomeValue(io_err),
)
})?;
// but the `version`/`version_info` key is not read by the standard library,
// and is provided under different keys depending on which virtual-environment creation tool
// created the `pyvenv.cfg` file. Lenient parsing is appropriate here:
// the file isn't really *invalid* if it doesn't have this key,
// or if the value doesn't parse according to our expectations.
let version = version_info_string.and_then(|version_string| {
let mut version_info_parts = version_string.split('.');
let (major, minor) = (version_info_parts.next()?, version_info_parts.next()?);
PythonVersion::try_from((major, minor)).ok()
});
let metadata = Self {
venv_path,
base_executable_home_path,
include_system_site_packages,
version,
};
tracing::trace!("Resolved metadata for virtual environment: {metadata:?}");
Ok(metadata)
}
/// Return a list of `site-packages` directories that are available from this virtual environment
///
/// See the documentation for `site_packages_dir_from_sys_prefix` for more details.
pub(crate) fn site_packages_directories(
&self,
system: &dyn System,
) -> SitePackagesDiscoveryResult<Vec<SystemPathBuf>> {
let VirtualEnvironment {
venv_path,
base_executable_home_path,
include_system_site_packages,
version,
} = self;
let mut site_packages_directories = vec![site_packages_directory_from_sys_prefix(
venv_path, *version, system,
)?];
if *include_system_site_packages {
let system_sys_prefix =
SysPrefixPath::from_executable_home_path(base_executable_home_path);
// If we fail to resolve the `sys.prefix` path from the base executable home path,
// or if we fail to resolve the `site-packages` from the `sys.prefix` path,
// we should probably print a warning but *not* abort type checking
if let Some(sys_prefix_path) = system_sys_prefix {
match site_packages_directory_from_sys_prefix(&sys_prefix_path, *version, system) {
Ok(site_packages_directory) => {
site_packages_directories.push(site_packages_directory);
}
Err(error) => tracing::warn!(
"{error}. System site-packages will not be used for module resolution."
),
}
} else {
tracing::warn!(
"Failed to resolve `sys.prefix` of the system Python installation \
from the `home` value in the `pyvenv.cfg` file at `{}`. \
System site-packages will not be used for module resolution.",
venv_path.join("pyvenv.cfg")
);
}
}
tracing::debug!("Resolved site-packages directories for this virtual environment are: {site_packages_directories:?}");
Ok(site_packages_directories)
}
}
#[derive(Debug, thiserror::Error)]
pub(crate) enum SitePackagesDiscoveryError {
#[error("Invalid {1}: `{0}` could not be canonicalized")]
VenvDirCanonicalizationError(SystemPathBuf, SysPrefixPathOrigin, #[source] io::Error),
#[error("Invalid {1}: `{0}` does not point to a directory on disk")]
VenvDirIsNotADirectory(SystemPathBuf, SysPrefixPathOrigin),
#[error("{0} points to a broken venv with no pyvenv.cfg file")]
NoPyvenvCfgFile(SysPrefixPathOrigin, #[source] io::Error),
#[error("Failed to parse the pyvenv.cfg file at {0} because {1}")]
PyvenvCfgParseError(SystemPathBuf, PyvenvCfgParseErrorKind),
#[error("Failed to search the `lib` directory of the Python installation at {1} for `site-packages`")]
CouldNotReadLibDirectory(#[source] io::Error, SysPrefixPath),
#[error("Could not find the `site-packages` directory for the Python installation at {0}")]
NoSitePackagesDirFound(SysPrefixPath),
}
/// The various ways in which parsing a `pyvenv.cfg` file could fail
#[derive(Debug)]
pub(crate) enum PyvenvCfgParseErrorKind {
TooManyEquals { line_number: NonZeroUsize },
MalformedKeyValuePair { line_number: NonZeroUsize },
NoHomeKey,
InvalidHomeValue(io::Error),
}
impl fmt::Display for PyvenvCfgParseErrorKind {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::TooManyEquals { line_number } => {
write!(f, "line {line_number} has too many '=' characters")
}
Self::MalformedKeyValuePair { line_number } => write!(
f,
"line {line_number} has a malformed `<key> = <value>` pair"
),
Self::NoHomeKey => f.write_str("the file does not have a `home` key"),
Self::InvalidHomeValue(io_err) => {
write!(
f,
"the following error was encountered \
when trying to resolve the `home` value to a directory on disk: {io_err}"
)
}
}
}
}
/// Attempt to retrieve the `site-packages` directory
/// associated with a given Python installation.
///
/// The location of the `site-packages` directory can vary according to the
/// Python version that this installation represents. The Python version may
/// or may not be known at this point, which is why the `python_version`
/// parameter is an `Option`.
fn site_packages_directory_from_sys_prefix(
sys_prefix_path: &SysPrefixPath,
python_version: Option<PythonVersion>,
system: &dyn System,
) -> SitePackagesDiscoveryResult<SystemPathBuf> {
tracing::debug!("Searching for site-packages directory in {sys_prefix_path}");
if cfg!(target_os = "windows") {
let site_packages = sys_prefix_path.join(r"Lib\site-packages");
return system
.is_directory(&site_packages)
.then_some(site_packages)
.ok_or(SitePackagesDiscoveryError::NoSitePackagesDirFound(
sys_prefix_path.to_owned(),
));
}
// In the Python standard library's `site.py` module (used for finding `site-packages`
// at runtime), we can find this in [the non-Windows branch]:
//
// ```py
// libdirs = [sys.platlibdir]
// if sys.platlibdir != "lib":
// libdirs.append("lib")
// ```
//
// Pyright therefore searches for both a `lib/python3.X/site-packages` directory
// and a `lib64/python3.X/site-packages` directory on non-MacOS Unix systems,
// since `sys.platlibdir` can sometimes be equal to `"lib64"`.
//
// However, we only care about the `site-packages` directory insofar as it allows
// us to discover Python source code that can be used for inferring type
// information regarding third-party dependencies. That means that we don't need
// to care about any possible `lib64/site-packages` directories, since
// [the `sys`-module documentation] states that `sys.platlibdir` is *only* ever
// used for C extensions, never for pure-Python modules.
//
// [the non-Windows branch]: https://github.com/python/cpython/blob/a8be8fc6c4682089be45a87bd5ee1f686040116c/Lib/site.py#L401-L410
// [the `sys`-module documentation]: https://docs.python.org/3/library/sys.html#sys.platlibdir
// If we were able to figure out what Python version this installation is,
// we should be able to avoid iterating through all items in the `lib/` directory:
if let Some(version) = python_version {
let expected_path = sys_prefix_path.join(format!("lib/python{version}/site-packages"));
if system.is_directory(&expected_path) {
return Ok(expected_path);
}
if version.free_threaded_build_available() {
// Nearly the same as `expected_path`, but with an additional `t` after {version}:
let alternative_path =
sys_prefix_path.join(format!("lib/python{version}t/site-packages"));
if system.is_directory(&alternative_path) {
return Ok(alternative_path);
}
}
}
// Either we couldn't figure out the version before calling this function
// (e.g., from a `pyvenv.cfg` file if this was a venv),
// or we couldn't find a `site-packages` folder at the expected location given
// the parsed version
//
// Note: the `python3.x` part of the `site-packages` path can't be computed from
// the `--python-version` the user has passed, as they might be running Python 3.12 locally
// even if they've requested that we type check their code "as if" they're running 3.8.
for entry_result in system
.read_directory(&sys_prefix_path.join("lib"))
.map_err(|io_err| {
SitePackagesDiscoveryError::CouldNotReadLibDirectory(io_err, sys_prefix_path.to_owned())
})?
{
let Ok(entry) = entry_result else {
continue;
};
if !entry.file_type().is_directory() {
continue;
}
let mut path = entry.into_path();
let name = path
.file_name()
.expect("File name to be non-null because path is guaranteed to be a child of `lib`");
if !name.starts_with("python3.") {
continue;
}
path.push("site-packages");
if system.is_directory(&path) {
return Ok(path);
}
}
Err(SitePackagesDiscoveryError::NoSitePackagesDirFound(
sys_prefix_path.to_owned(),
))
}
/// A path that represents the value of [`sys.prefix`] at runtime in Python
/// for a given Python executable.
///
/// For the case of a virtual environment, where a
/// Python binary is at `/.venv/bin/python`, `sys.prefix` is the path to
/// the virtual environment the Python binary lies inside, i.e. `/.venv`,
/// and `site-packages` will be at `.venv/lib/python3.X/site-packages`.
/// System Python installations generally work the same way: if a system
/// Python installation lies at `/opt/homebrew/bin/python`, `sys.prefix`
/// will be `/opt/homebrew`, and `site-packages` will be at
/// `/opt/homebrew/lib/python3.X/site-packages`.
///
/// [`sys.prefix`]: https://docs.python.org/3/library/sys.html#sys.prefix
#[derive(Debug, PartialEq, Eq, Clone)]
pub(crate) struct SysPrefixPath {
inner: SystemPathBuf,
origin: SysPrefixPathOrigin,
}
impl SysPrefixPath {
fn new(
unvalidated_path: impl AsRef<SystemPath>,
origin: SysPrefixPathOrigin,
system: &dyn System,
) -> SitePackagesDiscoveryResult<Self> {
Self::new_impl(unvalidated_path.as_ref(), origin, system)
}
fn new_impl(
unvalidated_path: &SystemPath,
origin: SysPrefixPathOrigin,
system: &dyn System,
) -> SitePackagesDiscoveryResult<Self> {
// It's important to resolve symlinks here rather than simply making the path absolute,
// since system Python installations often only put symlinks in the "expected"
// locations for `home` and `site-packages`
let canonicalized = system
.canonicalize_path(unvalidated_path)
.map_err(|io_err| {
SitePackagesDiscoveryError::VenvDirCanonicalizationError(
unvalidated_path.to_path_buf(),
origin,
io_err,
)
})?;
system
.is_directory(&canonicalized)
.then_some(Self {
inner: canonicalized,
origin,
})
.ok_or_else(|| {
SitePackagesDiscoveryError::VenvDirIsNotADirectory(
unvalidated_path.to_path_buf(),
origin,
)
})
}
fn from_executable_home_path(path: &PythonHomePath) -> Option<Self> {
// No need to check whether `path.parent()` is a directory:
// the parent of a canonicalised path that is known to exist
// is guaranteed to be a directory.
if cfg!(target_os = "windows") {
Some(Self {
inner: path.to_path_buf(),
origin: SysPrefixPathOrigin::Derived,
})
} else {
path.parent().map(|path| Self {
inner: path.to_path_buf(),
origin: SysPrefixPathOrigin::Derived,
})
}
}
}
impl Deref for SysPrefixPath {
type Target = SystemPath;
fn deref(&self) -> &Self::Target {
&self.inner
}
}
impl fmt::Display for SysPrefixPath {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "`sys.prefix` path `{}`", self.inner)
}
}
#[derive(Debug, PartialEq, Eq, Copy, Clone)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub enum SysPrefixPathOrigin {
PythonCliFlag,
VirtualEnvVar,
Derived,
LocalVenv,
}
impl Display for SysPrefixPathOrigin {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
match self {
Self::PythonCliFlag => f.write_str("`--python` argument"),
Self::VirtualEnvVar => f.write_str("`VIRTUAL_ENV` environment variable"),
Self::Derived => f.write_str("derived `sys.prefix` path"),
Self::LocalVenv => f.write_str("local virtual environment"),
}
}
}
/// The value given by the `home` key in `pyvenv.cfg` files.
///
/// This is equivalent to `{sys_prefix_path}/bin`, and points
/// to a directory in which a Python executable can be found.
/// Confusingly, it is *not* the same as the [`PYTHONHOME`]
/// environment variable that Python provides! However, it's
/// consistent among all mainstream creators of Python virtual
/// environments (the stdlib Python `venv` module, the third-party
/// `virtualenv` library, and `uv`), was specified by
/// [the original PEP adding the `venv` module],
/// and it's one of the few fields that's read by the Python
/// standard library's `site.py` module.
///
/// Although it doesn't appear to be specified anywhere,
/// all existing virtual environment tools always use an absolute path
/// for the `home` value, and the Python standard library also assumes
/// that the `home` value will be an absolute path.
///
/// Other values, such as the path to the Python executable or the
/// base-executable `sys.prefix` value, are either only provided in
/// `pyvenv.cfg` files by some virtual-environment creators,
/// or are included under different keys depending on which
/// virtual-environment creation tool you've used.
///
/// [`PYTHONHOME`]: https://docs.python.org/3/using/cmdline.html#envvar-PYTHONHOME
/// [the original PEP adding the `venv` module]: https://peps.python.org/pep-0405/
#[derive(Debug, PartialEq, Eq)]
struct PythonHomePath(SystemPathBuf);
impl PythonHomePath {
fn new(path: impl AsRef<SystemPath>, system: &dyn System) -> io::Result<Self> {
let path = path.as_ref();
// It's important to resolve symlinks here rather than simply making the path absolute,
// since system Python installations often only put symlinks in the "expected"
// locations for `home` and `site-packages`
let canonicalized = system.canonicalize_path(path)?;
system
.is_directory(&canonicalized)
.then_some(Self(canonicalized))
.ok_or_else(|| io::Error::new(io::ErrorKind::Other, "not a directory"))
}
}
impl Deref for PythonHomePath {
type Target = SystemPath;
fn deref(&self) -> &Self::Target {
&self.0
}
}
impl fmt::Display for PythonHomePath {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "`home` location `{}`", self.0)
}
}
impl PartialEq<SystemPath> for PythonHomePath {
fn eq(&self, other: &SystemPath) -> bool {
&*self.0 == other
}
}
impl PartialEq<SystemPathBuf> for PythonHomePath {
fn eq(&self, other: &SystemPathBuf) -> bool {
self == &**other
}
}
#[cfg(test)]
mod tests {
use ruff_db::system::TestSystem;
use super::*;
struct VirtualEnvironmentTester {
system: TestSystem,
minor_version: u8,
free_threaded: bool,
system_site_packages: bool,
pyvenv_cfg_version_field: Option<&'static str>,
}
impl VirtualEnvironmentTester {
/// Builds a mock virtual environment, and returns the path to the venv
fn build_mock_venv(&self) -> SystemPathBuf {
let VirtualEnvironmentTester {
system,
minor_version,
system_site_packages,
free_threaded,
pyvenv_cfg_version_field,
} = self;
let memory_fs = system.memory_file_system();
let unix_site_packages = if *free_threaded {
format!("lib/python3.{minor_version}t/site-packages")
} else {
format!("lib/python3.{minor_version}/site-packages")
};
let system_install_sys_prefix =
SystemPathBuf::from(&*format!("/Python3.{minor_version}"));
let (system_home_path, system_exe_path, system_site_packages_path) =
if cfg!(target_os = "windows") {
let system_home_path = system_install_sys_prefix.clone();
let system_exe_path = system_home_path.join("python.exe");
let system_site_packages_path =
system_install_sys_prefix.join(r"Lib\site-packages");
(system_home_path, system_exe_path, system_site_packages_path)
} else {
let system_home_path = system_install_sys_prefix.join("bin");
let system_exe_path = system_home_path.join("python");
let system_site_packages_path =
system_install_sys_prefix.join(&unix_site_packages);
(system_home_path, system_exe_path, system_site_packages_path)
};
memory_fs.write_file_all(system_exe_path, "").unwrap();
memory_fs
.create_directory_all(&system_site_packages_path)
.unwrap();
let venv_sys_prefix = SystemPathBuf::from("/.venv");
let (venv_exe, site_packages_path) = if cfg!(target_os = "windows") {
(
venv_sys_prefix.join(r"Scripts\python.exe"),
venv_sys_prefix.join(r"Lib\site-packages"),
)
} else {
(
venv_sys_prefix.join("bin/python"),
venv_sys_prefix.join(&unix_site_packages),
)
};
memory_fs.write_file_all(&venv_exe, "").unwrap();
memory_fs.create_directory_all(&site_packages_path).unwrap();
let pyvenv_cfg_path = venv_sys_prefix.join("pyvenv.cfg");
let mut pyvenv_cfg_contents = format!("home = {system_home_path}\n");
if let Some(version_field) = pyvenv_cfg_version_field {
pyvenv_cfg_contents.push_str(version_field);
pyvenv_cfg_contents.push('\n');
}
// Deliberately using weird casing here to test that our pyvenv.cfg parsing is case-insensitive:
if *system_site_packages {
pyvenv_cfg_contents.push_str("include-system-site-packages = TRuE\n");
}
memory_fs
.write_file_all(pyvenv_cfg_path, &pyvenv_cfg_contents)
.unwrap();
venv_sys_prefix
}
fn test(self) {
let venv_path = self.build_mock_venv();
let venv = VirtualEnvironment::new(
venv_path.clone(),
SysPrefixPathOrigin::VirtualEnvVar,
&self.system,
)
.unwrap();
assert_eq!(
venv.venv_path,
SysPrefixPath {
inner: self.system.canonicalize_path(&venv_path).unwrap(),
origin: SysPrefixPathOrigin::VirtualEnvVar,
}
);
assert_eq!(venv.include_system_site_packages, self.system_site_packages);
if self.pyvenv_cfg_version_field.is_some() {
assert_eq!(
venv.version,
Some(PythonVersion {
major: 3,
minor: self.minor_version
})
);
} else {
assert_eq!(venv.version, None);
}
let expected_home = if cfg!(target_os = "windows") {
SystemPathBuf::from(&*format!(r"\Python3.{}", self.minor_version))
} else {
SystemPathBuf::from(&*format!("/Python3.{}/bin", self.minor_version))
};
assert_eq!(venv.base_executable_home_path, expected_home);
let site_packages_directories = venv.site_packages_directories(&self.system).unwrap();
let expected_venv_site_packages = if cfg!(target_os = "windows") {
SystemPathBuf::from(r"\.venv\Lib\site-packages")
} else if self.free_threaded {
SystemPathBuf::from(&*format!(
"/.venv/lib/python3.{}t/site-packages",
self.minor_version
))
} else {
SystemPathBuf::from(&*format!(
"/.venv/lib/python3.{}/site-packages",
self.minor_version
))
};
let expected_system_site_packages = if cfg!(target_os = "windows") {
SystemPathBuf::from(&*format!(
r"\Python3.{}\Lib\site-packages",
self.minor_version
))
} else if self.free_threaded {
SystemPathBuf::from(&*format!(
"/Python3.{minor_version}/lib/python3.{minor_version}t/site-packages",
minor_version = self.minor_version
))
} else {
SystemPathBuf::from(&*format!(
"/Python3.{minor_version}/lib/python3.{minor_version}/site-packages",
minor_version = self.minor_version
))
};
if self.system_site_packages {
assert_eq!(
&site_packages_directories,
&[expected_venv_site_packages, expected_system_site_packages]
);
} else {
assert_eq!(&site_packages_directories, &[expected_venv_site_packages]);
}
}
}
#[test]
fn can_find_site_packages_directory_no_version_field_in_pyvenv_cfg() {
let tester = VirtualEnvironmentTester {
system: TestSystem::default(),
minor_version: 12,
free_threaded: false,
system_site_packages: false,
pyvenv_cfg_version_field: None,
};
tester.test();
}
#[test]
fn can_find_site_packages_directory_venv_style_version_field_in_pyvenv_cfg() {
let tester = VirtualEnvironmentTester {
system: TestSystem::default(),
minor_version: 12,
free_threaded: false,
system_site_packages: false,
pyvenv_cfg_version_field: Some("version = 3.12"),
};
tester.test();
}
#[test]
fn can_find_site_packages_directory_uv_style_version_field_in_pyvenv_cfg() {
let tester = VirtualEnvironmentTester {
system: TestSystem::default(),
minor_version: 12,
free_threaded: false,
system_site_packages: false,
pyvenv_cfg_version_field: Some("version_info = 3.12"),
};
tester.test();
}
#[test]
fn can_find_site_packages_directory_virtualenv_style_version_field_in_pyvenv_cfg() {
let tester = VirtualEnvironmentTester {
system: TestSystem::default(),
minor_version: 12,
free_threaded: false,
system_site_packages: false,
pyvenv_cfg_version_field: Some("version_info = 3.12.0rc2"),
};
tester.test();
}
#[test]
fn can_find_site_packages_directory_freethreaded_build() {
let tester = VirtualEnvironmentTester {
system: TestSystem::default(),
minor_version: 13,
free_threaded: true,
system_site_packages: false,
pyvenv_cfg_version_field: Some("version_info = 3.13"),
};
tester.test();
}
#[test]
fn finds_system_site_packages() {
let tester = VirtualEnvironmentTester {
system: TestSystem::default(),
minor_version: 13,
free_threaded: true,
system_site_packages: true,
pyvenv_cfg_version_field: Some("version_info = 3.13"),
};
tester.test();
}
#[test]
fn reject_venv_that_does_not_exist() {
let system = TestSystem::default();
assert!(matches!(
VirtualEnvironment::new("/.venv", SysPrefixPathOrigin::VirtualEnvVar, &system),
Err(SitePackagesDiscoveryError::VenvDirCanonicalizationError(..))
));
}
#[test]
fn reject_venv_that_is_not_a_directory() {
let system = TestSystem::default();
system
.memory_file_system()
.write_file_all("/.venv", "")
.unwrap();
assert!(matches!(
VirtualEnvironment::new("/.venv", SysPrefixPathOrigin::VirtualEnvVar, &system),
Err(SitePackagesDiscoveryError::VenvDirIsNotADirectory(..))
));
}
#[test]
fn reject_venv_with_no_pyvenv_cfg_file() {
let system = TestSystem::default();
system
.memory_file_system()
.create_directory_all("/.venv")
.unwrap();
assert!(matches!(
VirtualEnvironment::new("/.venv", SysPrefixPathOrigin::VirtualEnvVar, &system),
Err(SitePackagesDiscoveryError::NoPyvenvCfgFile(
SysPrefixPathOrigin::VirtualEnvVar,
_
))
));
}
#[test]
fn parsing_pyvenv_cfg_with_too_many_equals() {
let system = TestSystem::default();
let memory_fs = system.memory_file_system();
let pyvenv_cfg_path = SystemPathBuf::from("/.venv/pyvenv.cfg");
memory_fs
.write_file_all(&pyvenv_cfg_path, "home = bar = /.venv/bin")
.unwrap();
let venv_result =
VirtualEnvironment::new("/.venv", SysPrefixPathOrigin::VirtualEnvVar, &system);
assert!(matches!(
venv_result,
Err(SitePackagesDiscoveryError::PyvenvCfgParseError(
path,
PyvenvCfgParseErrorKind::TooManyEquals { line_number }
))
if path == pyvenv_cfg_path && Some(line_number) == NonZeroUsize::new(1)
));
}
#[test]
fn parsing_pyvenv_cfg_with_key_but_no_value_fails() {
let system = TestSystem::default();
let memory_fs = system.memory_file_system();
let pyvenv_cfg_path = SystemPathBuf::from("/.venv/pyvenv.cfg");
memory_fs
.write_file_all(&pyvenv_cfg_path, "home =")
.unwrap();
let venv_result =
VirtualEnvironment::new("/.venv", SysPrefixPathOrigin::VirtualEnvVar, &system);
assert!(matches!(
venv_result,
Err(SitePackagesDiscoveryError::PyvenvCfgParseError(
path,
PyvenvCfgParseErrorKind::MalformedKeyValuePair { line_number }
))
if path == pyvenv_cfg_path && Some(line_number) == NonZeroUsize::new(1)
));
}
#[test]
fn parsing_pyvenv_cfg_with_value_but_no_key_fails() {
let system = TestSystem::default();
let memory_fs = system.memory_file_system();
let pyvenv_cfg_path = SystemPathBuf::from("/.venv/pyvenv.cfg");
memory_fs
.write_file_all(&pyvenv_cfg_path, "= whatever")
.unwrap();
let venv_result =
VirtualEnvironment::new("/.venv", SysPrefixPathOrigin::VirtualEnvVar, &system);
assert!(matches!(
venv_result,
Err(SitePackagesDiscoveryError::PyvenvCfgParseError(
path,
PyvenvCfgParseErrorKind::MalformedKeyValuePair { line_number }
))
if path == pyvenv_cfg_path && Some(line_number) == NonZeroUsize::new(1)
));
}
#[test]
fn parsing_pyvenv_cfg_with_no_home_key_fails() {
let system = TestSystem::default();
let memory_fs = system.memory_file_system();
let pyvenv_cfg_path = SystemPathBuf::from("/.venv/pyvenv.cfg");
memory_fs.write_file_all(&pyvenv_cfg_path, "").unwrap();
let venv_result =
VirtualEnvironment::new("/.venv", SysPrefixPathOrigin::VirtualEnvVar, &system);
assert!(matches!(
venv_result,
Err(SitePackagesDiscoveryError::PyvenvCfgParseError(
path,
PyvenvCfgParseErrorKind::NoHomeKey
))
if path == pyvenv_cfg_path
));
}
#[test]
fn parsing_pyvenv_cfg_with_invalid_home_key_fails() {
let system = TestSystem::default();
let memory_fs = system.memory_file_system();
let pyvenv_cfg_path = SystemPathBuf::from("/.venv/pyvenv.cfg");
memory_fs
.write_file_all(&pyvenv_cfg_path, "home = foo")
.unwrap();
let venv_result =
VirtualEnvironment::new("/.venv", SysPrefixPathOrigin::VirtualEnvVar, &system);
assert!(matches!(
venv_result,
Err(SitePackagesDiscoveryError::PyvenvCfgParseError(
path,
PyvenvCfgParseErrorKind::InvalidHomeValue(_)
))
if path == pyvenv_cfg_path
));
}
}

File diff suppressed because it is too large Load diff

File diff suppressed because it is too large Load diff

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,984 @@
//! Smart builders for union and intersection types.
//!
//! Invariants we maintain here:
//! * No single-element union types (should just be the contained type instead.)
//! * No single-positive-element intersection types. Single-negative-element are OK, we don't
//! have a standalone negation type so there's no other representation for this.
//! * The same type should never appear more than once in a union or intersection. (This should
//! be expanded to cover subtyping -- see below -- but for now we only implement it for type
//! identity.)
//! * Disjunctive normal form (DNF): the tree of unions and intersections can never be deeper
//! than a union-of-intersections. Unions cannot contain other unions (the inner union just
//! flattens into the outer one), intersections cannot contain other intersections (also
//! flattens), and intersections cannot contain unions (the intersection distributes over the
//! union, inverting it into a union-of-intersections).
//! * No type in a union can be a subtype of any other type in the union (just eliminate the
//! subtype from the union).
//! * No type in an intersection can be a supertype of any other type in the intersection (just
//! eliminate the supertype from the intersection).
//! * An intersection containing two non-overlapping types simplifies to [`Type::Never`].
//!
//! The implication of these invariants is that a [`UnionBuilder`] does not necessarily build a
//! [`Type::Union`]. For example, if only one type is added to the [`UnionBuilder`], `build()` will
//! just return that type directly. The same is true for [`IntersectionBuilder`]; for example, if a
//! union type is added to the intersection, it will distribute and [`IntersectionBuilder::build`]
//! may end up returning a [`Type::Union`] of intersections.
//!
//! ## Performance
//!
//! In practice, there are two kinds of unions found in the wild: relatively-small unions made up
//! of normal user types (classes, etc), and large unions made up of literals, which can occur via
//! large enums (not yet implemented) or from string/integer/bytes literals, which can grow due to
//! literal arithmetic or operations on literal strings/bytes. For normal unions, it's most
//! efficient to just store the member types in a vector, and do O(n^2) `is_subtype_of` checks to
//! maintain the union in simplified form. But literal unions can grow to a size where this becomes
//! a performance problem. For this reason, we group literal types in `UnionBuilder`. Since every
//! different string literal type shares exactly the same possible super-types, and none of them
//! are subtypes of each other (unless exactly the same literal type), we can avoid many
//! unnecessary `is_subtype_of` checks.
use crate::types::{
BytesLiteralType, IntersectionType, KnownClass, StringLiteralType, Type,
TypeVarBoundOrConstraints, UnionType,
};
use crate::{Db, FxOrderSet};
use smallvec::SmallVec;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
enum LiteralKind {
Int,
String,
Bytes,
}
impl<'db> Type<'db> {
/// Return `true` if this type can be a supertype of some literals of `kind` and not others.
fn splits_literals(self, db: &'db dyn Db, kind: LiteralKind) -> bool {
match (self, kind) {
(Type::AlwaysFalsy | Type::AlwaysTruthy, _) => true,
(Type::StringLiteral(_), LiteralKind::String) => true,
(Type::BytesLiteral(_), LiteralKind::Bytes) => true,
(Type::IntLiteral(_), LiteralKind::Int) => true,
(Type::Intersection(intersection), _) => {
intersection
.positive(db)
.iter()
.any(|ty| ty.splits_literals(db, kind))
|| intersection
.negative(db)
.iter()
.any(|ty| ty.splits_literals(db, kind))
}
(Type::Union(union), _) => union
.elements(db)
.iter()
.any(|ty| ty.splits_literals(db, kind)),
_ => false,
}
}
}
enum UnionElement<'db> {
IntLiterals(FxOrderSet<i64>),
StringLiterals(FxOrderSet<StringLiteralType<'db>>),
BytesLiterals(FxOrderSet<BytesLiteralType<'db>>),
Type(Type<'db>),
}
impl<'db> UnionElement<'db> {
/// Try reducing this `UnionElement` given the presence in the same union of `other_type`.
///
/// If this `UnionElement` is a group of literals, filter the literals present if needed and
/// return `ReduceResult::KeepIf` with a boolean value indicating whether the remaining group
/// of literals should be kept in the union
///
/// If this `UnionElement` is some other type, return `ReduceResult::Type` so `UnionBuilder`
/// can perform more complex checks on it.
fn try_reduce(&mut self, db: &'db dyn Db, other_type: Type<'db>) -> ReduceResult<'db> {
match self {
UnionElement::IntLiterals(literals) => {
if other_type.splits_literals(db, LiteralKind::Int) {
let mut collapse = false;
let negated = other_type.negate(db);
literals.retain(|literal| {
let ty = Type::IntLiteral(*literal);
if negated.is_subtype_of(db, ty) {
collapse = true;
}
!ty.is_subtype_of(db, other_type)
});
if collapse {
ReduceResult::CollapseToObject
} else {
ReduceResult::KeepIf(!literals.is_empty())
}
} else {
ReduceResult::KeepIf(
!Type::IntLiteral(literals[0]).is_subtype_of(db, other_type),
)
}
}
UnionElement::StringLiterals(literals) => {
if other_type.splits_literals(db, LiteralKind::String) {
let mut collapse = false;
let negated = other_type.negate(db);
literals.retain(|literal| {
let ty = Type::StringLiteral(*literal);
if negated.is_subtype_of(db, ty) {
collapse = true;
}
!ty.is_subtype_of(db, other_type)
});
if collapse {
ReduceResult::CollapseToObject
} else {
ReduceResult::KeepIf(!literals.is_empty())
}
} else {
ReduceResult::KeepIf(
!Type::StringLiteral(literals[0]).is_subtype_of(db, other_type),
)
}
}
UnionElement::BytesLiterals(literals) => {
if other_type.splits_literals(db, LiteralKind::Bytes) {
let mut collapse = false;
let negated = other_type.negate(db);
literals.retain(|literal| {
let ty = Type::BytesLiteral(*literal);
if negated.is_subtype_of(db, ty) {
collapse = true;
}
!ty.is_subtype_of(db, other_type)
});
if collapse {
ReduceResult::CollapseToObject
} else {
ReduceResult::KeepIf(!literals.is_empty())
}
} else {
ReduceResult::KeepIf(
!Type::BytesLiteral(literals[0]).is_subtype_of(db, other_type),
)
}
}
UnionElement::Type(existing) => ReduceResult::Type(*existing),
}
}
}
enum ReduceResult<'db> {
/// Reduction of this `UnionElement` is complete; keep it in the union if the nested
/// boolean is true, eliminate it from the union if false.
KeepIf(bool),
/// Collapse this entire union to `object`.
CollapseToObject,
/// The given `Type` can stand-in for the entire `UnionElement` for further union
/// simplification checks.
Type(Type<'db>),
}
// TODO increase this once we extend `UnionElement` throughout all union/intersection
// representations, so that we can make large unions of literals fast in all operations.
const MAX_UNION_LITERALS: usize = 200;
pub(crate) struct UnionBuilder<'db> {
elements: Vec<UnionElement<'db>>,
db: &'db dyn Db,
}
impl<'db> UnionBuilder<'db> {
pub(crate) fn new(db: &'db dyn Db) -> Self {
Self {
db,
elements: vec![],
}
}
pub(crate) fn is_empty(&self) -> bool {
self.elements.is_empty()
}
/// Collapse the union to a single type: `object`.
fn collapse_to_object(&mut self) {
self.elements.clear();
self.elements
.push(UnionElement::Type(Type::object(self.db)));
}
/// Adds a type to this union.
pub(crate) fn add(mut self, ty: Type<'db>) -> Self {
self.add_in_place(ty);
self
}
/// Adds a type to this union.
pub(crate) fn add_in_place(&mut self, ty: Type<'db>) {
match ty {
Type::Union(union) => {
let new_elements = union.elements(self.db);
self.elements.reserve(new_elements.len());
for element in new_elements {
self.add_in_place(*element);
}
}
// Adding `Never` to a union is a no-op.
Type::Never => {}
// If adding a string literal, look for an existing `UnionElement::StringLiterals` to
// add it to, or an existing element that is a super-type of string literals, which
// means we shouldn't add it. Otherwise, add a new `UnionElement::StringLiterals`
// containing it.
Type::StringLiteral(literal) => {
let mut found = false;
let ty_negated = ty.negate(self.db);
for element in &mut self.elements {
match element {
UnionElement::StringLiterals(literals) => {
if literals.len() >= MAX_UNION_LITERALS {
let replace_with = KnownClass::Str.to_instance(self.db);
self.add_in_place(replace_with);
return;
}
literals.insert(literal);
found = true;
break;
}
UnionElement::Type(existing) => {
if ty.is_subtype_of(self.db, *existing) {
return;
}
if ty_negated.is_subtype_of(self.db, *existing) {
// The type that includes both this new element, and its negation
// (or a supertype of its negation), must be simply `object`.
self.collapse_to_object();
return;
}
}
_ => {}
}
}
if !found {
self.elements
.push(UnionElement::StringLiterals(FxOrderSet::from_iter([
literal,
])));
}
}
// Same for bytes literals as for string literals, above.
Type::BytesLiteral(literal) => {
let mut found = false;
let ty_negated = ty.negate(self.db);
for element in &mut self.elements {
match element {
UnionElement::BytesLiterals(literals) => {
if literals.len() >= MAX_UNION_LITERALS {
let replace_with = KnownClass::Bytes.to_instance(self.db);
self.add_in_place(replace_with);
return;
}
literals.insert(literal);
found = true;
break;
}
UnionElement::Type(existing) => {
if ty.is_subtype_of(self.db, *existing) {
return;
}
if ty_negated.is_subtype_of(self.db, *existing) {
// The type that includes both this new element, and its negation
// (or a supertype of its negation), must be simply `object`.
self.collapse_to_object();
return;
}
}
_ => {}
}
}
if !found {
self.elements
.push(UnionElement::BytesLiterals(FxOrderSet::from_iter([
literal,
])));
}
}
// And same for int literals as well.
Type::IntLiteral(literal) => {
let mut found = false;
let ty_negated = ty.negate(self.db);
for element in &mut self.elements {
match element {
UnionElement::IntLiterals(literals) => {
if literals.len() >= MAX_UNION_LITERALS {
let replace_with = KnownClass::Int.to_instance(self.db);
self.add_in_place(replace_with);
return;
}
literals.insert(literal);
found = true;
break;
}
UnionElement::Type(existing) => {
if ty.is_subtype_of(self.db, *existing) {
return;
}
if ty_negated.is_subtype_of(self.db, *existing) {
// The type that includes both this new element, and its negation
// (or a supertype of its negation), must be simply `object`.
self.collapse_to_object();
return;
}
}
_ => {}
}
}
if !found {
self.elements
.push(UnionElement::IntLiterals(FxOrderSet::from_iter([literal])));
}
}
// Adding `object` to a union results in `object`.
ty if ty.is_object(self.db) => {
self.collapse_to_object();
}
_ => {
let bool_pair = if let Type::BooleanLiteral(b) = ty {
Some(Type::BooleanLiteral(!b))
} else {
None
};
let mut to_add = ty;
let mut to_remove = SmallVec::<[usize; 2]>::new();
let ty_negated = ty.negate(self.db);
for (index, element) in self.elements.iter_mut().enumerate() {
let element_type = match element.try_reduce(self.db, ty) {
ReduceResult::KeepIf(keep) => {
if !keep {
to_remove.push(index);
}
continue;
}
ReduceResult::Type(ty) => ty,
ReduceResult::CollapseToObject => {
self.collapse_to_object();
return;
}
};
if Some(element_type) == bool_pair {
to_add = KnownClass::Bool.to_instance(self.db);
to_remove.push(index);
// The type we are adding is a BooleanLiteral, which doesn't have any
// subtypes. And we just found that the union already contained our
// mirror-image BooleanLiteral, so it can't also contain bool or any
// supertype of bool. Therefore, we are done.
break;
}
if ty.is_gradual_equivalent_to(self.db, element_type)
|| ty.is_subtype_of(self.db, element_type)
|| element_type.is_object(self.db)
{
return;
} else if element_type.is_subtype_of(self.db, ty) {
to_remove.push(index);
} else if ty_negated.is_subtype_of(self.db, element_type) {
// We add `ty` to the union. We just checked that `~ty` is a subtype of an
// existing `element`. This also means that `~ty | ty` is a subtype of
// `element | ty`, because both elements in the first union are subtypes of
// the corresponding elements in the second union. But `~ty | ty` is just
// `object`. Since `object` is a subtype of `element | ty`, we can only
// conclude that `element | ty` must be `object` (object has no other
// supertypes). This means we can simplify the whole union to just
// `object`, since all other potential elements would also be subtypes of
// `object`.
self.collapse_to_object();
return;
}
}
if let Some((&first, rest)) = to_remove.split_first() {
self.elements[first] = UnionElement::Type(to_add);
// We iterate in descending order to keep remaining indices valid after `swap_remove`.
for &index in rest.iter().rev() {
self.elements.swap_remove(index);
}
} else {
self.elements.push(UnionElement::Type(to_add));
}
}
}
}
pub(crate) fn build(self) -> Type<'db> {
let mut types = vec![];
for element in self.elements {
match element {
UnionElement::IntLiterals(literals) => {
types.extend(literals.into_iter().map(Type::IntLiteral));
}
UnionElement::StringLiterals(literals) => {
types.extend(literals.into_iter().map(Type::StringLiteral));
}
UnionElement::BytesLiterals(literals) => {
types.extend(literals.into_iter().map(Type::BytesLiteral));
}
UnionElement::Type(ty) => types.push(ty),
}
}
match types.len() {
0 => Type::Never,
1 => types[0],
_ => Type::Union(UnionType::new(self.db, types.into_boxed_slice())),
}
}
}
#[derive(Clone)]
pub(crate) struct IntersectionBuilder<'db> {
// Really this builds a union-of-intersections, because we always keep our set-theoretic types
// in disjunctive normal form (DNF), a union of intersections. In the simplest case there's
// just a single intersection in this vector, and we are building a single intersection type,
// but if a union is added to the intersection, we'll distribute ourselves over that union and
// create a union of intersections.
intersections: Vec<InnerIntersectionBuilder<'db>>,
db: &'db dyn Db,
}
impl<'db> IntersectionBuilder<'db> {
pub(crate) fn new(db: &'db dyn Db) -> Self {
Self {
db,
intersections: vec![InnerIntersectionBuilder::default()],
}
}
fn empty(db: &'db dyn Db) -> Self {
Self {
db,
intersections: vec![],
}
}
pub(crate) fn add_positive(mut self, ty: Type<'db>) -> Self {
if let Type::Union(union) = ty {
// Distribute ourself over this union: for each union element, clone ourself and
// intersect with that union element, then create a new union-of-intersections with all
// of those sub-intersections in it. E.g. if `self` is a simple intersection `T1 & T2`
// and we add `T3 | T4` to the intersection, we don't get `T1 & T2 & (T3 | T4)` (that's
// not in DNF), we distribute the union and get `(T1 & T3) | (T2 & T3) | (T1 & T4) |
// (T2 & T4)`. If `self` is already a union-of-intersections `(T1 & T2) | (T3 & T4)`
// and we add `T5 | T6` to it, that flattens all the way out to `(T1 & T2 & T5) | (T1 &
// T2 & T6) | (T3 & T4 & T5) ...` -- you get the idea.
union
.elements(self.db)
.iter()
.map(|elem| self.clone().add_positive(*elem))
.fold(IntersectionBuilder::empty(self.db), |mut builder, sub| {
builder.intersections.extend(sub.intersections);
builder
})
} else {
// If we are already a union-of-intersections, distribute the new intersected element
// across all of those intersections.
for inner in &mut self.intersections {
inner.add_positive(self.db, ty);
}
self
}
}
pub(crate) fn add_negative(mut self, ty: Type<'db>) -> Self {
// See comments above in `add_positive`; this is just the negated version.
if let Type::Union(union) = ty {
for elem in union.elements(self.db) {
self = self.add_negative(*elem);
}
self
} else if let Type::Intersection(intersection) = ty {
// (A | B) & ~(C & ~D)
// -> (A | B) & (~C | D)
// -> ((A | B) & ~C) | ((A | B) & D)
// i.e. if we have an intersection of positive constraints C
// and negative constraints D, then our new intersection
// is (existing & ~C) | (existing & D)
let positive_side = intersection
.positive(self.db)
.iter()
// we negate all the positive constraints while distributing
.map(|elem| self.clone().add_negative(*elem));
let negative_side = intersection
.negative(self.db)
.iter()
// all negative constraints end up becoming positive constraints
.map(|elem| self.clone().add_positive(*elem));
positive_side.chain(negative_side).fold(
IntersectionBuilder::empty(self.db),
|mut builder, sub| {
builder.intersections.extend(sub.intersections);
builder
},
)
} else {
for inner in &mut self.intersections {
inner.add_negative(self.db, ty);
}
self
}
}
pub(crate) fn build(mut self) -> Type<'db> {
// Avoid allocating the UnionBuilder unnecessarily if we have just one intersection:
if self.intersections.len() == 1 {
self.intersections.pop().unwrap().build(self.db)
} else {
UnionType::from_elements(
self.db,
self.intersections
.into_iter()
.map(|inner| inner.build(self.db)),
)
}
}
}
#[derive(Debug, Clone, Default)]
struct InnerIntersectionBuilder<'db> {
positive: FxOrderSet<Type<'db>>,
negative: FxOrderSet<Type<'db>>,
}
impl<'db> InnerIntersectionBuilder<'db> {
/// Adds a positive type to this intersection.
fn add_positive(&mut self, db: &'db dyn Db, mut new_positive: Type<'db>) {
match new_positive {
// `LiteralString & AlwaysTruthy` -> `LiteralString & ~Literal[""]`
Type::AlwaysTruthy if self.positive.contains(&Type::LiteralString) => {
self.add_negative(db, Type::string_literal(db, ""));
}
// `LiteralString & AlwaysFalsy` -> `Literal[""]`
Type::AlwaysFalsy if self.positive.swap_remove(&Type::LiteralString) => {
self.add_positive(db, Type::string_literal(db, ""));
}
// `AlwaysTruthy & LiteralString` -> `LiteralString & ~Literal[""]`
Type::LiteralString if self.positive.swap_remove(&Type::AlwaysTruthy) => {
self.add_positive(db, Type::LiteralString);
self.add_negative(db, Type::string_literal(db, ""));
}
// `AlwaysFalsy & LiteralString` -> `Literal[""]`
Type::LiteralString if self.positive.swap_remove(&Type::AlwaysFalsy) => {
self.add_positive(db, Type::string_literal(db, ""));
}
// `LiteralString & ~AlwaysTruthy` -> `LiteralString & AlwaysFalsy` -> `Literal[""]`
Type::LiteralString if self.negative.swap_remove(&Type::AlwaysTruthy) => {
self.add_positive(db, Type::string_literal(db, ""));
}
// `LiteralString & ~AlwaysFalsy` -> `LiteralString & ~Literal[""]`
Type::LiteralString if self.negative.swap_remove(&Type::AlwaysFalsy) => {
self.add_positive(db, Type::LiteralString);
self.add_negative(db, Type::string_literal(db, ""));
}
// `(A & B & ~C) & (D & E & ~F)` -> `A & B & D & E & ~C & ~F`
Type::Intersection(other) => {
for pos in other.positive(db) {
self.add_positive(db, *pos);
}
for neg in other.negative(db) {
self.add_negative(db, *neg);
}
}
_ => {
let known_instance = new_positive
.into_nominal_instance()
.and_then(|instance| instance.class().known(db));
if known_instance == Some(KnownClass::Object) {
// `object & T` -> `T`; it is always redundant to add `object` to an intersection
return;
}
let addition_is_bool_instance = known_instance == Some(KnownClass::Bool);
for (index, existing_positive) in self.positive.iter().enumerate() {
match existing_positive {
// `AlwaysTruthy & bool` -> `Literal[True]`
Type::AlwaysTruthy if addition_is_bool_instance => {
new_positive = Type::BooleanLiteral(true);
}
// `AlwaysFalsy & bool` -> `Literal[False]`
Type::AlwaysFalsy if addition_is_bool_instance => {
new_positive = Type::BooleanLiteral(false);
}
Type::NominalInstance(instance)
if instance.class().is_known(db, KnownClass::Bool) =>
{
match new_positive {
// `bool & AlwaysTruthy` -> `Literal[True]`
Type::AlwaysTruthy => {
new_positive = Type::BooleanLiteral(true);
}
// `bool & AlwaysFalsy` -> `Literal[False]`
Type::AlwaysFalsy => {
new_positive = Type::BooleanLiteral(false);
}
_ => continue,
}
}
_ => continue,
}
self.positive.swap_remove_index(index);
break;
}
if addition_is_bool_instance {
for (index, existing_negative) in self.negative.iter().enumerate() {
match existing_negative {
// `bool & ~Literal[False]` -> `Literal[True]`
// `bool & ~Literal[True]` -> `Literal[False]`
Type::BooleanLiteral(bool_value) => {
new_positive = Type::BooleanLiteral(!bool_value);
}
// `bool & ~AlwaysTruthy` -> `Literal[False]`
Type::AlwaysTruthy => {
new_positive = Type::BooleanLiteral(false);
}
// `bool & ~AlwaysFalsy` -> `Literal[True]`
Type::AlwaysFalsy => {
new_positive = Type::BooleanLiteral(true);
}
_ => continue,
}
self.negative.swap_remove_index(index);
break;
}
}
let mut to_remove = SmallVec::<[usize; 1]>::new();
for (index, existing_positive) in self.positive.iter().enumerate() {
// S & T = S if S <: T
if existing_positive.is_subtype_of(db, new_positive)
|| existing_positive.is_gradual_equivalent_to(db, new_positive)
{
return;
}
// same rule, reverse order
if new_positive.is_subtype_of(db, *existing_positive) {
to_remove.push(index);
}
// A & B = Never if A and B are disjoint
if new_positive.is_disjoint_from(db, *existing_positive) {
*self = Self::default();
self.positive.insert(Type::Never);
return;
}
}
for index in to_remove.into_iter().rev() {
self.positive.swap_remove_index(index);
}
let mut to_remove = SmallVec::<[usize; 1]>::new();
for (index, existing_negative) in self.negative.iter().enumerate() {
// S & ~T = Never if S <: T
if new_positive.is_subtype_of(db, *existing_negative) {
*self = Self::default();
self.positive.insert(Type::Never);
return;
}
// A & ~B = A if A and B are disjoint
if existing_negative.is_disjoint_from(db, new_positive) {
to_remove.push(index);
}
}
for index in to_remove.into_iter().rev() {
self.negative.swap_remove_index(index);
}
self.positive.insert(new_positive);
}
}
}
/// Adds a negative type to this intersection.
fn add_negative(&mut self, db: &'db dyn Db, new_negative: Type<'db>) {
let contains_bool = || {
self.positive
.iter()
.filter_map(|ty| ty.into_nominal_instance())
.filter_map(|instance| instance.class().known(db))
.any(KnownClass::is_bool)
};
match new_negative {
Type::Intersection(inter) => {
for pos in inter.positive(db) {
self.add_negative(db, *pos);
}
for neg in inter.negative(db) {
self.add_positive(db, *neg);
}
}
Type::Never => {
// Adding ~Never to an intersection is a no-op.
}
Type::NominalInstance(instance) if instance.class().is_object(db) => {
// Adding ~object to an intersection results in Never.
*self = Self::default();
self.positive.insert(Type::Never);
}
ty @ Type::Dynamic(_) => {
// Adding any of these types to the negative side of an intersection
// is equivalent to adding it to the positive side. We do this to
// simplify the representation.
self.add_positive(db, ty);
}
// `bool & ~AlwaysTruthy` -> `bool & Literal[False]`
// `bool & ~Literal[True]` -> `bool & Literal[False]`
Type::AlwaysTruthy | Type::BooleanLiteral(true) if contains_bool() => {
self.add_positive(db, Type::BooleanLiteral(false));
}
// `LiteralString & ~AlwaysTruthy` -> `LiteralString & Literal[""]`
Type::AlwaysTruthy if self.positive.contains(&Type::LiteralString) => {
self.add_positive(db, Type::string_literal(db, ""));
}
// `bool & ~AlwaysFalsy` -> `bool & Literal[True]`
// `bool & ~Literal[False]` -> `bool & Literal[True]`
Type::AlwaysFalsy | Type::BooleanLiteral(false) if contains_bool() => {
self.add_positive(db, Type::BooleanLiteral(true));
}
// `LiteralString & ~AlwaysFalsy` -> `LiteralString & ~Literal[""]`
Type::AlwaysFalsy if self.positive.contains(&Type::LiteralString) => {
self.add_negative(db, Type::string_literal(db, ""));
}
_ => {
let mut to_remove = SmallVec::<[usize; 1]>::new();
for (index, existing_negative) in self.negative.iter().enumerate() {
// ~S & ~T = ~T if S <: T
if existing_negative.is_subtype_of(db, new_negative)
|| existing_negative.is_gradual_equivalent_to(db, new_negative)
{
to_remove.push(index);
}
// same rule, reverse order
if new_negative.is_subtype_of(db, *existing_negative) {
return;
}
}
for index in to_remove.into_iter().rev() {
self.negative.swap_remove_index(index);
}
for existing_positive in &self.positive {
// S & ~T = Never if S <: T
if existing_positive.is_subtype_of(db, new_negative) {
*self = Self::default();
self.positive.insert(Type::Never);
return;
}
// A & ~B = A if A and B are disjoint
if existing_positive.is_disjoint_from(db, new_negative) {
return;
}
}
self.negative.insert(new_negative);
}
}
}
/// Tries to simplify any constrained typevars in the intersection:
///
/// - If the intersection contains a positive entry for exactly one of the constraints, we can
/// remove the typevar (effectively replacing it with that one positive constraint).
///
/// - If the intersection contains negative entries for all but one of the constraints, we can
/// remove the negative constraints and replace the typevar with the remaining positive
/// constraint.
///
/// - If the intersection contains negative entries for all of the constraints, the overall
/// intersection is `Never`.
fn simplify_constrained_typevars(&mut self, db: &'db dyn Db) {
let mut to_add = SmallVec::<[Type<'db>; 1]>::new();
let mut positive_to_remove = SmallVec::<[usize; 1]>::new();
for (typevar_index, ty) in self.positive.iter().enumerate() {
let Type::TypeVar(typevar) = ty else {
continue;
};
let Some(TypeVarBoundOrConstraints::Constraints(constraints)) =
typevar.bound_or_constraints(db)
else {
continue;
};
// Determine which constraints appear as positive entries in the intersection. Note
// that we shouldn't have duplicate entries in the positive or negative lists, so we
// don't need to worry about finding any particular constraint more than once.
let constraints = constraints.elements(db);
let mut positive_constraint_count = 0;
for positive in &self.positive {
// This linear search should be fine as long as we don't encounter typevars with
// thousands of constraints.
positive_constraint_count += constraints
.iter()
.filter(|c| c.is_subtype_of(db, *positive))
.count();
}
// If precisely one constraint appears as a positive element, we can replace the
// typevar with that positive constraint.
if positive_constraint_count == 1 {
positive_to_remove.push(typevar_index);
continue;
}
// Determine which constraints appear as negative entries in the intersection.
let mut to_remove = Vec::with_capacity(constraints.len());
let mut remaining_constraints: Vec<_> = constraints.iter().copied().map(Some).collect();
for (negative_index, negative) in self.negative.iter().enumerate() {
// This linear search should be fine as long as we don't encounter typevars with
// thousands of constraints.
let matching_constraints = constraints
.iter()
.enumerate()
.filter(|(_, c)| c.is_subtype_of(db, *negative));
for (constraint_index, _) in matching_constraints {
to_remove.push(negative_index);
remaining_constraints[constraint_index] = None;
}
}
let mut iter = remaining_constraints.into_iter().flatten();
let Some(remaining_constraint) = iter.next() else {
// All of the typevar constraints have been removed, so the entire intersection is
// `Never`.
*self = Self::default();
self.positive.insert(Type::Never);
return;
};
let more_than_one_remaining_constraint = iter.next().is_some();
if more_than_one_remaining_constraint {
// This typevar cannot be simplified.
continue;
}
// Only one typevar constraint remains. Remove all of the negative constraints, and
// replace the typevar itself with the remaining positive constraint.
to_add.push(remaining_constraint);
positive_to_remove.push(typevar_index);
}
// We don't need to sort the positive list, since we only append to it in increasing order.
for index in positive_to_remove.into_iter().rev() {
self.positive.swap_remove_index(index);
}
for remaining_constraint in to_add {
self.add_positive(db, remaining_constraint);
}
}
fn build(mut self, db: &'db dyn Db) -> Type<'db> {
self.simplify_constrained_typevars(db);
match (self.positive.len(), self.negative.len()) {
(0, 0) => Type::object(db),
(1, 0) => self.positive[0],
_ => {
self.positive.shrink_to_fit();
self.negative.shrink_to_fit();
Type::Intersection(IntersectionType::new(db, self.positive, self.negative))
}
}
}
}
#[cfg(test)]
mod tests {
use super::{IntersectionBuilder, Type, UnionBuilder, UnionType};
use crate::db::tests::setup_db;
use crate::types::{KnownClass, Truthiness};
use test_case::test_case;
#[test]
fn build_union_no_elements() {
let db = setup_db();
let empty_union = UnionBuilder::new(&db).build();
assert_eq!(empty_union, Type::Never);
}
#[test]
fn build_union_single_element() {
let db = setup_db();
let t0 = Type::IntLiteral(0);
let union = UnionType::from_elements(&db, [t0]);
assert_eq!(union, t0);
}
#[test]
fn build_union_two_elements() {
let db = setup_db();
let t0 = Type::IntLiteral(0);
let t1 = Type::IntLiteral(1);
let union = UnionType::from_elements(&db, [t0, t1]).expect_union();
assert_eq!(union.elements(&db), &[t0, t1]);
}
#[test]
fn build_intersection_empty_intersection_equals_object() {
let db = setup_db();
let intersection = IntersectionBuilder::new(&db).build();
assert_eq!(intersection, Type::object(&db));
}
#[test_case(Type::BooleanLiteral(true))]
#[test_case(Type::BooleanLiteral(false))]
#[test_case(Type::AlwaysTruthy)]
#[test_case(Type::AlwaysFalsy)]
fn build_intersection_simplify_split_bool(t_splitter: Type) {
let db = setup_db();
let bool_value = t_splitter.bool(&db) == Truthiness::AlwaysTrue;
// We add t_object in various orders (in first or second position) in
// the tests below to ensure that the boolean simplification eliminates
// everything from the intersection, not just `bool`.
let t_object = Type::object(&db);
let t_bool = KnownClass::Bool.to_instance(&db);
let ty = IntersectionBuilder::new(&db)
.add_positive(t_object)
.add_positive(t_bool)
.add_negative(t_splitter)
.build();
assert_eq!(ty, Type::BooleanLiteral(!bool_value));
let ty = IntersectionBuilder::new(&db)
.add_positive(t_bool)
.add_positive(t_object)
.add_negative(t_splitter)
.build();
assert_eq!(ty, Type::BooleanLiteral(!bool_value));
let ty = IntersectionBuilder::new(&db)
.add_positive(t_object)
.add_negative(t_splitter)
.add_positive(t_bool)
.build();
assert_eq!(ty, Type::BooleanLiteral(!bool_value));
let ty = IntersectionBuilder::new(&db)
.add_negative(t_splitter)
.add_positive(t_object)
.add_positive(t_bool)
.build();
assert_eq!(ty, Type::BooleanLiteral(!bool_value));
}
}

View file

@ -0,0 +1,69 @@
use super::context::InferContext;
use super::{CallableSignature, Signature, Signatures, Type};
use crate::Db;
mod arguments;
mod bind;
pub(super) use arguments::{Argument, CallArgumentTypes, CallArguments};
pub(super) use bind::{Bindings, CallableBinding};
/// Wraps a [`Bindings`] for an unsuccessful call with information about why the call was
/// unsuccessful.
///
/// The bindings are boxed so that we do not pass around large `Err` variants on the stack.
#[derive(Debug)]
pub(crate) struct CallError<'db>(pub(crate) CallErrorKind, pub(crate) Box<Bindings<'db>>);
/// The reason why calling a type failed.
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub(crate) enum CallErrorKind {
/// The type is not callable. For a union type, _none_ of the union elements are callable.
NotCallable,
/// The type is not callable with the given arguments.
///
/// `BindingError` takes precedence over `PossiblyNotCallable`: for a union type, there might
/// be some union elements that are not callable at all, but the call arguments are not
/// compatible with at least one of the callable elements.
BindingError,
/// Not all of the elements of a union type are callable, but the call arguments are compatible
/// with all of the callable elements.
PossiblyNotCallable,
}
#[derive(Debug)]
pub(super) enum CallDunderError<'db> {
/// The dunder attribute exists but it can't be called with the given arguments.
///
/// This includes non-callable dunder attributes that are possibly unbound.
CallError(CallErrorKind, Box<Bindings<'db>>),
/// The type has the specified dunder method and it is callable
/// with the specified arguments without any binding errors
/// but it is possibly unbound.
PossiblyUnbound(Box<Bindings<'db>>),
/// The dunder method with the specified name is missing.
MethodNotAvailable,
}
impl<'db> CallDunderError<'db> {
pub(super) fn return_type(&self, db: &'db dyn Db) -> Option<Type<'db>> {
match self {
Self::MethodNotAvailable | Self::CallError(CallErrorKind::NotCallable, _) => None,
Self::CallError(_, bindings) => Some(bindings.return_type(db)),
Self::PossiblyUnbound(bindings) => Some(bindings.return_type(db)),
}
}
pub(super) fn fallback_return_type(&self, db: &'db dyn Db) -> Type<'db> {
self.return_type(db).unwrap_or(Type::unknown())
}
}
impl<'db> From<CallError<'db>> for CallDunderError<'db> {
fn from(CallError(kind, bindings): CallError<'db>) -> Self {
Self::CallError(kind, bindings)
}
}

View file

@ -0,0 +1,124 @@
use std::borrow::Cow;
use std::ops::{Deref, DerefMut};
use super::Type;
/// Arguments for a single call, in source order.
#[derive(Clone, Debug, Default)]
pub(crate) struct CallArguments<'a>(Vec<Argument<'a>>);
impl<'a> CallArguments<'a> {
/// Prepend an optional extra synthetic argument (for a `self` or `cls` parameter) to the front
/// of this argument list. (If `bound_self` is none, we return the the argument list
/// unmodified.)
pub(crate) fn with_self(&self, bound_self: Option<Type<'_>>) -> Cow<Self> {
if bound_self.is_some() {
let arguments = std::iter::once(Argument::Synthetic)
.chain(self.0.iter().copied())
.collect();
Cow::Owned(CallArguments(arguments))
} else {
Cow::Borrowed(self)
}
}
pub(crate) fn len(&self) -> usize {
self.0.len()
}
pub(crate) fn iter(&self) -> impl Iterator<Item = Argument<'a>> + '_ {
self.0.iter().copied()
}
}
impl<'a> FromIterator<Argument<'a>> for CallArguments<'a> {
fn from_iter<T: IntoIterator<Item = Argument<'a>>>(iter: T) -> Self {
Self(iter.into_iter().collect())
}
}
#[derive(Clone, Copy, Debug)]
pub(crate) enum Argument<'a> {
/// The synthetic `self` or `cls` argument, which doesn't appear explicitly at the call site.
Synthetic,
/// A positional argument.
Positional,
/// A starred positional argument (e.g. `*args`).
Variadic,
/// A keyword argument (e.g. `a=1`).
Keyword(&'a str),
/// The double-starred keywords argument (e.g. `**kwargs`).
Keywords,
}
/// Arguments for a single call, in source order, along with inferred types for each argument.
#[derive(Clone, Debug, Default)]
pub(crate) struct CallArgumentTypes<'a, 'db> {
arguments: CallArguments<'a>,
types: Vec<Type<'db>>,
}
impl<'a, 'db> CallArgumentTypes<'a, 'db> {
/// Create a [`CallArgumentTypes`] with no arguments.
pub(crate) fn none() -> Self {
Self::default()
}
/// Create a [`CallArgumentTypes`] from an iterator over non-variadic positional argument
/// types.
pub(crate) fn positional(positional_tys: impl IntoIterator<Item = Type<'db>>) -> Self {
let types: Vec<_> = positional_tys.into_iter().collect();
let arguments = CallArguments(vec![Argument::Positional; types.len()]);
Self { arguments, types }
}
/// Create a new [`CallArgumentTypes`] to store the inferred types of the arguments in a
/// [`CallArguments`]. Uses the provided callback to infer each argument type.
pub(crate) fn new<F>(arguments: CallArguments<'a>, mut f: F) -> Self
where
F: FnMut(usize, Argument<'a>) -> Type<'db>,
{
let types = arguments
.iter()
.enumerate()
.map(|(idx, argument)| f(idx, argument))
.collect();
Self { arguments, types }
}
/// Prepend an optional extra synthetic argument (for a `self` or `cls` parameter) to the front
/// of this argument list. (If `bound_self` is none, we return the the argument list
/// unmodified.)
pub(crate) fn with_self(&self, bound_self: Option<Type<'db>>) -> Cow<Self> {
if let Some(bound_self) = bound_self {
let arguments = CallArguments(
std::iter::once(Argument::Synthetic)
.chain(self.arguments.0.iter().copied())
.collect(),
);
let types = std::iter::once(bound_self)
.chain(self.types.iter().copied())
.collect();
Cow::Owned(CallArgumentTypes { arguments, types })
} else {
Cow::Borrowed(self)
}
}
pub(crate) fn iter(&self) -> impl Iterator<Item = (Argument<'a>, Type<'db>)> + '_ {
self.arguments.iter().zip(self.types.iter().copied())
}
}
impl<'a> Deref for CallArgumentTypes<'a, '_> {
type Target = CallArguments<'a>;
fn deref(&self) -> &CallArguments<'a> {
&self.arguments
}
}
impl<'a> DerefMut for CallArgumentTypes<'a, '_> {
fn deref_mut(&mut self) -> &mut CallArguments<'a> {
&mut self.arguments
}
}

File diff suppressed because it is too large Load diff

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,283 @@
use crate::types::generics::GenericContext;
use crate::types::{
todo_type, ClassType, DynamicType, KnownClass, KnownInstanceType, MroIterator, Type,
};
use crate::Db;
/// Enumeration of the possible kinds of types we allow in class bases.
///
/// This is much more limited than the [`Type`] enum: all types that would be invalid to have as a
/// class base are transformed into [`ClassBase::unknown()`]
///
/// Note that a non-specialized generic class _cannot_ be a class base. When we see a
/// non-specialized generic class in any type expression (including the list of base classes), we
/// automatically construct the default specialization for that class.
#[derive(Debug, Copy, Clone, Hash, PartialEq, Eq, salsa::Update)]
pub enum ClassBase<'db> {
Dynamic(DynamicType),
Class(ClassType<'db>),
/// Although `Protocol` is not a class in typeshed's stubs, it is at runtime,
/// and can appear in the MRO of a class.
Protocol,
/// Bare `Generic` cannot be subclassed directly in user code,
/// but nonetheless appears in the MRO of classes that inherit from `Generic[T]`,
/// `Protocol[T]`, or bare `Protocol`.
Generic(Option<GenericContext<'db>>),
}
impl<'db> ClassBase<'db> {
pub(crate) const fn any() -> Self {
Self::Dynamic(DynamicType::Any)
}
pub(crate) const fn unknown() -> Self {
Self::Dynamic(DynamicType::Unknown)
}
pub(crate) fn display(self, db: &'db dyn Db) -> impl std::fmt::Display + 'db {
struct Display<'db> {
base: ClassBase<'db>,
db: &'db dyn Db,
}
impl std::fmt::Display for Display<'_> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self.base {
ClassBase::Dynamic(dynamic) => dynamic.fmt(f),
ClassBase::Class(class @ ClassType::NonGeneric(_)) => {
write!(f, "<class '{}'>", class.name(self.db))
}
ClassBase::Class(ClassType::Generic(alias)) => {
write!(f, "<class '{}'>", alias.display(self.db))
}
ClassBase::Protocol => f.write_str("typing.Protocol"),
ClassBase::Generic(generic_context) => {
f.write_str("typing.Generic")?;
if let Some(generic_context) = generic_context {
write!(f, "{}", generic_context.display(self.db))?;
}
Ok(())
}
}
}
}
Display { base: self, db }
}
/// Return a `ClassBase` representing the class `builtins.object`
pub(super) fn object(db: &'db dyn Db) -> Self {
KnownClass::Object
.to_class_literal(db)
.to_class_type(db)
.map_or(Self::unknown(), Self::Class)
}
/// Attempt to resolve `ty` into a `ClassBase`.
///
/// Return `None` if `ty` is not an acceptable type for a class base.
pub(super) fn try_from_type(db: &'db dyn Db, ty: Type<'db>) -> Option<Self> {
match ty {
Type::Dynamic(dynamic) => Some(Self::Dynamic(dynamic)),
Type::ClassLiteral(literal) => {
if literal.is_known(db, KnownClass::Any) {
Some(Self::Dynamic(DynamicType::Any))
} else if literal.is_known(db, KnownClass::NamedTuple) {
Self::try_from_type(db, KnownClass::Tuple.to_class_literal(db))
} else {
Some(Self::Class(literal.default_specialization(db)))
}
}
Type::GenericAlias(generic) => Some(Self::Class(ClassType::Generic(generic))),
Type::NominalInstance(instance)
if instance.class().is_known(db, KnownClass::GenericAlias) =>
{
Self::try_from_type(db, todo_type!("GenericAlias instance"))
}
Type::Union(_) => None, // TODO -- forces consideration of multiple possible MROs?
Type::Intersection(_) => None, // TODO -- probably incorrect?
Type::NominalInstance(_) => None, // TODO -- handle `__mro_entries__`?
Type::PropertyInstance(_) => None,
Type::Never
| Type::BooleanLiteral(_)
| Type::FunctionLiteral(_)
| Type::Callable(..)
| Type::BoundMethod(_)
| Type::MethodWrapper(_)
| Type::WrapperDescriptor(_)
| Type::DataclassDecorator(_)
| Type::DataclassTransformer(_)
| Type::BytesLiteral(_)
| Type::IntLiteral(_)
| Type::StringLiteral(_)
| Type::LiteralString
| Type::Tuple(_)
| Type::SliceLiteral(_)
| Type::ModuleLiteral(_)
| Type::SubclassOf(_)
| Type::TypeVar(_)
| Type::BoundSuper(_)
| Type::ProtocolInstance(_)
| Type::AlwaysFalsy
| Type::AlwaysTruthy => None,
Type::KnownInstance(known_instance) => match known_instance {
KnownInstanceType::TypeVar(_)
| KnownInstanceType::TypeAliasType(_)
| KnownInstanceType::Annotated
| KnownInstanceType::Literal
| KnownInstanceType::LiteralString
| KnownInstanceType::Union
| KnownInstanceType::NoReturn
| KnownInstanceType::Never
| KnownInstanceType::Final
| KnownInstanceType::NotRequired
| KnownInstanceType::TypeGuard
| KnownInstanceType::TypeIs
| KnownInstanceType::TypingSelf
| KnownInstanceType::Unpack
| KnownInstanceType::ClassVar
| KnownInstanceType::Concatenate
| KnownInstanceType::Required
| KnownInstanceType::TypeAlias
| KnownInstanceType::ReadOnly
| KnownInstanceType::Optional
| KnownInstanceType::Not
| KnownInstanceType::Intersection
| KnownInstanceType::TypeOf
| KnownInstanceType::CallableTypeOf
| KnownInstanceType::AlwaysTruthy
| KnownInstanceType::AlwaysFalsy => None,
KnownInstanceType::Unknown => Some(Self::unknown()),
KnownInstanceType::Any => Some(Self::any()),
// TODO: Classes inheriting from `typing.Type` et al. also have `Generic` in their MRO
KnownInstanceType::Dict => {
Self::try_from_type(db, KnownClass::Dict.to_class_literal(db))
}
KnownInstanceType::List => {
Self::try_from_type(db, KnownClass::List.to_class_literal(db))
}
KnownInstanceType::Type => {
Self::try_from_type(db, KnownClass::Type.to_class_literal(db))
}
KnownInstanceType::Tuple => {
Self::try_from_type(db, KnownClass::Tuple.to_class_literal(db))
}
KnownInstanceType::Set => {
Self::try_from_type(db, KnownClass::Set.to_class_literal(db))
}
KnownInstanceType::FrozenSet => {
Self::try_from_type(db, KnownClass::FrozenSet.to_class_literal(db))
}
KnownInstanceType::ChainMap => {
Self::try_from_type(db, KnownClass::ChainMap.to_class_literal(db))
}
KnownInstanceType::Counter => {
Self::try_from_type(db, KnownClass::Counter.to_class_literal(db))
}
KnownInstanceType::DefaultDict => {
Self::try_from_type(db, KnownClass::DefaultDict.to_class_literal(db))
}
KnownInstanceType::Deque => {
Self::try_from_type(db, KnownClass::Deque.to_class_literal(db))
}
KnownInstanceType::OrderedDict => {
Self::try_from_type(db, KnownClass::OrderedDict.to_class_literal(db))
}
KnownInstanceType::TypedDict => Self::try_from_type(db, todo_type!("TypedDict")),
KnownInstanceType::Callable => {
Self::try_from_type(db, todo_type!("Support for Callable as a base class"))
}
KnownInstanceType::Protocol => Some(ClassBase::Protocol),
KnownInstanceType::Generic(generic_context) => {
Some(ClassBase::Generic(generic_context))
}
},
}
}
pub(super) fn into_class(self) -> Option<ClassType<'db>> {
match self {
Self::Class(class) => Some(class),
Self::Dynamic(_) | Self::Generic(_) | Self::Protocol => None,
}
}
/// Iterate over the MRO of this base
pub(super) fn mro(self, db: &'db dyn Db) -> impl Iterator<Item = ClassBase<'db>> {
match self {
ClassBase::Protocol => {
ClassBaseMroIterator::length_3(db, self, ClassBase::Generic(None))
}
ClassBase::Dynamic(DynamicType::SubscriptedProtocol) => {
ClassBaseMroIterator::length_3(db, self, ClassBase::Generic(None))
}
ClassBase::Dynamic(_) | ClassBase::Generic(_) => {
ClassBaseMroIterator::length_2(db, self)
}
ClassBase::Class(class) => ClassBaseMroIterator::from_class(db, class),
}
}
}
impl<'db> From<ClassType<'db>> for ClassBase<'db> {
fn from(value: ClassType<'db>) -> Self {
ClassBase::Class(value)
}
}
impl<'db> From<ClassBase<'db>> for Type<'db> {
fn from(value: ClassBase<'db>) -> Self {
match value {
ClassBase::Dynamic(dynamic) => Type::Dynamic(dynamic),
ClassBase::Class(class) => class.into(),
ClassBase::Protocol => Type::KnownInstance(KnownInstanceType::Protocol),
ClassBase::Generic(generic_context) => {
Type::KnownInstance(KnownInstanceType::Generic(generic_context))
}
}
}
}
impl<'db> From<&ClassBase<'db>> for Type<'db> {
fn from(value: &ClassBase<'db>) -> Self {
Self::from(*value)
}
}
/// An iterator over the MRO of a class base.
enum ClassBaseMroIterator<'db> {
Length2(core::array::IntoIter<ClassBase<'db>, 2>),
Length3(core::array::IntoIter<ClassBase<'db>, 3>),
FromClass(MroIterator<'db>),
}
impl<'db> ClassBaseMroIterator<'db> {
/// Iterate over an MRO of length 2 that consists of `first_element` and then `object`.
fn length_2(db: &'db dyn Db, first_element: ClassBase<'db>) -> Self {
ClassBaseMroIterator::Length2([first_element, ClassBase::object(db)].into_iter())
}
/// Iterate over an MRO of length 3 that consists of `first_element`, then `second_element`, then `object`.
fn length_3(db: &'db dyn Db, element_1: ClassBase<'db>, element_2: ClassBase<'db>) -> Self {
ClassBaseMroIterator::Length3([element_1, element_2, ClassBase::object(db)].into_iter())
}
/// Iterate over the MRO of an arbitrary class. The MRO may be of any length.
fn from_class(db: &'db dyn Db, class: ClassType<'db>) -> Self {
ClassBaseMroIterator::FromClass(class.iter_mro(db))
}
}
impl<'db> Iterator for ClassBaseMroIterator<'db> {
type Item = ClassBase<'db>;
fn next(&mut self) -> Option<Self::Item> {
match self {
Self::Length2(iter) => iter.next(),
Self::Length3(iter) => iter.next(),
Self::FromClass(iter) => iter.next(),
}
}
}
impl std::iter::FusedIterator for ClassBaseMroIterator<'_> {}

View file

@ -0,0 +1,542 @@
use std::fmt;
use drop_bomb::DebugDropBomb;
use ruff_db::{
diagnostic::{Annotation, Diagnostic, DiagnosticId, IntoDiagnosticMessage, Severity, Span},
files::File,
};
use ruff_text_size::{Ranged, TextRange};
use super::{binding_type, Type, TypeCheckDiagnostics};
use crate::semantic_index::symbol::ScopeId;
use crate::{
lint::{LintId, LintMetadata},
suppression::suppressions,
Db,
};
use crate::{semantic_index::semantic_index, types::FunctionDecorators};
/// Context for inferring the types of a single file.
///
/// One context exists for at least for every inferred region but it's
/// possible that inferring a sub-region, like an unpack assignment, creates
/// a sub-context.
///
/// Tracks the reported diagnostics of the inferred region.
///
/// ## Consuming
/// It's important that the context is explicitly consumed before dropping by calling
/// [`InferContext::finish`] and the returned diagnostics must be stored
/// on the current [`TypeInference`](super::infer::TypeInference) result.
pub(crate) struct InferContext<'db> {
db: &'db dyn Db,
scope: ScopeId<'db>,
file: File,
diagnostics: std::cell::RefCell<TypeCheckDiagnostics>,
no_type_check: InNoTypeCheck,
bomb: DebugDropBomb,
}
impl<'db> InferContext<'db> {
pub(crate) fn new(db: &'db dyn Db, scope: ScopeId<'db>) -> Self {
Self {
db,
scope,
file: scope.file(db),
diagnostics: std::cell::RefCell::new(TypeCheckDiagnostics::default()),
no_type_check: InNoTypeCheck::default(),
bomb: DebugDropBomb::new("`InferContext` needs to be explicitly consumed by calling `::finish` to prevent accidental loss of diagnostics."),
}
}
/// The file for which the types are inferred.
pub(crate) fn file(&self) -> File {
self.file
}
/// Create a span with the range of the given expression
/// in the file being currently type checked.
///
/// If you're creating a diagnostic with snippets in files
/// other than this one, you should create the span directly
/// and not use this convenience API.
pub(crate) fn span<T: Ranged>(&self, ranged: T) -> Span {
Span::from(self.file()).with_range(ranged.range())
}
/// Create a secondary annotation attached to the range of the given value in
/// the file currently being type checked.
///
/// The annotation returned has no message attached to it.
pub(crate) fn secondary<T: Ranged>(&self, ranged: T) -> Annotation {
Annotation::secondary(self.span(ranged))
}
pub(crate) fn db(&self) -> &'db dyn Db {
self.db
}
pub(crate) fn extend(&mut self, other: &TypeCheckDiagnostics) {
self.diagnostics.get_mut().extend(other);
}
/// Optionally return a builder for a lint diagnostic guard.
///
/// If the current context believes a diagnostic should be reported for
/// the given lint, then a builder is returned that enables building a
/// lint diagnostic guard. The guard can then be used, via its `DerefMut`
/// implementation, to directly mutate a `Diagnostic`.
///
/// The severity of the diagnostic returned is automatically determined
/// by the given lint and configuration. The message given to
/// `LintDiagnosticGuardBuilder::to_diagnostic` is used to construct the
/// initial diagnostic and should be considered the "top-level message" of
/// the diagnostic. (i.e., If nothing else about the diagnostic is seen,
/// aside from its identifier, the message is probably the thing you'd pick
/// to show.)
///
/// The diagnostic constructed also includes a primary annotation with a
/// `Span` derived from the range given attached to the `File` in this
/// typing context. (That means the range given _must_ be valid for the
/// `File` currently being type checked.) This primary annotation does
/// not have a message attached to it, but callers can attach one via
/// `LintDiagnosticGuard::set_primary_message`.
///
/// After using the builder to make a guard, once the guard is dropped, the
/// diagnostic is added to the context, unless there is something in the
/// diagnostic that excludes it. (Currently, no such conditions exist.)
///
/// If callers need to create a non-lint diagnostic, you'll want to use the
/// lower level `InferContext::report_diagnostic` routine.
pub(super) fn report_lint<'ctx, T: Ranged>(
&'ctx self,
lint: &'static LintMetadata,
ranged: T,
) -> Option<LintDiagnosticGuardBuilder<'ctx, 'db>> {
LintDiagnosticGuardBuilder::new(self, lint, ranged.range())
}
/// Optionally return a builder for a diagnostic guard.
///
/// This only returns a builder if the current context allows a diagnostic
/// with the given information to be added. In general, the requirements
/// here are quite a bit less than for `InferContext::report_lint`, since
/// this routine doesn't take rule selection into account (among other
/// things).
///
/// After using the builder to make a guard, once the guard is dropped, the
/// diagnostic is added to the context, unless there is something in the
/// diagnostic that excludes it. (Currently, no such conditions exist.)
///
/// Callers should generally prefer adding a lint diagnostic via
/// `InferContext::report_lint` whenever possible.
pub(super) fn report_diagnostic<'ctx>(
&'ctx self,
id: DiagnosticId,
severity: Severity,
) -> Option<DiagnosticGuardBuilder<'ctx, 'db>> {
DiagnosticGuardBuilder::new(self, id, severity)
}
pub(super) fn set_in_no_type_check(&mut self, no_type_check: InNoTypeCheck) {
self.no_type_check = no_type_check;
}
fn is_in_no_type_check(&self) -> bool {
match self.no_type_check {
InNoTypeCheck::Possibly => {
// Accessing the semantic index here is fine because
// the index belongs to the same file as for which we emit the diagnostic.
let index = semantic_index(self.db, self.file);
let scope_id = self.scope.file_scope_id(self.db);
// Inspect all ancestor function scopes by walking bottom up and infer the function's type.
let mut function_scope_tys = index
.ancestor_scopes(scope_id)
.filter_map(|(_, scope)| scope.node().as_function())
.map(|node| binding_type(self.db, index.expect_single_definition(node)))
.filter_map(Type::into_function_literal);
// Iterate over all functions and test if any is decorated with `@no_type_check`.
function_scope_tys.any(|function_ty| {
function_ty.has_known_decorator(self.db, FunctionDecorators::NO_TYPE_CHECK)
})
}
InNoTypeCheck::Yes => true,
}
}
/// Are we currently inferring types in a stub file?
pub(crate) fn in_stub(&self) -> bool {
self.file.is_stub(self.db().upcast())
}
#[must_use]
pub(crate) fn finish(mut self) -> TypeCheckDiagnostics {
self.bomb.defuse();
let mut diagnostics = self.diagnostics.into_inner();
diagnostics.shrink_to_fit();
diagnostics
}
}
impl fmt::Debug for InferContext<'_> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
f.debug_struct("TyContext")
.field("file", &self.file)
.field("diagnostics", &self.diagnostics)
.field("defused", &self.bomb)
.finish()
}
}
#[derive(Copy, Clone, Debug, PartialEq, Eq, Default)]
pub(crate) enum InNoTypeCheck {
/// The inference might be in a `no_type_check` block but only if any
/// ancestor function is decorated with `@no_type_check`.
#[default]
Possibly,
/// The inference is known to be in an `@no_type_check` decorated function.
Yes,
}
/// An abstraction for mutating a diagnostic through the lense of a lint.
///
/// Callers can build this guard by starting with `InferContext::report_lint`.
///
/// There are two primary functions of this guard, which mutably derefs to
/// a `Diagnostic`:
///
/// * On `Drop`, the underlying diagnostic is added to the typing context.
/// * Some convenience methods for mutating the underlying `Diagnostic`
/// in lint context. For example, `LintDiagnosticGuard::set_primary_message`
/// will attach a message to the primary span on the diagnostic.
pub(super) struct LintDiagnosticGuard<'db, 'ctx> {
/// The typing context.
ctx: &'ctx InferContext<'db>,
/// The diagnostic that we want to report.
///
/// This is always `Some` until the `Drop` impl.
diag: Option<Diagnostic>,
}
impl LintDiagnosticGuard<'_, '_> {
/// Set the message on the primary annotation for this diagnostic.
///
/// If a message already exists on the primary annotation, then this
/// overwrites the existing message.
///
/// This message is associated with the primary annotation created
/// for every `Diagnostic` that uses the `LintDiagnosticGuard` API.
/// Specifically, the annotation is derived from the `TextRange` given to
/// the `InferContext::report_lint` API.
///
/// Callers can add additional primary or secondary annotations via the
/// `DerefMut` trait implementation to a `Diagnostic`.
pub(super) fn set_primary_message(&mut self, message: impl IntoDiagnosticMessage) {
// N.B. It is normally bad juju to define `self` methods
// on types that implement `Deref`. Instead, it's idiomatic
// to do `fn foo(this: &mut LintDiagnosticGuard)`, which in
// turn forces callers to use
// `LintDiagnosticGuard(&mut guard, message)`. But this is
// supremely annoying for what is expected to be a common
// case.
//
// Moreover, most of the downside that comes from these sorts
// of methods is a semver hazard. Because the deref target type
// could also define a method by the same name, and that leads
// to confusion. But we own all the code involved here and
// there is no semver boundary. So... ¯\_(ツ)_/¯ ---AG
// OK because we know the diagnostic was constructed with a single
// primary annotation that will always come before any other annotation
// in the diagnostic. (This relies on the `Diagnostic` API not exposing
// any methods for removing annotations or re-ordering them, which is
// true as of 2025-04-11.)
let ann = self.primary_annotation_mut().unwrap();
ann.set_message(message);
}
}
impl std::ops::Deref for LintDiagnosticGuard<'_, '_> {
type Target = Diagnostic;
fn deref(&self) -> &Diagnostic {
// OK because `self.diag` is only `None` within `Drop`.
self.diag.as_ref().unwrap()
}
}
/// Return a mutable borrow of the diagnostic in this guard.
///
/// Callers may mutate the diagnostic to add new sub-diagnostics
/// or annotations.
///
/// The diagnostic is added to the typing context, if appropriate,
/// when this guard is dropped.
impl std::ops::DerefMut for LintDiagnosticGuard<'_, '_> {
fn deref_mut(&mut self) -> &mut Diagnostic {
// OK because `self.diag` is only `None` within `Drop`.
self.diag.as_mut().unwrap()
}
}
/// Finishes use of this guard.
///
/// This will add the lint as a diagnostic to the typing context if
/// appropriate. The diagnostic may be skipped, for example, if there is a
/// relevant suppression.
impl Drop for LintDiagnosticGuard<'_, '_> {
fn drop(&mut self) {
// OK because the only way `self.diag` is `None`
// is via this impl, which can only run at most
// once.
let diag = self.diag.take().unwrap();
self.ctx.diagnostics.borrow_mut().push(diag);
}
}
/// A builder for constructing a lint diagnostic guard.
///
/// This type exists to separate the phases of "check if a diagnostic should
/// be reported" and "build the actual diagnostic." It's why, for example,
/// `InferContext::report_lint` only requires a `LintMetadata` (and a range),
/// but this builder further requires a message before one can mutate the
/// diagnostic. This is because the `LintMetadata` can be used to derive
/// the diagnostic ID and its severity (based on configuration). Combined
/// with a message you get the minimum amount of data required to build a
/// `Diagnostic`.
///
/// Additionally, the range is used to construct a primary annotation (without
/// a message) using the file current being type checked. The range given to
/// `InferContext::report_lint` must be from the file currently being type
/// checked.
///
/// If callers need to report a diagnostic with an identifier type other
/// than `DiagnosticId::Lint`, then they should use the more general
/// `InferContext::report_diagnostic` API. But note that this API will not take
/// rule selection or suppressions into account.
///
/// # When is the diagnostic added?
///
/// When a builder is not returned by `InferContext::report_lint`, then
/// it is known that the diagnostic should not be reported. This can happen
/// when the diagnostic is disabled or suppressed (among other reasons).
pub(super) struct LintDiagnosticGuardBuilder<'db, 'ctx> {
ctx: &'ctx InferContext<'db>,
id: DiagnosticId,
severity: Severity,
primary_span: Span,
}
impl<'db, 'ctx> LintDiagnosticGuardBuilder<'db, 'ctx> {
fn new(
ctx: &'ctx InferContext<'db>,
lint: &'static LintMetadata,
range: TextRange,
) -> Option<LintDiagnosticGuardBuilder<'db, 'ctx>> {
// The comment below was copied from the original
// implementation of diagnostic reporting. The code
// has been refactored, but this still kind of looked
// relevant, so I've preserved the note. ---AG
//
// TODO: Don't emit the diagnostic if:
// * The enclosing node contains any syntax errors
// * The rule is disabled for this file. We probably want to introduce a new query that
// returns a rule selector for a given file that respects the package's settings,
// any global pragma comments in the file, and any per-file-ignores.
if !ctx.db.is_file_open(ctx.file) {
return None;
}
let lint_id = LintId::of(lint);
// Skip over diagnostics if the rule
// is disabled.
let severity = ctx.db.rule_selection().severity(lint_id)?;
// If we're not in type checking mode,
// we can bail now.
if ctx.is_in_no_type_check() {
return None;
}
let id = DiagnosticId::Lint(lint.name());
let suppressions = suppressions(ctx.db(), ctx.file());
if let Some(suppression) = suppressions.find_suppression(range, lint_id) {
ctx.diagnostics.borrow_mut().mark_used(suppression.id());
return None;
}
let primary_span = Span::from(ctx.file()).with_range(range);
Some(LintDiagnosticGuardBuilder {
ctx,
id,
severity,
primary_span,
})
}
/// Create a new lint diagnostic guard.
///
/// This initializes a new diagnostic using the given message along with
/// the ID and severity derived from the `LintMetadata` used to create
/// this builder. The diagnostic also includes a primary annotation
/// without a message. To add a message to this primary annotation, use
/// `LintDiagnosticGuard::set_primary_message`.
///
/// The diagnostic can be further mutated on the guard via its `DerefMut`
/// impl to `Diagnostic`.
pub(super) fn into_diagnostic(
self,
message: impl std::fmt::Display,
) -> LintDiagnosticGuard<'db, 'ctx> {
let mut diag = Diagnostic::new(self.id, self.severity, message);
// This is why `LintDiagnosticGuard::set_primary_message` exists.
// We add the primary annotation here (because it's required), but
// the optional message can be added later. We could accept it here
// in this `build` method, but we already accept the main diagnostic
// message. So the messages are likely to be quite confusable.
diag.annotate(Annotation::primary(self.primary_span.clone()));
LintDiagnosticGuard {
ctx: self.ctx,
diag: Some(diag),
}
}
}
/// An abstraction for mutating a diagnostic.
///
/// Callers can build this guard by starting with
/// `InferContext::report_diagnostic`.
///
/// Callers likely should use `LintDiagnosticGuard` via
/// `InferContext::report_lint` instead. This guard is only intended for use
/// with non-lint diagnostics. It is fundamentally lower level and easier to
/// get things wrong by using it.
///
/// Unlike `LintDiagnosticGuard`, this API does not guarantee that the
/// constructed `Diagnostic` not only has a primary annotation, but its
/// associated file is equivalent to the file being type checked. As a result,
/// if either is violated, then the `Drop` impl on `DiagnosticGuard` will
/// panic.
pub(super) struct DiagnosticGuard<'db, 'ctx> {
ctx: &'ctx InferContext<'db>,
/// The diagnostic that we want to report.
///
/// This is always `Some` until the `Drop` impl.
diag: Option<Diagnostic>,
}
impl std::ops::Deref for DiagnosticGuard<'_, '_> {
type Target = Diagnostic;
fn deref(&self) -> &Diagnostic {
// OK because `self.diag` is only `None` within `Drop`.
self.diag.as_ref().unwrap()
}
}
/// Return a mutable borrow of the diagnostic in this guard.
///
/// Callers may mutate the diagnostic to add new sub-diagnostics
/// or annotations.
///
/// The diagnostic is added to the typing context, if appropriate,
/// when this guard is dropped.
impl std::ops::DerefMut for DiagnosticGuard<'_, '_> {
fn deref_mut(&mut self) -> &mut Diagnostic {
// OK because `self.diag` is only `None` within `Drop`.
self.diag.as_mut().unwrap()
}
}
/// Finishes use of this guard.
///
/// This will add the diagnostic to the typing context if appropriate.
///
/// # Panics
///
/// This panics when the the underlying diagnostic lacks a primary
/// annotation, or if it has one and its file doesn't match the file
/// being type checked.
impl Drop for DiagnosticGuard<'_, '_> {
fn drop(&mut self) {
// OK because the only way `self.diag` is `None`
// is via this impl, which can only run at most
// once.
let diag = self.diag.take().unwrap();
if std::thread::panicking() {
// Don't submit diagnostics when panicking because they might be incomplete.
return;
}
let Some(ann) = diag.primary_annotation() else {
panic!(
"All diagnostics reported by `InferContext` must have a \
primary annotation, but diagnostic {id} does not",
id = diag.id(),
);
};
let expected_file = self.ctx.file();
let got_file = ann.get_span().file();
assert_eq!(
expected_file,
got_file,
"All diagnostics reported by `InferContext` must have a \
primary annotation whose file matches the file of the \
current typing context, but diagnostic {id} has file \
{got_file:?} and we expected {expected_file:?}",
id = diag.id(),
);
self.ctx.diagnostics.borrow_mut().push(diag);
}
}
/// A builder for constructing a diagnostic guard.
///
/// This type exists to separate the phases of "check if a diagnostic should
/// be reported" and "build the actual diagnostic." It's why, for example,
/// `InferContext::report_diagnostic` only requires an ID and a severity, but
/// this builder further requires a message (with those three things being the
/// minimal amount of information with which to construct a diagnostic) before
/// one can mutate the diagnostic.
pub(super) struct DiagnosticGuardBuilder<'db, 'ctx> {
ctx: &'ctx InferContext<'db>,
id: DiagnosticId,
severity: Severity,
}
impl<'db, 'ctx> DiagnosticGuardBuilder<'db, 'ctx> {
fn new(
ctx: &'ctx InferContext<'db>,
id: DiagnosticId,
severity: Severity,
) -> Option<DiagnosticGuardBuilder<'db, 'ctx>> {
if !ctx.db.is_file_open(ctx.file) {
return None;
}
Some(DiagnosticGuardBuilder { ctx, id, severity })
}
/// Create a new guard.
///
/// This initializes a new diagnostic using the given message along with
/// the ID and severity used to create this builder.
///
/// The diagnostic can be further mutated on the guard via its `DerefMut`
/// impl to `Diagnostic`.
pub(super) fn into_diagnostic(
self,
message: impl std::fmt::Display,
) -> DiagnosticGuard<'db, 'ctx> {
let diag = Some(Diagnostic::new(self.id, self.severity, message));
DiagnosticGuard {
ctx: self.ctx,
diag,
}
}
}

View file

@ -0,0 +1,39 @@
use crate::semantic_index::definition::Definition;
use crate::{Db, Module};
use ruff_db::files::FileRange;
use ruff_db::source::source_text;
use ruff_text_size::{TextLen, TextRange};
#[derive(Debug, PartialEq, Eq, Hash)]
pub enum TypeDefinition<'db> {
Module(Module),
Class(Definition<'db>),
Function(Definition<'db>),
TypeVar(Definition<'db>),
TypeAlias(Definition<'db>),
}
impl TypeDefinition<'_> {
pub fn focus_range(&self, db: &dyn Db) -> Option<FileRange> {
match self {
Self::Module(_) => None,
Self::Class(definition)
| Self::Function(definition)
| Self::TypeVar(definition)
| Self::TypeAlias(definition) => Some(definition.focus_range(db)),
}
}
pub fn full_range(&self, db: &dyn Db) -> FileRange {
match self {
Self::Module(module) => {
let source = source_text(db.upcast(), module.file());
FileRange::new(module.file(), TextRange::up_to(source.text_len()))
}
Self::Class(definition)
| Self::Function(definition)
| Self::TypeVar(definition)
| Self::TypeAlias(definition) => definition.full_range(db),
}
}
}

File diff suppressed because it is too large Load diff

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,533 @@
use ruff_python_ast as ast;
use rustc_hash::FxHashMap;
use crate::semantic_index::SemanticIndex;
use crate::types::signatures::{Parameter, Parameters, Signature};
use crate::types::{
declaration_type, KnownInstanceType, Type, TypeVarBoundOrConstraints, TypeVarInstance,
UnionType,
};
use crate::{Db, FxOrderSet};
/// A list of formal type variables for a generic function, class, or type alias.
///
/// TODO: Handle nested generic contexts better, with actual parent links to the lexically
/// containing context.
#[salsa::interned(debug)]
pub struct GenericContext<'db> {
#[return_ref]
pub(crate) variables: FxOrderSet<TypeVarInstance<'db>>,
}
impl<'db> GenericContext<'db> {
/// Creates a generic context from a list of PEP-695 type parameters.
pub(crate) fn from_type_params(
db: &'db dyn Db,
index: &'db SemanticIndex<'db>,
type_params_node: &ast::TypeParams,
) -> Self {
let variables: FxOrderSet<_> = type_params_node
.iter()
.filter_map(|type_param| Self::variable_from_type_param(db, index, type_param))
.collect();
Self::new(db, variables)
}
fn variable_from_type_param(
db: &'db dyn Db,
index: &'db SemanticIndex<'db>,
type_param_node: &ast::TypeParam,
) -> Option<TypeVarInstance<'db>> {
match type_param_node {
ast::TypeParam::TypeVar(node) => {
let definition = index.expect_single_definition(node);
let Type::KnownInstance(KnownInstanceType::TypeVar(typevar)) =
declaration_type(db, definition).inner_type()
else {
panic!("typevar should be inferred as a TypeVarInstance");
};
Some(typevar)
}
// TODO: Support these!
ast::TypeParam::ParamSpec(_) => None,
ast::TypeParam::TypeVarTuple(_) => None,
}
}
/// Creates a generic context from the legacy `TypeVar`s that appear in a function parameter
/// list.
pub(crate) fn from_function_params(
db: &'db dyn Db,
parameters: &Parameters<'db>,
return_type: Option<Type<'db>>,
) -> Option<Self> {
let mut variables = FxOrderSet::default();
for param in parameters {
if let Some(ty) = param.annotated_type() {
ty.find_legacy_typevars(db, &mut variables);
}
if let Some(ty) = param.default_type() {
ty.find_legacy_typevars(db, &mut variables);
}
}
if let Some(ty) = return_type {
ty.find_legacy_typevars(db, &mut variables);
}
if variables.is_empty() {
return None;
}
Some(Self::new(db, variables))
}
/// Creates a generic context from the legacy `TypeVar`s that appear in class's base class
/// list.
pub(crate) fn from_base_classes(
db: &'db dyn Db,
bases: impl Iterator<Item = Type<'db>>,
) -> Option<Self> {
let mut variables = FxOrderSet::default();
for base in bases {
base.find_legacy_typevars(db, &mut variables);
}
if variables.is_empty() {
return None;
}
Some(Self::new(db, variables))
}
pub(crate) fn len(self, db: &'db dyn Db) -> usize {
self.variables(db).len()
}
pub(crate) fn signature(self, db: &'db dyn Db) -> Signature<'db> {
let parameters = Parameters::new(
self.variables(db)
.iter()
.map(|typevar| Self::parameter_from_typevar(db, *typevar)),
);
Signature::new(parameters, None)
}
fn parameter_from_typevar(db: &'db dyn Db, typevar: TypeVarInstance<'db>) -> Parameter<'db> {
let mut parameter = Parameter::positional_only(Some(typevar.name(db).clone()));
match typevar.bound_or_constraints(db) {
Some(TypeVarBoundOrConstraints::UpperBound(bound)) => {
// TODO: This should be a type form.
parameter = parameter.with_annotated_type(bound);
}
Some(TypeVarBoundOrConstraints::Constraints(constraints)) => {
// TODO: This should be a new type variant where only these exact types are
// assignable, and not subclasses of them, nor a union of them.
parameter = parameter
.with_annotated_type(UnionType::from_elements(db, constraints.iter(db)));
}
None => {}
}
parameter
}
pub(crate) fn default_specialization(self, db: &'db dyn Db) -> Specialization<'db> {
let types = self
.variables(db)
.iter()
.map(|typevar| typevar.default_ty(db).unwrap_or(Type::unknown()))
.collect();
self.specialize(db, types)
}
pub(crate) fn identity_specialization(self, db: &'db dyn Db) -> Specialization<'db> {
let types = self
.variables(db)
.iter()
.map(|typevar| Type::TypeVar(*typevar))
.collect();
self.specialize(db, types)
}
pub(crate) fn unknown_specialization(self, db: &'db dyn Db) -> Specialization<'db> {
let types = vec![Type::unknown(); self.variables(db).len()];
self.specialize(db, types.into())
}
pub(crate) fn is_subset_of(self, db: &'db dyn Db, other: GenericContext<'db>) -> bool {
self.variables(db).is_subset(other.variables(db))
}
/// Creates a specialization of this generic context. Panics if the length of `types` does not
/// match the number of typevars in the generic context.
pub(crate) fn specialize(
self,
db: &'db dyn Db,
types: Box<[Type<'db>]>,
) -> Specialization<'db> {
assert!(self.variables(db).len() == types.len());
Specialization::new(db, self, types)
}
}
/// An assignment of a specific type to each type variable in a generic scope.
///
/// TODO: Handle nested specializations better, with actual parent links to the specialization of
/// the lexically containing context.
#[salsa::interned(debug)]
pub struct Specialization<'db> {
pub(crate) generic_context: GenericContext<'db>,
#[return_ref]
pub(crate) types: Box<[Type<'db>]>,
}
impl<'db> Specialization<'db> {
/// Applies a specialization to this specialization. This is used, for instance, when a generic
/// class inherits from a generic alias:
///
/// ```py
/// class A[T]: ...
/// class B[U](A[U]): ...
/// ```
///
/// `B` is a generic class, whose MRO includes the generic alias `A[U]`, which specializes `A`
/// with the specialization `{T: U}`. If `B` is specialized to `B[int]`, with specialization
/// `{U: int}`, we can apply the second specialization to the first, resulting in `T: int`.
/// That lets us produce the generic alias `A[int]`, which is the corresponding entry in the
/// MRO of `B[int]`.
pub(crate) fn apply_specialization(self, db: &'db dyn Db, other: Specialization<'db>) -> Self {
let types: Box<[_]> = self
.types(db)
.into_iter()
.map(|ty| ty.apply_specialization(db, other))
.collect();
Specialization::new(db, self.generic_context(db), types)
}
/// Combines two specializations of the same generic context. If either specialization maps a
/// typevar to `Type::Unknown`, the other specialization's mapping is used. If both map the
/// typevar to a known type, those types are unioned together.
///
/// Panics if the two specializations are not for the same generic context.
pub(crate) fn combine(self, db: &'db dyn Db, other: Self) -> Self {
let generic_context = self.generic_context(db);
assert!(other.generic_context(db) == generic_context);
// TODO special-casing Unknown to mean "no mapping" is not right here, and can give
// confusing/wrong results in cases where there was a mapping found for a typevar, and it
// was of type Unknown. We should probably add a bitset or similar to Specialization that
// explicitly tells us which typevars are mapped.
let types: Box<[_]> = self
.types(db)
.into_iter()
.zip(other.types(db))
.map(|(self_type, other_type)| match (self_type, other_type) {
(unknown, known) | (known, unknown) if unknown.is_unknown() => *known,
_ => UnionType::from_elements(db, [self_type, other_type]),
})
.collect();
Specialization::new(db, self.generic_context(db), types)
}
pub(crate) fn normalized(self, db: &'db dyn Db) -> Self {
let types: Box<[_]> = self.types(db).iter().map(|ty| ty.normalized(db)).collect();
Self::new(db, self.generic_context(db), types)
}
/// Returns the type that a typevar is specialized to, or None if the typevar isn't part of
/// this specialization.
pub(crate) fn get(self, db: &'db dyn Db, typevar: TypeVarInstance<'db>) -> Option<Type<'db>> {
let index = self
.generic_context(db)
.variables(db)
.get_index_of(&typevar)?;
Some(self.types(db)[index])
}
pub(crate) fn is_subtype_of(self, db: &'db dyn Db, other: Specialization<'db>) -> bool {
let generic_context = self.generic_context(db);
if generic_context != other.generic_context(db) {
return false;
}
for ((_typevar, self_type), other_type) in (generic_context.variables(db).into_iter())
.zip(self.types(db))
.zip(other.types(db))
{
if matches!(self_type, Type::Dynamic(_)) || matches!(other_type, Type::Dynamic(_)) {
return false;
}
// TODO: We currently treat all typevars as invariant. Once we track the actual
// variance of each typevar, these checks should change:
// - covariant: verify that self_type <: other_type
// - contravariant: verify that other_type <: self_type
// - invariant: verify that self_type == other_type
// - bivariant: skip, can't make subtyping false
if !self_type.is_equivalent_to(db, *other_type) {
return false;
}
}
true
}
pub(crate) fn is_equivalent_to(self, db: &'db dyn Db, other: Specialization<'db>) -> bool {
let generic_context = self.generic_context(db);
if generic_context != other.generic_context(db) {
return false;
}
for ((_typevar, self_type), other_type) in (generic_context.variables(db).into_iter())
.zip(self.types(db))
.zip(other.types(db))
{
if matches!(self_type, Type::Dynamic(_)) || matches!(other_type, Type::Dynamic(_)) {
return false;
}
// TODO: We currently treat all typevars as invariant. Once we track the actual
// variance of each typevar, these checks should change:
// - covariant: verify that self_type == other_type
// - contravariant: verify that other_type == self_type
// - invariant: verify that self_type == other_type
// - bivariant: skip, can't make equivalence false
if !self_type.is_equivalent_to(db, *other_type) {
return false;
}
}
true
}
pub(crate) fn is_assignable_to(self, db: &'db dyn Db, other: Specialization<'db>) -> bool {
let generic_context = self.generic_context(db);
if generic_context != other.generic_context(db) {
return false;
}
for ((_typevar, self_type), other_type) in (generic_context.variables(db).into_iter())
.zip(self.types(db))
.zip(other.types(db))
{
if matches!(self_type, Type::Dynamic(_)) || matches!(other_type, Type::Dynamic(_)) {
continue;
}
// TODO: We currently treat all typevars as invariant. Once we track the actual
// variance of each typevar, these checks should change:
// - covariant: verify that self_type <: other_type
// - contravariant: verify that other_type <: self_type
// - invariant: verify that self_type == other_type
// - bivariant: skip, can't make assignability false
if !self_type.is_gradual_equivalent_to(db, *other_type) {
return false;
}
}
true
}
pub(crate) fn is_gradual_equivalent_to(
self,
db: &'db dyn Db,
other: Specialization<'db>,
) -> bool {
let generic_context = self.generic_context(db);
if generic_context != other.generic_context(db) {
return false;
}
for ((_typevar, self_type), other_type) in (generic_context.variables(db).into_iter())
.zip(self.types(db))
.zip(other.types(db))
{
// TODO: We currently treat all typevars as invariant. Once we track the actual
// variance of each typevar, these checks should change:
// - covariant: verify that self_type == other_type
// - contravariant: verify that other_type == self_type
// - invariant: verify that self_type == other_type
// - bivariant: skip, can't make equivalence false
if !self_type.is_gradual_equivalent_to(db, *other_type) {
return false;
}
}
true
}
pub(crate) fn find_legacy_typevars(
self,
db: &'db dyn Db,
typevars: &mut FxOrderSet<TypeVarInstance<'db>>,
) {
for ty in self.types(db) {
ty.find_legacy_typevars(db, typevars);
}
}
}
/// Performs type inference between parameter annotations and argument types, producing a
/// specialization of a generic function.
pub(crate) struct SpecializationBuilder<'db> {
db: &'db dyn Db,
types: FxHashMap<TypeVarInstance<'db>, Type<'db>>,
}
impl<'db> SpecializationBuilder<'db> {
pub(crate) fn new(db: &'db dyn Db) -> Self {
Self {
db,
types: FxHashMap::default(),
}
}
pub(crate) fn build(&mut self, generic_context: GenericContext<'db>) -> Specialization<'db> {
let types: Box<[_]> = generic_context
.variables(self.db)
.iter()
.map(|variable| {
self.types
.get(variable)
.copied()
.unwrap_or(variable.default_ty(self.db).unwrap_or(Type::unknown()))
})
.collect();
Specialization::new(self.db, generic_context, types)
}
fn add_type_mapping(&mut self, typevar: TypeVarInstance<'db>, ty: Type<'db>) {
self.types
.entry(typevar)
.and_modify(|existing| {
*existing = UnionType::from_elements(self.db, [*existing, ty]);
})
.or_insert(ty);
}
pub(crate) fn infer(
&mut self,
formal: Type<'db>,
actual: Type<'db>,
) -> Result<(), SpecializationError<'db>> {
// If the actual type is a subtype of the formal type, then return without adding any new
// type mappings. (Note that if the formal type contains any typevars, this check will
// fail, since no non-typevar types are assignable to a typevar. Also note that we are
// checking _subtyping_, not _assignability_, so that we do specialize typevars to dynamic
// argument types; and we have a special case for `Never`, which is a subtype of all types,
// but which we also do want as a specialization candidate.)
//
// In particular, this handles a case like
//
// ```py
// def f[T](t: T | None): ...
//
// f(None)
// ```
//
// without specializing `T` to `None`.
if !actual.is_never() && actual.is_subtype_of(self.db, formal) {
return Ok(());
}
match (formal, actual) {
(Type::TypeVar(typevar), _) => match typevar.bound_or_constraints(self.db) {
Some(TypeVarBoundOrConstraints::UpperBound(bound)) => {
if !actual.is_assignable_to(self.db, bound) {
return Err(SpecializationError::MismatchedBound {
typevar,
argument: actual,
});
}
self.add_type_mapping(typevar, actual);
}
Some(TypeVarBoundOrConstraints::Constraints(constraints)) => {
for constraint in constraints.iter(self.db) {
if actual.is_assignable_to(self.db, *constraint) {
self.add_type_mapping(typevar, *constraint);
return Ok(());
}
}
return Err(SpecializationError::MismatchedConstraint {
typevar,
argument: actual,
});
}
_ => {
self.add_type_mapping(typevar, actual);
}
},
(Type::Tuple(formal_tuple), Type::Tuple(actual_tuple)) => {
let formal_elements = formal_tuple.elements(self.db);
let actual_elements = actual_tuple.elements(self.db);
if formal_elements.len() == actual_elements.len() {
for (formal_element, actual_element) in
formal_elements.iter().zip(actual_elements)
{
self.infer(*formal_element, *actual_element)?;
}
}
}
(Type::Union(formal), _) => {
// TODO: We haven't implemented a full unification solver yet. If typevars appear
// in multiple union elements, we ideally want to express that _only one_ of them
// needs to match, and that we should infer the smallest type mapping that allows
// that.
//
// For now, we punt on handling multiple typevar elements. Instead, if _precisely
// one_ union element _is_ a typevar (not _contains_ a typevar), then we go ahead
// and add a mapping between that typevar and the actual type. (Note that we've
// already handled above the case where the actual is assignable to a _non-typevar_
// union element.)
let mut typevars = formal.iter(self.db).filter_map(|ty| match ty {
Type::TypeVar(typevar) => Some(*typevar),
_ => None,
});
let typevar = typevars.next();
let additional_typevars = typevars.next();
if let (Some(typevar), None) = (typevar, additional_typevars) {
self.add_type_mapping(typevar, actual);
}
}
(Type::Intersection(formal), _) => {
// The actual type must be assignable to every (positive) element of the
// formal intersection, so we must infer type mappings for each of them. (The
// actual type must also be disjoint from every negative element of the
// intersection, but that doesn't help us infer any type mappings.)
for positive in formal.iter_positive(self.db) {
self.infer(positive, actual)?;
}
}
// TODO: Add more forms that we can structurally induct into: type[C], callables
_ => {}
}
Ok(())
}
}
#[derive(Clone, Debug, Eq, PartialEq)]
pub(crate) enum SpecializationError<'db> {
MismatchedBound {
typevar: TypeVarInstance<'db>,
argument: Type<'db>,
},
MismatchedConstraint {
typevar: TypeVarInstance<'db>,
argument: Type<'db>,
},
}
impl<'db> SpecializationError<'db> {
pub(crate) fn typevar(&self) -> TypeVarInstance<'db> {
match self {
Self::MismatchedBound { typevar, .. } => *typevar,
Self::MismatchedConstraint { typevar, .. } => *typevar,
}
}
pub(crate) fn argument_type(&self) -> Type<'db> {
match self {
Self::MismatchedBound { argument, .. } => *argument,
Self::MismatchedConstraint { argument, .. } => *argument,
}
}
}

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,297 @@
//! Instance types: both nominal and structural.
use super::protocol_class::ProtocolInterface;
use super::{ClassType, KnownClass, SubclassOfType, Type};
use crate::symbol::{Symbol, SymbolAndQualifiers};
use crate::Db;
pub(super) use synthesized_protocol::SynthesizedProtocolType;
impl<'db> Type<'db> {
pub(crate) fn instance(db: &'db dyn Db, class: ClassType<'db>) -> Self {
if class.class_literal(db).0.is_protocol(db) {
Self::ProtocolInstance(ProtocolInstanceType(Protocol::FromClass(class)))
} else {
Self::NominalInstance(NominalInstanceType { class })
}
}
pub(crate) const fn into_nominal_instance(self) -> Option<NominalInstanceType<'db>> {
match self {
Type::NominalInstance(instance_type) => Some(instance_type),
_ => None,
}
}
/// Return `true` if `self` conforms to the interface described by `protocol`.
///
/// TODO: we may need to split this into two methods in the future, once we start
/// differentiating between fully-static and non-fully-static protocols.
pub(super) fn satisfies_protocol(
self,
db: &'db dyn Db,
protocol: ProtocolInstanceType<'db>,
) -> bool {
// TODO: this should consider the types of the protocol members
// as well as whether each member *exists* on `self`.
protocol
.0
.interface(db)
.members()
.all(|member| !self.member(db, member.name()).symbol.is_unbound())
}
}
/// A type representing the set of runtime objects which are instances of a certain nominal class.
#[derive(Copy, Clone, Debug, Eq, PartialEq, Hash, salsa::Update)]
pub struct NominalInstanceType<'db> {
// Keep this field private, so that the only way of constructing `NominalInstanceType` instances
// is through the `Type::instance` constructor function.
class: ClassType<'db>,
}
impl<'db> NominalInstanceType<'db> {
pub(super) fn class(self) -> ClassType<'db> {
self.class
}
pub(super) fn is_subtype_of(self, db: &'db dyn Db, other: Self) -> bool {
// N.B. The subclass relation is fully static
self.class.is_subclass_of(db, other.class)
}
pub(super) fn is_equivalent_to(self, db: &'db dyn Db, other: Self) -> bool {
self.class.is_equivalent_to(db, other.class)
}
pub(super) fn is_assignable_to(self, db: &'db dyn Db, other: Self) -> bool {
self.class.is_assignable_to(db, other.class)
}
pub(super) fn is_disjoint_from(self, db: &'db dyn Db, other: Self) -> bool {
if self.class.is_final(db) && !self.class.is_subclass_of(db, other.class) {
return true;
}
if other.class.is_final(db) && !other.class.is_subclass_of(db, self.class) {
return true;
}
// Check to see whether the metaclasses of `self` and `other` are disjoint.
// Avoid this check if the metaclass of either `self` or `other` is `type`,
// however, since we end up with infinite recursion in that case due to the fact
// that `type` is its own metaclass (and we know that `type` cannot be disjoint
// from any metaclass, anyway).
let type_type = KnownClass::Type.to_instance(db);
let self_metaclass = self.class.metaclass_instance_type(db);
if self_metaclass == type_type {
return false;
}
let other_metaclass = other.class.metaclass_instance_type(db);
if other_metaclass == type_type {
return false;
}
self_metaclass.is_disjoint_from(db, other_metaclass)
}
pub(super) fn is_gradual_equivalent_to(self, db: &'db dyn Db, other: Self) -> bool {
self.class.is_gradual_equivalent_to(db, other.class)
}
pub(super) fn is_singleton(self, db: &'db dyn Db) -> bool {
self.class.known(db).is_some_and(KnownClass::is_singleton)
}
pub(super) fn is_single_valued(self, db: &'db dyn Db) -> bool {
self.class
.known(db)
.is_some_and(KnownClass::is_single_valued)
}
pub(super) fn to_meta_type(self, db: &'db dyn Db) -> Type<'db> {
SubclassOfType::from(db, self.class)
}
}
impl<'db> From<NominalInstanceType<'db>> for Type<'db> {
fn from(value: NominalInstanceType<'db>) -> Self {
Self::NominalInstance(value)
}
}
/// A `ProtocolInstanceType` represents the set of all possible runtime objects
/// that conform to the interface described by a certain protocol.
#[derive(Copy, Clone, Debug, Eq, PartialEq, Hash, PartialOrd, Ord, salsa::Update)]
pub struct ProtocolInstanceType<'db>(
// Keep the inner field here private,
// so that the only way of constructing `ProtocolInstanceType` instances
// is through the `Type::instance` constructor function.
Protocol<'db>,
);
impl<'db> ProtocolInstanceType<'db> {
pub(super) fn inner(self) -> Protocol<'db> {
self.0
}
/// Return the meta-type of this protocol-instance type.
pub(super) fn to_meta_type(self, db: &'db dyn Db) -> Type<'db> {
match self.0 {
Protocol::FromClass(class) => SubclassOfType::from(db, class),
// TODO: we can and should do better here.
//
// This is supported by mypy, and should be supported by us as well.
// We'll need to come up with a better solution for the meta-type of
// synthesized protocols to solve this:
//
// ```py
// from typing import Callable
//
// def foo(x: Callable[[], int]) -> None:
// reveal_type(type(x)) # mypy: "type[def (builtins.int) -> builtins.str]"
// reveal_type(type(x).__call__) # mypy: "def (*args: Any, **kwds: Any) -> Any"
// ```
Protocol::Synthesized(_) => KnownClass::Type.to_instance(db),
}
}
/// Return a "normalized" version of this `Protocol` type.
///
/// See [`Type::normalized`] for more details.
pub(super) fn normalized(self, db: &'db dyn Db) -> Type<'db> {
let object = KnownClass::Object.to_instance(db);
if object.satisfies_protocol(db, self) {
return object;
}
match self.0 {
Protocol::FromClass(_) => Type::ProtocolInstance(Self(Protocol::Synthesized(
SynthesizedProtocolType::new(db, self.0.interface(db).clone()),
))),
Protocol::Synthesized(_) => Type::ProtocolInstance(self),
}
}
/// Return `true` if any of the members of this protocol type contain any `Todo` types.
pub(super) fn contains_todo(self, db: &'db dyn Db) -> bool {
self.0.interface(db).contains_todo(db)
}
/// Return `true` if this protocol type is fully static.
pub(super) fn is_fully_static(self, db: &'db dyn Db) -> bool {
self.0.interface(db).is_fully_static(db)
}
/// Return `true` if this protocol type is a subtype of the protocol `other`.
pub(super) fn is_subtype_of(self, db: &'db dyn Db, other: Self) -> bool {
self.is_fully_static(db) && other.is_fully_static(db) && self.is_assignable_to(db, other)
}
/// Return `true` if this protocol type is assignable to the protocol `other`.
///
/// TODO: consider the types of the members as well as their existence
pub(super) fn is_assignable_to(self, db: &'db dyn Db, other: Self) -> bool {
other
.0
.interface(db)
.is_sub_interface_of(self.0.interface(db))
}
/// Return `true` if this protocol type is equivalent to the protocol `other`.
///
/// TODO: consider the types of the members as well as their existence
pub(super) fn is_equivalent_to(self, db: &'db dyn Db, other: Self) -> bool {
self.is_fully_static(db)
&& other.is_fully_static(db)
&& self.normalized(db) == other.normalized(db)
}
/// Return `true` if this protocol type is gradually equivalent to the protocol `other`.
///
/// TODO: consider the types of the members as well as their existence
pub(super) fn is_gradual_equivalent_to(self, db: &'db dyn Db, other: Self) -> bool {
self.normalized(db) == other.normalized(db)
}
/// Return `true` if this protocol type is disjoint from the protocol `other`.
///
/// TODO: a protocol `X` is disjoint from a protocol `Y` if `X` and `Y`
/// have a member with the same name but disjoint types
#[expect(clippy::unused_self)]
pub(super) fn is_disjoint_from(self, _db: &'db dyn Db, _other: Self) -> bool {
false
}
pub(crate) fn instance_member(self, db: &'db dyn Db, name: &str) -> SymbolAndQualifiers<'db> {
match self.inner() {
Protocol::FromClass(class) => class.instance_member(db, name),
Protocol::Synthesized(synthesized) => synthesized
.interface(db)
.member_by_name(name)
.map(|member| SymbolAndQualifiers {
symbol: Symbol::bound(member.ty()),
qualifiers: member.qualifiers(),
})
.unwrap_or_else(|| KnownClass::Object.to_instance(db).instance_member(db, name)),
}
}
}
/// An enumeration of the two kinds of protocol types: those that originate from a class
/// definition in source code, and those that are synthesized from a set of members.
#[derive(Copy, Clone, Debug, Eq, PartialEq, Hash, salsa::Update, PartialOrd, Ord)]
pub(super) enum Protocol<'db> {
FromClass(ClassType<'db>),
Synthesized(SynthesizedProtocolType<'db>),
}
impl<'db> Protocol<'db> {
/// Return the members of this protocol type
fn interface(self, db: &'db dyn Db) -> &'db ProtocolInterface<'db> {
match self {
Self::FromClass(class) => class
.class_literal(db)
.0
.into_protocol_class(db)
.expect("Protocol class literal should be a protocol class")
.interface(db),
Self::Synthesized(synthesized) => synthesized.interface(db),
}
}
}
mod synthesized_protocol {
use crate::db::Db;
use crate::types::protocol_class::ProtocolInterface;
/// A "synthesized" protocol type that is dissociated from a class definition in source code.
///
/// Two synthesized protocol types with the same members will share the same Salsa ID,
/// making them easy to compare for equivalence. A synthesized protocol type is therefore
/// returned by [`super::ProtocolInstanceType::normalized`] so that two protocols with the same members
/// will be understood as equivalent even in the context of differently ordered unions or intersections.
///
/// The constructor method of this type maintains the invariant that a synthesized protocol type
/// is always constructed from a *normalized* protocol interface.
#[derive(Copy, Clone, Debug, Eq, PartialEq, Hash, salsa::Update, PartialOrd, Ord)]
pub(in crate::types) struct SynthesizedProtocolType<'db>(SynthesizedProtocolTypeInner<'db>);
impl<'db> SynthesizedProtocolType<'db> {
pub(super) fn new(db: &'db dyn Db, interface: ProtocolInterface<'db>) -> Self {
Self(SynthesizedProtocolTypeInner::new(
db,
interface.normalized(db),
))
}
pub(in crate::types) fn interface(self, db: &'db dyn Db) -> &'db ProtocolInterface<'db> {
self.0.interface(db)
}
}
#[salsa::interned(debug)]
struct SynthesizedProtocolTypeInner<'db> {
#[return_ref]
interface: ProtocolInterface<'db>,
}
}

View file

@ -0,0 +1,408 @@
//! The `KnownInstance` type.
//!
//! Despite its name, this is quite a different type from [`super::NominalInstanceType`].
//! For the vast majority of instance-types in Python, we cannot say how many possible
//! inhabitants there are or could be of that type at runtime. Each variant of the
//! [`KnownInstanceType`] enum, however, represents a specific runtime symbol
//! that requires heavy special-casing in the type system. Thus any one `KnownInstance`
//! variant can only be inhabited by one or two specific objects at runtime with
//! locations that are known in advance.
use std::fmt::Display;
use super::generics::GenericContext;
use super::{class::KnownClass, ClassType, Truthiness, Type, TypeAliasType, TypeVarInstance};
use crate::db::Db;
use crate::module_resolver::{file_to_module, KnownModule};
use ruff_db::files::File;
/// Enumeration of specific runtime symbols that are special enough
/// that they can each be considered to inhabit a unique type.
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, salsa::Update)]
pub enum KnownInstanceType<'db> {
/// The symbol `typing.Annotated` (which can also be found as `typing_extensions.Annotated`)
Annotated,
/// The symbol `typing.Literal` (which can also be found as `typing_extensions.Literal`)
Literal,
/// The symbol `typing.LiteralString` (which can also be found as `typing_extensions.LiteralString`)
LiteralString,
/// The symbol `typing.Optional` (which can also be found as `typing_extensions.Optional`)
Optional,
/// The symbol `typing.Union` (which can also be found as `typing_extensions.Union`)
Union,
/// The symbol `typing.NoReturn` (which can also be found as `typing_extensions.NoReturn`)
NoReturn,
/// The symbol `typing.Never` available since 3.11 (which can also be found as `typing_extensions.Never`)
Never,
/// The symbol `typing.Any` (which can also be found as `typing_extensions.Any`)
/// This is not used since typeshed switched to representing `Any` as a class; now we use
/// `KnownClass::Any` instead. But we still support the old `Any = object()` representation, at
/// least for now. TODO maybe remove?
Any,
/// The symbol `typing.Tuple` (which can also be found as `typing_extensions.Tuple`)
Tuple,
/// The symbol `typing.List` (which can also be found as `typing_extensions.List`)
List,
/// The symbol `typing.Dict` (which can also be found as `typing_extensions.Dict`)
Dict,
/// The symbol `typing.Set` (which can also be found as `typing_extensions.Set`)
Set,
/// The symbol `typing.FrozenSet` (which can also be found as `typing_extensions.FrozenSet`)
FrozenSet,
/// The symbol `typing.ChainMap` (which can also be found as `typing_extensions.ChainMap`)
ChainMap,
/// The symbol `typing.Counter` (which can also be found as `typing_extensions.Counter`)
Counter,
/// The symbol `typing.DefaultDict` (which can also be found as `typing_extensions.DefaultDict`)
DefaultDict,
/// The symbol `typing.Deque` (which can also be found as `typing_extensions.Deque`)
Deque,
/// The symbol `typing.OrderedDict` (which can also be found as `typing_extensions.OrderedDict`)
OrderedDict,
/// The symbol `typing.Protocol` (which can also be found as `typing_extensions.Protocol`)
Protocol,
/// The symbol `typing.Generic` (which can also be found as `typing_extensions.Generic`)
Generic(Option<GenericContext<'db>>),
/// The symbol `typing.Type` (which can also be found as `typing_extensions.Type`)
Type,
/// A single instance of `typing.TypeVar`
TypeVar(TypeVarInstance<'db>),
/// A single instance of `typing.TypeAliasType` (PEP 695 type alias)
TypeAliasType(TypeAliasType<'db>),
/// The symbol `ty_extensions.Unknown`
Unknown,
/// The symbol `ty_extensions.AlwaysTruthy`
AlwaysTruthy,
/// The symbol `ty_extensions.AlwaysFalsy`
AlwaysFalsy,
/// The symbol `ty_extensions.Not`
Not,
/// The symbol `ty_extensions.Intersection`
Intersection,
/// The symbol `ty_extensions.TypeOf`
TypeOf,
/// The symbol `ty_extensions.CallableTypeOf`
CallableTypeOf,
/// The symbol `typing.Callable`
/// (which can also be found as `typing_extensions.Callable` or as `collections.abc.Callable`)
Callable,
// Various special forms, special aliases and type qualifiers that we don't yet understand
// (all currently inferred as TODO in most contexts):
TypingSelf,
Final,
ClassVar,
Concatenate,
Unpack,
Required,
NotRequired,
TypeAlias,
TypeGuard,
TypedDict,
TypeIs,
ReadOnly,
// TODO: fill this enum out with more special forms, etc.
}
impl<'db> KnownInstanceType<'db> {
/// Evaluate the known instance in boolean context
pub(crate) const fn bool(self) -> Truthiness {
match self {
Self::Annotated
| Self::Literal
| Self::LiteralString
| Self::Optional
// This is a legacy `TypeVar` _outside_ of any generic class or function, so it's
// AlwaysTrue. The truthiness of a typevar inside of a generic class or function
// depends on its bounds and constraints; but that's represented by `Type::TypeVar` and
// handled in elsewhere.
| Self::TypeVar(_)
| Self::Union
| Self::NoReturn
| Self::Never
| Self::Any
| Self::Tuple
| Self::Type
| Self::TypingSelf
| Self::Final
| Self::ClassVar
| Self::Callable
| Self::Concatenate
| Self::Unpack
| Self::Required
| Self::NotRequired
| Self::TypeAlias
| Self::TypeGuard
| Self::TypedDict
| Self::TypeIs
| Self::List
| Self::Dict
| Self::DefaultDict
| Self::Set
| Self::FrozenSet
| Self::Counter
| Self::Deque
| Self::ChainMap
| Self::OrderedDict
| Self::Protocol
| Self::Generic(_)
| Self::ReadOnly
| Self::TypeAliasType(_)
| Self::Unknown
| Self::AlwaysTruthy
| Self::AlwaysFalsy
| Self::Not
| Self::Intersection
| Self::TypeOf
| Self::CallableTypeOf => Truthiness::AlwaysTrue,
}
}
/// Return the repr of the symbol at runtime
pub(crate) fn repr(self, db: &'db dyn Db) -> impl Display + 'db {
KnownInstanceRepr {
known_instance: self,
db,
}
}
/// Return the [`KnownClass`] which this symbol is an instance of
pub(crate) const fn class(self) -> KnownClass {
match self {
Self::Annotated => KnownClass::SpecialForm,
Self::Literal => KnownClass::SpecialForm,
Self::LiteralString => KnownClass::SpecialForm,
Self::Optional => KnownClass::SpecialForm,
Self::Union => KnownClass::SpecialForm,
Self::NoReturn => KnownClass::SpecialForm,
Self::Never => KnownClass::SpecialForm,
Self::Any => KnownClass::Object,
Self::Tuple => KnownClass::SpecialForm,
Self::Type => KnownClass::SpecialForm,
Self::TypingSelf => KnownClass::SpecialForm,
Self::Final => KnownClass::SpecialForm,
Self::ClassVar => KnownClass::SpecialForm,
Self::Callable => KnownClass::SpecialForm,
Self::Concatenate => KnownClass::SpecialForm,
Self::Unpack => KnownClass::SpecialForm,
Self::Required => KnownClass::SpecialForm,
Self::NotRequired => KnownClass::SpecialForm,
Self::TypeAlias => KnownClass::SpecialForm,
Self::TypeGuard => KnownClass::SpecialForm,
Self::TypedDict => KnownClass::SpecialForm,
Self::TypeIs => KnownClass::SpecialForm,
Self::ReadOnly => KnownClass::SpecialForm,
Self::List => KnownClass::StdlibAlias,
Self::Dict => KnownClass::StdlibAlias,
Self::DefaultDict => KnownClass::StdlibAlias,
Self::Set => KnownClass::StdlibAlias,
Self::FrozenSet => KnownClass::StdlibAlias,
Self::Counter => KnownClass::StdlibAlias,
Self::Deque => KnownClass::StdlibAlias,
Self::ChainMap => KnownClass::StdlibAlias,
Self::OrderedDict => KnownClass::StdlibAlias,
Self::Protocol => KnownClass::SpecialForm, // actually `_ProtocolMeta` at runtime but this is what typeshed says
Self::Generic(_) => KnownClass::SpecialForm, // actually `type` at runtime but this is what typeshed says
Self::TypeVar(_) => KnownClass::TypeVar,
Self::TypeAliasType(_) => KnownClass::TypeAliasType,
Self::TypeOf => KnownClass::SpecialForm,
Self::Not => KnownClass::SpecialForm,
Self::Intersection => KnownClass::SpecialForm,
Self::CallableTypeOf => KnownClass::SpecialForm,
Self::Unknown => KnownClass::Object,
Self::AlwaysTruthy => KnownClass::Object,
Self::AlwaysFalsy => KnownClass::Object,
}
}
/// Return the instance type which this type is a subtype of.
///
/// For example, the symbol `typing.Literal` is an instance of `typing._SpecialForm`,
/// so `KnownInstanceType::Literal.instance_fallback(db)`
/// returns `Type::NominalInstance(NominalInstanceType { class: <typing._SpecialForm> })`.
pub(super) fn instance_fallback(self, db: &dyn Db) -> Type {
self.class().to_instance(db)
}
/// Return `true` if this symbol is an instance of `class`.
pub(super) fn is_instance_of(self, db: &'db dyn Db, class: ClassType<'db>) -> bool {
self.class().is_subclass_of(db, class)
}
pub(super) fn try_from_file_and_name(
db: &'db dyn Db,
file: File,
symbol_name: &str,
) -> Option<Self> {
let candidate = match symbol_name {
"Any" => Self::Any,
"ClassVar" => Self::ClassVar,
"Deque" => Self::Deque,
"List" => Self::List,
"Dict" => Self::Dict,
"DefaultDict" => Self::DefaultDict,
"Set" => Self::Set,
"FrozenSet" => Self::FrozenSet,
"Counter" => Self::Counter,
"ChainMap" => Self::ChainMap,
"OrderedDict" => Self::OrderedDict,
"Generic" => Self::Generic(None),
"Protocol" => Self::Protocol,
"Optional" => Self::Optional,
"Union" => Self::Union,
"NoReturn" => Self::NoReturn,
"Tuple" => Self::Tuple,
"Type" => Self::Type,
"Callable" => Self::Callable,
"Annotated" => Self::Annotated,
"Literal" => Self::Literal,
"Never" => Self::Never,
"Self" => Self::TypingSelf,
"Final" => Self::Final,
"Unpack" => Self::Unpack,
"Required" => Self::Required,
"TypeAlias" => Self::TypeAlias,
"TypeGuard" => Self::TypeGuard,
"TypedDict" => Self::TypedDict,
"TypeIs" => Self::TypeIs,
"ReadOnly" => Self::ReadOnly,
"Concatenate" => Self::Concatenate,
"NotRequired" => Self::NotRequired,
"LiteralString" => Self::LiteralString,
"Unknown" => Self::Unknown,
"AlwaysTruthy" => Self::AlwaysTruthy,
"AlwaysFalsy" => Self::AlwaysFalsy,
"Not" => Self::Not,
"Intersection" => Self::Intersection,
"TypeOf" => Self::TypeOf,
"CallableTypeOf" => Self::CallableTypeOf,
_ => return None,
};
candidate
.check_module(file_to_module(db, file)?.known()?)
.then_some(candidate)
}
/// Return `true` if `module` is a module from which this `KnownInstance` variant can validly originate.
///
/// Most variants can only exist in one module, which is the same as `self.class().canonical_module()`.
/// Some variants could validly be defined in either `typing` or `typing_extensions`, however.
pub(super) fn check_module(self, module: KnownModule) -> bool {
match self {
Self::Any
| Self::ClassVar
| Self::Deque
| Self::List
| Self::Dict
| Self::DefaultDict
| Self::Set
| Self::FrozenSet
| Self::Counter
| Self::ChainMap
| Self::OrderedDict
| Self::Optional
| Self::Union
| Self::NoReturn
| Self::Tuple
| Self::Type
| Self::Generic(_)
| Self::Callable => module.is_typing(),
Self::Annotated
| Self::Protocol
| Self::Literal
| Self::LiteralString
| Self::Never
| Self::TypingSelf
| Self::Final
| Self::Concatenate
| Self::Unpack
| Self::Required
| Self::NotRequired
| Self::TypeAlias
| Self::TypeGuard
| Self::TypedDict
| Self::TypeIs
| Self::ReadOnly
| Self::TypeAliasType(_)
| Self::TypeVar(_) => {
matches!(module, KnownModule::Typing | KnownModule::TypingExtensions)
}
Self::Unknown
| Self::AlwaysTruthy
| Self::AlwaysFalsy
| Self::Not
| Self::Intersection
| Self::TypeOf
| Self::CallableTypeOf => module.is_ty_extensions(),
}
}
pub(super) fn to_meta_type(self, db: &'db dyn Db) -> Type<'db> {
self.class().to_class_literal(db)
}
}
struct KnownInstanceRepr<'db> {
known_instance: KnownInstanceType<'db>,
db: &'db dyn Db,
}
impl Display for KnownInstanceRepr<'_> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self.known_instance {
KnownInstanceType::Annotated => f.write_str("typing.Annotated"),
KnownInstanceType::Literal => f.write_str("typing.Literal"),
KnownInstanceType::LiteralString => f.write_str("typing.LiteralString"),
KnownInstanceType::Optional => f.write_str("typing.Optional"),
KnownInstanceType::Union => f.write_str("typing.Union"),
KnownInstanceType::NoReturn => f.write_str("typing.NoReturn"),
KnownInstanceType::Never => f.write_str("typing.Never"),
KnownInstanceType::Any => f.write_str("typing.Any"),
KnownInstanceType::Tuple => f.write_str("typing.Tuple"),
KnownInstanceType::Type => f.write_str("typing.Type"),
KnownInstanceType::TypingSelf => f.write_str("typing.Self"),
KnownInstanceType::Final => f.write_str("typing.Final"),
KnownInstanceType::ClassVar => f.write_str("typing.ClassVar"),
KnownInstanceType::Callable => f.write_str("typing.Callable"),
KnownInstanceType::Concatenate => f.write_str("typing.Concatenate"),
KnownInstanceType::Unpack => f.write_str("typing.Unpack"),
KnownInstanceType::Required => f.write_str("typing.Required"),
KnownInstanceType::NotRequired => f.write_str("typing.NotRequired"),
KnownInstanceType::TypeAlias => f.write_str("typing.TypeAlias"),
KnownInstanceType::TypeGuard => f.write_str("typing.TypeGuard"),
KnownInstanceType::TypedDict => f.write_str("typing.TypedDict"),
KnownInstanceType::TypeIs => f.write_str("typing.TypeIs"),
KnownInstanceType::List => f.write_str("typing.List"),
KnownInstanceType::Dict => f.write_str("typing.Dict"),
KnownInstanceType::DefaultDict => f.write_str("typing.DefaultDict"),
KnownInstanceType::Set => f.write_str("typing.Set"),
KnownInstanceType::FrozenSet => f.write_str("typing.FrozenSet"),
KnownInstanceType::Counter => f.write_str("typing.Counter"),
KnownInstanceType::Deque => f.write_str("typing.Deque"),
KnownInstanceType::ChainMap => f.write_str("typing.ChainMap"),
KnownInstanceType::OrderedDict => f.write_str("typing.OrderedDict"),
KnownInstanceType::Protocol => f.write_str("typing.Protocol"),
KnownInstanceType::Generic(generic_context) => {
f.write_str("typing.Generic")?;
if let Some(generic_context) = generic_context {
write!(f, "{}", generic_context.display(self.db))?;
}
Ok(())
}
KnownInstanceType::ReadOnly => f.write_str("typing.ReadOnly"),
// This is a legacy `TypeVar` _outside_ of any generic class or function, so we render
// it as an instance of `typing.TypeVar`. Inside of a generic class or function, we'll
// have a `Type::TypeVar(_)`, which is rendered as the typevar's name.
KnownInstanceType::TypeVar(_) => f.write_str("typing.TypeVar"),
KnownInstanceType::TypeAliasType(_) => f.write_str("typing.TypeAliasType"),
KnownInstanceType::Unknown => f.write_str("ty_extensions.Unknown"),
KnownInstanceType::AlwaysTruthy => f.write_str("ty_extensions.AlwaysTruthy"),
KnownInstanceType::AlwaysFalsy => f.write_str("ty_extensions.AlwaysFalsy"),
KnownInstanceType::Not => f.write_str("ty_extensions.Not"),
KnownInstanceType::Intersection => f.write_str("ty_extensions.Intersection"),
KnownInstanceType::TypeOf => f.write_str("ty_extensions.TypeOf"),
KnownInstanceType::CallableTypeOf => f.write_str("ty_extensions.CallableTypeOf"),
}
}
}

View file

@ -0,0 +1,390 @@
use std::collections::VecDeque;
use std::ops::Deref;
use rustc_hash::FxHashSet;
use crate::types::class_base::ClassBase;
use crate::types::generics::Specialization;
use crate::types::{ClassLiteral, ClassType, Type};
use crate::Db;
/// The inferred method resolution order of a given class.
///
/// An MRO cannot contain non-specialized generic classes. (This is why [`ClassBase`] contains a
/// [`ClassType`], not a [`ClassLiteral`].) Any generic classes in a base class list are always
/// specialized — either because the class is explicitly specialized if there is a subscript
/// expression, or because we create the default specialization if there isn't.
///
/// The MRO of a non-specialized generic class can contain generic classes that are specialized
/// with a typevar from the inheriting class. When the inheriting class is specialized, the MRO of
/// the resulting generic alias will substitute those type variables accordingly. For instance, in
/// the following example, the MRO of `D[int]` includes `C[int]`, and the MRO of `D[U]` includes
/// `C[U]` (which is a generic alias, not a non-specialized generic class):
///
/// ```py
/// class C[T]: ...
/// class D[U](C[U]): ...
/// ```
///
/// See [`ClassType::iter_mro`] for more details.
#[derive(PartialEq, Eq, Clone, Debug, salsa::Update)]
pub(super) struct Mro<'db>(Box<[ClassBase<'db>]>);
impl<'db> Mro<'db> {
/// Attempt to resolve the MRO of a given class. Because we derive the MRO from the list of
/// base classes in the class definition, this operation is performed on a [class
/// literal][ClassLiteral], not a [class type][ClassType]. (You can _also_ get the MRO of a
/// class type, but this is done by first getting the MRO of the underlying class literal, and
/// specializing each base class as needed if the class type is a generic alias.)
///
/// In the event that a possible list of bases would (or could) lead to a `TypeError` being
/// raised at runtime due to an unresolvable MRO, we infer the MRO of the class as being `[<the
/// class in question>, Unknown, object]`. This seems most likely to reduce the possibility of
/// cascading errors elsewhere. (For a generic class, the first entry in this fallback MRO uses
/// the default specialization of the class's type variables.)
///
/// (We emit a diagnostic warning about the runtime `TypeError` in
/// [`super::infer::TypeInferenceBuilder::infer_region_scope`].)
pub(super) fn of_class(
db: &'db dyn Db,
class: ClassLiteral<'db>,
specialization: Option<Specialization<'db>>,
) -> Result<Self, MroError<'db>> {
Self::of_class_impl(db, class, specialization).map_err(|err| {
err.into_mro_error(db, class.apply_optional_specialization(db, specialization))
})
}
pub(super) fn from_error(db: &'db dyn Db, class: ClassType<'db>) -> Self {
Self::from([
ClassBase::Class(class),
ClassBase::unknown(),
ClassBase::object(db),
])
}
fn of_class_impl(
db: &'db dyn Db,
class: ClassLiteral<'db>,
specialization: Option<Specialization<'db>>,
) -> Result<Self, MroErrorKind<'db>> {
let class_bases = class.explicit_bases(db);
if !class_bases.is_empty() && class.inheritance_cycle(db).is_some() {
// We emit errors for cyclically defined classes elsewhere.
// It's important that we don't even try to infer the MRO for a cyclically defined class,
// or we'll end up in an infinite loop.
return Ok(Mro::from_error(
db,
class.apply_optional_specialization(db, specialization),
));
}
match class_bases {
// `builtins.object` is the special case:
// the only class in Python that has an MRO with length <2
[] if class.is_object(db) => Ok(Self::from([
// object is not generic, so the default specialization should be a no-op
ClassBase::Class(class.apply_optional_specialization(db, specialization)),
])),
// All other classes in Python have an MRO with length >=2.
// Even if a class has no explicit base classes,
// it will implicitly inherit from `object` at runtime;
// `object` will appear in the class's `__bases__` list and `__mro__`:
//
// ```pycon
// >>> class Foo: ...
// ...
// >>> Foo.__bases__
// (<class 'object'>,)
// >>> Foo.__mro__
// (<class '__main__.Foo'>, <class 'object'>)
// ```
[] => Ok(Self::from([
ClassBase::Class(class.apply_optional_specialization(db, specialization)),
ClassBase::object(db),
])),
// Fast path for a class that has only a single explicit base.
//
// This *could* theoretically be handled by the final branch below,
// but it's a common case (i.e., worth optimizing for),
// and the `c3_merge` function requires lots of allocations.
[single_base] => ClassBase::try_from_type(db, *single_base).map_or_else(
|| Err(MroErrorKind::InvalidBases(Box::from([(0, *single_base)]))),
|single_base| {
Ok(std::iter::once(ClassBase::Class(
class.apply_optional_specialization(db, specialization),
))
.chain(single_base.mro(db))
.collect())
},
),
// The class has multiple explicit bases.
//
// We'll fallback to a full implementation of the C3-merge algorithm to determine
// what MRO Python will give this class at runtime
// (if an MRO is indeed resolvable at all!)
multiple_bases => {
let mut valid_bases = vec![];
let mut invalid_bases = vec![];
for (i, base) in multiple_bases.iter().enumerate() {
match ClassBase::try_from_type(db, *base) {
Some(valid_base) => valid_bases.push(valid_base),
None => invalid_bases.push((i, *base)),
}
}
if !invalid_bases.is_empty() {
return Err(MroErrorKind::InvalidBases(invalid_bases.into_boxed_slice()));
}
let mut seqs = vec![VecDeque::from([ClassBase::Class(
class.apply_optional_specialization(db, specialization),
)])];
for base in &valid_bases {
seqs.push(base.mro(db).collect());
}
seqs.push(valid_bases.iter().copied().collect());
c3_merge(seqs).ok_or_else(|| {
let mut seen_bases = FxHashSet::default();
let mut duplicate_bases = vec![];
for (index, base) in valid_bases
.iter()
.enumerate()
.filter_map(|(index, base)| Some((index, base.into_class()?)))
{
if !seen_bases.insert(base) {
let (base_class_literal, _) = base.class_literal(db);
duplicate_bases.push((index, base_class_literal));
}
}
if duplicate_bases.is_empty() {
MroErrorKind::UnresolvableMro {
bases_list: valid_bases.into_boxed_slice(),
}
} else {
MroErrorKind::DuplicateBases(duplicate_bases.into_boxed_slice())
}
})
}
}
}
}
impl<'db, const N: usize> From<[ClassBase<'db>; N]> for Mro<'db> {
fn from(value: [ClassBase<'db>; N]) -> Self {
Self(Box::from(value))
}
}
impl<'db> From<Vec<ClassBase<'db>>> for Mro<'db> {
fn from(value: Vec<ClassBase<'db>>) -> Self {
Self(value.into_boxed_slice())
}
}
impl<'db> Deref for Mro<'db> {
type Target = [ClassBase<'db>];
fn deref(&self) -> &Self::Target {
&self.0
}
}
impl<'db> FromIterator<ClassBase<'db>> for Mro<'db> {
fn from_iter<T: IntoIterator<Item = ClassBase<'db>>>(iter: T) -> Self {
Self(iter.into_iter().collect())
}
}
/// Iterator that yields elements of a class's MRO.
///
/// We avoid materialising the *full* MRO unless it is actually necessary:
/// - Materialising the full MRO is expensive
/// - We need to do it for every class in the code that we're checking, as we need to make sure
/// that there are no class definitions in the code we're checking that would cause an
/// exception to be raised at runtime. But the same does *not* necessarily apply for every class
/// in third-party and stdlib dependencies: we never emit diagnostics about non-first-party code.
/// - However, we *do* need to resolve attribute accesses on classes/instances from
/// third-party and stdlib dependencies. That requires iterating over the MRO of third-party/stdlib
/// classes, but not necessarily the *whole* MRO: often just the first element is enough.
/// Luckily we know that for any class `X`, the first element of `X`'s MRO will always be `X` itself.
/// We can therefore avoid resolving the full MRO for many third-party/stdlib classes while still
/// being faithful to the runtime semantics.
///
/// Even for first-party code, where we will have to resolve the MRO for every class we encounter,
/// loading the cached MRO comes with a certain amount of overhead, so it's best to avoid calling the
/// Salsa-tracked [`ClassLiteral::try_mro`] method unless it's absolutely necessary.
pub(super) struct MroIterator<'db> {
db: &'db dyn Db,
/// The class whose MRO we're iterating over
class: ClassLiteral<'db>,
/// The specialization to apply to each MRO element, if any
specialization: Option<Specialization<'db>>,
/// Whether or not we've already yielded the first element of the MRO
first_element_yielded: bool,
/// Iterator over all elements of the MRO except the first.
///
/// The full MRO is expensive to materialize, so this field is `None`
/// unless we actually *need* to iterate past the first element of the MRO,
/// at which point it is lazily materialized.
subsequent_elements: Option<std::slice::Iter<'db, ClassBase<'db>>>,
}
impl<'db> MroIterator<'db> {
pub(super) fn new(
db: &'db dyn Db,
class: ClassLiteral<'db>,
specialization: Option<Specialization<'db>>,
) -> Self {
Self {
db,
class,
specialization,
first_element_yielded: false,
subsequent_elements: None,
}
}
/// Materialize the full MRO of the class.
/// Return an iterator over that MRO which skips the first element of the MRO.
fn full_mro_except_first_element(&mut self) -> impl Iterator<Item = ClassBase<'db>> + '_ {
self.subsequent_elements
.get_or_insert_with(|| {
let mut full_mro_iter = match self.class.try_mro(self.db, self.specialization) {
Ok(mro) => mro.iter(),
Err(error) => error.fallback_mro().iter(),
};
full_mro_iter.next();
full_mro_iter
})
.copied()
}
}
impl<'db> Iterator for MroIterator<'db> {
type Item = ClassBase<'db>;
fn next(&mut self) -> Option<Self::Item> {
if !self.first_element_yielded {
self.first_element_yielded = true;
return Some(ClassBase::Class(
self.class
.apply_optional_specialization(self.db, self.specialization),
));
}
self.full_mro_except_first_element().next()
}
}
impl std::iter::FusedIterator for MroIterator<'_> {}
#[derive(Debug, PartialEq, Eq, salsa::Update)]
pub(super) struct MroError<'db> {
kind: MroErrorKind<'db>,
fallback_mro: Mro<'db>,
}
impl<'db> MroError<'db> {
/// Return an [`MroErrorKind`] variant describing why we could not resolve the MRO for this class.
pub(super) fn reason(&self) -> &MroErrorKind<'db> {
&self.kind
}
/// Return the fallback MRO we should infer for this class during type inference
/// (since accurate resolution of its "true" MRO was impossible)
pub(super) fn fallback_mro(&self) -> &Mro<'db> {
&self.fallback_mro
}
}
/// Possible ways in which attempting to resolve the MRO of a class might fail.
#[derive(Debug, PartialEq, Eq, salsa::Update)]
pub(super) enum MroErrorKind<'db> {
/// The class inherits from one or more invalid bases.
///
/// To avoid excessive complexity in our implementation,
/// we only permit classes to inherit from class-literal types,
/// `Todo`, `Unknown` or `Any`. Anything else results in us
/// emitting a diagnostic.
///
/// This variant records the indices and types of class bases
/// that we deem to be invalid. The indices are the indices of nodes
/// in the bases list of the class's [`StmtClassDef`](ruff_python_ast::StmtClassDef) node.
/// Each index is the index of a node representing an invalid base.
InvalidBases(Box<[(usize, Type<'db>)]>),
/// The class has one or more duplicate bases.
///
/// This variant records the indices and [`ClassLiteral`]s
/// of the duplicate bases. The indices are the indices of nodes
/// in the bases list of the class's [`StmtClassDef`](ruff_python_ast::StmtClassDef) node.
/// Each index is the index of a node representing a duplicate base.
DuplicateBases(Box<[(usize, ClassLiteral<'db>)]>),
/// The MRO is otherwise unresolvable through the C3-merge algorithm.
///
/// See [`c3_merge`] for more details.
UnresolvableMro { bases_list: Box<[ClassBase<'db>]> },
}
impl<'db> MroErrorKind<'db> {
pub(super) fn into_mro_error(self, db: &'db dyn Db, class: ClassType<'db>) -> MroError<'db> {
MroError {
kind: self,
fallback_mro: Mro::from_error(db, class),
}
}
}
/// Implementation of the [C3-merge algorithm] for calculating a Python class's
/// [method resolution order].
///
/// [C3-merge algorithm]: https://docs.python.org/3/howto/mro.html#python-2-3-mro
/// [method resolution order]: https://docs.python.org/3/glossary.html#term-method-resolution-order
fn c3_merge(mut sequences: Vec<VecDeque<ClassBase>>) -> Option<Mro> {
// Most MROs aren't that long...
let mut mro = Vec::with_capacity(8);
loop {
sequences.retain(|sequence| !sequence.is_empty());
if sequences.is_empty() {
return Some(Mro::from(mro));
}
// If the candidate exists "deeper down" in the inheritance hierarchy,
// we should refrain from adding it to the MRO for now. Add the first candidate
// for which this does not hold true. If this holds true for all candidates,
// return `None`; it will be impossible to find a consistent MRO for the class
// with the given bases.
let mro_entry = sequences.iter().find_map(|outer_sequence| {
let candidate = outer_sequence[0];
let not_head = sequences
.iter()
.all(|sequence| sequence.iter().skip(1).all(|base| base != &candidate));
not_head.then_some(candidate)
})?;
mro.push(mro_entry);
// Make sure we don't try to add the candidate to the MRO twice:
for sequence in &mut sequences {
if sequence[0] == mro_entry {
sequence.pop_front();
}
}
}
}

View file

@ -0,0 +1,857 @@
use crate::semantic_index::ast_ids::HasScopedExpressionId;
use crate::semantic_index::definition::Definition;
use crate::semantic_index::expression::Expression;
use crate::semantic_index::predicate::{
PatternPredicate, PatternPredicateKind, Predicate, PredicateNode,
};
use crate::semantic_index::symbol::{ScopeId, ScopedSymbolId, SymbolTable};
use crate::semantic_index::symbol_table;
use crate::types::infer::infer_same_file_expression_type;
use crate::types::{
infer_expression_types, IntersectionBuilder, KnownClass, SubclassOfType, Truthiness, Type,
UnionBuilder,
};
use crate::Db;
use itertools::Itertools;
use ruff_python_ast as ast;
use ruff_python_ast::{BoolOp, ExprBoolOp};
use rustc_hash::FxHashMap;
use std::collections::hash_map::Entry;
use std::sync::Arc;
use super::UnionType;
/// Return the type constraint that `test` (if true) would place on `definition`, if any.
///
/// For example, if we have this code:
///
/// ```python
/// y = 1 if flag else None
/// x = 1 if flag else None
/// if x is not None:
/// ...
/// ```
///
/// The `test` expression `x is not None` places the constraint "not None" on the definition of
/// `x`, so in that case we'd return `Some(Type::Intersection(negative=[Type::None]))`.
///
/// But if we called this with the same `test` expression, but the `definition` of `y`, no
/// constraint is applied to that definition, so we'd just return `None`.
pub(crate) fn infer_narrowing_constraint<'db>(
db: &'db dyn Db,
predicate: Predicate<'db>,
definition: Definition<'db>,
) -> Option<Type<'db>> {
let constraints = match predicate.node {
PredicateNode::Expression(expression) => {
if predicate.is_positive {
all_narrowing_constraints_for_expression(db, expression)
} else {
all_negative_narrowing_constraints_for_expression(db, expression)
}
}
PredicateNode::Pattern(pattern) => {
if predicate.is_positive {
all_narrowing_constraints_for_pattern(db, pattern)
} else {
all_negative_narrowing_constraints_for_pattern(db, pattern)
}
}
PredicateNode::StarImportPlaceholder(_) => return None,
};
if let Some(constraints) = constraints {
constraints.get(&definition.symbol(db)).copied()
} else {
None
}
}
#[allow(clippy::ref_option)]
#[salsa::tracked(return_ref)]
fn all_narrowing_constraints_for_pattern<'db>(
db: &'db dyn Db,
pattern: PatternPredicate<'db>,
) -> Option<NarrowingConstraints<'db>> {
NarrowingConstraintsBuilder::new(db, PredicateNode::Pattern(pattern), true).finish()
}
#[allow(clippy::ref_option)]
#[salsa::tracked(
return_ref,
cycle_fn=constraints_for_expression_cycle_recover,
cycle_initial=constraints_for_expression_cycle_initial,
)]
fn all_narrowing_constraints_for_expression<'db>(
db: &'db dyn Db,
expression: Expression<'db>,
) -> Option<NarrowingConstraints<'db>> {
NarrowingConstraintsBuilder::new(db, PredicateNode::Expression(expression), true).finish()
}
#[allow(clippy::ref_option)]
#[salsa::tracked(
return_ref,
cycle_fn=negative_constraints_for_expression_cycle_recover,
cycle_initial=negative_constraints_for_expression_cycle_initial,
)]
fn all_negative_narrowing_constraints_for_expression<'db>(
db: &'db dyn Db,
expression: Expression<'db>,
) -> Option<NarrowingConstraints<'db>> {
NarrowingConstraintsBuilder::new(db, PredicateNode::Expression(expression), false).finish()
}
#[allow(clippy::ref_option)]
#[salsa::tracked(return_ref)]
fn all_negative_narrowing_constraints_for_pattern<'db>(
db: &'db dyn Db,
pattern: PatternPredicate<'db>,
) -> Option<NarrowingConstraints<'db>> {
NarrowingConstraintsBuilder::new(db, PredicateNode::Pattern(pattern), false).finish()
}
#[allow(clippy::ref_option)]
fn constraints_for_expression_cycle_recover<'db>(
_db: &'db dyn Db,
_value: &Option<NarrowingConstraints<'db>>,
_count: u32,
_expression: Expression<'db>,
) -> salsa::CycleRecoveryAction<Option<NarrowingConstraints<'db>>> {
salsa::CycleRecoveryAction::Iterate
}
fn constraints_for_expression_cycle_initial<'db>(
_db: &'db dyn Db,
_expression: Expression<'db>,
) -> Option<NarrowingConstraints<'db>> {
None
}
#[allow(clippy::ref_option)]
fn negative_constraints_for_expression_cycle_recover<'db>(
_db: &'db dyn Db,
_value: &Option<NarrowingConstraints<'db>>,
_count: u32,
_expression: Expression<'db>,
) -> salsa::CycleRecoveryAction<Option<NarrowingConstraints<'db>>> {
salsa::CycleRecoveryAction::Iterate
}
fn negative_constraints_for_expression_cycle_initial<'db>(
_db: &'db dyn Db,
_expression: Expression<'db>,
) -> Option<NarrowingConstraints<'db>> {
None
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum KnownConstraintFunction {
/// `builtins.isinstance`
IsInstance,
/// `builtins.issubclass`
IsSubclass,
}
impl KnownConstraintFunction {
/// Generate a constraint from the type of a `classinfo` argument to `isinstance` or `issubclass`.
///
/// The `classinfo` argument can be a class literal, a tuple of (tuples of) class literals. PEP 604
/// union types are not yet supported. Returns `None` if the `classinfo` argument has a wrong type.
fn generate_constraint<'db>(self, db: &'db dyn Db, classinfo: Type<'db>) -> Option<Type<'db>> {
let constraint_fn = |class| match self {
KnownConstraintFunction::IsInstance => Type::instance(db, class),
KnownConstraintFunction::IsSubclass => SubclassOfType::from(db, class),
};
match classinfo {
Type::Tuple(tuple) => {
let mut builder = UnionBuilder::new(db);
for element in tuple.elements(db) {
builder = builder.add(self.generate_constraint(db, *element)?);
}
Some(builder.build())
}
Type::ClassLiteral(class_literal) => {
// At runtime (on Python 3.11+), this will return `True` for classes that actually
// do inherit `typing.Any` and `False` otherwise. We could accurately model that?
if class_literal.is_known(db, KnownClass::Any) {
None
} else {
Some(constraint_fn(class_literal.default_specialization(db)))
}
}
Type::SubclassOf(subclass_of_ty) => {
subclass_of_ty.subclass_of().into_class().map(constraint_fn)
}
_ => None,
}
}
}
type NarrowingConstraints<'db> = FxHashMap<ScopedSymbolId, Type<'db>>;
fn merge_constraints_and<'db>(
into: &mut NarrowingConstraints<'db>,
from: NarrowingConstraints<'db>,
db: &'db dyn Db,
) {
for (key, value) in from {
match into.entry(key) {
Entry::Occupied(mut entry) => {
*entry.get_mut() = IntersectionBuilder::new(db)
.add_positive(*entry.get())
.add_positive(value)
.build();
}
Entry::Vacant(entry) => {
entry.insert(value);
}
}
}
}
fn merge_constraints_or<'db>(
into: &mut NarrowingConstraints<'db>,
from: &NarrowingConstraints<'db>,
db: &'db dyn Db,
) {
for (key, value) in from {
match into.entry(*key) {
Entry::Occupied(mut entry) => {
*entry.get_mut() = UnionBuilder::new(db).add(*entry.get()).add(*value).build();
}
Entry::Vacant(entry) => {
entry.insert(Type::object(db));
}
}
}
for (key, value) in into.iter_mut() {
if !from.contains_key(key) {
*value = Type::object(db);
}
}
}
fn negate_if<'db>(constraints: &mut NarrowingConstraints<'db>, db: &'db dyn Db, yes: bool) {
for (_symbol, ty) in constraints.iter_mut() {
*ty = ty.negate_if(db, yes);
}
}
fn expr_name(expr: &ast::Expr) -> Option<&ast::name::Name> {
match expr {
ast::Expr::Named(ast::ExprNamed { target, .. }) => match target.as_ref() {
ast::Expr::Name(ast::ExprName { id, .. }) => Some(id),
_ => None,
},
ast::Expr::Name(ast::ExprName { id, .. }) => Some(id),
_ => None,
}
}
struct NarrowingConstraintsBuilder<'db> {
db: &'db dyn Db,
predicate: PredicateNode<'db>,
is_positive: bool,
}
impl<'db> NarrowingConstraintsBuilder<'db> {
fn new(db: &'db dyn Db, predicate: PredicateNode<'db>, is_positive: bool) -> Self {
Self {
db,
predicate,
is_positive,
}
}
fn finish(mut self) -> Option<NarrowingConstraints<'db>> {
let constraints: Option<NarrowingConstraints<'db>> = match self.predicate {
PredicateNode::Expression(expression) => {
self.evaluate_expression_predicate(expression, self.is_positive)
}
PredicateNode::Pattern(pattern) => {
self.evaluate_pattern_predicate(pattern, self.is_positive)
}
PredicateNode::StarImportPlaceholder(_) => return None,
};
if let Some(mut constraints) = constraints {
constraints.shrink_to_fit();
Some(constraints)
} else {
None
}
}
fn evaluate_expression_predicate(
&mut self,
expression: Expression<'db>,
is_positive: bool,
) -> Option<NarrowingConstraints<'db>> {
let expression_node = expression.node_ref(self.db).node();
self.evaluate_expression_node_predicate(expression_node, expression, is_positive)
}
fn evaluate_expression_node_predicate(
&mut self,
expression_node: &ruff_python_ast::Expr,
expression: Expression<'db>,
is_positive: bool,
) -> Option<NarrowingConstraints<'db>> {
match expression_node {
ast::Expr::Name(name) => Some(self.evaluate_expr_name(name, is_positive)),
ast::Expr::Compare(expr_compare) => {
self.evaluate_expr_compare(expr_compare, expression, is_positive)
}
ast::Expr::Call(expr_call) => {
self.evaluate_expr_call(expr_call, expression, is_positive)
}
ast::Expr::UnaryOp(unary_op) if unary_op.op == ast::UnaryOp::Not => {
self.evaluate_expression_node_predicate(&unary_op.operand, expression, !is_positive)
}
ast::Expr::BoolOp(bool_op) => self.evaluate_bool_op(bool_op, expression, is_positive),
ast::Expr::Named(expr_named) => self.evaluate_expr_named(expr_named, is_positive),
_ => None,
}
}
fn evaluate_pattern_predicate_kind(
&mut self,
pattern_predicate_kind: &PatternPredicateKind<'db>,
subject: Expression<'db>,
) -> Option<NarrowingConstraints<'db>> {
match pattern_predicate_kind {
PatternPredicateKind::Singleton(singleton) => {
self.evaluate_match_pattern_singleton(subject, *singleton)
}
PatternPredicateKind::Class(cls) => self.evaluate_match_pattern_class(subject, *cls),
PatternPredicateKind::Value(expr) => self.evaluate_match_pattern_value(subject, *expr),
PatternPredicateKind::Or(predicates) => {
self.evaluate_match_pattern_or(subject, predicates)
}
PatternPredicateKind::Unsupported => None,
}
}
fn evaluate_pattern_predicate(
&mut self,
pattern: PatternPredicate<'db>,
is_positive: bool,
) -> Option<NarrowingConstraints<'db>> {
let subject = pattern.subject(self.db);
self.evaluate_pattern_predicate_kind(pattern.kind(self.db), subject)
.map(|mut constraints| {
negate_if(&mut constraints, self.db, !is_positive);
constraints
})
}
fn symbols(&self) -> Arc<SymbolTable> {
symbol_table(self.db, self.scope())
}
fn scope(&self) -> ScopeId<'db> {
match self.predicate {
PredicateNode::Expression(expression) => expression.scope(self.db),
PredicateNode::Pattern(pattern) => pattern.scope(self.db),
PredicateNode::StarImportPlaceholder(definition) => definition.scope(self.db),
}
}
#[track_caller]
fn expect_expr_name_symbol(&self, symbol: &str) -> ScopedSymbolId {
self.symbols()
.symbol_id_by_name(symbol)
.expect("We should always have a symbol for every `Name` node")
}
fn evaluate_expr_name(
&mut self,
expr_name: &ast::ExprName,
is_positive: bool,
) -> NarrowingConstraints<'db> {
let ast::ExprName { id, .. } = expr_name;
let symbol = self.expect_expr_name_symbol(id);
let ty = if is_positive {
Type::AlwaysFalsy.negate(self.db)
} else {
Type::AlwaysTruthy.negate(self.db)
};
NarrowingConstraints::from_iter([(symbol, ty)])
}
fn evaluate_expr_named(
&mut self,
expr_named: &ast::ExprNamed,
is_positive: bool,
) -> Option<NarrowingConstraints<'db>> {
if let ast::Expr::Name(expr_name) = expr_named.target.as_ref() {
Some(self.evaluate_expr_name(expr_name, is_positive))
} else {
None
}
}
fn evaluate_expr_eq(&mut self, lhs_ty: Type<'db>, rhs_ty: Type<'db>) -> Option<Type<'db>> {
// We can only narrow on equality checks against single-valued types.
if rhs_ty.is_single_valued(self.db) || rhs_ty.is_union_of_single_valued(self.db) {
// The fully-general (and more efficient) approach here would be to introduce a
// `NeverEqualTo` type that can wrap a single-valued type, and then simply return
// `~NeverEqualTo(rhs_ty)` here and let union/intersection builder sort it out. This is
// how we handle `AlwaysTruthy` and `AlwaysFalsy`. But this means we have to deal with
// this type everywhere, and possibly have it show up unsimplified in some cases, and
// so we instead prefer to just do the simplification here. (Another hybrid option that
// would be similar to this, but more efficient, would be to allow narrowing to return
// something that is not a type, and handle this not-a-type in `symbol_from_bindings`,
// instead of intersecting with a type.)
// Return `true` if it is possible for any two inhabitants of the given types to
// compare equal to each other; otherwise return `false`.
fn could_compare_equal<'db>(
db: &'db dyn Db,
left_ty: Type<'db>,
right_ty: Type<'db>,
) -> bool {
if !left_ty.is_disjoint_from(db, right_ty) {
// If types overlap, they have inhabitants in common; it's definitely possible
// for an object to compare equal to itself.
return true;
}
match (left_ty, right_ty) {
// In order to be sure a union type cannot compare equal to another type, it
// must be true that no element of the union can compare equal to that type.
(Type::Union(union), _) => union
.elements(db)
.iter()
.any(|ty| could_compare_equal(db, *ty, right_ty)),
(_, Type::Union(union)) => union
.elements(db)
.iter()
.any(|ty| could_compare_equal(db, left_ty, *ty)),
// Boolean literals and int literals are disjoint, and single valued, and yet
// `True == 1` and `False == 0`.
(Type::BooleanLiteral(b), Type::IntLiteral(i))
| (Type::IntLiteral(i), Type::BooleanLiteral(b)) => i64::from(b) == i,
// Other than the above cases, two single-valued disjoint types cannot compare
// equal.
_ => !(left_ty.is_single_valued(db) && right_ty.is_single_valued(db)),
}
}
// Return `true` if `lhs_ty` consists only of `LiteralString` and types that cannot
// compare equal to `rhs_ty`.
fn can_narrow_to_rhs<'db>(
db: &'db dyn Db,
lhs_ty: Type<'db>,
rhs_ty: Type<'db>,
) -> bool {
match lhs_ty {
Type::Union(union) => union
.elements(db)
.iter()
.all(|ty| can_narrow_to_rhs(db, *ty, rhs_ty)),
// Either `rhs_ty` is a string literal, in which case we can narrow to it (no
// other string literal could compare equal to it), or it is not a string
// literal, in which case (given that it is single-valued), LiteralString
// cannot compare equal to it.
Type::LiteralString => true,
_ => !could_compare_equal(db, lhs_ty, rhs_ty),
}
}
// Filter `ty` to just the types that cannot be equal to `rhs_ty`.
fn filter_to_cannot_be_equal<'db>(
db: &'db dyn Db,
ty: Type<'db>,
rhs_ty: Type<'db>,
) -> Type<'db> {
match ty {
Type::Union(union) => {
union.map(db, |ty| filter_to_cannot_be_equal(db, *ty, rhs_ty))
}
// Treat `bool` as `Literal[True, False]`.
Type::NominalInstance(instance)
if instance.class().is_known(db, KnownClass::Bool) =>
{
UnionType::from_elements(
db,
[Type::BooleanLiteral(true), Type::BooleanLiteral(false)]
.into_iter()
.map(|ty| filter_to_cannot_be_equal(db, ty, rhs_ty)),
)
}
_ => {
if ty.is_single_valued(db) && !could_compare_equal(db, ty, rhs_ty) {
ty
} else {
Type::Never
}
}
}
}
Some(if can_narrow_to_rhs(self.db, lhs_ty, rhs_ty) {
rhs_ty
} else {
filter_to_cannot_be_equal(self.db, lhs_ty, rhs_ty).negate(self.db)
})
} else {
None
}
}
fn evaluate_expr_ne(&mut self, lhs_ty: Type<'db>, rhs_ty: Type<'db>) -> Option<Type<'db>> {
match (lhs_ty, rhs_ty) {
(Type::NominalInstance(instance), Type::IntLiteral(i))
if instance.class().is_known(self.db, KnownClass::Bool) =>
{
if i == 0 {
Some(Type::BooleanLiteral(false).negate(self.db))
} else if i == 1 {
Some(Type::BooleanLiteral(true).negate(self.db))
} else {
None
}
}
(_, Type::BooleanLiteral(b)) => Some(
UnionType::from_elements(self.db, [rhs_ty, Type::IntLiteral(i64::from(b))])
.negate(self.db),
),
_ if rhs_ty.is_single_valued(self.db) => Some(rhs_ty.negate(self.db)),
_ => None,
}
}
fn evaluate_expr_in(&mut self, lhs_ty: Type<'db>, rhs_ty: Type<'db>) -> Option<Type<'db>> {
if lhs_ty.is_single_valued(self.db) || lhs_ty.is_union_of_single_valued(self.db) {
match rhs_ty {
Type::Tuple(rhs_tuple) => Some(UnionType::from_elements(
self.db,
rhs_tuple.elements(self.db),
)),
Type::StringLiteral(string_literal) => Some(UnionType::from_elements(
self.db,
string_literal
.iter_each_char(self.db)
.map(Type::StringLiteral),
)),
_ => None,
}
} else {
None
}
}
fn evaluate_expr_compare_op(
&mut self,
lhs_ty: Type<'db>,
rhs_ty: Type<'db>,
op: ast::CmpOp,
) -> Option<Type<'db>> {
match op {
ast::CmpOp::IsNot => {
if rhs_ty.is_singleton(self.db) {
let ty = IntersectionBuilder::new(self.db)
.add_negative(rhs_ty)
.build();
Some(ty)
} else {
// Non-singletons cannot be safely narrowed using `is not`
None
}
}
ast::CmpOp::Is => Some(rhs_ty),
ast::CmpOp::Eq => self.evaluate_expr_eq(lhs_ty, rhs_ty),
ast::CmpOp::NotEq => self.evaluate_expr_ne(lhs_ty, rhs_ty),
ast::CmpOp::In => self.evaluate_expr_in(lhs_ty, rhs_ty),
ast::CmpOp::NotIn => self
.evaluate_expr_in(lhs_ty, rhs_ty)
.map(|ty| ty.negate(self.db)),
_ => None,
}
}
fn evaluate_expr_compare(
&mut self,
expr_compare: &ast::ExprCompare,
expression: Expression<'db>,
is_positive: bool,
) -> Option<NarrowingConstraints<'db>> {
fn is_narrowing_target_candidate(expr: &ast::Expr) -> bool {
matches!(
expr,
ast::Expr::Name(_) | ast::Expr::Call(_) | ast::Expr::Named(_)
)
}
let ast::ExprCompare {
range: _,
left,
ops,
comparators,
} = expr_compare;
// Performance optimization: early return if there are no potential narrowing targets.
if !is_narrowing_target_candidate(left)
&& comparators
.iter()
.all(|c| !is_narrowing_target_candidate(c))
{
return None;
}
if !is_positive && comparators.len() > 1 {
// We can't negate a constraint made by a multi-comparator expression, since we can't
// know which comparison part is the one being negated.
// For example, the negation of `x is 1 is y is 2`, would be `(x is not 1) or (y is not 1) or (y is not 2)`
// and that requires cross-symbol constraints, which we don't support yet.
return None;
}
let scope = self.scope();
let inference = infer_expression_types(self.db, expression);
let comparator_tuples = std::iter::once(&**left)
.chain(comparators)
.tuple_windows::<(&ruff_python_ast::Expr, &ruff_python_ast::Expr)>();
let mut constraints = NarrowingConstraints::default();
let mut last_rhs_ty: Option<Type> = None;
for (op, (left, right)) in std::iter::zip(&**ops, comparator_tuples) {
let lhs_ty = last_rhs_ty.unwrap_or_else(|| {
inference.expression_type(left.scoped_expression_id(self.db, scope))
});
let rhs_ty = inference.expression_type(right.scoped_expression_id(self.db, scope));
last_rhs_ty = Some(rhs_ty);
match left {
ast::Expr::Name(_) | ast::Expr::Named(_) => {
if let Some(id) = expr_name(left) {
let symbol = self.expect_expr_name_symbol(id);
let op = if is_positive { *op } else { op.negate() };
if let Some(ty) = self.evaluate_expr_compare_op(lhs_ty, rhs_ty, op) {
constraints.insert(symbol, ty);
}
}
}
ast::Expr::Call(ast::ExprCall {
range: _,
func: callable,
arguments:
ast::Arguments {
args,
keywords,
range: _,
},
}) if keywords.is_empty() => {
let rhs_class = match rhs_ty {
Type::ClassLiteral(class) => class,
Type::GenericAlias(alias) => alias.origin(self.db),
_ => {
continue;
}
};
let id = match &**args {
[first] => match expr_name(first) {
Some(id) => id,
None => continue,
},
_ => continue,
};
let is_valid_constraint = if is_positive {
op == &ast::CmpOp::Is
} else {
op == &ast::CmpOp::IsNot
};
if !is_valid_constraint {
continue;
}
let callable_type =
inference.expression_type(callable.scoped_expression_id(self.db, scope));
if callable_type
.into_class_literal()
.is_some_and(|c| c.is_known(self.db, KnownClass::Type))
{
let symbol = self.expect_expr_name_symbol(id);
constraints.insert(
symbol,
Type::instance(self.db, rhs_class.unknown_specialization(self.db)),
);
}
}
_ => {}
}
}
Some(constraints)
}
fn evaluate_expr_call(
&mut self,
expr_call: &ast::ExprCall,
expression: Expression<'db>,
is_positive: bool,
) -> Option<NarrowingConstraints<'db>> {
let scope = self.scope();
let inference = infer_expression_types(self.db, expression);
let callable_ty =
inference.expression_type(expr_call.func.scoped_expression_id(self.db, scope));
// TODO: add support for PEP 604 union types on the right hand side of `isinstance`
// and `issubclass`, for example `isinstance(x, str | (int | float))`.
match callable_ty {
Type::FunctionLiteral(function_type) if expr_call.arguments.keywords.is_empty() => {
let function = function_type.known(self.db)?.into_constraint_function()?;
let (id, class_info) = match &*expr_call.arguments.args {
[first, class_info] => match expr_name(first) {
Some(id) => (id, class_info),
None => return None,
},
_ => return None,
};
let symbol = self.expect_expr_name_symbol(id);
let class_info_ty =
inference.expression_type(class_info.scoped_expression_id(self.db, scope));
function
.generate_constraint(self.db, class_info_ty)
.map(|constraint| {
NarrowingConstraints::from_iter([(
symbol,
constraint.negate_if(self.db, !is_positive),
)])
})
}
// for the expression `bool(E)`, we further narrow the type based on `E`
Type::ClassLiteral(class_type)
if expr_call.arguments.args.len() == 1
&& expr_call.arguments.keywords.is_empty()
&& class_type.is_known(self.db, KnownClass::Bool) =>
{
self.evaluate_expression_node_predicate(
&expr_call.arguments.args[0],
expression,
is_positive,
)
}
_ => None,
}
}
fn evaluate_match_pattern_singleton(
&mut self,
subject: Expression<'db>,
singleton: ast::Singleton,
) -> Option<NarrowingConstraints<'db>> {
let symbol = self.expect_expr_name_symbol(&subject.node_ref(self.db).as_name_expr()?.id);
let ty = match singleton {
ast::Singleton::None => Type::none(self.db),
ast::Singleton::True => Type::BooleanLiteral(true),
ast::Singleton::False => Type::BooleanLiteral(false),
};
Some(NarrowingConstraints::from_iter([(symbol, ty)]))
}
fn evaluate_match_pattern_class(
&mut self,
subject: Expression<'db>,
cls: Expression<'db>,
) -> Option<NarrowingConstraints<'db>> {
let symbol = self.expect_expr_name_symbol(&subject.node_ref(self.db).as_name_expr()?.id);
let ty = infer_same_file_expression_type(self.db, cls).to_instance(self.db)?;
Some(NarrowingConstraints::from_iter([(symbol, ty)]))
}
fn evaluate_match_pattern_value(
&mut self,
subject: Expression<'db>,
value: Expression<'db>,
) -> Option<NarrowingConstraints<'db>> {
let symbol = self.expect_expr_name_symbol(&subject.node_ref(self.db).as_name_expr()?.id);
let ty = infer_same_file_expression_type(self.db, value);
Some(NarrowingConstraints::from_iter([(symbol, ty)]))
}
fn evaluate_match_pattern_or(
&mut self,
subject: Expression<'db>,
predicates: &Vec<PatternPredicateKind<'db>>,
) -> Option<NarrowingConstraints<'db>> {
let db = self.db;
predicates
.iter()
.filter_map(|predicate| self.evaluate_pattern_predicate_kind(predicate, subject))
.reduce(|mut constraints, constraints_| {
merge_constraints_or(&mut constraints, &constraints_, db);
constraints
})
}
fn evaluate_bool_op(
&mut self,
expr_bool_op: &ExprBoolOp,
expression: Expression<'db>,
is_positive: bool,
) -> Option<NarrowingConstraints<'db>> {
let inference = infer_expression_types(self.db, expression);
let scope = self.scope();
let mut sub_constraints = expr_bool_op
.values
.iter()
// filter our arms with statically known truthiness
.filter(|expr| {
inference
.expression_type(expr.scoped_expression_id(self.db, scope))
.bool(self.db)
!= match expr_bool_op.op {
BoolOp::And => Truthiness::AlwaysTrue,
BoolOp::Or => Truthiness::AlwaysFalse,
}
})
.map(|sub_expr| {
self.evaluate_expression_node_predicate(sub_expr, expression, is_positive)
})
.collect::<Vec<_>>();
match (expr_bool_op.op, is_positive) {
(BoolOp::And, true) | (BoolOp::Or, false) => {
let mut aggregation: Option<NarrowingConstraints> = None;
for sub_constraint in sub_constraints.into_iter().flatten() {
if let Some(ref mut some_aggregation) = aggregation {
merge_constraints_and(some_aggregation, sub_constraint, self.db);
} else {
aggregation = Some(sub_constraint);
}
}
aggregation
}
(BoolOp::Or, true) | (BoolOp::And, false) => {
let (first, rest) = sub_constraints.split_first_mut()?;
if let Some(ref mut first) = first {
for rest_constraint in rest {
if let Some(rest_constraint) = rest_constraint {
merge_constraints_or(first, rest_constraint, self.db);
} else {
return None;
}
}
}
first.clone()
}
}
}
}

View file

@ -0,0 +1,306 @@
//! This module contains quickcheck-based property tests for `Type`s.
//!
//! These tests are disabled by default, as they are non-deterministic and slow. You can
//! run them explicitly using:
//!
//! ```sh
//! cargo test -p ty_python_semantic -- --ignored types::property_tests::stable
//! ```
//!
//! The number of tests (default: 100) can be controlled by setting the `QUICKCHECK_TESTS`
//! environment variable. For example:
//!
//! ```sh
//! QUICKCHECK_TESTS=10000 cargo test …
//! ```
//!
//! If you want to run these tests for a longer period of time, it's advisable to run them
//! in release mode. As some tests are slower than others, it's advisable to run them in a
//! loop until they fail:
//!
//! ```sh
//! export QUICKCHECK_TESTS=100000
//! while cargo test --release -p ty_python_semantic -- \
//! --ignored types::property_tests::stable; do :; done
//! ```
mod setup;
mod type_generation;
use type_generation::{intersection, union};
/// A macro to define a property test for types.
///
/// The `$test_name` identifier specifies the name of the test function. The `$db` identifier
/// is used to refer to the salsa database in the property to be tested. The actual property is
/// specified using the syntax:
///
/// forall types t1, t2, ..., tn . <property>`
///
/// where `t1`, `t2`, ..., `tn` are identifiers that represent arbitrary types, and `<property>`
/// is an expression using these identifiers.
///
macro_rules! type_property_test {
($test_name:ident, $db:ident, forall types $($types:ident),+ . $property:expr) => {
#[quickcheck_macros::quickcheck]
#[ignore]
fn $test_name($($types: crate::types::property_tests::type_generation::Ty),+) -> bool {
let $db = &crate::types::property_tests::setup::get_cached_db();
$(let $types = $types.into_type($db);)+
$property
}
};
// A property test with a logical implication.
($name:ident, $db:ident, forall types $($types:ident),+ . $premise:expr => $conclusion:expr) => {
type_property_test!($name, $db, forall types $($types),+ . !($premise) || ($conclusion));
};
}
mod stable {
use super::union;
use crate::types::{CallableType, Type};
// Reflexivity: `T` is equivalent to itself.
type_property_test!(
equivalent_to_is_reflexive, db,
forall types t. t.is_fully_static(db) => t.is_equivalent_to(db, t)
);
// Symmetry: If `S` is equivalent to `T`, then `T` must be equivalent to `S`.
// Note that this (trivially) holds true for gradual types as well.
type_property_test!(
equivalent_to_is_symmetric, db,
forall types s, t. s.is_equivalent_to(db, t) => t.is_equivalent_to(db, s)
);
// Transitivity: If `S` is equivalent to `T` and `T` is equivalent to `U`, then `S` must be equivalent to `U`.
type_property_test!(
equivalent_to_is_transitive, db,
forall types s, t, u. s.is_equivalent_to(db, t) && t.is_equivalent_to(db, u) => s.is_equivalent_to(db, u)
);
// Symmetry: If `S` is gradual equivalent to `T`, `T` is gradual equivalent to `S`.
type_property_test!(
gradual_equivalent_to_is_symmetric, db,
forall types s, t. s.is_gradual_equivalent_to(db, t) => t.is_gradual_equivalent_to(db, s)
);
// A fully static type `T` is a subtype of itself.
type_property_test!(
subtype_of_is_reflexive, db,
forall types t. t.is_fully_static(db) => t.is_subtype_of(db, t)
);
// `S <: T` and `T <: U` implies that `S <: U`.
type_property_test!(
subtype_of_is_transitive, db,
forall types s, t, u. s.is_subtype_of(db, t) && t.is_subtype_of(db, u) => s.is_subtype_of(db, u)
);
// `S <: T` and `T <: S` implies that `S` is equivalent to `T`.
type_property_test!(
subtype_of_is_antisymmetric, db,
forall types s, t. s.is_subtype_of(db, t) && t.is_subtype_of(db, s) => s.is_equivalent_to(db, t)
);
// `T` is not disjoint from itself, unless `T` is `Never`.
type_property_test!(
disjoint_from_is_irreflexive, db,
forall types t. t.is_disjoint_from(db, t) => t.is_never()
);
// `S` is disjoint from `T` implies that `T` is disjoint from `S`.
type_property_test!(
disjoint_from_is_symmetric, db,
forall types s, t. s.is_disjoint_from(db, t) == t.is_disjoint_from(db, s)
);
// `S <: T` implies that `S` is not disjoint from `T`, unless `S` is `Never`.
type_property_test!(
subtype_of_implies_not_disjoint_from, db,
forall types s, t. s.is_subtype_of(db, t) => !s.is_disjoint_from(db, t) || s.is_never()
);
// `S <: T` implies that `S` can be assigned to `T`.
type_property_test!(
subtype_of_implies_assignable_to, db,
forall types s, t. s.is_subtype_of(db, t) => s.is_assignable_to(db, t)
);
// If `T` is a singleton, it is also single-valued.
type_property_test!(
singleton_implies_single_valued, db,
forall types t. t.is_singleton(db) => t.is_single_valued(db)
);
// If `T` contains a gradual form, it should not participate in equivalence
type_property_test!(
non_fully_static_types_do_not_participate_in_equivalence, db,
forall types s, t. !s.is_fully_static(db) => !s.is_equivalent_to(db, t) && !t.is_equivalent_to(db, s)
);
// If `T` contains a gradual form, it should not participate in subtyping
type_property_test!(
non_fully_static_types_do_not_participate_in_subtyping, db,
forall types s, t. !s.is_fully_static(db) => !s.is_subtype_of(db, t) && !t.is_subtype_of(db, s)
);
// All types should be assignable to `object`
type_property_test!(
all_types_assignable_to_object, db,
forall types t. t.is_assignable_to(db, Type::object(db))
);
// And for fully static types, they should also be subtypes of `object`
type_property_test!(
all_fully_static_types_subtype_of_object, db,
forall types t. t.is_fully_static(db) => t.is_subtype_of(db, Type::object(db))
);
// Never should be assignable to every type
type_property_test!(
never_assignable_to_every_type, db,
forall types t. Type::Never.is_assignable_to(db, t)
);
// And it should be a subtype of all fully static types
type_property_test!(
never_subtype_of_every_fully_static_type, db,
forall types t. t.is_fully_static(db) => Type::Never.is_subtype_of(db, t)
);
// Similar to `Never`, a fully-static "bottom" callable type should be a subtype of all
// fully-static callable types
type_property_test!(
bottom_callable_is_subtype_of_all_fully_static_callable, db,
forall types t. t.is_callable_type() && t.is_fully_static(db)
=> CallableType::bottom(db).is_subtype_of(db, t)
);
// For any two fully static types, each type in the pair must be a subtype of their union.
type_property_test!(
all_fully_static_type_pairs_are_subtype_of_their_union, db,
forall types s, t.
s.is_fully_static(db) && t.is_fully_static(db)
=> s.is_subtype_of(db, union(db, [s, t])) && t.is_subtype_of(db, union(db, [s, t]))
);
// A fully static type does not have any materializations.
// Thus, two equivalent (fully static) types are also gradual equivalent.
type_property_test!(
two_equivalent_types_are_also_gradual_equivalent, db,
forall types s, t. s.is_equivalent_to(db, t) => s.is_gradual_equivalent_to(db, t)
);
// Two gradual equivalent fully static types are also equivalent.
type_property_test!(
two_gradual_equivalent_fully_static_types_are_also_equivalent, db,
forall types s, t.
s.is_fully_static(db) && s.is_gradual_equivalent_to(db, t) => s.is_equivalent_to(db, t)
);
// `T` can be assigned to itself.
type_property_test!(
assignable_to_is_reflexive, db,
forall types t. t.is_assignable_to(db, t)
);
// For *any* pair of types, whether fully static or not,
// each of the pair should be assignable to the union of the two.
type_property_test!(
all_type_pairs_are_assignable_to_their_union, db,
forall types s, t. s.is_assignable_to(db, union(db, [s, t])) && t.is_assignable_to(db, union(db, [s, t]))
);
}
/// This module contains property tests that currently lead to many false positives.
///
/// The reason for this is our insufficient understanding of equivalence of types. For
/// example, we currently consider `int | str` and `str | int` to be different types.
/// Similar issues exist for intersection types. Once this is resolved, we can move these
/// tests to the `stable` section. In the meantime, it can still be useful to run these
/// tests (using [`types::property_tests::flaky`]), to see if there are any new obvious bugs.
mod flaky {
use itertools::Itertools;
use super::{intersection, union};
// Negating `T` twice is equivalent to `T`.
type_property_test!(
double_negation_is_identity, db,
forall types t. t.negate(db).negate(db).is_equivalent_to(db, t)
);
// ~T should be disjoint from T
type_property_test!(
negation_is_disjoint, db,
forall types t. t.is_fully_static(db) => t.negate(db).is_disjoint_from(db, t)
);
// For two fully static types, their intersection must be a subtype of each type in the pair.
type_property_test!(
all_fully_static_type_pairs_are_supertypes_of_their_intersection, db,
forall types s, t.
s.is_fully_static(db) && t.is_fully_static(db)
=> intersection(db, [s, t]).is_subtype_of(db, s) && intersection(db, [s, t]).is_subtype_of(db, t)
);
// And for non-fully-static types, the intersection of a pair of types
// should be assignable to both types of the pair.
// Currently fails due to https://github.com/astral-sh/ruff/issues/14899
type_property_test!(
all_type_pairs_can_be_assigned_from_their_intersection, db,
forall types s, t. intersection(db, [s, t]).is_assignable_to(db, s) && intersection(db, [s, t]).is_assignable_to(db, t)
);
// Equal element sets of intersections implies equivalence
// flaky at least in part because of https://github.com/astral-sh/ruff/issues/15513
type_property_test!(
intersection_equivalence_not_order_dependent, db,
forall types s, t, u.
s.is_fully_static(db) && t.is_fully_static(db) && u.is_fully_static(db)
=> [s, t, u]
.into_iter()
.permutations(3)
.map(|trio_of_types| intersection(db, trio_of_types))
.permutations(2)
.all(|vec_of_intersections| vec_of_intersections[0].is_equivalent_to(db, vec_of_intersections[1]))
);
// Equal element sets of unions implies equivalence
// flaky at least in part because of https://github.com/astral-sh/ruff/issues/15513
type_property_test!(
union_equivalence_not_order_dependent, db,
forall types s, t, u.
s.is_fully_static(db) && t.is_fully_static(db) && u.is_fully_static(db)
=> [s, t, u]
.into_iter()
.permutations(3)
.map(|trio_of_types| union(db, trio_of_types))
.permutations(2)
.all(|vec_of_unions| vec_of_unions[0].is_equivalent_to(db, vec_of_unions[1]))
);
// `S | T` is always a supertype of `S`.
// Thus, `S` is never disjoint from `S | T`.
type_property_test!(
constituent_members_of_union_is_not_disjoint_from_that_union, db,
forall types s, t.
!s.is_disjoint_from(db, union(db, [s, t])) && !t.is_disjoint_from(db, union(db, [s, t]))
);
// If `S <: T`, then `~T <: ~S`.
//
// DO NOT STABILISE this test until the mdtests here pass:
// https://github.com/astral-sh/ruff/blob/2711e08eb8eb38d1ce323aae0517fede371cba15/crates/ty_python_semantic/resources/mdtest/type_properties/is_subtype_of.md?plain=1#L276-L315
//
// This test has flakes relating to those subtyping and simplification tests
// (see https://github.com/astral-sh/ruff/issues/16913), but it is hard to
// reliably trigger the flakes when running this test manually as the flakes
// occur very rarely (even running the test with several million seeds does
// not always reliably reproduce the flake).
type_property_test!(
negation_reverses_subtype_order, db,
forall types s, t. s.is_subtype_of(db, t) => t.negate(db).is_subtype_of(db, s.negate(db))
);
}

View file

@ -0,0 +1,9 @@
use crate::db::tests::{setup_db, TestDb};
use std::sync::{Arc, Mutex, OnceLock};
static CACHED_DB: OnceLock<Arc<Mutex<TestDb>>> = OnceLock::new();
pub(crate) fn get_cached_db() -> TestDb {
let db = CACHED_DB.get_or_init(|| Arc::new(Mutex::new(setup_db())));
db.lock().unwrap().clone()
}

View file

@ -0,0 +1,469 @@
use crate::db::tests::TestDb;
use crate::symbol::{builtins_symbol, known_module_symbol};
use crate::types::{
BoundMethodType, CallableType, IntersectionBuilder, KnownClass, KnownInstanceType, Parameter,
Parameters, Signature, SubclassOfType, TupleType, Type, UnionType,
};
use crate::{Db, KnownModule};
use hashbrown::HashSet;
use quickcheck::{Arbitrary, Gen};
use ruff_python_ast::name::Name;
/// A test representation of a type that can be transformed unambiguously into a real Type,
/// given a db.
///
/// TODO: We should add some variants that exercise generic classes and specializations thereof.
#[derive(Debug, Clone, PartialEq)]
pub(crate) enum Ty {
Never,
Unknown,
None,
Any,
IntLiteral(i64),
BooleanLiteral(bool),
StringLiteral(&'static str),
LiteralString,
BytesLiteral(&'static str),
// BuiltinInstance("str") corresponds to an instance of the builtin `str` class
BuiltinInstance(&'static str),
/// Members of the `abc` stdlib module
AbcInstance(&'static str),
AbcClassLiteral(&'static str),
TypingLiteral,
// BuiltinClassLiteral("str") corresponds to the builtin `str` class object itself
BuiltinClassLiteral(&'static str),
KnownClassInstance(KnownClass),
Union(Vec<Ty>),
Intersection {
pos: Vec<Ty>,
neg: Vec<Ty>,
},
Tuple(Vec<Ty>),
SubclassOfAny,
SubclassOfBuiltinClass(&'static str),
SubclassOfAbcClass(&'static str),
AlwaysTruthy,
AlwaysFalsy,
BuiltinsFunction(&'static str),
BuiltinsBoundMethod {
class: &'static str,
method: &'static str,
},
Callable {
params: CallableParams,
returns: Option<Box<Ty>>,
},
}
#[derive(Debug, Clone, PartialEq)]
pub(crate) enum CallableParams {
GradualForm,
List(Vec<Param>),
}
impl CallableParams {
pub(crate) fn into_parameters(self, db: &TestDb) -> Parameters<'_> {
match self {
CallableParams::GradualForm => Parameters::gradual_form(),
CallableParams::List(params) => Parameters::new(params.into_iter().map(|param| {
let mut parameter = match param.kind {
ParamKind::PositionalOnly => Parameter::positional_only(param.name),
ParamKind::PositionalOrKeyword => {
Parameter::positional_or_keyword(param.name.unwrap())
}
ParamKind::Variadic => Parameter::variadic(param.name.unwrap()),
ParamKind::KeywordOnly => Parameter::keyword_only(param.name.unwrap()),
ParamKind::KeywordVariadic => Parameter::keyword_variadic(param.name.unwrap()),
};
if let Some(annotated_ty) = param.annotated_ty {
parameter = parameter.with_annotated_type(annotated_ty.into_type(db));
}
if let Some(default_ty) = param.default_ty {
parameter = parameter.with_default_type(default_ty.into_type(db));
}
parameter
})),
}
}
}
#[derive(Debug, Clone, PartialEq)]
pub(crate) struct Param {
kind: ParamKind,
name: Option<Name>,
annotated_ty: Option<Ty>,
default_ty: Option<Ty>,
}
#[derive(Debug, Clone, Copy, PartialEq)]
enum ParamKind {
PositionalOnly,
PositionalOrKeyword,
Variadic,
KeywordOnly,
KeywordVariadic,
}
#[salsa::tracked]
fn create_bound_method<'db>(
db: &'db dyn Db,
function: Type<'db>,
builtins_class: Type<'db>,
) -> Type<'db> {
Type::BoundMethod(BoundMethodType::new(
db,
function.expect_function_literal(),
builtins_class.to_instance(db).unwrap(),
))
}
impl Ty {
pub(crate) fn into_type(self, db: &TestDb) -> Type<'_> {
match self {
Ty::Never => Type::Never,
Ty::Unknown => Type::unknown(),
Ty::None => Type::none(db),
Ty::Any => Type::any(),
Ty::IntLiteral(n) => Type::IntLiteral(n),
Ty::StringLiteral(s) => Type::string_literal(db, s),
Ty::BooleanLiteral(b) => Type::BooleanLiteral(b),
Ty::LiteralString => Type::LiteralString,
Ty::BytesLiteral(s) => Type::bytes_literal(db, s.as_bytes()),
Ty::BuiltinInstance(s) => builtins_symbol(db, s)
.symbol
.expect_type()
.to_instance(db)
.unwrap(),
Ty::AbcInstance(s) => known_module_symbol(db, KnownModule::Abc, s)
.symbol
.expect_type()
.to_instance(db)
.unwrap(),
Ty::AbcClassLiteral(s) => known_module_symbol(db, KnownModule::Abc, s)
.symbol
.expect_type(),
Ty::TypingLiteral => Type::KnownInstance(KnownInstanceType::Literal),
Ty::BuiltinClassLiteral(s) => builtins_symbol(db, s).symbol.expect_type(),
Ty::KnownClassInstance(known_class) => known_class.to_instance(db),
Ty::Union(tys) => {
UnionType::from_elements(db, tys.into_iter().map(|ty| ty.into_type(db)))
}
Ty::Intersection { pos, neg } => {
let mut builder = IntersectionBuilder::new(db);
for p in pos {
builder = builder.add_positive(p.into_type(db));
}
for n in neg {
builder = builder.add_negative(n.into_type(db));
}
builder.build()
}
Ty::Tuple(tys) => {
let elements = tys.into_iter().map(|ty| ty.into_type(db));
TupleType::from_elements(db, elements)
}
Ty::SubclassOfAny => SubclassOfType::subclass_of_any(),
Ty::SubclassOfBuiltinClass(s) => SubclassOfType::from(
db,
builtins_symbol(db, s)
.symbol
.expect_type()
.expect_class_literal()
.default_specialization(db),
),
Ty::SubclassOfAbcClass(s) => SubclassOfType::from(
db,
known_module_symbol(db, KnownModule::Abc, s)
.symbol
.expect_type()
.expect_class_literal()
.default_specialization(db),
),
Ty::AlwaysTruthy => Type::AlwaysTruthy,
Ty::AlwaysFalsy => Type::AlwaysFalsy,
Ty::BuiltinsFunction(name) => builtins_symbol(db, name).symbol.expect_type(),
Ty::BuiltinsBoundMethod { class, method } => {
let builtins_class = builtins_symbol(db, class).symbol.expect_type();
let function = builtins_class.member(db, method).symbol.expect_type();
create_bound_method(db, function, builtins_class)
}
Ty::Callable { params, returns } => Type::Callable(CallableType::single(
db,
Signature::new(
params.into_parameters(db),
returns.map(|ty| ty.into_type(db)),
),
)),
}
}
}
fn arbitrary_core_type(g: &mut Gen) -> Ty {
// We could select a random integer here, but this would make it much less
// likely to explore interesting edge cases:
let int_lit = Ty::IntLiteral(*g.choose(&[-2, -1, 0, 1, 2]).unwrap());
let bool_lit = Ty::BooleanLiteral(bool::arbitrary(g));
g.choose(&[
Ty::Never,
Ty::Unknown,
Ty::None,
Ty::Any,
int_lit,
bool_lit,
Ty::StringLiteral(""),
Ty::StringLiteral("a"),
Ty::LiteralString,
Ty::BytesLiteral(""),
Ty::BytesLiteral("\x00"),
Ty::KnownClassInstance(KnownClass::Object),
Ty::KnownClassInstance(KnownClass::Str),
Ty::KnownClassInstance(KnownClass::Int),
Ty::KnownClassInstance(KnownClass::Bool),
Ty::KnownClassInstance(KnownClass::List),
Ty::KnownClassInstance(KnownClass::Tuple),
Ty::KnownClassInstance(KnownClass::FunctionType),
Ty::KnownClassInstance(KnownClass::SpecialForm),
Ty::KnownClassInstance(KnownClass::TypeVar),
Ty::KnownClassInstance(KnownClass::TypeAliasType),
Ty::KnownClassInstance(KnownClass::NoDefaultType),
Ty::TypingLiteral,
Ty::BuiltinClassLiteral("str"),
Ty::BuiltinClassLiteral("int"),
Ty::BuiltinClassLiteral("bool"),
Ty::BuiltinClassLiteral("object"),
Ty::BuiltinInstance("type"),
Ty::AbcInstance("ABC"),
Ty::AbcInstance("ABCMeta"),
Ty::SubclassOfAny,
Ty::SubclassOfBuiltinClass("object"),
Ty::SubclassOfBuiltinClass("str"),
Ty::SubclassOfBuiltinClass("type"),
Ty::AbcClassLiteral("ABC"),
Ty::AbcClassLiteral("ABCMeta"),
Ty::SubclassOfAbcClass("ABC"),
Ty::SubclassOfAbcClass("ABCMeta"),
Ty::AlwaysTruthy,
Ty::AlwaysFalsy,
Ty::BuiltinsFunction("chr"),
Ty::BuiltinsFunction("ascii"),
Ty::BuiltinsBoundMethod {
class: "str",
method: "isascii",
},
Ty::BuiltinsBoundMethod {
class: "int",
method: "bit_length",
},
])
.unwrap()
.clone()
}
/// Constructs an arbitrary type.
///
/// The `size` parameter controls the depth of the type tree. For example,
/// a simple type like `int` has a size of 0, `Union[int, str]` has a size
/// of 1, `tuple[int, Union[str, bytes]]` has a size of 2, etc.
fn arbitrary_type(g: &mut Gen, size: u32) -> Ty {
if size == 0 {
arbitrary_core_type(g)
} else {
match u32::arbitrary(g) % 5 {
0 => arbitrary_core_type(g),
1 => Ty::Union(
(0..*g.choose(&[2, 3]).unwrap())
.map(|_| arbitrary_type(g, size - 1))
.collect(),
),
2 => Ty::Tuple(
(0..*g.choose(&[0, 1, 2]).unwrap())
.map(|_| arbitrary_type(g, size - 1))
.collect(),
),
3 => Ty::Intersection {
pos: (0..*g.choose(&[0, 1, 2]).unwrap())
.map(|_| arbitrary_type(g, size - 1))
.collect(),
neg: (0..*g.choose(&[0, 1, 2]).unwrap())
.map(|_| arbitrary_type(g, size - 1))
.collect(),
},
4 => Ty::Callable {
params: match u32::arbitrary(g) % 2 {
0 => CallableParams::GradualForm,
1 => CallableParams::List(arbitrary_parameter_list(g, size)),
_ => unreachable!(),
},
returns: arbitrary_optional_type(g, size - 1).map(Box::new),
},
_ => unreachable!(),
}
}
}
fn arbitrary_parameter_list(g: &mut Gen, size: u32) -> Vec<Param> {
let mut params: Vec<Param> = vec![];
let mut used_names = HashSet::new();
// First, choose the number of parameters to generate.
for _ in 0..*g.choose(&[0, 1, 2, 3, 4, 5]).unwrap() {
// Next, choose the kind of parameters that can be generated based on the last parameter.
let next_kind = match params.last().map(|p| p.kind) {
None | Some(ParamKind::PositionalOnly) => *g
.choose(&[
ParamKind::PositionalOnly,
ParamKind::PositionalOrKeyword,
ParamKind::Variadic,
ParamKind::KeywordOnly,
ParamKind::KeywordVariadic,
])
.unwrap(),
Some(ParamKind::PositionalOrKeyword) => *g
.choose(&[
ParamKind::PositionalOrKeyword,
ParamKind::Variadic,
ParamKind::KeywordOnly,
ParamKind::KeywordVariadic,
])
.unwrap(),
Some(ParamKind::Variadic | ParamKind::KeywordOnly) => *g
.choose(&[ParamKind::KeywordOnly, ParamKind::KeywordVariadic])
.unwrap(),
Some(ParamKind::KeywordVariadic) => {
// There can't be any other parameter kind after a keyword variadic parameter.
break;
}
};
let name = loop {
let name = if matches!(next_kind, ParamKind::PositionalOnly) {
arbitrary_optional_name(g)
} else {
Some(arbitrary_name(g))
};
if let Some(name) = name {
if used_names.insert(name.clone()) {
break Some(name);
}
} else {
break None;
}
};
params.push(Param {
kind: next_kind,
name,
annotated_ty: arbitrary_optional_type(g, size),
default_ty: if matches!(next_kind, ParamKind::Variadic | ParamKind::KeywordVariadic) {
None
} else {
arbitrary_optional_type(g, size)
},
});
}
params
}
fn arbitrary_optional_type(g: &mut Gen, size: u32) -> Option<Ty> {
match u32::arbitrary(g) % 2 {
0 => None,
1 => Some(arbitrary_type(g, size)),
_ => unreachable!(),
}
}
fn arbitrary_name(g: &mut Gen) -> Name {
Name::new(format!("n{}", u32::arbitrary(g) % 10))
}
fn arbitrary_optional_name(g: &mut Gen) -> Option<Name> {
match u32::arbitrary(g) % 2 {
0 => None,
1 => Some(arbitrary_name(g)),
_ => unreachable!(),
}
}
impl Arbitrary for Ty {
fn arbitrary(g: &mut Gen) -> Ty {
const MAX_SIZE: u32 = 2;
arbitrary_type(g, MAX_SIZE)
}
fn shrink(&self) -> Box<dyn Iterator<Item = Self>> {
match self.clone() {
Ty::Union(types) => Box::new(types.shrink().filter_map(|elts| match elts.len() {
0 => None,
1 => Some(elts.into_iter().next().unwrap()),
_ => Some(Ty::Union(elts)),
})),
Ty::Tuple(types) => Box::new(types.shrink().filter_map(|elts| match elts.len() {
0 => None,
1 => Some(elts.into_iter().next().unwrap()),
_ => Some(Ty::Tuple(elts)),
})),
Ty::Intersection { pos, neg } => {
// Shrinking on intersections is not exhaustive!
//
// We try to shrink the positive side or the negative side,
// but we aren't shrinking both at the same time.
//
// This should remove positive or negative constraints but
// won't shrink (A & B & ~C & ~D) to (A & ~C) in one shrink
// iteration.
//
// Instead, it hopes that (A & B & ~C) or (A & ~C & ~D) fails
// so that shrinking can happen there.
let pos_orig = pos.clone();
let neg_orig = neg.clone();
Box::new(
// we shrink negative constraints first, as
// intersections with only negative constraints are
// more confusing
neg.shrink()
.map(move |shrunk_neg| Ty::Intersection {
pos: pos_orig.clone(),
neg: shrunk_neg,
})
.chain(pos.shrink().map(move |shrunk_pos| Ty::Intersection {
pos: shrunk_pos,
neg: neg_orig.clone(),
}))
.filter_map(|ty| {
if let Ty::Intersection { pos, neg } = &ty {
match (pos.len(), neg.len()) {
// an empty intersection does not mean
// anything
(0, 0) => None,
// a single positive element should be
// unwrapped
(1, 0) => Some(pos[0].clone()),
_ => Some(ty),
}
} else {
unreachable!()
}
}),
)
}
_ => Box::new(std::iter::empty()),
}
}
}
pub(crate) fn intersection<'db>(
db: &'db TestDb,
tys: impl IntoIterator<Item = Type<'db>>,
) -> Type<'db> {
let mut builder = IntersectionBuilder::new(db);
for ty in tys {
builder = builder.add_positive(ty);
}
builder.build()
}
pub(crate) fn union<'db>(db: &'db TestDb, tys: impl IntoIterator<Item = Type<'db>>) -> Type<'db> {
UnionType::from_elements(db, tys)
}

View file

@ -0,0 +1,260 @@
use std::{collections::BTreeMap, ops::Deref};
use itertools::Itertools;
use ruff_python_ast::name::Name;
use crate::{
db::Db,
semantic_index::{symbol_table, use_def_map},
symbol::{symbol_from_bindings, symbol_from_declarations},
types::{ClassBase, ClassLiteral, KnownFunction, Type, TypeQualifiers},
};
impl<'db> ClassLiteral<'db> {
/// Returns `Some` if this is a protocol class, `None` otherwise.
pub(super) fn into_protocol_class(self, db: &'db dyn Db) -> Option<ProtocolClassLiteral<'db>> {
self.is_protocol(db).then_some(ProtocolClassLiteral(self))
}
}
/// Representation of a single `Protocol` class definition.
#[derive(Debug, Copy, Clone, PartialEq, Eq)]
pub(super) struct ProtocolClassLiteral<'db>(ClassLiteral<'db>);
impl<'db> ProtocolClassLiteral<'db> {
/// Returns the protocol members of this class.
///
/// A protocol's members define the interface declared by the protocol.
/// They therefore determine how the protocol should behave with regards to
/// assignability and subtyping.
///
/// The list of members consists of all bindings and declarations that take place
/// in the protocol's class body, except for a list of excluded attributes which should
/// not be taken into account. (This list includes `__init__` and `__new__`, which can
/// legally be defined on protocol classes but do not constitute protocol members.)
///
/// It is illegal for a protocol class to have any instance attributes that are not declared
/// in the protocol's class body. If any are assigned to, they are not taken into account in
/// the protocol's list of members.
pub(super) fn interface(self, db: &'db dyn Db) -> &'db ProtocolInterface<'db> {
let _span = tracing::trace_span!("protocol_members", "class='{}'", self.name(db)).entered();
cached_protocol_interface(db, *self)
}
pub(super) fn is_runtime_checkable(self, db: &'db dyn Db) -> bool {
self.known_function_decorators(db)
.contains(&KnownFunction::RuntimeCheckable)
}
}
impl<'db> Deref for ProtocolClassLiteral<'db> {
type Target = ClassLiteral<'db>;
fn deref(&self) -> &Self::Target {
&self.0
}
}
/// The interface of a protocol: the members of that protocol, and the types of those members.
#[derive(Debug, PartialEq, Eq, salsa::Update, Default, Clone, Hash)]
pub(super) struct ProtocolInterface<'db>(BTreeMap<Name, ProtocolMemberData<'db>>);
impl<'db> ProtocolInterface<'db> {
/// Iterate over the members of this protocol.
pub(super) fn members<'a>(&'a self) -> impl ExactSizeIterator<Item = ProtocolMember<'a, 'db>> {
self.0.iter().map(|(name, data)| ProtocolMember {
name,
ty: data.ty,
qualifiers: data.qualifiers,
})
}
pub(super) fn member_by_name<'a>(&self, name: &'a str) -> Option<ProtocolMember<'a, 'db>> {
self.0.get(name).map(|data| ProtocolMember {
name,
ty: data.ty,
qualifiers: data.qualifiers,
})
}
/// Return `true` if all members of this protocol are fully static.
pub(super) fn is_fully_static(&self, db: &'db dyn Db) -> bool {
self.members().all(|member| member.ty.is_fully_static(db))
}
/// Return `true` if if all members on `self` are also members of `other`.
///
/// TODO: this method should consider the types of the members as well as their names.
pub(super) fn is_sub_interface_of(&self, other: &Self) -> bool {
self.0
.keys()
.all(|member_name| other.0.contains_key(member_name))
}
/// Return `true` if any of the members of this protocol type contain any `Todo` types.
pub(super) fn contains_todo(&self, db: &'db dyn Db) -> bool {
self.members().any(|member| member.ty.contains_todo(db))
}
pub(super) fn normalized(self, db: &'db dyn Db) -> Self {
Self(
self.0
.into_iter()
.map(|(name, data)| (name, data.normalized(db)))
.collect(),
)
}
}
#[derive(Debug, PartialEq, Eq, Clone, Hash, salsa::Update)]
struct ProtocolMemberData<'db> {
ty: Type<'db>,
qualifiers: TypeQualifiers,
}
impl<'db> ProtocolMemberData<'db> {
fn normalized(self, db: &'db dyn Db) -> Self {
Self {
ty: self.ty.normalized(db),
qualifiers: self.qualifiers,
}
}
}
/// A single member of a protocol interface.
#[derive(Debug, PartialEq, Eq)]
pub(super) struct ProtocolMember<'a, 'db> {
name: &'a str,
ty: Type<'db>,
qualifiers: TypeQualifiers,
}
impl<'a, 'db> ProtocolMember<'a, 'db> {
pub(super) fn name(&self) -> &'a str {
self.name
}
pub(super) fn ty(&self) -> Type<'db> {
self.ty
}
pub(super) fn qualifiers(&self) -> TypeQualifiers {
self.qualifiers
}
}
/// Returns `true` if a declaration or binding to a given name in a protocol class body
/// should be excluded from the list of protocol members of that class.
///
/// The list of excluded members is subject to change between Python versions,
/// especially for dunders, but it probably doesn't matter *too* much if this
/// list goes out of date. It's up to date as of Python commit 87b1ea016b1454b1e83b9113fa9435849b7743aa
/// (<https://github.com/python/cpython/blob/87b1ea016b1454b1e83b9113fa9435849b7743aa/Lib/typing.py#L1776-L1791>)
fn excluded_from_proto_members(member: &str) -> bool {
matches!(
member,
"_is_protocol"
| "__non_callable_proto_members__"
| "__static_attributes__"
| "__orig_class__"
| "__match_args__"
| "__weakref__"
| "__doc__"
| "__parameters__"
| "__module__"
| "_MutableMapping__marker"
| "__slots__"
| "__dict__"
| "__new__"
| "__protocol_attrs__"
| "__init__"
| "__class_getitem__"
| "__firstlineno__"
| "__abstractmethods__"
| "__orig_bases__"
| "_is_runtime_protocol"
| "__subclasshook__"
| "__type_params__"
| "__annotations__"
| "__annotate__"
| "__annotate_func__"
| "__annotations_cache__"
)
}
/// Inner Salsa query for [`ProtocolClassLiteral::interface`].
#[salsa::tracked(return_ref, cycle_fn=proto_interface_cycle_recover, cycle_initial=proto_interface_cycle_initial)]
fn cached_protocol_interface<'db>(
db: &'db dyn Db,
class: ClassLiteral<'db>,
) -> ProtocolInterface<'db> {
let mut members = BTreeMap::default();
for parent_protocol in class
.iter_mro(db, None)
.filter_map(ClassBase::into_class)
.filter_map(|class| class.class_literal(db).0.into_protocol_class(db))
{
let parent_scope = parent_protocol.body_scope(db);
let use_def_map = use_def_map(db, parent_scope);
let symbol_table = symbol_table(db, parent_scope);
members.extend(
use_def_map
.all_public_declarations()
.flat_map(|(symbol_id, declarations)| {
symbol_from_declarations(db, declarations).map(|symbol| (symbol_id, symbol))
})
.filter_map(|(symbol_id, symbol)| {
symbol
.symbol
.ignore_possibly_unbound()
.map(|ty| (symbol_id, ty, symbol.qualifiers))
})
// Bindings in the class body that are not declared in the class body
// are not valid protocol members, and we plan to emit diagnostics for them
// elsewhere. Invalid or not, however, it's important that we still consider
// them to be protocol members. The implementation of `issubclass()` and
// `isinstance()` for runtime-checkable protocols considers them to be protocol
// members at runtime, and it's important that we accurately understand
// type narrowing that uses `isinstance()` or `issubclass()` with
// runtime-checkable protocols.
.chain(
use_def_map
.all_public_bindings()
.filter_map(|(symbol_id, bindings)| {
symbol_from_bindings(db, bindings)
.ignore_possibly_unbound()
.map(|ty| (symbol_id, ty, TypeQualifiers::default()))
}),
)
.map(|(symbol_id, member, qualifiers)| {
(symbol_table.symbol(symbol_id).name(), member, qualifiers)
})
.filter(|(name, _, _)| !excluded_from_proto_members(name))
.map(|(name, ty, qualifiers)| {
let member = ProtocolMemberData { ty, qualifiers };
(name.clone(), member)
}),
);
}
ProtocolInterface(members)
}
fn proto_interface_cycle_recover<'db>(
_db: &dyn Db,
_value: &ProtocolInterface<'db>,
_count: u32,
_class: ClassLiteral<'db>,
) -> salsa::CycleRecoveryAction<ProtocolInterface<'db>> {
salsa::CycleRecoveryAction::Iterate
}
fn proto_interface_cycle_initial<'db>(
_db: &dyn Db,
_class: ClassLiteral<'db>,
) -> ProtocolInterface<'db> {
ProtocolInterface::default()
}

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,109 @@
use ruff_python_ast as ast;
use crate::db::Db;
use crate::symbol::{Boundness, Symbol};
use crate::types::class_base::ClassBase;
use crate::types::diagnostic::report_base_with_incompatible_slots;
use crate::types::{ClassLiteral, Type};
use super::InferContext;
#[derive(Copy, Clone, Debug, Eq, PartialEq)]
enum SlotsKind {
/// `__slots__` is not found in the class.
NotSpecified,
/// `__slots__` is defined but empty: `__slots__ = ()`.
Empty,
/// `__slots__` is defined and is not empty: `__slots__ = ("a", "b")`.
NotEmpty,
/// `__slots__` is defined but its value is dynamic:
/// * `__slots__ = tuple(a for a in b)`
/// * `__slots__ = ["a", "b"]`
Dynamic,
}
impl SlotsKind {
fn from(db: &dyn Db, base: ClassLiteral) -> Self {
let Symbol::Type(slots_ty, bound) = base.own_class_member(db, None, "__slots__").symbol
else {
return Self::NotSpecified;
};
if matches!(bound, Boundness::PossiblyUnbound) {
return Self::Dynamic;
}
match slots_ty {
// __slots__ = ("a", "b")
Type::Tuple(tuple) => {
if tuple.elements(db).is_empty() {
Self::Empty
} else {
Self::NotEmpty
}
}
// __slots__ = "abc" # Same as `("abc",)`
Type::StringLiteral(_) => Self::NotEmpty,
_ => Self::Dynamic,
}
}
}
pub(super) fn check_class_slots(
context: &InferContext,
class: ClassLiteral,
node: &ast::StmtClassDef,
) {
let db = context.db();
let mut first_with_solid_base = None;
let mut common_solid_base = None;
let mut found_second = false;
for (index, base) in class.explicit_bases(db).iter().enumerate() {
let Type::ClassLiteral(base) = base else {
continue;
};
let solid_base = base.iter_mro(db, None).find_map(|current| {
let ClassBase::Class(current) = current else {
return None;
};
let (class_literal, _) = current.class_literal(db);
match SlotsKind::from(db, class_literal) {
SlotsKind::NotEmpty => Some(current),
SlotsKind::NotSpecified | SlotsKind::Empty => None,
SlotsKind::Dynamic => None,
}
});
if solid_base.is_none() {
continue;
}
let base_node = &node.bases()[index];
if first_with_solid_base.is_none() {
first_with_solid_base = Some(index);
common_solid_base = solid_base;
continue;
}
if solid_base == common_solid_base {
continue;
}
found_second = true;
report_base_with_incompatible_slots(context, base_node);
}
if found_second {
if let Some(index) = first_with_solid_base {
let base_node = &node.bases()[index];
report_base_with_incompatible_slots(context, base_node);
}
}
}

View file

@ -0,0 +1,180 @@
use ruff_db::source::source_text;
use ruff_python_ast::{self as ast, ModExpression};
use ruff_python_parser::Parsed;
use ruff_text_size::Ranged;
use crate::declare_lint;
use crate::lint::{Level, LintStatus};
use super::context::InferContext;
declare_lint! {
/// ## What it does
/// Checks for f-strings in type annotation positions.
///
/// ## Why is this bad?
/// Static analysis tools like ty can't analyse type annotations that use f-string notation.
///
/// ## Examples
/// ```python
/// def test(): -> f"int":
/// ...
/// ```
///
/// Use instead:
/// ```python
/// def test(): -> "int":
/// ...
/// ```
pub(crate) static FSTRING_TYPE_ANNOTATION = {
summary: "detects F-strings in type annotation positions",
status: LintStatus::preview("1.0.0"),
default_level: Level::Error,
}
}
declare_lint! {
/// ## What it does
/// Checks for byte-strings in type annotation positions.
///
/// ## Why is this bad?
/// Static analysis tools like ty can't analyse type annotations that use byte-string notation.
///
/// ## Examples
/// ```python
/// def test(): -> b"int":
/// ...
/// ```
///
/// Use instead:
/// ```python
/// def test(): -> "int":
/// ...
/// ```
pub(crate) static BYTE_STRING_TYPE_ANNOTATION = {
summary: "detects byte strings in type annotation positions",
status: LintStatus::preview("1.0.0"),
default_level: Level::Error,
}
}
declare_lint! {
/// ## What it does
/// Checks for raw-strings in type annotation positions.
///
/// ## Why is this bad?
/// Static analysis tools like ty can't analyse type annotations that use raw-string notation.
///
/// ## Examples
/// ```python
/// def test(): -> r"int":
/// ...
/// ```
///
/// Use instead:
/// ```python
/// def test(): -> "int":
/// ...
/// ```
pub(crate) static RAW_STRING_TYPE_ANNOTATION = {
summary: "detects raw strings in type annotation positions",
status: LintStatus::preview("1.0.0"),
default_level: Level::Error,
}
}
declare_lint! {
/// ## What it does
/// Checks for implicit concatenated strings in type annotation positions.
///
/// ## Why is this bad?
/// Static analysis tools like ty can't analyse type annotations that use implicit concatenated strings.
///
/// ## Examples
/// ```python
/// def test(): -> "Literal[" "5" "]":
/// ...
/// ```
///
/// Use instead:
/// ```python
/// def test(): -> "Literal[5]":
/// ...
/// ```
pub(crate) static IMPLICIT_CONCATENATED_STRING_TYPE_ANNOTATION = {
summary: "detects implicit concatenated strings in type annotations",
status: LintStatus::preview("1.0.0"),
default_level: Level::Error,
}
}
declare_lint! {
/// TODO #14889
pub(crate) static INVALID_SYNTAX_IN_FORWARD_ANNOTATION = {
summary: "detects invalid syntax in forward annotations",
status: LintStatus::preview("1.0.0"),
default_level: Level::Error,
}
}
declare_lint! {
/// TODO #14889
pub(crate) static ESCAPE_CHARACTER_IN_FORWARD_ANNOTATION = {
summary: "detects forward type annotations with escape characters",
status: LintStatus::preview("1.0.0"),
default_level: Level::Error,
}
}
/// Parses the given expression as a string annotation.
pub(crate) fn parse_string_annotation(
context: &InferContext,
string_expr: &ast::ExprStringLiteral,
) -> Option<Parsed<ModExpression>> {
let file = context.file();
let db = context.db();
let _span = tracing::trace_span!("parse_string_annotation", string=?string_expr.range(), ?file)
.entered();
let source = source_text(db.upcast(), file);
if let Some(string_literal) = string_expr.as_single_part_string() {
let prefix = string_literal.flags.prefix();
if prefix.is_raw() {
if let Some(builder) = context.report_lint(&RAW_STRING_TYPE_ANNOTATION, string_literal)
{
builder.into_diagnostic("Type expressions cannot use raw string literal");
}
// Compare the raw contents (without quotes) of the expression with the parsed contents
// contained in the string literal.
} else if &source[string_literal.content_range()] == string_literal.as_str() {
match ruff_python_parser::parse_string_annotation(source.as_str(), string_literal) {
Ok(parsed) => return Some(parsed),
Err(parse_error) => {
if let Some(builder) =
context.report_lint(&INVALID_SYNTAX_IN_FORWARD_ANNOTATION, string_literal)
{
builder.into_diagnostic(format_args!(
"Syntax error in forward annotation: {}",
parse_error.error
));
}
}
}
} else if let Some(builder) =
context.report_lint(&ESCAPE_CHARACTER_IN_FORWARD_ANNOTATION, string_expr)
{
// The raw contents of the string doesn't match the parsed content. This could be the
// case for annotations that contain escape sequences.
builder.into_diagnostic("Type expressions cannot contain escape characters");
}
} else if let Some(builder) =
context.report_lint(&IMPLICIT_CONCATENATED_STRING_TYPE_ANNOTATION, string_expr)
{
// String is implicitly concatenated.
builder.into_diagnostic("Type expressions cannot span multiple string literals");
}
None
}

View file

@ -0,0 +1,168 @@
use crate::symbol::SymbolAndQualifiers;
use super::{ClassType, Db, DynamicType, KnownClass, MemberLookupPolicy, Type};
/// A type that represents `type[C]`, i.e. the class object `C` and class objects that are subclasses of `C`.
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, salsa::Update)]
pub struct SubclassOfType<'db> {
// Keep this field private, so that the only way of constructing the struct is through the `from` method.
subclass_of: SubclassOfInner<'db>,
}
impl<'db> SubclassOfType<'db> {
/// Construct a new [`Type`] instance representing a given class object (or a given dynamic type)
/// and all possible subclasses of that class object/dynamic type.
///
/// This method does not always return a [`Type::SubclassOf`] variant.
/// If the class object is known to be a final class,
/// this method will return a [`Type::ClassLiteral`] variant; this is a more precise type.
/// If the class object is `builtins.object`, `Type::NominalInstance(<builtins.type>)`
/// will be returned; this is no more precise, but it is exactly equivalent to `type[object]`.
///
/// The eager normalization here means that we do not need to worry elsewhere about distinguishing
/// between `@final` classes and other classes when dealing with [`Type::SubclassOf`] variants.
pub(crate) fn from(db: &'db dyn Db, subclass_of: impl Into<SubclassOfInner<'db>>) -> Type<'db> {
let subclass_of = subclass_of.into();
match subclass_of {
SubclassOfInner::Dynamic(_) => Type::SubclassOf(Self { subclass_of }),
SubclassOfInner::Class(class) => {
if class.is_final(db) {
Type::from(class)
} else if class.is_object(db) {
KnownClass::Type.to_instance(db)
} else {
Type::SubclassOf(Self { subclass_of })
}
}
}
}
/// Return a [`Type`] instance representing the type `type[Unknown]`.
pub(crate) const fn subclass_of_unknown() -> Type<'db> {
Type::SubclassOf(SubclassOfType {
subclass_of: SubclassOfInner::unknown(),
})
}
/// Return a [`Type`] instance representing the type `type[Any]`.
pub(crate) const fn subclass_of_any() -> Type<'db> {
Type::SubclassOf(SubclassOfType {
subclass_of: SubclassOfInner::Dynamic(DynamicType::Any),
})
}
/// Return the inner [`SubclassOfInner`] value wrapped by this `SubclassOfType`.
pub(crate) const fn subclass_of(self) -> SubclassOfInner<'db> {
self.subclass_of
}
pub(crate) const fn is_dynamic(self) -> bool {
// Unpack `self` so that we're forced to update this method if any more fields are added in the future.
let Self { subclass_of } = self;
subclass_of.is_dynamic()
}
pub(crate) const fn is_fully_static(self) -> bool {
!self.is_dynamic()
}
pub(crate) fn find_name_in_mro_with_policy(
self,
db: &'db dyn Db,
name: &str,
policy: MemberLookupPolicy,
) -> Option<SymbolAndQualifiers<'db>> {
Type::from(self.subclass_of).find_name_in_mro_with_policy(db, name, policy)
}
/// Return `true` if `self` is a subtype of `other`.
///
/// This can only return `true` if `self.subclass_of` is a [`SubclassOfInner::Class`] variant;
/// only fully static types participate in subtyping.
pub(crate) fn is_subtype_of(self, db: &'db dyn Db, other: SubclassOfType<'db>) -> bool {
match (self.subclass_of, other.subclass_of) {
// Non-fully-static types do not participate in subtyping
(SubclassOfInner::Dynamic(_), _) | (_, SubclassOfInner::Dynamic(_)) => false,
// For example, `type[bool]` describes all possible runtime subclasses of the class `bool`,
// and `type[int]` describes all possible runtime subclasses of the class `int`.
// The first set is a subset of the second set, because `bool` is itself a subclass of `int`.
(SubclassOfInner::Class(self_class), SubclassOfInner::Class(other_class)) => {
// N.B. The subclass relation is fully static
self_class.is_subclass_of(db, other_class)
}
}
}
pub(crate) fn to_instance(self, db: &'db dyn Db) -> Type<'db> {
match self.subclass_of {
SubclassOfInner::Class(class) => Type::instance(db, class),
SubclassOfInner::Dynamic(dynamic_type) => Type::Dynamic(dynamic_type),
}
}
}
/// An enumeration of the different kinds of `type[]` types that a [`SubclassOfType`] can represent:
///
/// 1. A "subclass of a class": `type[C]` for any class object `C`
/// 2. A "subclass of a dynamic type": `type[Any]`, `type[Unknown]` and `type[@Todo]`
///
/// In the long term, we may want to implement <https://github.com/astral-sh/ruff/issues/15381>.
/// Doing this would allow us to get rid of this enum,
/// since `type[Any]` would be represented as `type & Any`
/// rather than using the [`Type::SubclassOf`] variant at all;
/// [`SubclassOfType`] would then be a simple wrapper around [`ClassType`].
///
/// Note that this enum is similar to the [`super::ClassBase`] enum,
/// but does not include the `ClassBase::Protocol` and `ClassBase::Generic` variants
/// (`type[Protocol]` and `type[Generic]` are not valid types).
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, salsa::Update)]
pub(crate) enum SubclassOfInner<'db> {
Class(ClassType<'db>),
Dynamic(DynamicType),
}
impl<'db> SubclassOfInner<'db> {
pub(crate) const fn unknown() -> Self {
Self::Dynamic(DynamicType::Unknown)
}
pub(crate) const fn is_dynamic(self) -> bool {
matches!(self, Self::Dynamic(_))
}
pub(crate) const fn into_class(self) -> Option<ClassType<'db>> {
match self {
Self::Class(class) => Some(class),
Self::Dynamic(_) => None,
}
}
pub(crate) fn try_from_type(db: &'db dyn Db, ty: Type<'db>) -> Option<Self> {
match ty {
Type::Dynamic(dynamic) => Some(Self::Dynamic(dynamic)),
Type::ClassLiteral(literal) => Some(if literal.is_known(db, KnownClass::Any) {
Self::Dynamic(DynamicType::Any)
} else {
Self::Class(literal.default_specialization(db))
}),
Type::GenericAlias(generic) => Some(Self::Class(ClassType::Generic(generic))),
_ => None,
}
}
}
impl<'db> From<ClassType<'db>> for SubclassOfInner<'db> {
fn from(value: ClassType<'db>) -> Self {
SubclassOfInner::Class(value)
}
}
impl<'db> From<SubclassOfInner<'db>> for Type<'db> {
fn from(value: SubclassOfInner<'db>) -> Self {
match value {
SubclassOfInner::Dynamic(dynamic) => Type::Dynamic(dynamic),
SubclassOfInner::Class(class) => class.into(),
}
}
}

View file

@ -0,0 +1,390 @@
use std::cmp::Ordering;
use crate::db::Db;
use super::{
class_base::ClassBase, subclass_of::SubclassOfInner, DynamicType, KnownInstanceType,
SuperOwnerKind, TodoType, Type,
};
/// Return an [`Ordering`] that describes the canonical order in which two types should appear
/// in an [`crate::types::IntersectionType`] or a [`crate::types::UnionType`] in order for them
/// to be compared for equivalence.
///
/// Two intersections are compared lexicographically. Element types in the intersection must
/// already be sorted. Two unions are never compared in this function because DNF does not permit
/// nested unions.
///
/// ## Why not just implement [`Ord`] on [`Type`]?
///
/// It would be fairly easy to slap `#[derive(PartialOrd, Ord)]` on [`Type`], and the ordering we
/// create here is not user-facing. However, it doesn't really "make sense" for `Type` to implement
/// [`Ord`] in terms of the semantics. There are many different ways in which you could plausibly
/// sort a list of types; this is only one (somewhat arbitrary, at times) possible ordering.
pub(super) fn union_or_intersection_elements_ordering<'db>(
db: &'db dyn Db,
left: &Type<'db>,
right: &Type<'db>,
) -> Ordering {
debug_assert_eq!(
*left,
left.normalized(db),
"`left` must be normalized before a meaningful ordering can be established"
);
debug_assert_eq!(
*right,
right.normalized(db),
"`right` must be normalized before a meaningful ordering can be established"
);
if left == right {
return Ordering::Equal;
}
match (left, right) {
(Type::Never, _) => Ordering::Less,
(_, Type::Never) => Ordering::Greater,
(Type::LiteralString, _) => Ordering::Less,
(_, Type::LiteralString) => Ordering::Greater,
(Type::BooleanLiteral(left), Type::BooleanLiteral(right)) => left.cmp(right),
(Type::BooleanLiteral(_), _) => Ordering::Less,
(_, Type::BooleanLiteral(_)) => Ordering::Greater,
(Type::IntLiteral(left), Type::IntLiteral(right)) => left.cmp(right),
(Type::IntLiteral(_), _) => Ordering::Less,
(_, Type::IntLiteral(_)) => Ordering::Greater,
(Type::StringLiteral(left), Type::StringLiteral(right)) => left.cmp(right),
(Type::StringLiteral(_), _) => Ordering::Less,
(_, Type::StringLiteral(_)) => Ordering::Greater,
(Type::BytesLiteral(left), Type::BytesLiteral(right)) => left.cmp(right),
(Type::BytesLiteral(_), _) => Ordering::Less,
(_, Type::BytesLiteral(_)) => Ordering::Greater,
(Type::SliceLiteral(left), Type::SliceLiteral(right)) => left.cmp(right),
(Type::SliceLiteral(_), _) => Ordering::Less,
(_, Type::SliceLiteral(_)) => Ordering::Greater,
(Type::FunctionLiteral(left), Type::FunctionLiteral(right)) => left.cmp(right),
(Type::FunctionLiteral(_), _) => Ordering::Less,
(_, Type::FunctionLiteral(_)) => Ordering::Greater,
(Type::BoundMethod(left), Type::BoundMethod(right)) => left.cmp(right),
(Type::BoundMethod(_), _) => Ordering::Less,
(_, Type::BoundMethod(_)) => Ordering::Greater,
(Type::MethodWrapper(left), Type::MethodWrapper(right)) => left.cmp(right),
(Type::MethodWrapper(_), _) => Ordering::Less,
(_, Type::MethodWrapper(_)) => Ordering::Greater,
(Type::WrapperDescriptor(left), Type::WrapperDescriptor(right)) => left.cmp(right),
(Type::WrapperDescriptor(_), _) => Ordering::Less,
(_, Type::WrapperDescriptor(_)) => Ordering::Greater,
(Type::DataclassDecorator(left), Type::DataclassDecorator(right)) => {
left.bits().cmp(&right.bits())
}
(Type::DataclassDecorator(_), _) => Ordering::Less,
(_, Type::DataclassDecorator(_)) => Ordering::Greater,
(Type::DataclassTransformer(left), Type::DataclassTransformer(right)) => {
left.bits().cmp(&right.bits())
}
(Type::DataclassTransformer(_), _) => Ordering::Less,
(_, Type::DataclassTransformer(_)) => Ordering::Greater,
(Type::Callable(left), Type::Callable(right)) => left.cmp(right),
(Type::Callable(_), _) => Ordering::Less,
(_, Type::Callable(_)) => Ordering::Greater,
(Type::Tuple(left), Type::Tuple(right)) => left.cmp(right),
(Type::Tuple(_), _) => Ordering::Less,
(_, Type::Tuple(_)) => Ordering::Greater,
(Type::ModuleLiteral(left), Type::ModuleLiteral(right)) => left.cmp(right),
(Type::ModuleLiteral(_), _) => Ordering::Less,
(_, Type::ModuleLiteral(_)) => Ordering::Greater,
(Type::ClassLiteral(left), Type::ClassLiteral(right)) => left.cmp(right),
(Type::ClassLiteral(_), _) => Ordering::Less,
(_, Type::ClassLiteral(_)) => Ordering::Greater,
(Type::GenericAlias(left), Type::GenericAlias(right)) => left.cmp(right),
(Type::GenericAlias(_), _) => Ordering::Less,
(_, Type::GenericAlias(_)) => Ordering::Greater,
(Type::SubclassOf(left), Type::SubclassOf(right)) => {
match (left.subclass_of(), right.subclass_of()) {
(SubclassOfInner::Class(left), SubclassOfInner::Class(right)) => left.cmp(&right),
(SubclassOfInner::Class(_), _) => Ordering::Less,
(_, SubclassOfInner::Class(_)) => Ordering::Greater,
(SubclassOfInner::Dynamic(left), SubclassOfInner::Dynamic(right)) => {
dynamic_elements_ordering(left, right)
}
}
}
(Type::SubclassOf(_), _) => Ordering::Less,
(_, Type::SubclassOf(_)) => Ordering::Greater,
(Type::NominalInstance(left), Type::NominalInstance(right)) => {
left.class().cmp(&right.class())
}
(Type::NominalInstance(_), _) => Ordering::Less,
(_, Type::NominalInstance(_)) => Ordering::Greater,
(Type::ProtocolInstance(left_proto), Type::ProtocolInstance(right_proto)) => {
left_proto.cmp(right_proto)
}
(Type::ProtocolInstance(_), _) => Ordering::Less,
(_, Type::ProtocolInstance(_)) => Ordering::Greater,
(Type::TypeVar(left), Type::TypeVar(right)) => left.cmp(right),
(Type::TypeVar(_), _) => Ordering::Less,
(_, Type::TypeVar(_)) => Ordering::Greater,
(Type::AlwaysTruthy, _) => Ordering::Less,
(_, Type::AlwaysTruthy) => Ordering::Greater,
(Type::AlwaysFalsy, _) => Ordering::Less,
(_, Type::AlwaysFalsy) => Ordering::Greater,
(Type::BoundSuper(left), Type::BoundSuper(right)) => {
(match (left.pivot_class(db), right.pivot_class(db)) {
(ClassBase::Class(left), ClassBase::Class(right)) => left.cmp(right),
(ClassBase::Class(_), _) => Ordering::Less,
(_, ClassBase::Class(_)) => Ordering::Greater,
(ClassBase::Protocol, _) => Ordering::Less,
(_, ClassBase::Protocol) => Ordering::Greater,
(ClassBase::Generic(left), ClassBase::Generic(right)) => left.cmp(right),
(ClassBase::Generic(_), _) => Ordering::Less,
(_, ClassBase::Generic(_)) => Ordering::Greater,
(ClassBase::Dynamic(left), ClassBase::Dynamic(right)) => {
dynamic_elements_ordering(*left, *right)
}
})
.then_with(|| match (left.owner(db), right.owner(db)) {
(SuperOwnerKind::Class(left), SuperOwnerKind::Class(right)) => left.cmp(right),
(SuperOwnerKind::Class(_), _) => Ordering::Less,
(_, SuperOwnerKind::Class(_)) => Ordering::Greater,
(SuperOwnerKind::Instance(left), SuperOwnerKind::Instance(right)) => {
left.class().cmp(&right.class())
}
(SuperOwnerKind::Instance(_), _) => Ordering::Less,
(_, SuperOwnerKind::Instance(_)) => Ordering::Greater,
(SuperOwnerKind::Dynamic(left), SuperOwnerKind::Dynamic(right)) => {
dynamic_elements_ordering(*left, *right)
}
})
}
(Type::BoundSuper(_), _) => Ordering::Less,
(_, Type::BoundSuper(_)) => Ordering::Greater,
(Type::KnownInstance(left_instance), Type::KnownInstance(right_instance)) => {
match (left_instance, right_instance) {
(KnownInstanceType::Any, _) => Ordering::Less,
(_, KnownInstanceType::Any) => Ordering::Greater,
(KnownInstanceType::Tuple, _) => Ordering::Less,
(_, KnownInstanceType::Tuple) => Ordering::Greater,
(KnownInstanceType::AlwaysFalsy, _) => Ordering::Less,
(_, KnownInstanceType::AlwaysFalsy) => Ordering::Greater,
(KnownInstanceType::AlwaysTruthy, _) => Ordering::Less,
(_, KnownInstanceType::AlwaysTruthy) => Ordering::Greater,
(KnownInstanceType::Annotated, _) => Ordering::Less,
(_, KnownInstanceType::Annotated) => Ordering::Greater,
(KnownInstanceType::Callable, _) => Ordering::Less,
(_, KnownInstanceType::Callable) => Ordering::Greater,
(KnownInstanceType::ChainMap, _) => Ordering::Less,
(_, KnownInstanceType::ChainMap) => Ordering::Greater,
(KnownInstanceType::ClassVar, _) => Ordering::Less,
(_, KnownInstanceType::ClassVar) => Ordering::Greater,
(KnownInstanceType::Concatenate, _) => Ordering::Less,
(_, KnownInstanceType::Concatenate) => Ordering::Greater,
(KnownInstanceType::Counter, _) => Ordering::Less,
(_, KnownInstanceType::Counter) => Ordering::Greater,
(KnownInstanceType::DefaultDict, _) => Ordering::Less,
(_, KnownInstanceType::DefaultDict) => Ordering::Greater,
(KnownInstanceType::Deque, _) => Ordering::Less,
(_, KnownInstanceType::Deque) => Ordering::Greater,
(KnownInstanceType::Dict, _) => Ordering::Less,
(_, KnownInstanceType::Dict) => Ordering::Greater,
(KnownInstanceType::Final, _) => Ordering::Less,
(_, KnownInstanceType::Final) => Ordering::Greater,
(KnownInstanceType::FrozenSet, _) => Ordering::Less,
(_, KnownInstanceType::FrozenSet) => Ordering::Greater,
(KnownInstanceType::TypeGuard, _) => Ordering::Less,
(_, KnownInstanceType::TypeGuard) => Ordering::Greater,
(KnownInstanceType::TypedDict, _) => Ordering::Less,
(_, KnownInstanceType::TypedDict) => Ordering::Greater,
(KnownInstanceType::List, _) => Ordering::Less,
(_, KnownInstanceType::List) => Ordering::Greater,
(KnownInstanceType::Literal, _) => Ordering::Less,
(_, KnownInstanceType::Literal) => Ordering::Greater,
(KnownInstanceType::LiteralString, _) => Ordering::Less,
(_, KnownInstanceType::LiteralString) => Ordering::Greater,
(KnownInstanceType::Optional, _) => Ordering::Less,
(_, KnownInstanceType::Optional) => Ordering::Greater,
(KnownInstanceType::OrderedDict, _) => Ordering::Less,
(_, KnownInstanceType::OrderedDict) => Ordering::Greater,
(KnownInstanceType::Generic(left), KnownInstanceType::Generic(right)) => {
left.cmp(right)
}
(KnownInstanceType::Generic(_), _) => Ordering::Less,
(_, KnownInstanceType::Generic(_)) => Ordering::Greater,
(KnownInstanceType::Protocol, _) => Ordering::Less,
(_, KnownInstanceType::Protocol) => Ordering::Greater,
(KnownInstanceType::NoReturn, _) => Ordering::Less,
(_, KnownInstanceType::NoReturn) => Ordering::Greater,
(KnownInstanceType::Never, _) => Ordering::Less,
(_, KnownInstanceType::Never) => Ordering::Greater,
(KnownInstanceType::Set, _) => Ordering::Less,
(_, KnownInstanceType::Set) => Ordering::Greater,
(KnownInstanceType::Type, _) => Ordering::Less,
(_, KnownInstanceType::Type) => Ordering::Greater,
(KnownInstanceType::TypeAlias, _) => Ordering::Less,
(_, KnownInstanceType::TypeAlias) => Ordering::Greater,
(KnownInstanceType::Unknown, _) => Ordering::Less,
(_, KnownInstanceType::Unknown) => Ordering::Greater,
(KnownInstanceType::Not, _) => Ordering::Less,
(_, KnownInstanceType::Not) => Ordering::Greater,
(KnownInstanceType::Intersection, _) => Ordering::Less,
(_, KnownInstanceType::Intersection) => Ordering::Greater,
(KnownInstanceType::TypeOf, _) => Ordering::Less,
(_, KnownInstanceType::TypeOf) => Ordering::Greater,
(KnownInstanceType::CallableTypeOf, _) => Ordering::Less,
(_, KnownInstanceType::CallableTypeOf) => Ordering::Greater,
(KnownInstanceType::Unpack, _) => Ordering::Less,
(_, KnownInstanceType::Unpack) => Ordering::Greater,
(KnownInstanceType::TypingSelf, _) => Ordering::Less,
(_, KnownInstanceType::TypingSelf) => Ordering::Greater,
(KnownInstanceType::Required, _) => Ordering::Less,
(_, KnownInstanceType::Required) => Ordering::Greater,
(KnownInstanceType::NotRequired, _) => Ordering::Less,
(_, KnownInstanceType::NotRequired) => Ordering::Greater,
(KnownInstanceType::TypeIs, _) => Ordering::Less,
(_, KnownInstanceType::TypeIs) => Ordering::Greater,
(KnownInstanceType::ReadOnly, _) => Ordering::Less,
(_, KnownInstanceType::ReadOnly) => Ordering::Greater,
(KnownInstanceType::Union, _) => Ordering::Less,
(_, KnownInstanceType::Union) => Ordering::Greater,
(
KnownInstanceType::TypeAliasType(left),
KnownInstanceType::TypeAliasType(right),
) => left.cmp(right),
(KnownInstanceType::TypeAliasType(_), _) => Ordering::Less,
(_, KnownInstanceType::TypeAliasType(_)) => Ordering::Greater,
(KnownInstanceType::TypeVar(left), KnownInstanceType::TypeVar(right)) => {
left.cmp(right)
}
}
}
(Type::KnownInstance(_), _) => Ordering::Less,
(_, Type::KnownInstance(_)) => Ordering::Greater,
(Type::PropertyInstance(left), Type::PropertyInstance(right)) => left.cmp(right),
(Type::PropertyInstance(_), _) => Ordering::Less,
(_, Type::PropertyInstance(_)) => Ordering::Greater,
(Type::Dynamic(left), Type::Dynamic(right)) => dynamic_elements_ordering(*left, *right),
(Type::Dynamic(_), _) => Ordering::Less,
(_, Type::Dynamic(_)) => Ordering::Greater,
(Type::Union(_), _) | (_, Type::Union(_)) => {
unreachable!("our type representation does not permit nested unions");
}
(Type::Intersection(left), Type::Intersection(right)) => {
// Lexicographically compare the elements of the two unequal intersections.
let left_positive = left.positive(db);
let right_positive = right.positive(db);
if left_positive.len() != right_positive.len() {
return left_positive.len().cmp(&right_positive.len());
}
let left_negative = left.negative(db);
let right_negative = right.negative(db);
if left_negative.len() != right_negative.len() {
return left_negative.len().cmp(&right_negative.len());
}
for (left, right) in left_positive.iter().zip(right_positive) {
let ordering = union_or_intersection_elements_ordering(db, left, right);
if ordering != Ordering::Equal {
return ordering;
}
}
for (left, right) in left_negative.iter().zip(right_negative) {
let ordering = union_or_intersection_elements_ordering(db, left, right);
if ordering != Ordering::Equal {
return ordering;
}
}
unreachable!("Two equal, normalized intersections should share the same Salsa ID")
}
}
}
/// Determine a canonical order for two instances of [`DynamicType`].
fn dynamic_elements_ordering(left: DynamicType, right: DynamicType) -> Ordering {
match (left, right) {
(DynamicType::Any, _) => Ordering::Less,
(_, DynamicType::Any) => Ordering::Greater,
(DynamicType::Unknown, _) => Ordering::Less,
(_, DynamicType::Unknown) => Ordering::Greater,
#[cfg(debug_assertions)]
(DynamicType::Todo(TodoType(left)), DynamicType::Todo(TodoType(right))) => left.cmp(right),
#[cfg(not(debug_assertions))]
(DynamicType::Todo(TodoType), DynamicType::Todo(TodoType)) => Ordering::Equal,
(DynamicType::SubscriptedProtocol, _) => Ordering::Less,
(_, DynamicType::SubscriptedProtocol) => Ordering::Greater,
}
}

View file

@ -0,0 +1,315 @@
use std::borrow::Cow;
use std::cmp::Ordering;
use rustc_hash::FxHashMap;
use ruff_python_ast::{self as ast, AnyNodeRef};
use crate::semantic_index::ast_ids::{HasScopedExpressionId, ScopedExpressionId};
use crate::semantic_index::symbol::ScopeId;
use crate::types::{infer_expression_types, todo_type, Type, TypeCheckDiagnostics};
use crate::unpack::{UnpackKind, UnpackValue};
use crate::Db;
use super::context::InferContext;
use super::diagnostic::INVALID_ASSIGNMENT;
use super::{TupleType, UnionType};
/// Unpacks the value expression type to their respective targets.
pub(crate) struct Unpacker<'db> {
context: InferContext<'db>,
target_scope: ScopeId<'db>,
value_scope: ScopeId<'db>,
targets: FxHashMap<ScopedExpressionId, Type<'db>>,
}
impl<'db> Unpacker<'db> {
pub(crate) fn new(
db: &'db dyn Db,
target_scope: ScopeId<'db>,
value_scope: ScopeId<'db>,
) -> Self {
Self {
context: InferContext::new(db, target_scope),
targets: FxHashMap::default(),
target_scope,
value_scope,
}
}
fn db(&self) -> &'db dyn Db {
self.context.db()
}
/// Unpack the value to the target expression.
pub(crate) fn unpack(&mut self, target: &ast::Expr, value: UnpackValue<'db>) {
debug_assert!(
matches!(target, ast::Expr::List(_) | ast::Expr::Tuple(_)),
"Unpacking target must be a list or tuple expression"
);
let value_type = infer_expression_types(self.db(), value.expression())
.expression_type(value.scoped_expression_id(self.db(), self.value_scope));
let value_type = match value.kind() {
UnpackKind::Assign => {
if self.context.in_stub()
&& value
.expression()
.node_ref(self.db())
.is_ellipsis_literal_expr()
{
Type::unknown()
} else {
value_type
}
}
UnpackKind::Iterable => value_type.try_iterate(self.db()).unwrap_or_else(|err| {
err.report_diagnostic(&self.context, value_type, value.as_any_node_ref(self.db()));
err.fallback_element_type(self.db())
}),
UnpackKind::ContextManager => value_type.try_enter(self.db()).unwrap_or_else(|err| {
err.report_diagnostic(&self.context, value_type, value.as_any_node_ref(self.db()));
err.fallback_enter_type(self.db())
}),
};
self.unpack_inner(target, value.as_any_node_ref(self.db()), value_type);
}
fn unpack_inner(
&mut self,
target: &ast::Expr,
value_expr: AnyNodeRef<'db>,
value_ty: Type<'db>,
) {
match target {
ast::Expr::Name(_) | ast::Expr::Attribute(_) => {
self.targets.insert(
target.scoped_expression_id(self.db(), self.target_scope),
value_ty,
);
}
ast::Expr::Starred(ast::ExprStarred { value, .. }) => {
self.unpack_inner(value, value_expr, value_ty);
}
ast::Expr::List(ast::ExprList { elts, .. })
| ast::Expr::Tuple(ast::ExprTuple { elts, .. }) => {
// Initialize the vector of target types, one for each target.
//
// This is mainly useful for the union type where the target type at index `n` is
// going to be a union of types from every union type element at index `n`.
//
// For example, if the type is `tuple[int, int] | tuple[int, str]` and the target
// has two elements `(a, b)`, then
// * The type of `a` will be a union of `int` and `int` which are at index 0 in the
// first and second tuple respectively which resolves to an `int`.
// * Similarly, the type of `b` will be a union of `int` and `str` which are at
// index 1 in the first and second tuple respectively which will be `int | str`.
let mut target_types = vec![vec![]; elts.len()];
let unpack_types = match value_ty {
Type::Union(union_ty) => union_ty.elements(self.db()),
_ => std::slice::from_ref(&value_ty),
};
for ty in unpack_types.iter().copied() {
// Deconstruct certain types to delegate the inference back to the tuple type
// for correct handling of starred expressions.
let ty = match ty {
Type::StringLiteral(string_literal_ty) => {
// We could go further and deconstruct to an array of `StringLiteral`
// with each individual character, instead of just an array of
// `LiteralString`, but there would be a cost and it's not clear that
// it's worth it.
TupleType::from_elements(
self.db(),
std::iter::repeat_n(
Type::LiteralString,
string_literal_ty.python_len(self.db()),
),
)
}
_ => ty,
};
if let Some(tuple_ty) = ty.into_tuple() {
let tuple_ty_elements =
self.tuple_ty_elements(target, elts, tuple_ty, value_expr);
let length_mismatch =
match elts.len().cmp(&tuple_ty_elements.len()) {
Ordering::Less => {
if let Some(builder) =
self.context.report_lint(&INVALID_ASSIGNMENT, target)
{
let mut diag =
builder.into_diagnostic("Too many values to unpack");
diag.set_primary_message(format_args!(
"Expected {}",
elts.len(),
));
diag.annotate(self.context.secondary(value_expr).message(
format_args!("Got {}", tuple_ty_elements.len()),
));
}
true
}
Ordering::Greater => {
if let Some(builder) =
self.context.report_lint(&INVALID_ASSIGNMENT, target)
{
let mut diag =
builder.into_diagnostic("Not enough values to unpack");
diag.set_primary_message(format_args!(
"Expected {}",
elts.len(),
));
diag.annotate(self.context.secondary(value_expr).message(
format_args!("Got {}", tuple_ty_elements.len()),
));
}
true
}
Ordering::Equal => false,
};
for (index, ty) in tuple_ty_elements.iter().enumerate() {
if let Some(element_types) = target_types.get_mut(index) {
if length_mismatch {
element_types.push(Type::unknown());
} else {
element_types.push(*ty);
}
}
}
} else {
let ty = if ty.is_literal_string() {
Type::LiteralString
} else {
ty.try_iterate(self.db()).unwrap_or_else(|err| {
err.report_diagnostic(&self.context, ty, value_expr);
err.fallback_element_type(self.db())
})
};
for target_type in &mut target_types {
target_type.push(ty);
}
}
}
for (index, element) in elts.iter().enumerate() {
// SAFETY: `target_types` is initialized with the same length as `elts`.
let element_ty = match target_types[index].as_slice() {
[] => Type::unknown(),
types => UnionType::from_elements(self.db(), types),
};
self.unpack_inner(element, value_expr, element_ty);
}
}
_ => {}
}
}
/// Returns the [`Type`] elements inside the given [`TupleType`] taking into account that there
/// can be a starred expression in the `elements`.
///
/// `value_expr` is an AST reference to the value being unpacked. It is
/// only used for diagnostics.
fn tuple_ty_elements(
&self,
expr: &ast::Expr,
targets: &[ast::Expr],
tuple_ty: TupleType<'db>,
value_expr: AnyNodeRef<'_>,
) -> Cow<'_, [Type<'db>]> {
// If there is a starred expression, it will consume all of the types at that location.
let Some(starred_index) = targets.iter().position(ast::Expr::is_starred_expr) else {
// Otherwise, the types will be unpacked 1-1 to the targets.
return Cow::Borrowed(tuple_ty.elements(self.db()).as_ref());
};
if tuple_ty.len(self.db()) >= targets.len() - 1 {
// This branch is only taken when there are enough elements in the tuple type to
// combine for the starred expression. So, the arithmetic and indexing operations are
// safe to perform.
let mut element_types = Vec::with_capacity(targets.len());
// Insert all the elements before the starred expression.
element_types.extend_from_slice(
// SAFETY: Safe because of the length check above.
&tuple_ty.elements(self.db())[..starred_index],
);
// The number of target expressions that are remaining after the starred expression.
// For example, in `(a, *b, c, d) = ...`, the index of starred element `b` is 1 and the
// remaining elements after that are 2.
let remaining = targets.len() - (starred_index + 1);
// This index represents the position of the last element that belongs to the starred
// expression, in an exclusive manner. For example, in `(a, *b, c) = (1, 2, 3, 4)`, the
// starred expression `b` will consume the elements `Literal[2]` and `Literal[3]` and
// the index value would be 3.
let starred_end_index = tuple_ty.len(self.db()) - remaining;
// SAFETY: Safe because of the length check above.
let _starred_element_types =
&tuple_ty.elements(self.db())[starred_index..starred_end_index];
// TODO: Combine the types into a list type. If the
// starred_element_types is empty, then it should be `List[Any]`.
// combine_types(starred_element_types);
element_types.push(todo_type!("starred unpacking"));
// Insert the types remaining that aren't consumed by the starred expression.
element_types.extend_from_slice(
// SAFETY: Safe because of the length check above.
&tuple_ty.elements(self.db())[starred_end_index..],
);
Cow::Owned(element_types)
} else {
if let Some(builder) = self.context.report_lint(&INVALID_ASSIGNMENT, expr) {
let mut diag = builder.into_diagnostic("Not enough values to unpack");
diag.set_primary_message(format_args!("Expected {} or more", targets.len() - 1));
diag.annotate(
self.context
.secondary(value_expr)
.message(format_args!("Got {}", tuple_ty.len(self.db()))),
);
}
Cow::Owned(vec![Type::unknown(); targets.len()])
}
}
pub(crate) fn finish(mut self) -> UnpackResult<'db> {
self.targets.shrink_to_fit();
UnpackResult {
diagnostics: self.context.finish(),
targets: self.targets,
}
}
}
#[derive(Debug, Default, PartialEq, Eq, salsa::Update)]
pub(crate) struct UnpackResult<'db> {
targets: FxHashMap<ScopedExpressionId, Type<'db>>,
diagnostics: TypeCheckDiagnostics,
}
impl<'db> UnpackResult<'db> {
/// Returns the inferred type for a given sub-expression of the left-hand side target
/// of an unpacking assignment.
///
/// Panics if a scoped expression ID is passed in that does not correspond to a sub-
/// expression of the target.
#[track_caller]
pub(crate) fn expression_type(&self, expr_id: ScopedExpressionId) -> Type<'db> {
self.targets[&expr_id]
}
/// Returns the diagnostics in this unpacking assignment.
pub(crate) fn diagnostics(&self) -> &TypeCheckDiagnostics {
&self.diagnostics
}
}

View file

@ -0,0 +1,130 @@
use ruff_db::files::File;
use ruff_python_ast::{self as ast, AnyNodeRef};
use ruff_text_size::{Ranged, TextRange};
use crate::ast_node_ref::AstNodeRef;
use crate::semantic_index::ast_ids::{HasScopedExpressionId, ScopedExpressionId};
use crate::semantic_index::expression::Expression;
use crate::semantic_index::symbol::{FileScopeId, ScopeId};
use crate::Db;
/// This ingredient represents a single unpacking.
///
/// This is required to make use of salsa to cache the complete unpacking of multiple variables
/// involved. It allows us to:
/// 1. Avoid doing structural match multiple times for each definition
/// 2. Avoid highlighting the same error multiple times
///
/// ## Module-local type
/// This type should not be used as part of any cross-module API because
/// it holds a reference to the AST node. Range-offset changes
/// then propagate through all usages, and deserialization requires
/// reparsing the entire module.
///
/// E.g. don't use this type in:
///
/// * a return type of a cross-module query
/// * a field of a type that is a return type of a cross-module query
/// * an argument of a cross-module query
#[salsa::tracked(debug)]
pub(crate) struct Unpack<'db> {
pub(crate) file: File,
pub(crate) value_file_scope: FileScopeId,
pub(crate) target_file_scope: FileScopeId,
/// The target expression that is being unpacked. For example, in `(a, b) = (1, 2)`, the target
/// expression is `(a, b)`.
#[no_eq]
#[return_ref]
#[tracked]
pub(crate) target: AstNodeRef<ast::Expr>,
/// The ingredient representing the value expression of the unpacking. For example, in
/// `(a, b) = (1, 2)`, the value expression is `(1, 2)`.
pub(crate) value: UnpackValue<'db>,
count: countme::Count<Unpack<'static>>,
}
impl<'db> Unpack<'db> {
/// Returns the scope in which the unpack value expression belongs.
///
/// The scope in which the target and value expression belongs to are usually the same
/// except in generator expressions and comprehensions (list/dict/set), where the value
/// expression of the first generator is evaluated in the outer scope, while the ones in the subsequent
/// generators are evaluated in the comprehension scope.
pub(crate) fn value_scope(self, db: &'db dyn Db) -> ScopeId<'db> {
self.value_file_scope(db).to_scope_id(db, self.file(db))
}
/// Returns the scope where the unpack target expression belongs to.
pub(crate) fn target_scope(self, db: &'db dyn Db) -> ScopeId<'db> {
self.target_file_scope(db).to_scope_id(db, self.file(db))
}
/// Returns the range of the unpack target expression.
pub(crate) fn range(self, db: &'db dyn Db) -> TextRange {
self.target(db).range()
}
}
/// The expression that is being unpacked.
#[derive(Clone, Copy, Debug, Hash, salsa::Update)]
pub(crate) struct UnpackValue<'db> {
/// The kind of unpack expression
kind: UnpackKind,
/// The expression we are unpacking
expression: Expression<'db>,
}
impl<'db> UnpackValue<'db> {
pub(crate) fn new(kind: UnpackKind, expression: Expression<'db>) -> Self {
Self { kind, expression }
}
/// Returns the underlying [`Expression`] that is being unpacked.
pub(crate) const fn expression(self) -> Expression<'db> {
self.expression
}
/// Returns the [`ScopedExpressionId`] of the underlying expression.
pub(crate) fn scoped_expression_id(
self,
db: &'db dyn Db,
scope: ScopeId<'db>,
) -> ScopedExpressionId {
self.expression()
.node_ref(db)
.scoped_expression_id(db, scope)
}
/// Returns the expression as an [`AnyNodeRef`].
pub(crate) fn as_any_node_ref(self, db: &'db dyn Db) -> AnyNodeRef<'db> {
self.expression().node_ref(db).node().into()
}
pub(crate) const fn kind(self) -> UnpackKind {
self.kind
}
}
#[derive(Clone, Copy, Debug, Hash, salsa::Update)]
pub(crate) enum UnpackKind {
/// An iterable expression like the one in a `for` loop or a comprehension.
Iterable,
/// An context manager expression like the one in a `with` statement.
ContextManager,
/// An expression that is being assigned to a target.
Assign,
}
/// The position of the target element in an unpacking.
#[derive(Clone, Copy, Debug, Hash, PartialEq, salsa::Update)]
pub(crate) enum UnpackPosition {
/// The target element is in the first position of the unpacking.
First,
/// The target element is in the position other than the first position of the unpacking.
Other,
}

View file

@ -0,0 +1 @@
pub(crate) mod subscript;

View file

@ -0,0 +1,502 @@
//! This module provides utility functions for indexing (`PyIndex`) and slicing
//! operations (`PySlice`) on iterators, following the semantics of equivalent
//! operations in Python.
use itertools::Either;
#[derive(Debug, Clone, Copy, PartialEq)]
pub(crate) struct OutOfBoundsError;
pub(crate) trait PyIndex {
type Item;
fn py_index(&mut self, index: i32) -> Result<Self::Item, OutOfBoundsError>;
}
fn from_nonnegative_i32(index: i32) -> usize {
static_assertions::const_assert!(usize::BITS >= 32);
debug_assert!(index >= 0);
usize::try_from(index)
.expect("Should only ever pass a positive integer to `from_nonnegative_i32`")
}
fn from_negative_i32(index: i32) -> usize {
static_assertions::const_assert!(usize::BITS >= 32);
index.checked_neg().map(from_nonnegative_i32).unwrap_or({
// 'checked_neg' only fails for i32::MIN. We can not
// represent -i32::MIN as a i32, but we can represent
// it as a usize, since usize is at least 32 bits.
from_nonnegative_i32(i32::MAX) + 1
})
}
#[derive(Debug, Clone, Copy, PartialEq, PartialOrd)]
enum Position {
BeforeStart,
AtIndex(usize),
AfterEnd,
}
enum Nth {
FromStart(usize),
FromEnd(usize),
}
impl Nth {
fn from_index(index: i32) -> Self {
if index >= 0 {
Nth::FromStart(from_nonnegative_i32(index))
} else {
Nth::FromEnd(from_negative_i32(index) - 1)
}
}
fn to_position(&self, len: usize) -> Position {
debug_assert!(len > 0);
match self {
Nth::FromStart(nth) => {
if *nth < len {
Position::AtIndex(*nth)
} else {
Position::AfterEnd
}
}
Nth::FromEnd(nth_rev) => {
if *nth_rev < len {
Position::AtIndex(len - 1 - *nth_rev)
} else {
Position::BeforeStart
}
}
}
}
}
impl<I, T> PyIndex for T
where
T: DoubleEndedIterator<Item = I>,
{
type Item = I;
fn py_index(&mut self, index: i32) -> Result<I, OutOfBoundsError> {
match Nth::from_index(index) {
Nth::FromStart(nth) => self.nth(nth).ok_or(OutOfBoundsError),
Nth::FromEnd(nth_rev) => self.nth_back(nth_rev).ok_or(OutOfBoundsError),
}
}
}
#[derive(Debug, Clone, Copy, PartialEq)]
pub(crate) struct StepSizeZeroError;
pub(crate) trait PySlice {
type Item;
fn py_slice(
&self,
start: Option<i32>,
stop: Option<i32>,
step: Option<i32>,
) -> Result<
Either<impl Iterator<Item = &Self::Item>, impl Iterator<Item = &Self::Item>>,
StepSizeZeroError,
>;
}
impl<T> PySlice for [T] {
type Item = T;
fn py_slice(
&self,
start: Option<i32>,
stop: Option<i32>,
step_int: Option<i32>,
) -> Result<
Either<impl Iterator<Item = &Self::Item>, impl Iterator<Item = &Self::Item>>,
StepSizeZeroError,
> {
let step_int = step_int.unwrap_or(1);
if step_int == 0 {
return Err(StepSizeZeroError);
}
let len = self.len();
if len == 0 {
// The iterator needs to have the same type as the step>0 case below,
// so we need to use `.skip(0)`.
#[allow(clippy::iter_skip_zero)]
return Ok(Either::Left(self.iter().skip(0).take(0).step_by(1)));
}
let to_position = |index| Nth::from_index(index).to_position(len);
if step_int.is_positive() {
let step = from_nonnegative_i32(step_int);
let start = start.map(to_position).unwrap_or(Position::BeforeStart);
let stop = stop.map(to_position).unwrap_or(Position::AfterEnd);
let (skip, take, step) = if start < stop {
let skip = match start {
Position::BeforeStart => 0,
Position::AtIndex(start_index) => start_index,
Position::AfterEnd => len,
};
let take = match stop {
Position::BeforeStart => 0,
Position::AtIndex(stop_index) => stop_index - skip,
Position::AfterEnd => len - skip,
};
(skip, take, step)
} else {
(0, 0, step)
};
Ok(Either::Left(
self.iter().skip(skip).take(take).step_by(step),
))
} else {
let step = from_negative_i32(step_int);
let start = start.map(to_position).unwrap_or(Position::AfterEnd);
let stop = stop.map(to_position).unwrap_or(Position::BeforeStart);
let (skip, take, step) = if start <= stop {
(0, 0, step)
} else {
let skip = match start {
Position::BeforeStart => len,
Position::AtIndex(start_index) => len - 1 - start_index,
Position::AfterEnd => 0,
};
let take = match stop {
Position::BeforeStart => len - skip,
Position::AtIndex(stop_index) => (len - 1) - skip - stop_index,
Position::AfterEnd => 0,
};
(skip, take, step)
};
Ok(Either::Right(
self.iter().rev().skip(skip).take(take).step_by(step),
))
}
}
}
#[cfg(test)]
#[allow(clippy::redundant_clone)]
mod tests {
use crate::util::subscript::{OutOfBoundsError, StepSizeZeroError};
use super::{PyIndex, PySlice};
use itertools::assert_equal;
#[test]
fn py_index_empty() {
let iter = std::iter::empty::<char>();
assert_eq!(iter.clone().py_index(0), Err(OutOfBoundsError));
assert_eq!(iter.clone().py_index(1), Err(OutOfBoundsError));
assert_eq!(iter.clone().py_index(-1), Err(OutOfBoundsError));
assert_eq!(iter.clone().py_index(i32::MIN), Err(OutOfBoundsError));
assert_eq!(iter.clone().py_index(i32::MAX), Err(OutOfBoundsError));
}
#[test]
fn py_index_single_element() {
let iter = ['a'].into_iter();
assert_eq!(iter.clone().py_index(0), Ok('a'));
assert_eq!(iter.clone().py_index(1), Err(OutOfBoundsError));
assert_eq!(iter.clone().py_index(-1), Ok('a'));
assert_eq!(iter.clone().py_index(-2), Err(OutOfBoundsError));
}
#[test]
fn py_index_more_elements() {
let iter = ['a', 'b', 'c', 'd', 'e'].into_iter();
assert_eq!(iter.clone().py_index(0), Ok('a'));
assert_eq!(iter.clone().py_index(1), Ok('b'));
assert_eq!(iter.clone().py_index(4), Ok('e'));
assert_eq!(iter.clone().py_index(5), Err(OutOfBoundsError));
assert_eq!(iter.clone().py_index(-1), Ok('e'));
assert_eq!(iter.clone().py_index(-2), Ok('d'));
assert_eq!(iter.clone().py_index(-5), Ok('a'));
assert_eq!(iter.clone().py_index(-6), Err(OutOfBoundsError));
}
#[test]
fn py_index_uses_full_index_range() {
let iter = 0..=u32::MAX;
// u32::MAX - |i32::MIN| + 1 = 2^32 - 1 - 2^31 + 1 = 2^31
assert_eq!(iter.clone().py_index(i32::MIN), Ok(2u32.pow(31)));
assert_eq!(iter.clone().py_index(-2), Ok(u32::MAX - 2 + 1));
assert_eq!(iter.clone().py_index(-1), Ok(u32::MAX - 1 + 1));
assert_eq!(iter.clone().py_index(0), Ok(0));
assert_eq!(iter.clone().py_index(1), Ok(1));
assert_eq!(iter.clone().py_index(i32::MAX), Ok(i32::MAX as u32));
}
#[track_caller]
fn assert_eq_slice<const N: usize, const M: usize>(
input: &[char; N],
start: Option<i32>,
stop: Option<i32>,
step: Option<i32>,
expected: &[char; M],
) {
assert_equal(input.py_slice(start, stop, step).unwrap(), expected.iter());
}
#[test]
fn py_slice_empty_input() {
let input = [];
assert_eq_slice(&input, None, None, None, &[]);
assert_eq_slice(&input, Some(0), None, None, &[]);
assert_eq_slice(&input, None, Some(0), None, &[]);
assert_eq_slice(&input, Some(0), Some(0), None, &[]);
assert_eq_slice(&input, Some(-5), Some(-5), None, &[]);
assert_eq_slice(&input, None, None, Some(-1), &[]);
assert_eq_slice(&input, None, None, Some(2), &[]);
}
#[test]
fn py_slice_single_element_input() {
let input = ['a'];
assert_eq_slice(&input, None, None, None, &['a']);
assert_eq_slice(&input, Some(0), None, None, &['a']);
assert_eq_slice(&input, None, Some(0), None, &[]);
assert_eq_slice(&input, Some(0), Some(0), None, &[]);
assert_eq_slice(&input, Some(0), Some(1), None, &['a']);
assert_eq_slice(&input, Some(0), Some(2), None, &['a']);
assert_eq_slice(&input, Some(-1), None, None, &['a']);
assert_eq_slice(&input, Some(-1), Some(-1), None, &[]);
assert_eq_slice(&input, Some(-1), Some(0), None, &[]);
assert_eq_slice(&input, Some(-1), Some(1), None, &['a']);
assert_eq_slice(&input, Some(-1), Some(2), None, &['a']);
assert_eq_slice(&input, None, Some(-1), None, &[]);
assert_eq_slice(&input, Some(-2), None, None, &['a']);
assert_eq_slice(&input, Some(-2), Some(-1), None, &[]);
assert_eq_slice(&input, Some(-2), Some(0), None, &[]);
assert_eq_slice(&input, Some(-2), Some(1), None, &['a']);
assert_eq_slice(&input, Some(-2), Some(2), None, &['a']);
}
#[test]
fn py_slice_nonnegative_indices() {
let input = ['a', 'b', 'c', 'd', 'e'];
assert_eq_slice(&input, None, Some(0), None, &[]);
assert_eq_slice(&input, None, Some(1), None, &['a']);
assert_eq_slice(&input, None, Some(4), None, &['a', 'b', 'c', 'd']);
assert_eq_slice(&input, None, Some(5), None, &['a', 'b', 'c', 'd', 'e']);
assert_eq_slice(&input, None, Some(6), None, &['a', 'b', 'c', 'd', 'e']);
assert_eq_slice(&input, None, None, None, &['a', 'b', 'c', 'd', 'e']);
assert_eq_slice(&input, Some(0), Some(0), None, &[]);
assert_eq_slice(&input, Some(0), Some(1), None, &['a']);
assert_eq_slice(&input, Some(0), Some(4), None, &['a', 'b', 'c', 'd']);
assert_eq_slice(&input, Some(0), Some(5), None, &['a', 'b', 'c', 'd', 'e']);
assert_eq_slice(&input, Some(0), Some(6), None, &['a', 'b', 'c', 'd', 'e']);
assert_eq_slice(&input, Some(0), None, None, &['a', 'b', 'c', 'd', 'e']);
assert_eq_slice(&input, Some(1), Some(0), None, &[]);
assert_eq_slice(&input, Some(1), Some(1), None, &[]);
assert_eq_slice(&input, Some(1), Some(2), None, &['b']);
assert_eq_slice(&input, Some(1), Some(4), None, &['b', 'c', 'd']);
assert_eq_slice(&input, Some(1), Some(5), None, &['b', 'c', 'd', 'e']);
assert_eq_slice(&input, Some(1), Some(6), None, &['b', 'c', 'd', 'e']);
assert_eq_slice(&input, Some(1), None, None, &['b', 'c', 'd', 'e']);
assert_eq_slice(&input, Some(4), Some(0), None, &[]);
assert_eq_slice(&input, Some(4), Some(4), None, &[]);
assert_eq_slice(&input, Some(4), Some(5), None, &['e']);
assert_eq_slice(&input, Some(4), Some(6), None, &['e']);
assert_eq_slice(&input, Some(4), None, None, &['e']);
assert_eq_slice(&input, Some(5), Some(0), None, &[]);
assert_eq_slice(&input, Some(5), Some(5), None, &[]);
assert_eq_slice(&input, Some(5), Some(6), None, &[]);
assert_eq_slice(&input, Some(5), None, None, &[]);
assert_eq_slice(&input, Some(6), Some(0), None, &[]);
assert_eq_slice(&input, Some(6), Some(6), None, &[]);
assert_eq_slice(&input, Some(6), None, None, &[]);
}
#[test]
fn py_slice_negatice_indices() {
let input = ['a', 'b', 'c', 'd', 'e'];
assert_eq_slice(&input, Some(-6), None, None, &['a', 'b', 'c', 'd', 'e']);
assert_eq_slice(&input, Some(-6), Some(-1), None, &['a', 'b', 'c', 'd']);
assert_eq_slice(&input, Some(-6), Some(-4), None, &['a']);
assert_eq_slice(&input, Some(-6), Some(-5), None, &[]);
assert_eq_slice(&input, Some(-6), Some(-6), None, &[]);
assert_eq_slice(&input, Some(-6), Some(-10), None, &[]);
assert_eq_slice(&input, Some(-5), None, None, &['a', 'b', 'c', 'd', 'e']);
assert_eq_slice(&input, Some(-5), Some(-1), None, &['a', 'b', 'c', 'd']);
assert_eq_slice(&input, Some(-5), Some(-4), None, &['a']);
assert_eq_slice(&input, Some(-5), Some(-5), None, &[]);
assert_eq_slice(&input, Some(-5), Some(-6), None, &[]);
assert_eq_slice(&input, Some(-5), Some(-10), None, &[]);
assert_eq_slice(&input, Some(-4), None, None, &['b', 'c', 'd', 'e']);
assert_eq_slice(&input, Some(-4), Some(-1), None, &['b', 'c', 'd']);
assert_eq_slice(&input, Some(-4), Some(-3), None, &['b']);
assert_eq_slice(&input, Some(-4), Some(-4), None, &[]);
assert_eq_slice(&input, Some(-4), Some(-10), None, &[]);
assert_eq_slice(&input, Some(-1), None, None, &['e']);
assert_eq_slice(&input, Some(-1), Some(-1), None, &[]);
assert_eq_slice(&input, Some(-1), Some(-10), None, &[]);
assert_eq_slice(&input, None, Some(-1), None, &['a', 'b', 'c', 'd']);
assert_eq_slice(&input, None, Some(-4), None, &['a']);
assert_eq_slice(&input, None, Some(-5), None, &[]);
assert_eq_slice(&input, None, Some(-6), None, &[]);
}
#[test]
fn py_slice_mixed_positive_negative_indices() {
let input = ['a', 'b', 'c', 'd', 'e'];
assert_eq_slice(&input, Some(0), Some(-1), None, &['a', 'b', 'c', 'd']);
assert_eq_slice(&input, Some(1), Some(-1), None, &['b', 'c', 'd']);
assert_eq_slice(&input, Some(3), Some(-1), None, &['d']);
assert_eq_slice(&input, Some(4), Some(-1), None, &[]);
assert_eq_slice(&input, Some(5), Some(-1), None, &[]);
assert_eq_slice(&input, Some(0), Some(-4), None, &['a']);
assert_eq_slice(&input, Some(1), Some(-4), None, &[]);
assert_eq_slice(&input, Some(3), Some(-4), None, &[]);
assert_eq_slice(&input, Some(0), Some(-5), None, &[]);
assert_eq_slice(&input, Some(1), Some(-5), None, &[]);
assert_eq_slice(&input, Some(3), Some(-5), None, &[]);
assert_eq_slice(&input, Some(0), Some(-6), None, &[]);
assert_eq_slice(&input, Some(1), Some(-6), None, &[]);
assert_eq_slice(&input, Some(-6), Some(6), None, &['a', 'b', 'c', 'd', 'e']);
assert_eq_slice(&input, Some(-6), Some(5), None, &['a', 'b', 'c', 'd', 'e']);
assert_eq_slice(&input, Some(-6), Some(4), None, &['a', 'b', 'c', 'd']);
assert_eq_slice(&input, Some(-6), Some(1), None, &['a']);
assert_eq_slice(&input, Some(-6), Some(0), None, &[]);
assert_eq_slice(&input, Some(-5), Some(6), None, &['a', 'b', 'c', 'd', 'e']);
assert_eq_slice(&input, Some(-5), Some(5), None, &['a', 'b', 'c', 'd', 'e']);
assert_eq_slice(&input, Some(-5), Some(4), None, &['a', 'b', 'c', 'd']);
assert_eq_slice(&input, Some(-5), Some(1), None, &['a']);
assert_eq_slice(&input, Some(-5), Some(0), None, &[]);
assert_eq_slice(&input, Some(-4), Some(6), None, &['b', 'c', 'd', 'e']);
assert_eq_slice(&input, Some(-4), Some(5), None, &['b', 'c', 'd', 'e']);
assert_eq_slice(&input, Some(-4), Some(4), None, &['b', 'c', 'd']);
assert_eq_slice(&input, Some(-4), Some(2), None, &['b']);
assert_eq_slice(&input, Some(-4), Some(1), None, &[]);
assert_eq_slice(&input, Some(-4), Some(0), None, &[]);
assert_eq_slice(&input, Some(-1), Some(6), None, &['e']);
assert_eq_slice(&input, Some(-1), Some(5), None, &['e']);
assert_eq_slice(&input, Some(-1), Some(4), None, &[]);
assert_eq_slice(&input, Some(-1), Some(1), None, &[]);
}
#[test]
fn py_slice_step_forward() {
// indices: 0 1 2 3 4 5 6
let input = ['a', 'b', 'c', 'd', 'e', 'f', 'g'];
// Step size zero is invalid:
assert!(matches!(
input.py_slice(None, None, Some(0)),
Err(StepSizeZeroError)
));
assert!(matches!(
input.py_slice(Some(0), Some(5), Some(0)),
Err(StepSizeZeroError)
));
assert!(matches!(
input.py_slice(Some(0), Some(0), Some(0)),
Err(StepSizeZeroError)
));
assert_eq_slice(&input, Some(0), Some(8), Some(2), &['a', 'c', 'e', 'g']);
assert_eq_slice(&input, Some(0), Some(7), Some(2), &['a', 'c', 'e', 'g']);
assert_eq_slice(&input, Some(0), Some(6), Some(2), &['a', 'c', 'e']);
assert_eq_slice(&input, Some(0), Some(5), Some(2), &['a', 'c', 'e']);
assert_eq_slice(&input, Some(0), Some(4), Some(2), &['a', 'c']);
assert_eq_slice(&input, Some(0), Some(3), Some(2), &['a', 'c']);
assert_eq_slice(&input, Some(0), Some(2), Some(2), &['a']);
assert_eq_slice(&input, Some(0), Some(1), Some(2), &['a']);
assert_eq_slice(&input, Some(0), Some(0), Some(2), &[]);
assert_eq_slice(&input, Some(1), Some(5), Some(2), &['b', 'd']);
assert_eq_slice(&input, Some(0), Some(7), Some(3), &['a', 'd', 'g']);
assert_eq_slice(&input, Some(0), Some(6), Some(3), &['a', 'd']);
assert_eq_slice(&input, Some(0), None, Some(10), &['a']);
}
#[test]
fn py_slice_step_backward() {
// indices: 0 1 2 3 4 5 6
let input = ['a', 'b', 'c', 'd', 'e', 'f', 'g'];
assert_eq_slice(&input, Some(7), Some(0), Some(-2), &['g', 'e', 'c']);
assert_eq_slice(&input, Some(6), Some(0), Some(-2), &['g', 'e', 'c']);
assert_eq_slice(&input, Some(5), Some(0), Some(-2), &['f', 'd', 'b']);
assert_eq_slice(&input, Some(4), Some(0), Some(-2), &['e', 'c']);
assert_eq_slice(&input, Some(3), Some(0), Some(-2), &['d', 'b']);
assert_eq_slice(&input, Some(2), Some(0), Some(-2), &['c']);
assert_eq_slice(&input, Some(1), Some(0), Some(-2), &['b']);
assert_eq_slice(&input, Some(0), Some(0), Some(-2), &[]);
assert_eq_slice(&input, Some(7), None, Some(-2), &['g', 'e', 'c', 'a']);
assert_eq_slice(&input, None, None, Some(-2), &['g', 'e', 'c', 'a']);
assert_eq_slice(&input, None, Some(0), Some(-2), &['g', 'e', 'c']);
assert_eq_slice(&input, Some(5), Some(1), Some(-2), &['f', 'd']);
assert_eq_slice(&input, Some(5), Some(2), Some(-2), &['f', 'd']);
assert_eq_slice(&input, Some(5), Some(3), Some(-2), &['f']);
assert_eq_slice(&input, Some(5), Some(4), Some(-2), &['f']);
assert_eq_slice(&input, Some(5), Some(5), Some(-2), &[]);
assert_eq_slice(&input, Some(6), None, Some(-3), &['g', 'd', 'a']);
assert_eq_slice(&input, Some(6), Some(0), Some(-3), &['g', 'd']);
assert_eq_slice(&input, Some(7), None, Some(-10), &['g']);
assert_eq_slice(&input, Some(-6), Some(-9), Some(-1), &['b', 'a']);
assert_eq_slice(&input, Some(-6), Some(-8), Some(-1), &['b', 'a']);
assert_eq_slice(&input, Some(-6), Some(-7), Some(-1), &['b']);
assert_eq_slice(&input, Some(-6), Some(-6), Some(-1), &[]);
assert_eq_slice(&input, Some(-7), Some(-9), Some(-1), &['a']);
assert_eq_slice(&input, Some(-8), Some(-9), Some(-1), &[]);
assert_eq_slice(&input, Some(-9), Some(-9), Some(-1), &[]);
assert_eq_slice(&input, Some(-6), Some(-2), Some(-1), &[]);
assert_eq_slice(&input, Some(-9), Some(-6), Some(-1), &[]);
}
}