syntax/ast/
expr_ext.rs

1//! Various extension methods to ast Expr Nodes, which are hard to code-generate.
2//!
3//! These methods should only do simple, shallow tasks related to the syntax of the node itself.
4
5use crate::{
6    AstToken,
7    SyntaxKind::{self, *},
8    SyntaxNode, SyntaxToken, T,
9    ast::{
10        self, ArgList, AstChildren, AstNode, BlockExpr, ClosureExpr, Const, Expr, Fn,
11        FormatArgsArg, FormatArgsExpr, MacroDef, Static, TokenTree,
12        operators::{ArithOp, BinaryOp, CmpOp, LogicOp, Ordering, RangeOp, UnaryOp},
13        support,
14    },
15};
16
17use super::RangeItem;
18
19impl ast::HasAttrs for ast::Expr {}
20
21impl ast::Expr {
22    pub fn is_block_like(&self) -> bool {
23        matches!(
24            self,
25            ast::Expr::IfExpr(_)
26                | ast::Expr::LoopExpr(_)
27                | ast::Expr::ForExpr(_)
28                | ast::Expr::WhileExpr(_)
29                | ast::Expr::BlockExpr(_)
30                | ast::Expr::MatchExpr(_)
31        )
32    }
33}
34
35#[derive(Debug, Clone, PartialEq, Eq)]
36pub enum ElseBranch {
37    Block(ast::BlockExpr),
38    IfExpr(ast::IfExpr),
39}
40
41impl From<ast::BlockExpr> for ElseBranch {
42    fn from(block_expr: ast::BlockExpr) -> Self {
43        Self::Block(block_expr)
44    }
45}
46
47impl From<ast::IfExpr> for ElseBranch {
48    fn from(if_expr: ast::IfExpr) -> Self {
49        Self::IfExpr(if_expr)
50    }
51}
52
53impl AstNode for ElseBranch {
54    fn can_cast(kind: SyntaxKind) -> bool {
55        ast::BlockExpr::can_cast(kind) || ast::IfExpr::can_cast(kind)
56    }
57
58    fn cast(syntax: SyntaxNode) -> Option<Self> {
59        if let Some(block_expr) = ast::BlockExpr::cast(syntax.clone()) {
60            Some(Self::Block(block_expr))
61        } else {
62            ast::IfExpr::cast(syntax).map(Self::IfExpr)
63        }
64    }
65
66    fn syntax(&self) -> &SyntaxNode {
67        match self {
68            ElseBranch::Block(block_expr) => block_expr.syntax(),
69            ElseBranch::IfExpr(if_expr) => if_expr.syntax(),
70        }
71    }
72}
73
74impl ast::IfExpr {
75    pub fn condition(&self) -> Option<ast::Expr> {
76        // If the condition is a BlockExpr, check if the then body is missing.
77        // If it is assume the condition is the expression that is missing instead.
78        let mut exprs = support::children(self.syntax());
79        let first = exprs.next();
80        match first {
81            Some(ast::Expr::BlockExpr(_)) => exprs.next().and(first),
82            first => first,
83        }
84    }
85
86    pub fn then_branch(&self) -> Option<ast::BlockExpr> {
87        match support::children(self.syntax()).nth(1)? {
88            ast::Expr::BlockExpr(block) => Some(block),
89            _ => None,
90        }
91    }
92
93    pub fn else_branch(&self) -> Option<ElseBranch> {
94        match support::children(self.syntax()).nth(2)? {
95            ast::Expr::BlockExpr(block) => Some(ElseBranch::Block(block)),
96            ast::Expr::IfExpr(elif) => Some(ElseBranch::IfExpr(elif)),
97            _ => None,
98        }
99    }
100}
101
102#[test]
103fn if_block_condition() {
104    let parse = ast::SourceFile::parse(
105        r#"
106        fn test() {
107            if { true } { "if" }
108            else if { false } { "first elif" }
109            else if true { "second elif" }
110            else if (true) { "third elif" }
111            else { "else" }
112        }
113        "#,
114        parser::Edition::CURRENT,
115    );
116    let if_ = parse.tree().syntax().descendants().find_map(ast::IfExpr::cast).unwrap();
117    assert_eq!(if_.then_branch().unwrap().syntax().text(), r#"{ "if" }"#);
118    let elif = match if_.else_branch().unwrap() {
119        ElseBranch::IfExpr(elif) => elif,
120        ElseBranch::Block(_) => panic!("should be `else if`"),
121    };
122    assert_eq!(elif.then_branch().unwrap().syntax().text(), r#"{ "first elif" }"#);
123    let elif = match elif.else_branch().unwrap() {
124        ElseBranch::IfExpr(elif) => elif,
125        ElseBranch::Block(_) => panic!("should be `else if`"),
126    };
127    assert_eq!(elif.then_branch().unwrap().syntax().text(), r#"{ "second elif" }"#);
128    let elif = match elif.else_branch().unwrap() {
129        ElseBranch::IfExpr(elif) => elif,
130        ElseBranch::Block(_) => panic!("should be `else if`"),
131    };
132    assert_eq!(elif.then_branch().unwrap().syntax().text(), r#"{ "third elif" }"#);
133    let else_ = match elif.else_branch().unwrap() {
134        ElseBranch::Block(else_) => else_,
135        ElseBranch::IfExpr(_) => panic!("should be `else`"),
136    };
137    assert_eq!(else_.syntax().text(), r#"{ "else" }"#);
138}
139
140#[test]
141fn if_condition_with_if_inside() {
142    let parse = ast::SourceFile::parse(
143        r#"
144        fn test() {
145            if if true { true } else { false } { "if" }
146            else { "else" }
147        }
148        "#,
149        parser::Edition::CURRENT,
150    );
151    let if_ = parse.tree().syntax().descendants().find_map(ast::IfExpr::cast).unwrap();
152    assert_eq!(if_.then_branch().unwrap().syntax().text(), r#"{ "if" }"#);
153    let else_ = match if_.else_branch().unwrap() {
154        ElseBranch::Block(else_) => else_,
155        ElseBranch::IfExpr(_) => panic!("should be `else`"),
156    };
157    assert_eq!(else_.syntax().text(), r#"{ "else" }"#);
158}
159
160impl ast::PrefixExpr {
161    pub fn op_kind(&self) -> Option<UnaryOp> {
162        let res = match self.op_token()?.kind() {
163            T![*] => UnaryOp::Deref,
164            T![!] => UnaryOp::Not,
165            T![-] => UnaryOp::Neg,
166            _ => return None,
167        };
168        Some(res)
169    }
170
171    pub fn op_token(&self) -> Option<SyntaxToken> {
172        self.syntax().first_child_or_token()?.into_token()
173    }
174}
175
176impl ast::BinExpr {
177    pub fn op_details(&self) -> Option<(SyntaxToken, BinaryOp)> {
178        self.syntax().children_with_tokens().filter_map(|it| it.into_token()).find_map(|c| {
179            #[rustfmt::skip]
180            let bin_op = match c.kind() {
181                T![||] => BinaryOp::LogicOp(LogicOp::Or),
182                T![&&] => BinaryOp::LogicOp(LogicOp::And),
183
184                T![==] => BinaryOp::CmpOp(CmpOp::Eq { negated: false }),
185                T![!=] => BinaryOp::CmpOp(CmpOp::Eq { negated: true }),
186                T![<=] => BinaryOp::CmpOp(CmpOp::Ord { ordering: Ordering::Less,    strict: false }),
187                T![>=] => BinaryOp::CmpOp(CmpOp::Ord { ordering: Ordering::Greater, strict: false }),
188                T![<]  => BinaryOp::CmpOp(CmpOp::Ord { ordering: Ordering::Less,    strict: true }),
189                T![>]  => BinaryOp::CmpOp(CmpOp::Ord { ordering: Ordering::Greater, strict: true }),
190
191                T![+]  => BinaryOp::ArithOp(ArithOp::Add),
192                T![*]  => BinaryOp::ArithOp(ArithOp::Mul),
193                T![-]  => BinaryOp::ArithOp(ArithOp::Sub),
194                T![/]  => BinaryOp::ArithOp(ArithOp::Div),
195                T![%]  => BinaryOp::ArithOp(ArithOp::Rem),
196                T![<<] => BinaryOp::ArithOp(ArithOp::Shl),
197                T![>>] => BinaryOp::ArithOp(ArithOp::Shr),
198                T![^]  => BinaryOp::ArithOp(ArithOp::BitXor),
199                T![|]  => BinaryOp::ArithOp(ArithOp::BitOr),
200                T![&]  => BinaryOp::ArithOp(ArithOp::BitAnd),
201
202                T![=]   => BinaryOp::Assignment { op: None },
203                T![+=]  => BinaryOp::Assignment { op: Some(ArithOp::Add) },
204                T![*=]  => BinaryOp::Assignment { op: Some(ArithOp::Mul) },
205                T![-=]  => BinaryOp::Assignment { op: Some(ArithOp::Sub) },
206                T![/=]  => BinaryOp::Assignment { op: Some(ArithOp::Div) },
207                T![%=]  => BinaryOp::Assignment { op: Some(ArithOp::Rem) },
208                T![<<=] => BinaryOp::Assignment { op: Some(ArithOp::Shl) },
209                T![>>=] => BinaryOp::Assignment { op: Some(ArithOp::Shr) },
210                T![^=]  => BinaryOp::Assignment { op: Some(ArithOp::BitXor) },
211                T![|=]  => BinaryOp::Assignment { op: Some(ArithOp::BitOr) },
212                T![&=]  => BinaryOp::Assignment { op: Some(ArithOp::BitAnd) },
213
214                _ => return None,
215            };
216            Some((c, bin_op))
217        })
218    }
219
220    pub fn op_kind(&self) -> Option<BinaryOp> {
221        self.op_details().map(|t| t.1)
222    }
223
224    pub fn op_token(&self) -> Option<SyntaxToken> {
225        self.op_details().map(|t| t.0)
226    }
227
228    pub fn lhs(&self) -> Option<ast::Expr> {
229        support::children(self.syntax()).next()
230    }
231
232    pub fn rhs(&self) -> Option<ast::Expr> {
233        support::children(self.syntax()).nth(1)
234    }
235
236    pub fn sub_exprs(&self) -> (Option<ast::Expr>, Option<ast::Expr>) {
237        let mut children = support::children(self.syntax());
238        let first = children.next();
239        let second = children.next();
240        (first, second)
241    }
242}
243
244impl ast::RangeExpr {
245    fn op_details(&self) -> Option<(usize, SyntaxToken, RangeOp)> {
246        self.syntax().children_with_tokens().enumerate().find_map(|(ix, child)| {
247            let token = child.into_token()?;
248            let bin_op = match token.kind() {
249                T![..] => RangeOp::Exclusive,
250                T![..=] => RangeOp::Inclusive,
251                _ => return None,
252            };
253            Some((ix, token, bin_op))
254        })
255    }
256
257    pub fn is_range_full(&self) -> bool {
258        support::children::<Expr>(&self.syntax).next().is_none()
259    }
260}
261
262impl RangeItem for ast::RangeExpr {
263    type Bound = ast::Expr;
264
265    fn start(&self) -> Option<ast::Expr> {
266        let op_ix = self.op_details()?.0;
267        self.syntax()
268            .children_with_tokens()
269            .take(op_ix)
270            .find_map(|it| ast::Expr::cast(it.into_node()?))
271    }
272
273    fn end(&self) -> Option<ast::Expr> {
274        let op_ix = self.op_details()?.0;
275        self.syntax()
276            .children_with_tokens()
277            .skip(op_ix + 1)
278            .find_map(|it| ast::Expr::cast(it.into_node()?))
279    }
280
281    fn op_token(&self) -> Option<SyntaxToken> {
282        self.op_details().map(|t| t.1)
283    }
284
285    fn op_kind(&self) -> Option<RangeOp> {
286        self.op_details().map(|t| t.2)
287    }
288}
289
290impl ast::IndexExpr {
291    pub fn base(&self) -> Option<ast::Expr> {
292        support::children(self.syntax()).next()
293    }
294    pub fn index(&self) -> Option<ast::Expr> {
295        support::children(self.syntax()).nth(1)
296    }
297}
298
299pub enum ArrayExprKind {
300    Repeat { initializer: Option<ast::Expr>, repeat: Option<ast::Expr> },
301    ElementList(AstChildren<ast::Expr>),
302}
303
304impl ast::ArrayExpr {
305    pub fn kind(&self) -> ArrayExprKind {
306        if self.is_repeat() {
307            ArrayExprKind::Repeat {
308                initializer: support::children(self.syntax()).next(),
309                repeat: support::children(self.syntax()).nth(1),
310            }
311        } else {
312            ArrayExprKind::ElementList(support::children(self.syntax()))
313        }
314    }
315
316    fn is_repeat(&self) -> bool {
317        self.semicolon_token().is_some()
318    }
319}
320
321#[derive(Clone, Debug, PartialEq, Eq, Hash)]
322pub enum LiteralKind {
323    String(ast::String),
324    ByteString(ast::ByteString),
325    CString(ast::CString),
326    IntNumber(ast::IntNumber),
327    FloatNumber(ast::FloatNumber),
328    Char(ast::Char),
329    Byte(ast::Byte),
330    Bool(bool),
331}
332
333impl ast::Literal {
334    pub fn token(&self) -> SyntaxToken {
335        self.syntax()
336            .children_with_tokens()
337            .find(|e| e.kind() != ATTR && !e.kind().is_trivia())
338            .and_then(|e| e.into_token())
339            .unwrap()
340    }
341
342    pub fn kind(&self) -> LiteralKind {
343        let token = self.token();
344
345        if let Some(t) = ast::IntNumber::cast(token.clone()) {
346            return LiteralKind::IntNumber(t);
347        }
348        if let Some(t) = ast::FloatNumber::cast(token.clone()) {
349            return LiteralKind::FloatNumber(t);
350        }
351        if let Some(t) = ast::String::cast(token.clone()) {
352            return LiteralKind::String(t);
353        }
354        if let Some(t) = ast::ByteString::cast(token.clone()) {
355            return LiteralKind::ByteString(t);
356        }
357        if let Some(t) = ast::CString::cast(token.clone()) {
358            return LiteralKind::CString(t);
359        }
360        if let Some(t) = ast::Char::cast(token.clone()) {
361            return LiteralKind::Char(t);
362        }
363        if let Some(t) = ast::Byte::cast(token.clone()) {
364            return LiteralKind::Byte(t);
365        }
366
367        match token.kind() {
368            T![true] => LiteralKind::Bool(true),
369            T![false] => LiteralKind::Bool(false),
370            _ => unreachable!(),
371        }
372    }
373}
374
375pub enum BlockModifier {
376    Async(SyntaxToken),
377    Unsafe(SyntaxToken),
378    Try(SyntaxToken),
379    Const(SyntaxToken),
380    AsyncGen(SyntaxToken),
381    Gen(SyntaxToken),
382    Label(ast::Label),
383}
384
385impl ast::BlockExpr {
386    pub fn modifier(&self) -> Option<BlockModifier> {
387        self.gen_token()
388            .map(|v| {
389                if self.async_token().is_some() {
390                    BlockModifier::AsyncGen(v)
391                } else {
392                    BlockModifier::Gen(v)
393                }
394            })
395            .or_else(|| self.async_token().map(BlockModifier::Async))
396            .or_else(|| self.unsafe_token().map(BlockModifier::Unsafe))
397            .or_else(|| self.try_token().map(BlockModifier::Try))
398            .or_else(|| self.const_token().map(BlockModifier::Const))
399            .or_else(|| self.label().map(BlockModifier::Label))
400    }
401    /// false if the block is an intrinsic part of the syntax and can't be
402    /// replaced with arbitrary expression.
403    ///
404    /// ```not_rust
405    /// fn foo() { not_stand_alone }
406    /// const FOO: () = { stand_alone };
407    /// ```
408    pub fn is_standalone(&self) -> bool {
409        let parent = match self.syntax().parent() {
410            Some(it) => it,
411            None => return true,
412        };
413        match parent.kind() {
414            FOR_EXPR | IF_EXPR => parent
415                .children()
416                .find(|it| ast::Expr::can_cast(it.kind()))
417                .is_none_or(|it| it == *self.syntax()),
418            LET_ELSE | FN | WHILE_EXPR | LOOP_EXPR | CONST_BLOCK_PAT => false,
419            _ => true,
420        }
421    }
422}
423
424#[test]
425fn test_literal_with_attr() {
426    let parse =
427        ast::SourceFile::parse(r#"const _: &str = { #[attr] "Hello" };"#, parser::Edition::CURRENT);
428    let lit = parse.tree().syntax().descendants().find_map(ast::Literal::cast).unwrap();
429    assert_eq!(lit.token().text(), r#""Hello""#);
430}
431
432impl ast::RecordExprField {
433    pub fn parent_record_lit(&self) -> ast::RecordExpr {
434        self.syntax().ancestors().find_map(ast::RecordExpr::cast).unwrap()
435    }
436}
437
438#[derive(Debug, Clone, PartialEq, Eq, Hash)]
439pub enum CallableExpr {
440    Call(ast::CallExpr),
441    MethodCall(ast::MethodCallExpr),
442}
443
444impl ast::HasAttrs for CallableExpr {}
445impl ast::HasArgList for CallableExpr {}
446
447impl AstNode for CallableExpr {
448    fn can_cast(kind: parser::SyntaxKind) -> bool
449    where
450        Self: Sized,
451    {
452        ast::CallExpr::can_cast(kind) || ast::MethodCallExpr::can_cast(kind)
453    }
454
455    fn cast(syntax: SyntaxNode) -> Option<Self>
456    where
457        Self: Sized,
458    {
459        if let Some(it) = ast::CallExpr::cast(syntax.clone()) {
460            Some(Self::Call(it))
461        } else {
462            ast::MethodCallExpr::cast(syntax).map(Self::MethodCall)
463        }
464    }
465
466    fn syntax(&self) -> &SyntaxNode {
467        match self {
468            Self::Call(it) => it.syntax(),
469            Self::MethodCall(it) => it.syntax(),
470        }
471    }
472}
473
474impl MacroDef {
475    fn tts(&self) -> (Option<ast::TokenTree>, Option<ast::TokenTree>) {
476        let mut types = support::children(self.syntax());
477        let first = types.next();
478        let second = types.next();
479        (first, second)
480    }
481
482    pub fn args(&self) -> Option<TokenTree> {
483        match self.tts() {
484            (Some(args), Some(_)) => Some(args),
485            _ => None,
486        }
487    }
488
489    pub fn body(&self) -> Option<TokenTree> {
490        match self.tts() {
491            (Some(body), None) | (_, Some(body)) => Some(body),
492            _ => None,
493        }
494    }
495}
496
497impl ClosureExpr {
498    pub fn body(&self) -> Option<Expr> {
499        support::child(&self.syntax)
500    }
501}
502impl Const {
503    pub fn body(&self) -> Option<Expr> {
504        support::child(&self.syntax)
505    }
506}
507impl Fn {
508    pub fn body(&self) -> Option<BlockExpr> {
509        support::child(&self.syntax)
510    }
511}
512impl Static {
513    pub fn body(&self) -> Option<Expr> {
514        support::child(&self.syntax)
515    }
516}
517impl FormatArgsExpr {
518    pub fn args(&self) -> AstChildren<FormatArgsArg> {
519        support::children(&self.syntax)
520    }
521}
522impl ArgList {
523    pub fn args(&self) -> AstChildren<Expr> {
524        support::children(&self.syntax)
525    }
526}