1use crate::{
4 AstToken, Direction, SyntaxElement, SyntaxKind, SyntaxNode, SyntaxToken, T,
5 algo::neighbor,
6 ast::{
7 self, AstNode, Fn, GenericParam, HasGenericParams, HasName, edit::IndentLevel, make,
8 syntax_factory::SyntaxFactory,
9 },
10 syntax_editor::{Position, SyntaxEditor},
11};
12
13impl SyntaxEditor {
14 pub fn add_generic_param(&mut self, function: &Fn, new_param: GenericParam) {
16 match function.generic_param_list() {
17 Some(generic_param_list) => match generic_param_list.generic_params().last() {
18 Some(last_param) => {
19 let position = generic_param_list.r_angle_token().map_or_else(
21 || Position::last_child_of(function.syntax()),
22 Position::before,
23 );
24
25 if last_param
26 .syntax()
27 .next_sibling_or_token()
28 .is_some_and(|it| it.kind() == SyntaxKind::COMMA)
29 {
30 self.insert(
31 Position::after(last_param.syntax()),
32 new_param.syntax().clone(),
33 );
34 self.insert(
35 Position::after(last_param.syntax()),
36 make::token(SyntaxKind::WHITESPACE),
37 );
38 self.insert(
39 Position::after(last_param.syntax()),
40 make::token(SyntaxKind::COMMA),
41 );
42 } else {
43 let elements = vec![
44 make::token(SyntaxKind::COMMA).into(),
45 make::token(SyntaxKind::WHITESPACE).into(),
46 new_param.syntax().clone().into(),
47 ];
48 self.insert_all(position, elements);
49 }
50 }
51 None => {
52 let position = Position::after(generic_param_list.l_angle_token().unwrap());
54 self.insert(position, new_param.syntax());
55 }
56 },
57 None => {
58 let position = if let Some(name) = function.name() {
60 Position::after(name.syntax)
61 } else if let Some(fn_token) = function.fn_token() {
62 Position::after(fn_token)
63 } else if let Some(param_list) = function.param_list() {
64 Position::before(param_list.syntax)
65 } else {
66 Position::last_child_of(function.syntax())
67 };
68 let elements = vec![
69 make::token(SyntaxKind::L_ANGLE).into(),
70 new_param.syntax().clone().into(),
71 make::token(SyntaxKind::R_ANGLE).into(),
72 ];
73 self.insert_all(position, elements);
74 }
75 }
76 }
77}
78
79fn get_or_insert_comma_after(editor: &mut SyntaxEditor, syntax: &SyntaxNode) -> SyntaxToken {
80 let make = SyntaxFactory::without_mappings();
81 match syntax
82 .siblings_with_tokens(Direction::Next)
83 .filter_map(|it| it.into_token())
84 .find(|it| it.kind() == T![,])
85 {
86 Some(it) => it,
87 None => {
88 let comma = make.token(T![,]);
89 editor.insert(Position::after(syntax), &comma);
90 comma
91 }
92 }
93}
94
95impl ast::VariantList {
96 pub fn add_variant(&self, editor: &mut SyntaxEditor, variant: &ast::Variant) {
97 let make = SyntaxFactory::without_mappings();
98 let (indent, position) = match self.variants().last() {
99 Some(last_item) => (
100 IndentLevel::from_node(last_item.syntax()),
101 Position::after(get_or_insert_comma_after(editor, last_item.syntax())),
102 ),
103 None => match self.l_curly_token() {
104 Some(l_curly) => {
105 normalize_ws_between_braces(editor, self.syntax());
106 (IndentLevel::from_token(&l_curly) + 1, Position::after(&l_curly))
107 }
108 None => (IndentLevel::single(), Position::last_child_of(self.syntax())),
109 },
110 };
111 let elements: Vec<SyntaxElement> = vec![
112 make.whitespace(&format!("{}{indent}", "\n")).into(),
113 variant.syntax().clone().into(),
114 make.token(T![,]).into(),
115 ];
116 editor.insert_all(position, elements);
117 }
118}
119
120fn normalize_ws_between_braces(editor: &mut SyntaxEditor, node: &SyntaxNode) -> Option<()> {
121 let make = SyntaxFactory::without_mappings();
122 let l = node
123 .children_with_tokens()
124 .filter_map(|it| it.into_token())
125 .find(|it| it.kind() == T!['{'])?;
126 let r = node
127 .children_with_tokens()
128 .filter_map(|it| it.into_token())
129 .find(|it| it.kind() == T!['}'])?;
130
131 let indent = IndentLevel::from_node(node);
132
133 match l.next_sibling_or_token() {
134 Some(ws) if ws.kind() == SyntaxKind::WHITESPACE => {
135 if ws.next_sibling_or_token()?.into_token()? == r {
136 editor.replace(ws, make.whitespace(&format!("\n{indent}")));
137 }
138 }
139 Some(ws) if ws.kind() == T!['}'] => {
140 editor.insert(Position::after(l), make.whitespace(&format!("\n{indent}")));
141 }
142 _ => (),
143 }
144 Some(())
145}
146
147pub trait Removable: AstNode {
148 fn remove(&self, editor: &mut SyntaxEditor);
149}
150
151impl Removable for ast::Use {
152 fn remove(&self, editor: &mut SyntaxEditor) {
153 let make = SyntaxFactory::without_mappings();
154
155 let next_ws = self
156 .syntax()
157 .next_sibling_or_token()
158 .and_then(|it| it.into_token())
159 .and_then(ast::Whitespace::cast);
160 if let Some(next_ws) = next_ws {
161 let ws_text = next_ws.syntax().text();
162 if let Some(rest) = ws_text.strip_prefix('\n') {
163 if rest.is_empty() {
164 editor.delete(next_ws.syntax());
165 } else {
166 editor.replace(next_ws.syntax(), make.whitespace(rest));
167 }
168 }
169 }
170
171 editor.delete(self.syntax());
172 }
173}
174
175impl Removable for ast::UseTree {
176 fn remove(&self, editor: &mut SyntaxEditor) {
177 for dir in [Direction::Next, Direction::Prev] {
178 if let Some(next_use_tree) = neighbor(self, dir) {
179 let separators = self
180 .syntax()
181 .siblings_with_tokens(dir)
182 .skip(1)
183 .take_while(|it| it.as_node() != Some(next_use_tree.syntax()));
184 for sep in separators {
185 editor.delete(sep);
186 }
187 break;
188 }
189 }
190 editor.delete(self.syntax());
191 }
192}
193
194#[cfg(test)]
195mod tests {
196 use parser::Edition;
197 use stdx::trim_indent;
198 use test_utils::assert_eq_text;
199
200 use crate::SourceFile;
201
202 use super::*;
203
204 fn ast_from_text<N: AstNode>(text: &str) -> N {
205 let parse = SourceFile::parse(text, Edition::CURRENT);
206 let node = match parse.tree().syntax().descendants().find_map(N::cast) {
207 Some(it) => it,
208 None => {
209 let node = std::any::type_name::<N>();
210 panic!("Failed to make ast node `{node}` from text {text}")
211 }
212 };
213 let node = node.clone_subtree();
214 assert_eq!(node.syntax().text_range().start(), 0.into());
215 node
216 }
217
218 #[test]
219 fn add_variant_to_empty_enum() {
220 let make = SyntaxFactory::without_mappings();
221 let variant = make.variant(None, make.name("Bar"), None, None);
222
223 check_add_variant(
224 r#"
225enum Foo {}
226"#,
227 r#"
228enum Foo {
229 Bar,
230}
231"#,
232 variant,
233 );
234 }
235
236 #[test]
237 fn add_variant_to_non_empty_enum() {
238 let make = SyntaxFactory::without_mappings();
239 let variant = make.variant(None, make.name("Baz"), None, None);
240
241 check_add_variant(
242 r#"
243enum Foo {
244 Bar,
245}
246"#,
247 r#"
248enum Foo {
249 Bar,
250 Baz,
251}
252"#,
253 variant,
254 );
255 }
256
257 #[test]
258 fn add_variant_with_tuple_field_list() {
259 let make = SyntaxFactory::without_mappings();
260 let variant = make.variant(
261 None,
262 make.name("Baz"),
263 Some(make.tuple_field_list([make.tuple_field(None, make.ty("bool"))]).into()),
264 None,
265 );
266
267 check_add_variant(
268 r#"
269enum Foo {
270 Bar,
271}
272"#,
273 r#"
274enum Foo {
275 Bar,
276 Baz(bool),
277}
278"#,
279 variant,
280 );
281 }
282
283 #[test]
284 fn add_variant_with_record_field_list() {
285 let make = SyntaxFactory::without_mappings();
286 let variant = make.variant(
287 None,
288 make.name("Baz"),
289 Some(
290 make.record_field_list([make.record_field(None, make.name("x"), make.ty("bool"))])
291 .into(),
292 ),
293 None,
294 );
295
296 check_add_variant(
297 r#"
298enum Foo {
299 Bar,
300}
301"#,
302 r#"
303enum Foo {
304 Bar,
305 Baz { x: bool },
306}
307"#,
308 variant,
309 );
310 }
311
312 fn check_add_variant(before: &str, expected: &str, variant: ast::Variant) {
313 let enum_ = ast_from_text::<ast::Enum>(before);
314 let mut editor = SyntaxEditor::new(enum_.syntax().clone());
315 if let Some(it) = enum_.variant_list() {
316 it.add_variant(&mut editor, &variant)
317 }
318 let edit = editor.finish();
319 let after = edit.new_root.to_string();
320 assert_eq_text!(&trim_indent(expected.trim()), &trim_indent(after.trim()));
321 }
322}