mirror of
https://github.com/astral-sh/ruff.git
synced 2025-10-26 18:06:36 +00:00
Add support for unions to our Python builtins type system (#6541)
## Summary Fixes some TODOs introduced in https://github.com/astral-sh/ruff/pull/6538. In short, given an expression like `1 if x > 0 else "Hello, world!"`, we now return a union type that says the expression can resolve to either an `int` or a `str`. The system remains very limited, it only works for obvious primitive types, and there's no attempt to do inference on any more complex variables. (If any expression yields `Unknown` or `TypeError`, we propagate that result throughout and abort on the client's end.)
This commit is contained in:
parent
eb24f5a0b9
commit
768686148f
9 changed files with 529 additions and 101 deletions
1
Cargo.lock
generated
1
Cargo.lock
generated
|
|
@ -2423,6 +2423,7 @@ dependencies = [
|
|||
"num-traits",
|
||||
"ruff_index",
|
||||
"ruff_python_ast",
|
||||
"ruff_python_parser",
|
||||
"ruff_python_stdlib",
|
||||
"ruff_source_file",
|
||||
"ruff_text_size",
|
||||
|
|
|
|||
|
|
@ -13,6 +13,7 @@ print("foo %(foo)d bar %(bar)d" % {"foo": "1", "bar": "2"})
|
|||
"%(key)d" % {"key": []}
|
||||
print("%d" % ("%s" % ("nested",),))
|
||||
"%d" % ((1, 2, 3),)
|
||||
"%d" % (1 if x > 0 else [])
|
||||
|
||||
# False negatives
|
||||
WORD = "abc"
|
||||
|
|
@ -55,3 +56,4 @@ r'\%03o' % (ord(c),)
|
|||
"%d" % (len(foo),)
|
||||
'(%r, %r, %r, %r)' % (hostname, address, username, '$PASSWORD')
|
||||
'%r' % ({'server_school_roles': server_school_roles, 'is_school_multiserver_domain': is_school_multiserver_domain}, )
|
||||
"%d" % (1 if x > 0 else 2)
|
||||
|
|
|
|||
|
|
@ -9,7 +9,7 @@ use rustc_hash::FxHashMap;
|
|||
use ruff_diagnostics::{Diagnostic, Violation};
|
||||
use ruff_macros::{derive_message_formats, violation};
|
||||
use ruff_python_ast::str::{leading_quote, trailing_quote};
|
||||
use ruff_python_semantic::analyze::type_inference::PythonType;
|
||||
use ruff_python_semantic::analyze::type_inference::{NumberLike, PythonType, ResolvedPythonType};
|
||||
|
||||
use crate::checkers::ast::Checker;
|
||||
|
||||
|
|
@ -59,14 +59,16 @@ impl FormatType {
|
|||
| PythonType::Set
|
||||
| PythonType::Tuple
|
||||
| PythonType::Generator
|
||||
| PythonType::Complex
|
||||
| PythonType::Bool
|
||||
| PythonType::Ellipsis
|
||||
| PythonType::None => matches!(
|
||||
self,
|
||||
FormatType::Unknown | FormatType::String | FormatType::Repr
|
||||
),
|
||||
PythonType::Integer => matches!(
|
||||
PythonType::Number(NumberLike::Complex | NumberLike::Bool) => matches!(
|
||||
self,
|
||||
FormatType::Unknown | FormatType::String | FormatType::Repr
|
||||
),
|
||||
PythonType::Number(NumberLike::Integer) => matches!(
|
||||
self,
|
||||
FormatType::Unknown
|
||||
| FormatType::String
|
||||
|
|
@ -75,7 +77,7 @@ impl FormatType {
|
|||
| FormatType::Float
|
||||
| FormatType::Number
|
||||
),
|
||||
PythonType::Float => matches!(
|
||||
PythonType::Number(NumberLike::Float) => matches!(
|
||||
self,
|
||||
FormatType::Unknown
|
||||
| FormatType::String
|
||||
|
|
@ -83,7 +85,6 @@ impl FormatType {
|
|||
| FormatType::Float
|
||||
| FormatType::Number
|
||||
),
|
||||
PythonType::Unknown => true,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -118,16 +119,22 @@ fn collect_specs(formats: &[CFormatStrOrBytes<String>]) -> Vec<&CFormatSpec> {
|
|||
|
||||
/// Return `true` if the format string is equivalent to the constant type
|
||||
fn equivalent(format: &CFormatSpec, value: &Expr) -> bool {
|
||||
let format: FormatType = format.format_char.into();
|
||||
let constant: PythonType = value.into();
|
||||
format.is_compatible_with(constant)
|
||||
let format = FormatType::from(format.format_char);
|
||||
match ResolvedPythonType::from(value) {
|
||||
ResolvedPythonType::Atom(atom) => format.is_compatible_with(atom),
|
||||
ResolvedPythonType::Union(atoms) => {
|
||||
atoms.iter().all(|atom| format.is_compatible_with(*atom))
|
||||
}
|
||||
ResolvedPythonType::Unknown => true,
|
||||
ResolvedPythonType::TypeError => true,
|
||||
}
|
||||
}
|
||||
|
||||
/// Return `true` if the [`Constnat`] aligns with the format type.
|
||||
/// Return `true` if the [`Constant`] aligns with the format type.
|
||||
fn is_valid_constant(formats: &[CFormatStrOrBytes<String>], value: &Expr) -> bool {
|
||||
let formats = collect_specs(formats);
|
||||
// If there is more than one format, this is not valid python and we should
|
||||
// return true so that no error is reported
|
||||
// If there is more than one format, this is not valid Python and we should
|
||||
// return true so that no error is reported.
|
||||
let [format] = formats.as_slice() else {
|
||||
return true;
|
||||
};
|
||||
|
|
@ -242,8 +249,7 @@ pub(crate) fn bad_string_format_type(checker: &mut Checker, expr: &Expr, right:
|
|||
values,
|
||||
range: _,
|
||||
}) => is_valid_dict(&format_strings, keys, values),
|
||||
Expr::Constant(_) => is_valid_constant(&format_strings, right),
|
||||
_ => true,
|
||||
_ => is_valid_constant(&format_strings, right),
|
||||
};
|
||||
if !is_valid {
|
||||
checker
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
use ruff_diagnostics::{Diagnostic, Violation};
|
||||
use ruff_macros::{derive_message_formats, violation};
|
||||
use ruff_python_ast::{self as ast, Ranged};
|
||||
use ruff_python_semantic::analyze::type_inference::PythonType;
|
||||
use ruff_python_semantic::analyze::type_inference::{PythonType, ResolvedPythonType};
|
||||
|
||||
use crate::checkers::ast::Checker;
|
||||
|
||||
|
|
@ -46,8 +46,8 @@ pub(crate) fn invalid_envvar_value(checker: &mut Checker, call: &ast::ExprCall)
|
|||
};
|
||||
|
||||
if matches!(
|
||||
PythonType::from(expr),
|
||||
PythonType::String | PythonType::Unknown
|
||||
ResolvedPythonType::from(expr),
|
||||
ResolvedPythonType::Unknown | ResolvedPythonType::Atom(PythonType::String)
|
||||
) {
|
||||
return;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -3,7 +3,7 @@ use ruff_python_ast::{Ranged, Stmt};
|
|||
use ruff_diagnostics::{Diagnostic, Violation};
|
||||
use ruff_macros::{derive_message_formats, violation};
|
||||
use ruff_python_ast::{helpers::ReturnStatementVisitor, statement_visitor::StatementVisitor};
|
||||
use ruff_python_semantic::analyze::type_inference::PythonType;
|
||||
use ruff_python_semantic::analyze::type_inference::{PythonType, ResolvedPythonType};
|
||||
|
||||
use crate::checkers::ast::Checker;
|
||||
|
||||
|
|
@ -42,8 +42,8 @@ pub(crate) fn invalid_str_return(checker: &mut Checker, name: &str, body: &[Stmt
|
|||
for stmt in returns {
|
||||
if let Some(value) = stmt.value.as_deref() {
|
||||
if !matches!(
|
||||
PythonType::from(value),
|
||||
PythonType::String | PythonType::Unknown
|
||||
ResolvedPythonType::from(value),
|
||||
ResolvedPythonType::Unknown | ResolvedPythonType::Atom(PythonType::String)
|
||||
) {
|
||||
checker
|
||||
.diagnostics
|
||||
|
|
|
|||
|
|
@ -69,6 +69,16 @@ bad_string_format_type.py:10:1: PLE1307 Format type does not match argument type
|
|||
12 | "%d" % ([],)
|
||||
|
|
||||
|
||||
bad_string_format_type.py:11:1: PLE1307 Format type does not match argument type
|
||||
|
|
||||
9 | "%x" % 1.1
|
||||
10 | "%(key)x" % {"key": 1.1}
|
||||
11 | "%d" % []
|
||||
| ^^^^^^^^^ PLE1307
|
||||
12 | "%d" % ([],)
|
||||
13 | "%(key)d" % {"key": []}
|
||||
|
|
||||
|
||||
bad_string_format_type.py:12:1: PLE1307 Format type does not match argument type
|
||||
|
|
||||
10 | "%(key)x" % {"key": 1.1}
|
||||
|
|
@ -96,6 +106,7 @@ bad_string_format_type.py:14:7: PLE1307 Format type does not match argument type
|
|||
14 | print("%d" % ("%s" % ("nested",),))
|
||||
| ^^^^^^^^^^^^^^^^^^^^^^^^^^^^ PLE1307
|
||||
15 | "%d" % ((1, 2, 3),)
|
||||
16 | "%d" % (1 if x > 0 else [])
|
||||
|
|
||||
|
||||
bad_string_format_type.py:15:1: PLE1307 Format type does not match argument type
|
||||
|
|
@ -104,8 +115,17 @@ bad_string_format_type.py:15:1: PLE1307 Format type does not match argument type
|
|||
14 | print("%d" % ("%s" % ("nested",),))
|
||||
15 | "%d" % ((1, 2, 3),)
|
||||
| ^^^^^^^^^^^^^^^^^^^ PLE1307
|
||||
16 |
|
||||
17 | # False negatives
|
||||
16 | "%d" % (1 if x > 0 else [])
|
||||
|
|
||||
|
||||
bad_string_format_type.py:16:1: PLE1307 Format type does not match argument type
|
||||
|
|
||||
14 | print("%d" % ("%s" % ("nested",),))
|
||||
15 | "%d" % ((1, 2, 3),)
|
||||
16 | "%d" % (1 if x > 0 else [])
|
||||
| ^^^^^^^^^^^^^^^^^^^^^^^^^^^ PLE1307
|
||||
17 |
|
||||
18 | # False negatives
|
||||
|
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -31,4 +31,24 @@ invalid_envvar_value.py:8:11: PLE1507 Invalid type for initial `os.getenv` argum
|
|||
10 | os.getenv(key=f"foo", default="bar")
|
||||
|
|
||||
|
||||
invalid_envvar_value.py:12:15: PLE1507 Invalid type for initial `os.getenv` argument; expected `str`
|
||||
|
|
||||
10 | os.getenv(key=f"foo", default="bar")
|
||||
11 | os.getenv(key="foo" + "bar", default=1)
|
||||
12 | os.getenv(key=1 + "bar", default=1) # [invalid-envvar-value]
|
||||
| ^^^^^^^^^ PLE1507
|
||||
13 | os.getenv("PATH_TEST" if using_clear_path else "PATH_ORIG")
|
||||
14 | os.getenv(1 if using_clear_path else "PATH_ORIG")
|
||||
|
|
||||
|
||||
invalid_envvar_value.py:14:11: PLE1507 Invalid type for initial `os.getenv` argument; expected `str`
|
||||
|
|
||||
12 | os.getenv(key=1 + "bar", default=1) # [invalid-envvar-value]
|
||||
13 | os.getenv("PATH_TEST" if using_clear_path else "PATH_ORIG")
|
||||
14 | os.getenv(1 if using_clear_path else "PATH_ORIG")
|
||||
| ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ PLE1507
|
||||
15 |
|
||||
16 | AA = "aa"
|
||||
|
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -23,6 +23,7 @@ bitflags = { workspace = true }
|
|||
is-macro = { workspace = true }
|
||||
num-traits = { workspace = true }
|
||||
rustc-hash = { workspace = true }
|
||||
|
||||
|
||||
smallvec = { workspace = true }
|
||||
|
||||
[dev-dependencies]
|
||||
ruff_python_parser = { path = "../ruff_python_parser" }
|
||||
|
|
|
|||
|
|
@ -1,7 +1,317 @@
|
|||
//! Analysis rules to perform basic type inference on individual expressions.
|
||||
|
||||
use rustc_hash::FxHashSet;
|
||||
|
||||
use ruff_python_ast as ast;
|
||||
use ruff_python_ast::{Constant, Expr, Operator};
|
||||
use ruff_python_ast::{Constant, Expr, Operator, UnaryOp};
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub enum ResolvedPythonType {
|
||||
/// The expression resolved to a single known type, like `str` or `int`.
|
||||
Atom(PythonType),
|
||||
/// The expression resolved to a union of known types, like `str | int`.
|
||||
Union(FxHashSet<PythonType>),
|
||||
/// The expression resolved to an unknown type, like a variable or function call.
|
||||
Unknown,
|
||||
/// The expression resolved to a `TypeError`, like `1 + "hello"`.
|
||||
TypeError,
|
||||
}
|
||||
|
||||
impl ResolvedPythonType {
|
||||
#[must_use]
|
||||
pub fn union(self, other: Self) -> Self {
|
||||
match (self, other) {
|
||||
(Self::TypeError, _) | (_, Self::TypeError) => Self::TypeError,
|
||||
(Self::Unknown, _) | (_, Self::Unknown) => Self::Unknown,
|
||||
(Self::Atom(a), Self::Atom(b)) => {
|
||||
if a == b {
|
||||
Self::Atom(a)
|
||||
} else {
|
||||
Self::Union(FxHashSet::from_iter([a, b]))
|
||||
}
|
||||
}
|
||||
(Self::Atom(a), Self::Union(mut b)) => {
|
||||
b.insert(a);
|
||||
Self::Union(b)
|
||||
}
|
||||
(Self::Union(mut a), Self::Atom(b)) => {
|
||||
a.insert(b);
|
||||
Self::Union(a)
|
||||
}
|
||||
(Self::Union(mut a), Self::Union(b)) => {
|
||||
a.extend(b);
|
||||
Self::Union(a)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<&Expr> for ResolvedPythonType {
|
||||
fn from(expr: &Expr) -> Self {
|
||||
match expr {
|
||||
// Primitives.
|
||||
Expr::Dict(_) => ResolvedPythonType::Atom(PythonType::Dict),
|
||||
Expr::DictComp(_) => ResolvedPythonType::Atom(PythonType::Dict),
|
||||
Expr::Set(_) => ResolvedPythonType::Atom(PythonType::Set),
|
||||
Expr::SetComp(_) => ResolvedPythonType::Atom(PythonType::Set),
|
||||
Expr::List(_) => ResolvedPythonType::Atom(PythonType::List),
|
||||
Expr::ListComp(_) => ResolvedPythonType::Atom(PythonType::List),
|
||||
Expr::Tuple(_) => ResolvedPythonType::Atom(PythonType::Tuple),
|
||||
Expr::GeneratorExp(_) => ResolvedPythonType::Atom(PythonType::Generator),
|
||||
Expr::FString(_) => ResolvedPythonType::Atom(PythonType::String),
|
||||
Expr::Constant(ast::ExprConstant { value, .. }) => match value {
|
||||
Constant::Str(_) => ResolvedPythonType::Atom(PythonType::String),
|
||||
Constant::Int(_) => {
|
||||
ResolvedPythonType::Atom(PythonType::Number(NumberLike::Integer))
|
||||
}
|
||||
Constant::Float(_) => {
|
||||
ResolvedPythonType::Atom(PythonType::Number(NumberLike::Float))
|
||||
}
|
||||
Constant::Bool(_) => ResolvedPythonType::Atom(PythonType::Number(NumberLike::Bool)),
|
||||
Constant::Complex { .. } => {
|
||||
ResolvedPythonType::Atom(PythonType::Number(NumberLike::Complex))
|
||||
}
|
||||
Constant::None => ResolvedPythonType::Atom(PythonType::None),
|
||||
Constant::Ellipsis => ResolvedPythonType::Atom(PythonType::Ellipsis),
|
||||
Constant::Bytes(_) => ResolvedPythonType::Atom(PythonType::Bytes),
|
||||
},
|
||||
// Simple container expressions.
|
||||
Expr::NamedExpr(ast::ExprNamedExpr { value, .. }) => {
|
||||
ResolvedPythonType::from(value.as_ref())
|
||||
}
|
||||
Expr::IfExp(ast::ExprIfExp { body, orelse, .. }) => {
|
||||
let body = ResolvedPythonType::from(body.as_ref());
|
||||
let orelse = ResolvedPythonType::from(orelse.as_ref());
|
||||
body.union(orelse)
|
||||
}
|
||||
|
||||
// Boolean operators.
|
||||
Expr::BoolOp(ast::ExprBoolOp { values, .. }) => values
|
||||
.iter()
|
||||
.map(ResolvedPythonType::from)
|
||||
.reduce(ResolvedPythonType::union)
|
||||
.unwrap_or(ResolvedPythonType::Unknown),
|
||||
|
||||
// Unary operators.
|
||||
Expr::UnaryOp(ast::ExprUnaryOp { operand, op, .. }) => match op {
|
||||
UnaryOp::Invert => {
|
||||
return match ResolvedPythonType::from(operand.as_ref()) {
|
||||
ResolvedPythonType::Atom(PythonType::Number(
|
||||
NumberLike::Bool | NumberLike::Integer,
|
||||
)) => ResolvedPythonType::Atom(PythonType::Number(NumberLike::Integer)),
|
||||
ResolvedPythonType::Atom(_) => ResolvedPythonType::TypeError,
|
||||
_ => ResolvedPythonType::Unknown,
|
||||
}
|
||||
}
|
||||
// Ex) `not 1.0`
|
||||
UnaryOp::Not => ResolvedPythonType::Atom(PythonType::Number(NumberLike::Bool)),
|
||||
// Ex) `+1` or `-1`
|
||||
UnaryOp::UAdd | UnaryOp::USub => {
|
||||
return match ResolvedPythonType::from(operand.as_ref()) {
|
||||
ResolvedPythonType::Atom(PythonType::Number(number)) => {
|
||||
ResolvedPythonType::Atom(PythonType::Number(
|
||||
if number == NumberLike::Bool {
|
||||
NumberLike::Integer
|
||||
} else {
|
||||
number
|
||||
},
|
||||
))
|
||||
}
|
||||
ResolvedPythonType::Atom(_) => ResolvedPythonType::TypeError,
|
||||
_ => ResolvedPythonType::Unknown,
|
||||
}
|
||||
}
|
||||
},
|
||||
|
||||
// Binary operators.
|
||||
Expr::BinOp(ast::ExprBinOp {
|
||||
left, op, right, ..
|
||||
}) => {
|
||||
match op {
|
||||
Operator::Add => {
|
||||
match (
|
||||
ResolvedPythonType::from(left.as_ref()),
|
||||
ResolvedPythonType::from(right.as_ref()),
|
||||
) {
|
||||
// Ex) `"Hello" + "world"`
|
||||
(
|
||||
ResolvedPythonType::Atom(PythonType::String),
|
||||
ResolvedPythonType::Atom(PythonType::String),
|
||||
) => return ResolvedPythonType::Atom(PythonType::String),
|
||||
// Ex) `b"Hello" + b"world"`
|
||||
(
|
||||
ResolvedPythonType::Atom(PythonType::Bytes),
|
||||
ResolvedPythonType::Atom(PythonType::Bytes),
|
||||
) => return ResolvedPythonType::Atom(PythonType::Bytes),
|
||||
// Ex) `[1] + [2]`
|
||||
(
|
||||
ResolvedPythonType::Atom(PythonType::List),
|
||||
ResolvedPythonType::Atom(PythonType::List),
|
||||
) => return ResolvedPythonType::Atom(PythonType::List),
|
||||
// Ex) `(1, 2) + (3, 4)`
|
||||
(
|
||||
ResolvedPythonType::Atom(PythonType::Tuple),
|
||||
ResolvedPythonType::Atom(PythonType::Tuple),
|
||||
) => return ResolvedPythonType::Atom(PythonType::Tuple),
|
||||
// Ex) `1 + 1.0`
|
||||
(
|
||||
ResolvedPythonType::Atom(PythonType::Number(left)),
|
||||
ResolvedPythonType::Atom(PythonType::Number(right)),
|
||||
) => {
|
||||
return ResolvedPythonType::Atom(PythonType::Number(
|
||||
left.coerce(right),
|
||||
));
|
||||
}
|
||||
// Ex) `"a" + 1`
|
||||
(ResolvedPythonType::Atom(_), ResolvedPythonType::Atom(_)) => {
|
||||
return ResolvedPythonType::TypeError;
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
Operator::Sub => {
|
||||
match (
|
||||
ResolvedPythonType::from(left.as_ref()),
|
||||
ResolvedPythonType::from(right.as_ref()),
|
||||
) {
|
||||
// Ex) `1 - 1`
|
||||
(
|
||||
ResolvedPythonType::Atom(PythonType::Number(left)),
|
||||
ResolvedPythonType::Atom(PythonType::Number(right)),
|
||||
) => {
|
||||
return ResolvedPythonType::Atom(PythonType::Number(
|
||||
left.coerce(right),
|
||||
));
|
||||
}
|
||||
// Ex) `{1, 2} - {2}`
|
||||
(
|
||||
ResolvedPythonType::Atom(PythonType::Set),
|
||||
ResolvedPythonType::Atom(PythonType::Set),
|
||||
) => return ResolvedPythonType::Atom(PythonType::Set),
|
||||
// Ex) `"a" - "b"`
|
||||
(ResolvedPythonType::Atom(_), ResolvedPythonType::Atom(_)) => {
|
||||
return ResolvedPythonType::TypeError;
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
// Ex) "a" % "b"
|
||||
Operator::Mod => match (
|
||||
ResolvedPythonType::from(left.as_ref()),
|
||||
ResolvedPythonType::from(right.as_ref()),
|
||||
) {
|
||||
// Ex) `"Hello" % "world"`
|
||||
(ResolvedPythonType::Atom(PythonType::String), _) => {
|
||||
return ResolvedPythonType::Atom(PythonType::String)
|
||||
}
|
||||
// Ex) `b"Hello" % b"world"`
|
||||
(ResolvedPythonType::Atom(PythonType::Bytes), _) => {
|
||||
return ResolvedPythonType::Atom(PythonType::Bytes)
|
||||
}
|
||||
// Ex) `1 % 2`
|
||||
(
|
||||
ResolvedPythonType::Atom(PythonType::Number(left)),
|
||||
ResolvedPythonType::Atom(PythonType::Number(right)),
|
||||
) => {
|
||||
return ResolvedPythonType::Atom(PythonType::Number(
|
||||
left.coerce(right),
|
||||
));
|
||||
}
|
||||
_ => {}
|
||||
},
|
||||
// Standard arithmetic operators, which coerce to the "highest" number type.
|
||||
Operator::Mult | Operator::FloorDiv | Operator::Pow => match (
|
||||
ResolvedPythonType::from(left.as_ref()),
|
||||
ResolvedPythonType::from(right.as_ref()),
|
||||
) {
|
||||
// Ex) `1 - 2`
|
||||
(
|
||||
ResolvedPythonType::Atom(PythonType::Number(left)),
|
||||
ResolvedPythonType::Atom(PythonType::Number(right)),
|
||||
) => {
|
||||
return ResolvedPythonType::Atom(PythonType::Number(
|
||||
left.coerce(right),
|
||||
));
|
||||
}
|
||||
(ResolvedPythonType::Atom(_), ResolvedPythonType::Atom(_)) => {
|
||||
return ResolvedPythonType::TypeError;
|
||||
}
|
||||
_ => {}
|
||||
},
|
||||
// Division, which returns at least `float`.
|
||||
Operator::Div => match (
|
||||
ResolvedPythonType::from(left.as_ref()),
|
||||
ResolvedPythonType::from(right.as_ref()),
|
||||
) {
|
||||
// Ex) `1 / 2`
|
||||
(
|
||||
ResolvedPythonType::Atom(PythonType::Number(left)),
|
||||
ResolvedPythonType::Atom(PythonType::Number(right)),
|
||||
) => {
|
||||
let resolved = left.coerce(right);
|
||||
return ResolvedPythonType::Atom(PythonType::Number(
|
||||
if resolved == NumberLike::Integer {
|
||||
NumberLike::Float
|
||||
} else {
|
||||
resolved
|
||||
},
|
||||
));
|
||||
}
|
||||
(ResolvedPythonType::Atom(_), ResolvedPythonType::Atom(_)) => {
|
||||
return ResolvedPythonType::TypeError;
|
||||
}
|
||||
_ => {}
|
||||
},
|
||||
// Bitwise operators, which only work on `int` and `bool`.
|
||||
Operator::BitAnd
|
||||
| Operator::BitOr
|
||||
| Operator::BitXor
|
||||
| Operator::LShift
|
||||
| Operator::RShift => {
|
||||
match (
|
||||
ResolvedPythonType::from(left.as_ref()),
|
||||
ResolvedPythonType::from(right.as_ref()),
|
||||
) {
|
||||
// Ex) `1 & 2`
|
||||
(
|
||||
ResolvedPythonType::Atom(PythonType::Number(left)),
|
||||
ResolvedPythonType::Atom(PythonType::Number(right)),
|
||||
) => {
|
||||
let resolved = left.coerce(right);
|
||||
return if resolved == NumberLike::Integer {
|
||||
ResolvedPythonType::Atom(PythonType::Number(
|
||||
NumberLike::Integer,
|
||||
))
|
||||
} else {
|
||||
ResolvedPythonType::TypeError
|
||||
};
|
||||
}
|
||||
(ResolvedPythonType::Atom(_), ResolvedPythonType::Atom(_)) => {
|
||||
return ResolvedPythonType::TypeError;
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
Operator::MatMult => {}
|
||||
}
|
||||
ResolvedPythonType::Unknown
|
||||
}
|
||||
Expr::Lambda(_)
|
||||
| Expr::Await(_)
|
||||
| Expr::Yield(_)
|
||||
| Expr::YieldFrom(_)
|
||||
| Expr::Compare(_)
|
||||
| Expr::Call(_)
|
||||
| Expr::FormattedValue(_)
|
||||
| Expr::Attribute(_)
|
||||
| Expr::Subscript(_)
|
||||
| Expr::Starred(_)
|
||||
| Expr::Name(_)
|
||||
| Expr::Slice(_)
|
||||
| Expr::IpyEscapeCommand(_) => ResolvedPythonType::Unknown,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// An extremely simple type inference system for individual expressions.
|
||||
///
|
||||
|
|
@ -9,20 +319,14 @@ use ruff_python_ast::{Constant, Expr, Operator};
|
|||
/// such as strings, integers, floats, and containers. It cannot infer the
|
||||
/// types of variables or expressions that are not statically known from
|
||||
/// individual AST nodes alone.
|
||||
#[derive(Debug, Copy, Clone, PartialEq, Eq, is_macro::Is)]
|
||||
#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)]
|
||||
pub enum PythonType {
|
||||
/// A string literal, such as `"hello"`.
|
||||
String,
|
||||
/// A bytes literal, such as `b"hello"`.
|
||||
Bytes,
|
||||
/// An integer literal, such as `1` or `0x1`.
|
||||
Integer,
|
||||
/// A floating-point literal, such as `1.0` or `1e10`.
|
||||
Float,
|
||||
/// A complex literal, such as `1j` or `1+1j`.
|
||||
Complex,
|
||||
/// A boolean literal, such as `True` or `False`.
|
||||
Bool,
|
||||
/// An integer, float, or complex literal, such as `1` or `1.0`.
|
||||
Number(NumberLike),
|
||||
/// A `None` literal, such as `None`.
|
||||
None,
|
||||
/// An ellipsis literal, such as `...`.
|
||||
|
|
@ -37,75 +341,149 @@ pub enum PythonType {
|
|||
Tuple,
|
||||
/// A generator expression, such as `(x for x in range(10))`.
|
||||
Generator,
|
||||
/// An unknown type, such as a variable or function call.
|
||||
Unknown,
|
||||
}
|
||||
|
||||
impl From<&Expr> for PythonType {
|
||||
fn from(expr: &Expr) -> Self {
|
||||
match expr {
|
||||
Expr::NamedExpr(ast::ExprNamedExpr { value, .. }) => (value.as_ref()).into(),
|
||||
Expr::UnaryOp(ast::ExprUnaryOp { operand, .. }) => (operand.as_ref()).into(),
|
||||
Expr::Dict(_) => PythonType::Dict,
|
||||
Expr::DictComp(_) => PythonType::Dict,
|
||||
Expr::Set(_) => PythonType::Set,
|
||||
Expr::SetComp(_) => PythonType::Set,
|
||||
Expr::List(_) => PythonType::List,
|
||||
Expr::ListComp(_) => PythonType::List,
|
||||
Expr::Tuple(_) => PythonType::Tuple,
|
||||
Expr::GeneratorExp(_) => PythonType::Generator,
|
||||
Expr::FString(_) => PythonType::String,
|
||||
Expr::IfExp(ast::ExprIfExp { body, orelse, .. }) => {
|
||||
let body = PythonType::from(body.as_ref());
|
||||
let orelse = PythonType::from(orelse.as_ref());
|
||||
// TODO(charlie): If we have two known types, we should return a union. As-is,
|
||||
// callers that ignore the `Unknown` type will allow invalid expressions (e.g.,
|
||||
// if you're testing for strings, you may accept `String` or `Unknown`, and you'd
|
||||
// now accept, e.g., `1 if True else "a"`, which resolves to `Unknown`).
|
||||
if body == orelse {
|
||||
body
|
||||
} else {
|
||||
PythonType::Unknown
|
||||
}
|
||||
}
|
||||
Expr::BinOp(ast::ExprBinOp {
|
||||
left, op, right, ..
|
||||
}) => {
|
||||
match op {
|
||||
// Ex) "a" + "b"
|
||||
Operator::Add => {
|
||||
match (
|
||||
PythonType::from(left.as_ref()),
|
||||
PythonType::from(right.as_ref()),
|
||||
) {
|
||||
(PythonType::String, PythonType::String) => return PythonType::String,
|
||||
(PythonType::Bytes, PythonType::Bytes) => return PythonType::Bytes,
|
||||
// TODO(charlie): If we have two known types, they may be incompatible.
|
||||
// Return an error (e.g., for `1 + "a"`).
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
// Ex) "a" % "b"
|
||||
Operator::Mod => match PythonType::from(left.as_ref()) {
|
||||
PythonType::String => return PythonType::String,
|
||||
PythonType::Bytes => return PythonType::Bytes,
|
||||
_ => {}
|
||||
},
|
||||
_ => {}
|
||||
}
|
||||
PythonType::Unknown
|
||||
}
|
||||
Expr::Constant(ast::ExprConstant { value, .. }) => match value {
|
||||
Constant::Str(_) => PythonType::String,
|
||||
Constant::Int(_) => PythonType::Integer,
|
||||
Constant::Float(_) => PythonType::Float,
|
||||
Constant::Bool(_) => PythonType::Bool,
|
||||
Constant::Complex { .. } => PythonType::Complex,
|
||||
Constant::None => PythonType::None,
|
||||
Constant::Ellipsis => PythonType::Ellipsis,
|
||||
Constant::Bytes(_) => PythonType::Bytes,
|
||||
},
|
||||
_ => PythonType::Unknown,
|
||||
/// A numeric type, or a type that can be trivially coerced to a numeric type.
|
||||
#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)]
|
||||
pub enum NumberLike {
|
||||
/// An integer literal, such as `1` or `0x1`.
|
||||
Integer,
|
||||
/// A floating-point literal, such as `1.0` or `1e10`.
|
||||
Float,
|
||||
/// A complex literal, such as `1j` or `1+1j`.
|
||||
Complex,
|
||||
/// A boolean literal, such as `True` or `False`.
|
||||
Bool,
|
||||
}
|
||||
|
||||
impl NumberLike {
|
||||
/// Coerces two number-like types to the "highest" number-like type.
|
||||
#[must_use]
|
||||
pub fn coerce(self, other: NumberLike) -> NumberLike {
|
||||
match (self, other) {
|
||||
(NumberLike::Complex, _) | (_, NumberLike::Complex) => NumberLike::Complex,
|
||||
(NumberLike::Float, _) | (_, NumberLike::Float) => NumberLike::Float,
|
||||
_ => NumberLike::Integer,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use rustc_hash::FxHashSet;
|
||||
|
||||
use ruff_python_ast::Expr;
|
||||
use ruff_python_parser::parse_expression;
|
||||
|
||||
use crate::analyze::type_inference::{NumberLike, PythonType, ResolvedPythonType};
|
||||
|
||||
fn parse(expression: &str) -> Expr {
|
||||
parse_expression(expression, "").unwrap()
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn type_inference() {
|
||||
// Atoms.
|
||||
assert_eq!(
|
||||
ResolvedPythonType::from(&parse("1")),
|
||||
ResolvedPythonType::Atom(PythonType::Number(NumberLike::Integer))
|
||||
);
|
||||
assert_eq!(
|
||||
ResolvedPythonType::from(&parse("'Hello, world'")),
|
||||
ResolvedPythonType::Atom(PythonType::String)
|
||||
);
|
||||
assert_eq!(
|
||||
ResolvedPythonType::from(&parse("b'Hello, world'")),
|
||||
ResolvedPythonType::Atom(PythonType::Bytes)
|
||||
);
|
||||
assert_eq!(
|
||||
ResolvedPythonType::from(&parse("'Hello' % 'world'")),
|
||||
ResolvedPythonType::Atom(PythonType::String)
|
||||
);
|
||||
|
||||
// Boolean operators.
|
||||
assert_eq!(
|
||||
ResolvedPythonType::from(&parse("1 and 2")),
|
||||
ResolvedPythonType::Atom(PythonType::Number(NumberLike::Integer))
|
||||
);
|
||||
assert_eq!(
|
||||
ResolvedPythonType::from(&parse("1 and True")),
|
||||
ResolvedPythonType::Union(FxHashSet::from_iter([
|
||||
PythonType::Number(NumberLike::Integer),
|
||||
PythonType::Number(NumberLike::Bool)
|
||||
]))
|
||||
);
|
||||
|
||||
// Binary operators.
|
||||
assert_eq!(
|
||||
ResolvedPythonType::from(&parse("1.0 * 2")),
|
||||
ResolvedPythonType::Atom(PythonType::Number(NumberLike::Float))
|
||||
);
|
||||
assert_eq!(
|
||||
ResolvedPythonType::from(&parse("2 * 1.0")),
|
||||
ResolvedPythonType::Atom(PythonType::Number(NumberLike::Float))
|
||||
);
|
||||
assert_eq!(
|
||||
ResolvedPythonType::from(&parse("1.0 * 2j")),
|
||||
ResolvedPythonType::Atom(PythonType::Number(NumberLike::Complex))
|
||||
);
|
||||
assert_eq!(
|
||||
ResolvedPythonType::from(&parse("1 / True")),
|
||||
ResolvedPythonType::Atom(PythonType::Number(NumberLike::Float))
|
||||
);
|
||||
assert_eq!(
|
||||
ResolvedPythonType::from(&parse("1 / 2")),
|
||||
ResolvedPythonType::Atom(PythonType::Number(NumberLike::Float))
|
||||
);
|
||||
assert_eq!(
|
||||
ResolvedPythonType::from(&parse("{1, 2} - {2}")),
|
||||
ResolvedPythonType::Atom(PythonType::Set)
|
||||
);
|
||||
|
||||
// Unary operators.
|
||||
assert_eq!(
|
||||
ResolvedPythonType::from(&parse("-1")),
|
||||
ResolvedPythonType::Atom(PythonType::Number(NumberLike::Integer))
|
||||
);
|
||||
assert_eq!(
|
||||
ResolvedPythonType::from(&parse("-1.0")),
|
||||
ResolvedPythonType::Atom(PythonType::Number(NumberLike::Float))
|
||||
);
|
||||
assert_eq!(
|
||||
ResolvedPythonType::from(&parse("-1j")),
|
||||
ResolvedPythonType::Atom(PythonType::Number(NumberLike::Complex))
|
||||
);
|
||||
assert_eq!(
|
||||
ResolvedPythonType::from(&parse("-True")),
|
||||
ResolvedPythonType::Atom(PythonType::Number(NumberLike::Integer))
|
||||
);
|
||||
assert_eq!(
|
||||
ResolvedPythonType::from(&parse("not 'Hello'")),
|
||||
ResolvedPythonType::Atom(PythonType::Number(NumberLike::Bool))
|
||||
);
|
||||
assert_eq!(
|
||||
ResolvedPythonType::from(&parse("not x.y.z")),
|
||||
ResolvedPythonType::Atom(PythonType::Number(NumberLike::Bool))
|
||||
);
|
||||
|
||||
// Conditional expressions.
|
||||
assert_eq!(
|
||||
ResolvedPythonType::from(&parse("1 if True else 2")),
|
||||
ResolvedPythonType::Atom(PythonType::Number(NumberLike::Integer))
|
||||
);
|
||||
assert_eq!(
|
||||
ResolvedPythonType::from(&parse("1 if True else 2.0")),
|
||||
ResolvedPythonType::Union(FxHashSet::from_iter([
|
||||
PythonType::Number(NumberLike::Integer),
|
||||
PythonType::Number(NumberLike::Float)
|
||||
]))
|
||||
);
|
||||
assert_eq!(
|
||||
ResolvedPythonType::from(&parse("1 if True else False")),
|
||||
ResolvedPythonType::Union(FxHashSet::from_iter([
|
||||
PythonType::Number(NumberLike::Integer),
|
||||
PythonType::Number(NumberLike::Bool)
|
||||
]))
|
||||
);
|
||||
}
|
||||
}
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue