diff --git a/crates/red_knot_python_semantic/resources/mdtest/boolean/short_circuit.md b/crates/red_knot_python_semantic/resources/mdtest/boolean/short_circuit.md index 6ad75f185b..f77eea2d31 100644 --- a/crates/red_knot_python_semantic/resources/mdtest/boolean/short_circuit.md +++ b/crates/red_knot_python_semantic/resources/mdtest/boolean/short_circuit.md @@ -43,7 +43,7 @@ if True and (x := 1): ```py def _(flag: bool): - flag or (x := 1) or reveal_type(x) # revealed: Literal[1] + flag or (x := 1) or reveal_type(x) # revealed: Never # error: [unresolved-reference] flag or reveal_type(y) or (y := 1) # revealed: Unknown diff --git a/crates/red_knot_python_semantic/resources/mdtest/narrow/conditionals/boolean.md b/crates/red_knot_python_semantic/resources/mdtest/narrow/conditionals/boolean.md index c0e1af2f3d..566ec10a78 100644 --- a/crates/red_knot_python_semantic/resources/mdtest/narrow/conditionals/boolean.md +++ b/crates/red_knot_python_semantic/resources/mdtest/narrow/conditionals/boolean.md @@ -223,3 +223,15 @@ def _(x: str | None, y: str | None): if y is not x: reveal_type(y) # revealed: str | None ``` + +## Assignment expressions + +```py +def f() -> bool: + return True + +if x := f(): + reveal_type(x) # revealed: Literal[True] +else: + reveal_type(x) # revealed: Literal[False] +``` diff --git a/crates/red_knot_python_semantic/resources/mdtest/narrow/conditionals/elif_else.md b/crates/red_knot_python_semantic/resources/mdtest/narrow/conditionals/elif_else.md index 76eae880ef..376c24f1e9 100644 --- a/crates/red_knot_python_semantic/resources/mdtest/narrow/conditionals/elif_else.md +++ b/crates/red_knot_python_semantic/resources/mdtest/narrow/conditionals/elif_else.md @@ -47,3 +47,16 @@ def _(flag1: bool, flag2: bool): # TODO should be Never reveal_type(x) # revealed: Literal[1, 2] ``` + +## Assignment expressions + +```py +def f() -> int | str | None: ... + +if isinstance(x := f(), int): + reveal_type(x) # revealed: int +elif isinstance(x, str): + reveal_type(x) # revealed: str & ~int +else: + reveal_type(x) # revealed: None +``` diff --git a/crates/red_knot_python_semantic/resources/mdtest/narrow/conditionals/in.md b/crates/red_knot_python_semantic/resources/mdtest/narrow/conditionals/in.md index dad0374702..865eb48788 100644 --- a/crates/red_knot_python_semantic/resources/mdtest/narrow/conditionals/in.md +++ b/crates/red_knot_python_semantic/resources/mdtest/narrow/conditionals/in.md @@ -78,3 +78,17 @@ def _(x: Literal[1, "a", "b", "c", "d"]): else: reveal_type(x) # revealed: Literal[1, "d"] ``` + +## Assignment expressions + +```py +from typing import Literal + +def f() -> Literal[1, 2, 3]: + return 1 + +if (x := f()) in (1,): + reveal_type(x) # revealed: Literal[1] +else: + reveal_type(x) # revealed: Literal[2, 3] +``` diff --git a/crates/red_knot_python_semantic/resources/mdtest/narrow/conditionals/is.md b/crates/red_knot_python_semantic/resources/mdtest/narrow/conditionals/is.md index 8a95bfc278..c7d99c48b2 100644 --- a/crates/red_knot_python_semantic/resources/mdtest/narrow/conditionals/is.md +++ b/crates/red_knot_python_semantic/resources/mdtest/narrow/conditionals/is.md @@ -100,3 +100,16 @@ def _(flag: bool): else: reveal_type(x) # revealed: Literal[42] ``` + +## Assignment expressions + +```py +from typing import Literal + +def f() -> Literal[1, 2] | None: ... + +if (x := f()) is None: + reveal_type(x) # revealed: None +else: + reveal_type(x) # revealed: Literal[1, 2] +``` diff --git a/crates/red_knot_python_semantic/resources/mdtest/narrow/conditionals/is_not.md b/crates/red_knot_python_semantic/resources/mdtest/narrow/conditionals/is_not.md index 980a66a68d..fba62e9213 100644 --- a/crates/red_knot_python_semantic/resources/mdtest/narrow/conditionals/is_not.md +++ b/crates/red_knot_python_semantic/resources/mdtest/narrow/conditionals/is_not.md @@ -82,3 +82,14 @@ def _(x_flag: bool, y_flag: bool): reveal_type(x) # revealed: bool reveal_type(y) # revealed: bool ``` + +## Assignment expressions + +```py +def f() -> int | str | None: ... + +if (x := f()) is not None: + reveal_type(x) # revealed: int | str +else: + reveal_type(x) # revealed: None +``` diff --git a/crates/red_knot_python_semantic/resources/mdtest/narrow/conditionals/not_eq.md b/crates/red_knot_python_semantic/resources/mdtest/narrow/conditionals/not_eq.md index abe0c4d5aa..20f25d9ed4 100644 --- a/crates/red_knot_python_semantic/resources/mdtest/narrow/conditionals/not_eq.md +++ b/crates/red_knot_python_semantic/resources/mdtest/narrow/conditionals/not_eq.md @@ -89,3 +89,18 @@ def _(flag1: bool, flag2: bool, a: int): else: reveal_type(x) # revealed: Literal[1, 2] ``` + +## Assignment expressions + +```py +from typing import Literal + +def f() -> Literal[1, 2, 3]: + return 1 + +if (x := f()) != 1: + reveal_type(x) # revealed: Literal[2, 3] +else: + # TODO should be Literal[1] + reveal_type(x) # revealed: Literal[1, 2, 3] +``` diff --git a/crates/red_knot_python_semantic/src/types/narrow.rs b/crates/red_knot_python_semantic/src/types/narrow.rs index cf5431b47e..04ca2ead84 100644 --- a/crates/red_knot_python_semantic/src/types/narrow.rs +++ b/crates/red_knot_python_semantic/src/types/narrow.rs @@ -275,7 +275,8 @@ impl<'db> NarrowingConstraintsBuilder<'db> { self.evaluate_expression_node_predicate(&unary_op.operand, expression, !is_positive) } ast::Expr::BoolOp(bool_op) => self.evaluate_bool_op(bool_op, expression, is_positive), - _ => None, // TODO other test expression kinds + ast::Expr::Named(expr_named) => self.evaluate_expr_named(expr_named, is_positive), + _ => None, } } @@ -343,6 +344,18 @@ impl<'db> NarrowingConstraintsBuilder<'db> { NarrowingConstraints::from_iter([(symbol, ty)]) } + fn evaluate_expr_named( + &mut self, + expr_named: &ast::ExprNamed, + is_positive: bool, + ) -> Option> { + if let ast::Expr::Name(expr_name) = expr_named.target.as_ref() { + Some(self.evaluate_expr_name(expr_name, is_positive)) + } else { + None + } + } + fn evaluate_expr_in(&mut self, lhs_ty: Type<'db>, rhs_ty: Type<'db>) -> Option> { if lhs_ty.is_single_valued(self.db) || lhs_ty.is_union_of_single_valued(self.db) { match rhs_ty { @@ -365,6 +378,44 @@ impl<'db> NarrowingConstraintsBuilder<'db> { } } + fn evaluate_expr_compare_op( + &mut self, + lhs_ty: Type<'db>, + rhs_ty: Type<'db>, + op: ast::CmpOp, + ) -> Option> { + match op { + ast::CmpOp::IsNot => { + if rhs_ty.is_singleton(self.db) { + let ty = IntersectionBuilder::new(self.db) + .add_negative(rhs_ty) + .build(); + Some(ty) + } else { + // Non-singletons cannot be safely narrowed using `is not` + None + } + } + ast::CmpOp::Is => Some(rhs_ty), + ast::CmpOp::NotEq => { + if rhs_ty.is_single_valued(self.db) { + let ty = IntersectionBuilder::new(self.db) + .add_negative(rhs_ty) + .build(); + Some(ty) + } else { + None + } + } + ast::CmpOp::Eq if lhs_ty.is_literal_string() => Some(rhs_ty), + ast::CmpOp::In => self.evaluate_expr_in(lhs_ty, rhs_ty), + ast::CmpOp::NotIn => self + .evaluate_expr_in(lhs_ty, rhs_ty) + .map(|ty| ty.negate(self.db)), + _ => None, + } + } + fn evaluate_expr_compare( &mut self, expr_compare: &ast::ExprCompare, @@ -372,7 +423,10 @@ impl<'db> NarrowingConstraintsBuilder<'db> { is_positive: bool, ) -> Option> { fn is_narrowing_target_candidate(expr: &ast::Expr) -> bool { - matches!(expr, ast::Expr::Name(_) | ast::Expr::Call(_)) + matches!( + expr, + ast::Expr::Name(_) | ast::Expr::Call(_) | ast::Expr::Named(_) + ) } let ast::ExprCompare { @@ -423,43 +477,24 @@ impl<'db> NarrowingConstraintsBuilder<'db> { }) => { let symbol = self.expect_expr_name_symbol(id); - match if is_positive { *op } else { op.negate() } { - ast::CmpOp::IsNot => { - if rhs_ty.is_singleton(self.db) { - let ty = IntersectionBuilder::new(self.db) - .add_negative(rhs_ty) - .build(); - constraints.insert(symbol, ty); - } else { - // Non-singletons cannot be safely narrowed using `is not` - } - } - ast::CmpOp::Is => { - constraints.insert(symbol, rhs_ty); - } - ast::CmpOp::NotEq => { - if rhs_ty.is_single_valued(self.db) { - let ty = IntersectionBuilder::new(self.db) - .add_negative(rhs_ty) - .build(); - constraints.insert(symbol, ty); - } - } - ast::CmpOp::Eq if lhs_ty.is_literal_string() => { - constraints.insert(symbol, rhs_ty); - } - ast::CmpOp::In => { - if let Some(ty) = self.evaluate_expr_in(lhs_ty, rhs_ty) { - constraints.insert(symbol, ty); - } - } - ast::CmpOp::NotIn => { - if let Some(ty) = self.evaluate_expr_in(lhs_ty, rhs_ty) { - constraints.insert(symbol, ty.negate(self.db)); - } - } - _ => { - // TODO other comparison types + let op = if is_positive { *op } else { op.negate() }; + + if let Some(ty) = self.evaluate_expr_compare_op(lhs_ty, rhs_ty, op) { + constraints.insert(symbol, ty); + } + } + ast::Expr::Named(ast::ExprNamed { + range: _, + target, + value: _, + }) => { + if let ast::Expr::Name(ast::ExprName { id, .. }) = target.as_ref() { + let symbol = self.expect_expr_name_symbol(id); + + let op = if is_positive { *op } else { op.negate() }; + + if let Some(ty) = self.evaluate_expr_compare_op(lhs_ty, rhs_ty, op) { + constraints.insert(symbol, ty); } } } @@ -535,10 +570,16 @@ impl<'db> NarrowingConstraintsBuilder<'db> { Type::FunctionLiteral(function_type) if expr_call.arguments.keywords.is_empty() => { let function = function_type.known(self.db)?.into_constraint_function()?; - let [ast::Expr::Name(ast::ExprName { id, .. }), class_info] = - &*expr_call.arguments.args - else { - return None; + let (id, class_info) = match &*expr_call.arguments.args { + [first, class_info] => match first { + ast::Expr::Named(ast::ExprNamed { target, .. }) => match target.as_ref() { + ast::Expr::Name(ast::ExprName { id, .. }) => (id, class_info), + _ => return None, + }, + ast::Expr::Name(ast::ExprName { id, .. }) => (id, class_info), + _ => return None, + }, + _ => return None, }; let symbol = self.expect_expr_name_symbol(id);