[red-knot] Type narrowing for assertions (take 2) (#17345)

## Summary

Fixes #17147.

This was landed in #17149 and then reverted in #17335 because it caused
cycle panics in checking pybind11. #17456 fixed the cause of that panic.

## Test Plan

Add new narrow/assert.md test file

Co-authored-by: Matthew Mckee <matthewmckee04@yahoo.co.uk>
This commit is contained in:
Carl Meyer 2025-04-18 08:11:07 -07:00 committed by GitHub
parent 1918c61623
commit e4e405d2a1
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 65 additions and 2 deletions

View file

@ -0,0 +1,53 @@
# Narrowing with assert statements
## `assert` a value `is None` or `is not None`
```py
def _(x: str | None, y: str | None):
assert x is not None
reveal_type(x) # revealed: str
assert y is None
reveal_type(y) # revealed: None
```
## `assert` a value is truthy or falsy
```py
def _(x: bool, y: bool):
assert x
reveal_type(x) # revealed: Literal[True]
assert not y
reveal_type(y) # revealed: Literal[False]
```
## `assert` with `is` and `==` for literals
```py
from typing import Literal
def _(x: Literal[1, 2, 3], y: Literal[1, 2, 3]):
assert x is 2
reveal_type(x) # revealed: Literal[2]
assert y == 2
reveal_type(y) # revealed: Literal[1, 2, 3]
```
## `assert` with `isinstance`
```py
def _(x: int | str):
assert isinstance(x, int)
reveal_type(x) # revealed: int
```
## `assert` a value `in` a tuple
```py
from typing import Literal
def _(x: Literal[1, 2, 3], y: Literal[1, 2, 3]):
assert x in (1, 2)
reveal_type(x) # revealed: Literal[1, 2]
assert y not in (1, 2)
reveal_type(y) # revealed: Literal[3]
```

View file

@ -569,7 +569,6 @@ impl<'db> SemanticIndexBuilder<'db> {
}
/// Records a visibility constraint by applying it to all live bindings and declarations.
#[must_use = "A visibility constraint must always be negated after it is added"]
fn record_visibility_constraint(
&mut self,
predicate: Predicate<'db>,
@ -1323,6 +1322,17 @@ where
);
}
}
ast::Stmt::Assert(node) => {
self.visit_expr(&node.test);
let predicate = self.record_expression_narrowing_constraint(&node.test);
self.record_visibility_constraint(predicate);
if let Some(msg) = &node.msg {
self.visit_expr(msg);
}
}
ast::Stmt::Assign(node) => {
debug_assert_eq!(&self.current_assignments, &[]);

View file

@ -3294,7 +3294,7 @@ impl<'db> TypeInferenceBuilder<'db> {
msg,
} = assert;
let test_ty = self.infer_expression(test);
let test_ty = self.infer_standalone_expression(test);
if let Err(err) = test_ty.try_bool(self.db()) {
err.report_diagnostic(&self.context, &**test);