mirror of
https://github.com/astral-sh/ruff.git
synced 2025-10-02 22:55:08 +00:00
[red-knot] feat: add StringLiteral
and LiteralString
comparison (#13634)
## Summary Implements string literal comparisons and fallbacks to `str` instance for `LiteralString`. Completes an item in #13618 ## Test Plan - Adds a dedicated test with non exhaustive cases --------- Co-authored-by: Alex Waygood <Alex.Waygood@Gmail.com>
This commit is contained in:
parent
f1205177fd
commit
8108f83810
1 changed files with 83 additions and 5 deletions
|
@ -2535,9 +2535,7 @@ impl<'db> TypeInferenceBuilder<'db> {
|
||||||
ast::CmpOp::In
|
ast::CmpOp::In
|
||||||
| ast::CmpOp::NotIn
|
| ast::CmpOp::NotIn
|
||||||
| ast::CmpOp::Is
|
| ast::CmpOp::Is
|
||||||
| ast::CmpOp::IsNot => {
|
| ast::CmpOp::IsNot => KnownClass::Bool.to_instance(self.db),
|
||||||
builtins_symbol_ty(self.db, "bool").to_instance(self.db)
|
|
||||||
}
|
|
||||||
// Other operators can return arbitrary types
|
// Other operators can return arbitrary types
|
||||||
_ => Type::Unknown,
|
_ => Type::Unknown,
|
||||||
}
|
}
|
||||||
|
@ -2573,14 +2571,14 @@ impl<'db> TypeInferenceBuilder<'db> {
|
||||||
ast::CmpOp::GtE => Some(Type::BooleanLiteral(n >= m)),
|
ast::CmpOp::GtE => Some(Type::BooleanLiteral(n >= m)),
|
||||||
ast::CmpOp::Is => {
|
ast::CmpOp::Is => {
|
||||||
if n == m {
|
if n == m {
|
||||||
Some(builtins_symbol_ty(self.db, "bool").to_instance(self.db))
|
Some(KnownClass::Bool.to_instance(self.db))
|
||||||
} else {
|
} else {
|
||||||
Some(Type::BooleanLiteral(false))
|
Some(Type::BooleanLiteral(false))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
ast::CmpOp::IsNot => {
|
ast::CmpOp::IsNot => {
|
||||||
if n == m {
|
if n == m {
|
||||||
Some(builtins_symbol_ty(self.db, "bool").to_instance(self.db))
|
Some(KnownClass::Bool.to_instance(self.db))
|
||||||
} else {
|
} else {
|
||||||
Some(Type::BooleanLiteral(true))
|
Some(Type::BooleanLiteral(true))
|
||||||
}
|
}
|
||||||
|
@ -2594,6 +2592,7 @@ impl<'db> TypeInferenceBuilder<'db> {
|
||||||
(Type::Instance(_), Type::IntLiteral(_)) => {
|
(Type::Instance(_), Type::IntLiteral(_)) => {
|
||||||
self.infer_binary_type_comparison(left, op, KnownClass::Int.to_instance(self.db))
|
self.infer_binary_type_comparison(left, op, KnownClass::Int.to_instance(self.db))
|
||||||
}
|
}
|
||||||
|
|
||||||
// Booleans are coded as integers (False = 0, True = 1)
|
// Booleans are coded as integers (False = 0, True = 1)
|
||||||
(Type::IntLiteral(n), Type::BooleanLiteral(b)) => self.infer_binary_type_comparison(
|
(Type::IntLiteral(n), Type::BooleanLiteral(b)) => self.infer_binary_type_comparison(
|
||||||
Type::IntLiteral(n),
|
Type::IntLiteral(n),
|
||||||
|
@ -2611,6 +2610,49 @@ impl<'db> TypeInferenceBuilder<'db> {
|
||||||
op,
|
op,
|
||||||
Type::IntLiteral(i64::from(b)),
|
Type::IntLiteral(i64::from(b)),
|
||||||
),
|
),
|
||||||
|
|
||||||
|
(Type::StringLiteral(salsa_s1), Type::StringLiteral(salsa_s2)) => {
|
||||||
|
let s1 = salsa_s1.value(self.db);
|
||||||
|
let s2 = salsa_s2.value(self.db);
|
||||||
|
match op {
|
||||||
|
ast::CmpOp::Eq => Some(Type::BooleanLiteral(s1 == s2)),
|
||||||
|
ast::CmpOp::NotEq => Some(Type::BooleanLiteral(s1 != s2)),
|
||||||
|
ast::CmpOp::Lt => Some(Type::BooleanLiteral(s1 < s2)),
|
||||||
|
ast::CmpOp::LtE => Some(Type::BooleanLiteral(s1 <= s2)),
|
||||||
|
ast::CmpOp::Gt => Some(Type::BooleanLiteral(s1 > s2)),
|
||||||
|
ast::CmpOp::GtE => Some(Type::BooleanLiteral(s1 >= s2)),
|
||||||
|
ast::CmpOp::In => Some(Type::BooleanLiteral(s2.contains(s1.as_ref()))),
|
||||||
|
ast::CmpOp::NotIn => Some(Type::BooleanLiteral(!s2.contains(s1.as_ref()))),
|
||||||
|
ast::CmpOp::Is => {
|
||||||
|
if s1 == s2 {
|
||||||
|
Some(KnownClass::Bool.to_instance(self.db))
|
||||||
|
} else {
|
||||||
|
Some(Type::BooleanLiteral(false))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
ast::CmpOp::IsNot => {
|
||||||
|
if s1 == s2 {
|
||||||
|
Some(KnownClass::Bool.to_instance(self.db))
|
||||||
|
} else {
|
||||||
|
Some(Type::BooleanLiteral(true))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
(Type::StringLiteral(_), _) => {
|
||||||
|
self.infer_binary_type_comparison(KnownClass::Str.to_instance(self.db), op, right)
|
||||||
|
}
|
||||||
|
(_, Type::StringLiteral(_)) => {
|
||||||
|
self.infer_binary_type_comparison(left, op, KnownClass::Str.to_instance(self.db))
|
||||||
|
}
|
||||||
|
|
||||||
|
(Type::LiteralString, _) => {
|
||||||
|
self.infer_binary_type_comparison(KnownClass::Str.to_instance(self.db), op, right)
|
||||||
|
}
|
||||||
|
(_, Type::LiteralString) => {
|
||||||
|
self.infer_binary_type_comparison(left, op, KnownClass::Str.to_instance(self.db))
|
||||||
|
}
|
||||||
|
|
||||||
// Lookup the rich comparison `__dunder__` methods on instances
|
// Lookup the rich comparison `__dunder__` methods on instances
|
||||||
(Type::Instance(left_class_ty), Type::Instance(right_class_ty)) => match op {
|
(Type::Instance(left_class_ty), Type::Instance(right_class_ty)) => match op {
|
||||||
ast::CmpOp::Lt => {
|
ast::CmpOp::Lt => {
|
||||||
|
@ -4110,6 +4152,42 @@ mod tests {
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn comparison_string_literals() -> anyhow::Result<()> {
|
||||||
|
let mut db = setup_db();
|
||||||
|
db.write_dedented(
|
||||||
|
"src/a.py",
|
||||||
|
r#"
|
||||||
|
def str_instance() -> str: ...
|
||||||
|
a = "abc" == "abc"
|
||||||
|
b = "ab_cd" <= "ab_ce"
|
||||||
|
c = "abc" in "ab cd"
|
||||||
|
d = "" not in "hello"
|
||||||
|
e = "--" is "--"
|
||||||
|
f = "A" is "B"
|
||||||
|
g = "--" is not "--"
|
||||||
|
h = "A" is not "B"
|
||||||
|
i = str_instance() < "..."
|
||||||
|
j = "ab" < "ab_cd"
|
||||||
|
"#,
|
||||||
|
)?;
|
||||||
|
|
||||||
|
assert_public_ty(&db, "src/a.py", "a", "Literal[True]");
|
||||||
|
assert_public_ty(&db, "src/a.py", "b", "Literal[True]");
|
||||||
|
assert_public_ty(&db, "src/a.py", "c", "Literal[False]");
|
||||||
|
assert_public_ty(&db, "src/a.py", "d", "Literal[False]");
|
||||||
|
assert_public_ty(&db, "src/a.py", "e", "bool");
|
||||||
|
assert_public_ty(&db, "src/a.py", "f", "Literal[False]");
|
||||||
|
assert_public_ty(&db, "src/a.py", "g", "bool");
|
||||||
|
assert_public_ty(&db, "src/a.py", "h", "Literal[True]");
|
||||||
|
assert_public_ty(&db, "src/a.py", "i", "bool");
|
||||||
|
// Very cornercase test ensuring we're not comparing the interned salsa symbols, which
|
||||||
|
// compare by order of declaration
|
||||||
|
assert_public_ty(&db, "src/a.py", "j", "Literal[True]");
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn comparison_unsupported_operators() -> anyhow::Result<()> {
|
fn comparison_unsupported_operators() -> anyhow::Result<()> {
|
||||||
let mut db = setup_db();
|
let mut db = setup_db();
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue