syntax/syntax_editor/
edits.rs

1//! Structural editing for ast using `SyntaxEditor`
2
3use crate::{
4    AstToken, Direction, SyntaxElement, SyntaxKind, SyntaxNode, SyntaxToken, T,
5    algo::neighbor,
6    ast::{
7        self, AstNode, Fn, GenericParam, HasGenericParams, HasName, edit::IndentLevel, make,
8        syntax_factory::SyntaxFactory,
9    },
10    syntax_editor::{Position, SyntaxEditor},
11};
12
13impl SyntaxEditor {
14    /// Adds a new generic param to the function using `SyntaxEditor`
15    pub fn add_generic_param(&mut self, function: &Fn, new_param: GenericParam) {
16        match function.generic_param_list() {
17            Some(generic_param_list) => match generic_param_list.generic_params().last() {
18                Some(last_param) => {
19                    // There exists a generic param list and it's not empty
20                    let position = generic_param_list.r_angle_token().map_or_else(
21                        || Position::last_child_of(function.syntax()),
22                        Position::before,
23                    );
24
25                    if last_param
26                        .syntax()
27                        .next_sibling_or_token()
28                        .is_some_and(|it| it.kind() == SyntaxKind::COMMA)
29                    {
30                        self.insert(
31                            Position::after(last_param.syntax()),
32                            new_param.syntax().clone(),
33                        );
34                        self.insert(
35                            Position::after(last_param.syntax()),
36                            make::token(SyntaxKind::WHITESPACE),
37                        );
38                        self.insert(
39                            Position::after(last_param.syntax()),
40                            make::token(SyntaxKind::COMMA),
41                        );
42                    } else {
43                        let elements = vec![
44                            make::token(SyntaxKind::COMMA).into(),
45                            make::token(SyntaxKind::WHITESPACE).into(),
46                            new_param.syntax().clone().into(),
47                        ];
48                        self.insert_all(position, elements);
49                    }
50                }
51                None => {
52                    // There exists a generic param list but it's empty
53                    let position = Position::after(generic_param_list.l_angle_token().unwrap());
54                    self.insert(position, new_param.syntax());
55                }
56            },
57            None => {
58                // There was no generic param list
59                let position = if let Some(name) = function.name() {
60                    Position::after(name.syntax)
61                } else if let Some(fn_token) = function.fn_token() {
62                    Position::after(fn_token)
63                } else if let Some(param_list) = function.param_list() {
64                    Position::before(param_list.syntax)
65                } else {
66                    Position::last_child_of(function.syntax())
67                };
68                let elements = vec![
69                    make::token(SyntaxKind::L_ANGLE).into(),
70                    new_param.syntax().clone().into(),
71                    make::token(SyntaxKind::R_ANGLE).into(),
72                ];
73                self.insert_all(position, elements);
74            }
75        }
76    }
77}
78
79fn get_or_insert_comma_after(editor: &mut SyntaxEditor, syntax: &SyntaxNode) -> SyntaxToken {
80    let make = SyntaxFactory::without_mappings();
81    match syntax
82        .siblings_with_tokens(Direction::Next)
83        .filter_map(|it| it.into_token())
84        .find(|it| it.kind() == T![,])
85    {
86        Some(it) => it,
87        None => {
88            let comma = make.token(T![,]);
89            editor.insert(Position::after(syntax), &comma);
90            comma
91        }
92    }
93}
94
95impl ast::VariantList {
96    pub fn add_variant(&self, editor: &mut SyntaxEditor, variant: &ast::Variant) {
97        let make = SyntaxFactory::without_mappings();
98        let (indent, position) = match self.variants().last() {
99            Some(last_item) => (
100                IndentLevel::from_node(last_item.syntax()),
101                Position::after(get_or_insert_comma_after(editor, last_item.syntax())),
102            ),
103            None => match self.l_curly_token() {
104                Some(l_curly) => {
105                    normalize_ws_between_braces(editor, self.syntax());
106                    (IndentLevel::from_token(&l_curly) + 1, Position::after(&l_curly))
107                }
108                None => (IndentLevel::single(), Position::last_child_of(self.syntax())),
109            },
110        };
111        let elements: Vec<SyntaxElement> = vec![
112            make.whitespace(&format!("{}{indent}", "\n")).into(),
113            variant.syntax().clone().into(),
114            make.token(T![,]).into(),
115        ];
116        editor.insert_all(position, elements);
117    }
118}
119
120fn normalize_ws_between_braces(editor: &mut SyntaxEditor, node: &SyntaxNode) -> Option<()> {
121    let make = SyntaxFactory::without_mappings();
122    let l = node
123        .children_with_tokens()
124        .filter_map(|it| it.into_token())
125        .find(|it| it.kind() == T!['{'])?;
126    let r = node
127        .children_with_tokens()
128        .filter_map(|it| it.into_token())
129        .find(|it| it.kind() == T!['}'])?;
130
131    let indent = IndentLevel::from_node(node);
132
133    match l.next_sibling_or_token() {
134        Some(ws) if ws.kind() == SyntaxKind::WHITESPACE => {
135            if ws.next_sibling_or_token()?.into_token()? == r {
136                editor.replace(ws, make.whitespace(&format!("\n{indent}")));
137            }
138        }
139        Some(ws) if ws.kind() == T!['}'] => {
140            editor.insert(Position::after(l), make.whitespace(&format!("\n{indent}")));
141        }
142        _ => (),
143    }
144    Some(())
145}
146
147pub trait Removable: AstNode {
148    fn remove(&self, editor: &mut SyntaxEditor);
149}
150
151impl Removable for ast::Use {
152    fn remove(&self, editor: &mut SyntaxEditor) {
153        let make = SyntaxFactory::without_mappings();
154
155        let next_ws = self
156            .syntax()
157            .next_sibling_or_token()
158            .and_then(|it| it.into_token())
159            .and_then(ast::Whitespace::cast);
160        if let Some(next_ws) = next_ws {
161            let ws_text = next_ws.syntax().text();
162            if let Some(rest) = ws_text.strip_prefix('\n') {
163                if rest.is_empty() {
164                    editor.delete(next_ws.syntax());
165                } else {
166                    editor.replace(next_ws.syntax(), make.whitespace(rest));
167                }
168            }
169        }
170
171        editor.delete(self.syntax());
172    }
173}
174
175impl Removable for ast::UseTree {
176    fn remove(&self, editor: &mut SyntaxEditor) {
177        for dir in [Direction::Next, Direction::Prev] {
178            if let Some(next_use_tree) = neighbor(self, dir) {
179                let separators = self
180                    .syntax()
181                    .siblings_with_tokens(dir)
182                    .skip(1)
183                    .take_while(|it| it.as_node() != Some(next_use_tree.syntax()));
184                for sep in separators {
185                    editor.delete(sep);
186                }
187                break;
188            }
189        }
190        editor.delete(self.syntax());
191    }
192}
193
194#[cfg(test)]
195mod tests {
196    use parser::Edition;
197    use stdx::trim_indent;
198    use test_utils::assert_eq_text;
199
200    use crate::SourceFile;
201
202    use super::*;
203
204    fn ast_from_text<N: AstNode>(text: &str) -> N {
205        let parse = SourceFile::parse(text, Edition::CURRENT);
206        let node = match parse.tree().syntax().descendants().find_map(N::cast) {
207            Some(it) => it,
208            None => {
209                let node = std::any::type_name::<N>();
210                panic!("Failed to make ast node `{node}` from text {text}")
211            }
212        };
213        let node = node.clone_subtree();
214        assert_eq!(node.syntax().text_range().start(), 0.into());
215        node
216    }
217
218    #[test]
219    fn add_variant_to_empty_enum() {
220        let make = SyntaxFactory::without_mappings();
221        let variant = make.variant(None, make.name("Bar"), None, None);
222
223        check_add_variant(
224            r#"
225enum Foo {}
226"#,
227            r#"
228enum Foo {
229    Bar,
230}
231"#,
232            variant,
233        );
234    }
235
236    #[test]
237    fn add_variant_to_non_empty_enum() {
238        let make = SyntaxFactory::without_mappings();
239        let variant = make.variant(None, make.name("Baz"), None, None);
240
241        check_add_variant(
242            r#"
243enum Foo {
244    Bar,
245}
246"#,
247            r#"
248enum Foo {
249    Bar,
250    Baz,
251}
252"#,
253            variant,
254        );
255    }
256
257    #[test]
258    fn add_variant_with_tuple_field_list() {
259        let make = SyntaxFactory::without_mappings();
260        let variant = make.variant(
261            None,
262            make.name("Baz"),
263            Some(make.tuple_field_list([make.tuple_field(None, make.ty("bool"))]).into()),
264            None,
265        );
266
267        check_add_variant(
268            r#"
269enum Foo {
270    Bar,
271}
272"#,
273            r#"
274enum Foo {
275    Bar,
276    Baz(bool),
277}
278"#,
279            variant,
280        );
281    }
282
283    #[test]
284    fn add_variant_with_record_field_list() {
285        let make = SyntaxFactory::without_mappings();
286        let variant = make.variant(
287            None,
288            make.name("Baz"),
289            Some(
290                make.record_field_list([make.record_field(None, make.name("x"), make.ty("bool"))])
291                    .into(),
292            ),
293            None,
294        );
295
296        check_add_variant(
297            r#"
298enum Foo {
299    Bar,
300}
301"#,
302            r#"
303enum Foo {
304    Bar,
305    Baz { x: bool },
306}
307"#,
308            variant,
309        );
310    }
311
312    fn check_add_variant(before: &str, expected: &str, variant: ast::Variant) {
313        let enum_ = ast_from_text::<ast::Enum>(before);
314        let mut editor = SyntaxEditor::new(enum_.syntax().clone());
315        if let Some(it) = enum_.variant_list() {
316            it.add_variant(&mut editor, &variant)
317        }
318        let edit = editor.finish();
319        let after = edit.new_root.to_string();
320        assert_eq_text!(&trim_indent(expected.trim()), &trim_indent(after.trim()));
321    }
322}