[red-knot] support fstring expressions (#13511)

<!--
Thank you for contributing to Ruff! To help us out with reviewing,
please consider the following:

- Does this pull request include a summary of the change? (See below.)
- Does this pull request include a descriptive title?
- Does this pull request include references to any relevant issues?
-->

## Summary

Implement inference for `f-string`, contributes to #12701.

### First Implementation

When looking at the way `mypy` handles things, I noticed the following:
- No variables (e.g. `f"hello"`) ⇒ `LiteralString`
- Any variable (e.g. `f"number {1}"`) ⇒ `str`

My first commit (1ba5d0f13fdf70ed8b2b1a41433b32fc9085add2) implements
exactly this logic, except that we deal with string literals just like
`infer_string_literal_expression` (if below `MAX_STRING_LITERAL_SIZE`,
show `Literal["exact string"]`)

### Second Implementation

My second commit (90326ce9af5549af7b4efae89cd074ddf68ada14) pushes
things a bit further to handle cases where the expression within the
`f-string` are all literal values (string representation known at static
time).

Here's an example of when this could happen in code:
```python
BASE_URL = "https://httpbin.org"
VERSION = "v1"
endpoint = f"{BASE_URL}/{VERSION}/post"  # Literal["https://httpbin.org/v1/post"]
```
As this can be sightly more costly (additional allocations), I don't
know if we want this feature.

## Test Plan

- Added a test `fstring_expression` covering all cases I can think of

---------

Co-authored-by: Carl Meyer <carl@astral.sh>
This commit is contained in:
Simon 2024-09-27 19:29:21 +02:00 committed by GitHub
parent f3e464ea4c
commit 1639488082
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 217 additions and 37 deletions

View file

@ -380,6 +380,10 @@ impl<'db> Type<'db> {
} }
} }
pub fn builtin_str(db: &'db dyn Db) -> Self {
builtins_symbol_ty(db, "str")
}
pub fn is_stdlib_symbol(&self, db: &'db dyn Db, module_name: &str, name: &str) -> bool { pub fn is_stdlib_symbol(&self, db: &'db dyn Db, module_name: &str, name: &str) -> bool {
match self { match self {
Type::Class(class) => class.is_stdlib_symbol(db, module_name, name), Type::Class(class) => class.is_stdlib_symbol(db, module_name, name),
@ -721,6 +725,44 @@ impl<'db> Type<'db> {
Type::Tuple(_) => builtins_symbol_ty(db, "tuple"), Type::Tuple(_) => builtins_symbol_ty(db, "tuple"),
} }
} }
/// Return the string representation of this type when converted to string as it would be
/// provided by the `__str__` method.
///
/// When not available, this should fall back to the value of `[Type::repr]`.
/// Note: this method is used in the builtins `format`, `print`, `str.format` and `f-strings`.
#[must_use]
pub fn str(&self, db: &'db dyn Db) -> Type<'db> {
match self {
Type::IntLiteral(_) | Type::BooleanLiteral(_) => self.repr(db),
Type::StringLiteral(_) | Type::LiteralString => *self,
// TODO: handle more complex types
_ => Type::builtin_str(db).to_instance(db),
}
}
/// Return the string representation of this type as it would be provided by the `__repr__`
/// method at runtime.
#[must_use]
pub fn repr(&self, db: &'db dyn Db) -> Type<'db> {
match self {
Type::IntLiteral(number) => Type::StringLiteral(StringLiteralType::new(db, {
number.to_string().into_boxed_str()
})),
Type::BooleanLiteral(true) => {
Type::StringLiteral(StringLiteralType::new(db, "True".into()))
}
Type::BooleanLiteral(false) => {
Type::StringLiteral(StringLiteralType::new(db, "False".into()))
}
Type::StringLiteral(literal) => Type::StringLiteral(StringLiteralType::new(db, {
format!("'{}'", literal.value(db).escape_default()).into()
})),
Type::LiteralString => Type::LiteralString,
// TODO: handle more complex types
_ => Type::builtin_str(db).to_instance(db),
}
}
} }
impl<'db> From<&Type<'db>> for Type<'db> { impl<'db> From<&Type<'db>> for Type<'db> {
@ -1198,12 +1240,13 @@ mod tests {
/// A test representation of a type that can be transformed unambiguously into a real Type, /// A test representation of a type that can be transformed unambiguously into a real Type,
/// given a db. /// given a db.
#[derive(Debug)] #[derive(Debug, Clone)]
enum Ty { enum Ty {
Never, Never,
Unknown, Unknown,
Any, Any,
IntLiteral(i64), IntLiteral(i64),
BoolLiteral(bool),
StringLiteral(&'static str), StringLiteral(&'static str),
LiteralString, LiteralString,
BytesLiteral(&'static str), BytesLiteral(&'static str),
@ -1222,6 +1265,7 @@ mod tests {
Ty::StringLiteral(s) => { Ty::StringLiteral(s) => {
Type::StringLiteral(StringLiteralType::new(db, (*s).into())) Type::StringLiteral(StringLiteralType::new(db, (*s).into()))
} }
Ty::BoolLiteral(b) => Type::BooleanLiteral(b),
Ty::LiteralString => Type::LiteralString, Ty::LiteralString => Type::LiteralString,
Ty::BytesLiteral(s) => { Ty::BytesLiteral(s) => {
Type::BytesLiteral(BytesLiteralType::new(db, s.as_bytes().into())) Type::BytesLiteral(BytesLiteralType::new(db, s.as_bytes().into()))
@ -1331,4 +1375,28 @@ mod tests {
let db = setup_db(); let db = setup_db();
assert_eq!(ty.into_type(&db).bool(&db), Truthiness::Ambiguous); assert_eq!(ty.into_type(&db).bool(&db), Truthiness::Ambiguous);
} }
#[test_case(Ty::IntLiteral(1), Ty::StringLiteral("1"))]
#[test_case(Ty::BoolLiteral(true), Ty::StringLiteral("True"))]
#[test_case(Ty::BoolLiteral(false), Ty::StringLiteral("False"))]
#[test_case(Ty::StringLiteral("ab'cd"), Ty::StringLiteral("ab'cd"))] // no quotes
#[test_case(Ty::LiteralString, Ty::LiteralString)]
#[test_case(Ty::BuiltinInstance("int"), Ty::BuiltinInstance("str"))]
fn has_correct_str(ty: Ty, expected: Ty) {
let db = setup_db();
assert_eq!(ty.into_type(&db).str(&db), expected.into_type(&db));
}
#[test_case(Ty::IntLiteral(1), Ty::StringLiteral("1"))]
#[test_case(Ty::BoolLiteral(true), Ty::StringLiteral("True"))]
#[test_case(Ty::BoolLiteral(false), Ty::StringLiteral("False"))]
#[test_case(Ty::StringLiteral("ab'cd"), Ty::StringLiteral("'ab\\'cd'"))] // single quotes
#[test_case(Ty::LiteralString, Ty::LiteralString)]
#[test_case(Ty::BuiltinInstance("int"), Ty::BuiltinInstance("str"))]
fn has_correct_repr(ty: Ty, expected: Ty) {
let db = setup_db();
assert_eq!(ty.into_type(&db).repr(&db), expected.into_type(&db));
}
} }

View file

@ -1653,50 +1653,50 @@ impl<'db> TypeInferenceBuilder<'db> {
fn infer_fstring_expression(&mut self, fstring: &ast::ExprFString) -> Type<'db> { fn infer_fstring_expression(&mut self, fstring: &ast::ExprFString) -> Type<'db> {
let ast::ExprFString { range: _, value } = fstring; let ast::ExprFString { range: _, value } = fstring;
let mut collector = StringPartsCollector::new();
for part in value { for part in value {
// Make sure we iter through every parts to infer all sub-expressions. The `collector`
// struct ensures we don't allocate unnecessary strings.
match part { match part {
ast::FStringPart::Literal(_) => { ast::FStringPart::Literal(literal) => {
// TODO string literal type collector.push_str(&literal.value);
} }
ast::FStringPart::FString(fstring) => { ast::FStringPart::FString(fstring) => {
let ast::FString { for element in &fstring.elements {
range: _, match element {
elements, ast::FStringElement::Expression(expression) => {
flags: _, let ast::FStringExpressionElement {
} = fstring; range: _,
for element in elements { expression,
self.infer_fstring_element(element); debug_text: _,
} conversion,
} format_spec,
} } = expression;
} let ty = self.infer_expression(expression);
// TODO str type // TODO: handle format specifiers by calling a method
Type::Unknown // (`Type::format`?) that handles the `__format__` method.
} // Conversion flags should be handled before calling `__format__`.
// https://docs.python.org/3/library/string.html#format-string-syntax
fn infer_fstring_element(&mut self, element: &ast::FStringElement) { if !conversion.is_none() || format_spec.is_some() {
match element { collector.add_expression();
ast::FStringElement::Literal(_) => { } else {
// TODO string literal type if let Type::StringLiteral(literal) = ty.str(self.db) {
} collector.push_str(literal.value(self.db));
ast::FStringElement::Expression(expr_element) => { } else {
let ast::FStringExpressionElement { collector.add_expression();
range: _, }
expression, }
debug_text: _, }
conversion: _, ast::FStringElement::Literal(literal) => {
format_spec, collector.push_str(&literal.value);
} = expr_element; }
self.infer_expression(expression); }
if let Some(format_spec) = format_spec {
for spec_element in &format_spec.elements {
self.infer_fstring_element(spec_element);
} }
} }
} }
} }
collector.ty(self.db)
} }
fn infer_ellipsis_literal_expression( fn infer_ellipsis_literal_expression(
@ -2659,6 +2659,53 @@ enum ModuleNameResolutionError {
TooManyDots, TooManyDots,
} }
/// Struct collecting string parts when inferring a formatted string. Infers a string literal if the
/// concatenated string is small enough, otherwise infers a literal string.
///
/// If the formatted string contains an expression (with a representation unknown at compile time),
/// infers an instance of `builtins.str`.
struct StringPartsCollector {
concatenated: Option<String>,
expression: bool,
}
impl StringPartsCollector {
fn new() -> Self {
Self {
concatenated: Some(String::new()),
expression: false,
}
}
fn push_str(&mut self, literal: &str) {
if let Some(mut concatenated) = self.concatenated.take() {
if concatenated.len().saturating_add(literal.len())
<= TypeInferenceBuilder::MAX_STRING_LITERAL_SIZE
{
concatenated.push_str(literal);
self.concatenated = Some(concatenated);
} else {
self.concatenated = None;
}
}
}
fn add_expression(&mut self) {
self.concatenated = None;
self.expression = true;
}
fn ty(self, db: &dyn Db) -> Type {
if self.expression {
Type::builtin_str(db).to_instance(db)
} else if let Some(concatenated) = self.concatenated {
Type::StringLiteral(StringLiteralType::new(db, concatenated.into_boxed_str()))
} else {
Type::LiteralString
}
}
}
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
@ -3593,6 +3640,71 @@ mod tests {
Ok(()) Ok(())
} }
#[test]
fn fstring_expression() -> anyhow::Result<()> {
let mut db = setup_db();
db.write_dedented(
"src/a.py",
"
x = 0
y = str()
z = False
a = f'hello'
b = f'h {x}'
c = 'one ' f'single ' f'literal'
d = 'first ' f'second({b})' f' third'
e = f'-{y}-'
f = f'-{y}-' f'--' '--'
g = f'{z} == {False} is {True}'
",
)?;
assert_public_ty(&db, "src/a.py", "a", "Literal[\"hello\"]");
assert_public_ty(&db, "src/a.py", "b", "Literal[\"h 0\"]");
assert_public_ty(&db, "src/a.py", "c", "Literal[\"one single literal\"]");
assert_public_ty(&db, "src/a.py", "d", "Literal[\"first second(h 0) third\"]");
assert_public_ty(&db, "src/a.py", "e", "str");
assert_public_ty(&db, "src/a.py", "f", "str");
assert_public_ty(&db, "src/a.py", "g", "Literal[\"False == False is True\"]");
Ok(())
}
#[test]
fn fstring_expression_with_conversion_flags() -> anyhow::Result<()> {
let mut db = setup_db();
db.write_dedented(
"src/a.py",
"
string = 'hello'
a = f'{string!r}'
",
)?;
assert_public_ty(&db, "src/a.py", "a", "str"); // Should be `Literal["'hello'"]`
Ok(())
}
#[test]
fn fstring_expression_with_format_specifier() -> anyhow::Result<()> {
let mut db = setup_db();
db.write_dedented(
"src/a.py",
"
a = f'{1:02}'
",
)?;
assert_public_ty(&db, "src/a.py", "a", "str"); // Should be `Literal["01"]`
Ok(())
}
#[test] #[test]
fn basic_call_expression() -> anyhow::Result<()> { fn basic_call_expression() -> anyhow::Result<()> {
let mut db = setup_db(); let mut db = setup_db();