Respect mixed return and raise cases in return-type analysis (#9310)

## Summary

Given:

```python
from somewhere import get_cfg

def lookup_cfg(cfg_description):
    cfg = get_cfg(cfg_description)
    if cfg is not None:
        return cfg
    raise AttributeError(f"No cfg found matching {cfg_description}")
```

We were analyzing the method from last-to-first statement. So we saw the
`raise`, then assumed the method _always_ raised. In reality, though, it
_might_ return. This PR improves the branch analysis to respect these
mixed cases.

Closes https://github.com/astral-sh/ruff/issues/9269.
Closes https://github.com/astral-sh/ruff/issues/9304.
This commit is contained in:
Charlie Marsh 2023-12-29 12:46:37 -04:00 committed by GitHub
parent 00f3c7d1d5
commit 2895e7d126
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
7 changed files with 482 additions and 203 deletions

View file

@ -921,206 +921,6 @@ where
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum Terminal {
/// Every path through the function ends with a `raise` statement.
Raise,
/// Every path through the function ends with a `return` (or `raise`) statement.
Return,
}
impl Terminal {
/// Returns the [`Terminal`] behavior of the function, if it can be determined, or `None` if the
/// function contains at least one control flow path that does not end with a `return` or `raise`
/// statement.
pub fn from_function(function: &ast::StmtFunctionDef) -> Option<Terminal> {
/// Returns `true` if the body may break via a `break` statement.
fn sometimes_breaks(stmts: &[Stmt]) -> bool {
for stmt in stmts {
match stmt {
Stmt::For(ast::StmtFor { body, orelse, .. }) => {
if returns(body).is_some() {
return false;
}
if sometimes_breaks(orelse) {
return true;
}
}
Stmt::While(ast::StmtWhile { body, orelse, .. }) => {
if returns(body).is_some() {
return false;
}
if sometimes_breaks(orelse) {
return true;
}
}
Stmt::If(ast::StmtIf {
body,
elif_else_clauses,
..
}) => {
if std::iter::once(body)
.chain(elif_else_clauses.iter().map(|clause| &clause.body))
.any(|body| sometimes_breaks(body))
{
return true;
}
}
Stmt::Match(ast::StmtMatch { cases, .. }) => {
if cases.iter().any(|case| sometimes_breaks(&case.body)) {
return true;
}
}
Stmt::Try(ast::StmtTry {
body,
handlers,
orelse,
finalbody,
..
}) => {
if sometimes_breaks(body)
|| handlers.iter().any(|handler| {
let ExceptHandler::ExceptHandler(ast::ExceptHandlerExceptHandler {
body,
..
}) = handler;
sometimes_breaks(body)
})
|| sometimes_breaks(orelse)
|| sometimes_breaks(finalbody)
{
return true;
}
}
Stmt::With(ast::StmtWith { body, .. }) => {
if sometimes_breaks(body) {
return true;
}
}
Stmt::Break(_) => return true,
Stmt::Return(_) => return false,
Stmt::Raise(_) => return false,
_ => {}
}
}
false
}
/// Returns `true` if the body may break via a `break` statement.
fn always_breaks(stmts: &[Stmt]) -> bool {
for stmt in stmts {
match stmt {
Stmt::Break(_) => return true,
Stmt::Return(_) => return false,
Stmt::Raise(_) => return false,
_ => {}
}
}
false
}
/// Returns `true` if the body contains a branch that ends without an explicit `return` or
/// `raise` statement.
fn returns(stmts: &[Stmt]) -> Option<Terminal> {
for stmt in stmts.iter().rev() {
match stmt {
Stmt::For(ast::StmtFor { body, orelse, .. })
| Stmt::While(ast::StmtWhile { body, orelse, .. }) => {
if always_breaks(body) {
return None;
}
if let Some(terminal) = returns(body) {
return Some(terminal);
}
if !sometimes_breaks(body) {
if let Some(terminal) = returns(orelse) {
return Some(terminal);
}
}
}
Stmt::If(ast::StmtIf {
body,
elif_else_clauses,
..
}) => {
if elif_else_clauses.iter().any(|clause| clause.test.is_none()) {
match Terminal::combine(std::iter::once(returns(body)).chain(
elif_else_clauses.iter().map(|clause| returns(&clause.body)),
)) {
Some(Terminal::Raise) => return Some(Terminal::Raise),
Some(Terminal::Return) => return Some(Terminal::Return),
_ => {}
}
}
}
Stmt::Match(ast::StmtMatch { cases, .. }) => {
// Note: we assume the `match` is exhaustive.
match Terminal::combine(cases.iter().map(|case| returns(&case.body))) {
Some(Terminal::Raise) => return Some(Terminal::Raise),
Some(Terminal::Return) => return Some(Terminal::Return),
_ => {}
}
}
Stmt::Try(ast::StmtTry {
body,
handlers,
orelse,
finalbody,
..
}) => {
// If the `finally` block returns, the `try` block must also return.
if let Some(terminal) = returns(finalbody) {
return Some(terminal);
}
// If the body returns, the `try` block must also return.
if returns(body) == Some(Terminal::Return) {
return Some(Terminal::Return);
}
// If the else block and all the handlers return, the `try` block must also
// return.
if let Some(terminal) =
Terminal::combine(std::iter::once(returns(orelse)).chain(
handlers.iter().map(|handler| {
let ExceptHandler::ExceptHandler(
ast::ExceptHandlerExceptHandler { body, .. },
) = handler;
returns(body)
}),
))
{
return Some(terminal);
}
}
Stmt::With(ast::StmtWith { body, .. }) => {
if let Some(terminal) = returns(body) {
return Some(terminal);
}
}
Stmt::Return(_) => return Some(Terminal::Return),
Stmt::Raise(_) => return Some(Terminal::Raise),
_ => {}
}
}
None
}
returns(&function.body)
}
/// Combine a series of [`Terminal`] operators.
fn combine(iter: impl Iterator<Item = Option<Terminal>>) -> Option<Terminal> {
iter.reduce(|acc, terminal| match (acc, terminal) {
(Some(Self::Raise), Some(Self::Raise)) => Some(Self::Raise),
(Some(_), Some(Self::Return)) => Some(Self::Return),
(Some(Self::Return), Some(_)) => Some(Self::Return),
_ => None,
})
.flatten()
}
}
/// A [`StatementVisitor`] that collects all `raise` statements in a function or method.
#[derive(Default)]
pub struct RaiseStatementVisitor<'a> {