[ruff] Add API for splicing into an existing import statement

Basically, given a `from module import name1, name2, ...` statement,
we'd like to be able to insert another name in that list.

This new `Insertion::existing_import` API provides such
functionality. There isn't much to it, although we are careful
to try and avoid inserting nonsense for import statements
that are already invalid.
This commit is contained in:
Andrew Gallant 2025-09-16 10:09:23 -04:00 committed by Andrew Gallant
parent a47a50e6e2
commit da5eb85087
3 changed files with 336 additions and 1 deletions

View file

@ -128,6 +128,57 @@ impl<'a> Insertion<'a> {
}
}
/// Create an [`Insertion`] to insert an additional member to import
/// into a `from <module> import member1, member2, ...` statement.
///
/// For example, given the following code:
///
/// ```python
/// """Hello, world!"""
///
/// from collections import Counter
///
///
/// def foo():
/// pass
/// ```
///
/// The insertion returned will begin after `Counter` but before the
/// newline terminator. Callers can then call [`Insertion::into_edit`]
/// with the additional member to add. A comma delimiter is handled
/// automatically.
///
/// The statement itself is assumed to be at the top-level of the module.
///
/// This returns `None` when `stmt` isn't a `from ... import ...`
/// statement.
pub fn existing_import(stmt: &Stmt, tokens: &Tokens) -> Option<Insertion<'static>> {
let Stmt::ImportFrom(ref import_from) = *stmt else {
return None;
};
if let Some(at) = import_from.names.last().map(Ranged::end) {
return Some(Insertion::inline(", ", at, ""));
}
// Our AST can deal with partial `from ... import`
// statements, so we might not have any members
// yet. In this case, we don't need the comma.
//
// ... however, unless we can be certain that
// inserting this name leads to a valid AST, we
// give up.
let at = import_from.end();
if !matches!(
tokens
.before(at)
.last()
.map(ruff_python_parser::Token::kind),
Some(TokenKind::Import)
) {
return None;
}
Some(Insertion::inline(" ", at, ""))
}
/// Create an [`Insertion`] to insert (e.g.) an import statement at the start of a given
/// block, along with a prefix and suffix to use for the insertion.
///
@ -314,7 +365,7 @@ mod tests {
use ruff_python_codegen::Stylist;
use ruff_python_parser::parse_module;
use ruff_source_file::LineEnding;
use ruff_text_size::TextSize;
use ruff_text_size::{Ranged, TextSize};
use super::Insertion;
@ -473,4 +524,286 @@ if True:
Insertion::indented("", TextSize::from(9), "\n", " ")
);
}
#[test]
fn existing_import_works() {
fn snapshot(content: &str, member: &str) -> String {
let parsed = parse_module(content).unwrap();
let edit = Insertion::existing_import(parsed.suite().first().unwrap(), parsed.tokens())
.unwrap()
.into_edit(member);
let insert_text = edit.content().expect("edit should be non-empty");
let mut content = content.to_string();
content.replace_range(edit.range().to_std_range(), insert_text);
content
}
let source = r#"
from collections import Counter
"#;
insta::assert_snapshot!(
snapshot(source, "defaultdict"),
@r"
from collections import Counter, defaultdict
",
);
let source = r#"
from collections import Counter, OrderedDict
"#;
insta::assert_snapshot!(
snapshot(source, "defaultdict"),
@r"
from collections import Counter, OrderedDict, defaultdict
",
);
let source = r#"
from collections import (Counter)
"#;
insta::assert_snapshot!(
snapshot(source, "defaultdict"),
@"from collections import (Counter, defaultdict)",
);
let source = r#"
from collections import (Counter, OrderedDict)
"#;
insta::assert_snapshot!(
snapshot(source, "defaultdict"),
@"from collections import (Counter, OrderedDict, defaultdict)",
);
let source = r#"
from collections import (Counter,)
"#;
insta::assert_snapshot!(
snapshot(source, "defaultdict"),
@"from collections import (Counter, defaultdict,)",
);
let source = r#"
from collections import (Counter, OrderedDict,)
"#;
insta::assert_snapshot!(
snapshot(source, "defaultdict"),
@"from collections import (Counter, OrderedDict, defaultdict,)",
);
let source = r#"
from collections import (
Counter
)
"#;
insta::assert_snapshot!(
snapshot(source, "defaultdict"),
@r"
from collections import (
Counter, defaultdict
)
",
);
let source = r#"
from collections import (
Counter,
)
"#;
insta::assert_snapshot!(
snapshot(source, "defaultdict"),
@r"
from collections import (
Counter, defaultdict,
)
",
);
let source = r#"
from collections import (
Counter,
OrderedDict
)
"#;
insta::assert_snapshot!(
snapshot(source, "defaultdict"),
@r"
from collections import (
Counter,
OrderedDict, defaultdict
)
",
);
let source = r#"
from collections import (
Counter,
OrderedDict,
)
"#;
insta::assert_snapshot!(
snapshot(source, "defaultdict"),
@r"
from collections import (
Counter,
OrderedDict, defaultdict,
)
",
);
let source = r#"
from collections import \
Counter
"#;
insta::assert_snapshot!(
snapshot(source, "defaultdict"),
@r"
from collections import \
Counter, defaultdict
",
);
let source = r#"
from collections import \
Counter, OrderedDict
"#;
insta::assert_snapshot!(
snapshot(source, "defaultdict"),
@r"
from collections import \
Counter, OrderedDict, defaultdict
",
);
let source = r#"
from collections import \
Counter, \
OrderedDict
"#;
insta::assert_snapshot!(
snapshot(source, "defaultdict"),
@r"
from collections import \
Counter, \
OrderedDict, defaultdict
",
);
/*
from collections import (
Collector # comment
)
from collections import (
Collector, # comment
)
from collections import (
Collector # comment
,
)
from collections import (
Collector
# comment
,
)
*/
let source = r#"
from collections import (
Counter # comment
)
"#;
insta::assert_snapshot!(
snapshot(source, "defaultdict"),
@r"
from collections import (
Counter, defaultdict # comment
)
",
);
let source = r#"
from collections import (
Counter, # comment
)
"#;
insta::assert_snapshot!(
snapshot(source, "defaultdict"),
@r"
from collections import (
Counter, defaultdict, # comment
)
",
);
let source = r#"
from collections import (
Counter # comment
,
)
"#;
insta::assert_snapshot!(
snapshot(source, "defaultdict"),
@r"
from collections import (
Counter, defaultdict # comment
,
)
",
);
let source = r#"
from collections import (
Counter
# comment
,
)
"#;
insta::assert_snapshot!(
snapshot(source, "defaultdict"),
@r"
from collections import (
Counter, defaultdict
# comment
,
)
",
);
let source = r#"
from collections import (
# comment 1
Counter # comment 2
# comment 3
)
"#;
insta::assert_snapshot!(
snapshot(source, "defaultdict"),
@r"
from collections import (
# comment 1
Counter, defaultdict # comment 2
# comment 3
)
",
);
let source = r#"
from collections import Counter # comment
"#;
insta::assert_snapshot!(
snapshot(source, "defaultdict"),
@"from collections import Counter, defaultdict # comment",
);
let source = r#"
from collections import Counter, OrderedDict # comment
"#;
insta::assert_snapshot!(
snapshot(source, "defaultdict"),
@"from collections import Counter, OrderedDict, defaultdict # comment",
);
}
}