[red-knot] Type narrowing for assertions (#17149)

## Summary

Fixes #17147 

## Test Plan

Add new narrow/assert.md test file

---------

Co-authored-by: Carl Meyer <carl@astral.sh>
This commit is contained in:
Matthew Mckee 2025-04-10 15:15:52 +01:00 committed by GitHub
parent fd9882a1f4
commit 907b6ed7b5
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 65 additions and 2 deletions

View file

@ -0,0 +1,53 @@
# Narrowing with assert statements
## `assert` a value `is None` or `is not None`
```py
def _(x: str | None, y: str | None):
assert x is not None
reveal_type(x) # revealed: str
assert y is None
reveal_type(y) # revealed: None
```
## `assert` a value is truthy or falsy
```py
def _(x: bool, y: bool):
assert x
reveal_type(x) # revealed: Literal[True]
assert not y
reveal_type(y) # revealed: Literal[False]
```
## `assert` with `is` and `==` for literals
```py
from typing import Literal
def _(x: Literal[1, 2, 3], y: Literal[1, 2, 3]):
assert x is 2
reveal_type(x) # revealed: Literal[2]
assert y == 2
reveal_type(y) # revealed: Literal[1, 2, 3]
```
## `assert` with `isinstance`
```py
def _(x: int | str):
assert isinstance(x, int)
reveal_type(x) # revealed: int
```
## `assert` a value `in` a tuple
```py
from typing import Literal
def _(x: Literal[1, 2, 3], y: Literal[1, 2, 3]):
assert x in (1, 2)
reveal_type(x) # revealed: Literal[1, 2]
assert y not in (1, 2)
reveal_type(y) # revealed: Literal[3]
```

View file

@ -534,7 +534,6 @@ impl<'db> SemanticIndexBuilder<'db> {
}
/// Records a visibility constraint by applying it to all live bindings and declarations.
#[must_use = "A visibility constraint must always be negated after it is added"]
fn record_visibility_constraint(
&mut self,
predicate: Predicate<'db>,
@ -1292,6 +1291,17 @@ where
);
}
}
ast::Stmt::Assert(node) => {
self.visit_expr(&node.test);
let predicate = self.record_expression_narrowing_constraint(&node.test);
self.record_visibility_constraint(predicate);
if let Some(msg) = &node.msg {
self.visit_expr(msg);
}
}
ast::Stmt::Assign(node) => {
debug_assert_eq!(&self.current_assignments, &[]);

View file

@ -3188,7 +3188,7 @@ impl<'db> TypeInferenceBuilder<'db> {
msg,
} = assert;
let test_ty = self.infer_expression(test);
let test_ty = self.infer_standalone_expression(test);
if let Err(err) = test_ty.try_bool(self.db()) {
err.report_diagnostic(&self.context, &**test);