From 31fd625321493473551b72b19c68aa70f6798bf1 Mon Sep 17 00:00:00 2001 From: Micha Reiser Date: Wed, 26 Apr 2023 10:24:34 -0600 Subject: [PATCH] Add `Located::start`, `Located::end` and impl `Deref` --- ast/asdl_rs.py | 18 ++++++++++++++ ast/src/ast_gen.rs | 18 ++++++++++++++ parser/python.lalrpop | 55 ++++++++++++++++++------------------------ parser/src/function.rs | 10 ++++---- 4 files changed, 65 insertions(+), 36 deletions(-) diff --git a/ast/asdl_rs.py b/ast/asdl_rs.py index 8caeca9..c9b819d 100755 --- a/ast/asdl_rs.py +++ b/ast/asdl_rs.py @@ -671,6 +671,24 @@ def write_ast_def(mod, typeinfo, f): pub fn new(location: Location, end_location: Location, node: T) -> Self { Self { location, end_location: Some(end_location), custom: (), node } } + + pub const fn start(&self) -> Location { + self.location + } + + /// Returns the node's [`end_location`](Located::end_location) or [`location`](Located::start) if + /// [`end_location`](Located::end_location) is `None`. + pub fn end(&self) -> Location { + self.end_location.unwrap_or(self.location) + } + } + + impl std::ops::Deref for Located { + type Target = T; + + fn deref(&self) -> &Self::Target { + &self.node + } } \n """.lstrip() diff --git a/ast/src/ast_gen.rs b/ast/src/ast_gen.rs index c11fd87..6771dd0 100644 --- a/ast/src/ast_gen.rs +++ b/ast/src/ast_gen.rs @@ -24,6 +24,24 @@ impl Located { node, } } + + pub const fn start(&self) -> Location { + self.location + } + + /// Returns the node's [`end_location`](Located::end_location) or [`location`](Located::start) if + /// [`end_location`](Located::end_location) is `None`. + pub fn end(&self) -> Location { + self.end_location.unwrap_or(self.location) + } +} + +impl std::ops::Deref for Located { + type Target = T; + + fn deref(&self) -> &Self::Target { + &self.node + } } #[derive(Clone, Debug, PartialEq)] diff --git a/parser/python.lalrpop b/parser/python.lalrpop index 32c0880..45113a3 100644 --- a/parser/python.lalrpop +++ b/parser/python.lalrpop @@ -348,8 +348,7 @@ MatchStatement: ast::Stmt = { .body .last() .unwrap() - .end_location - .unwrap(); + .end(); ast::Stmt::new( location, end_location, @@ -366,8 +365,7 @@ MatchStatement: ast::Stmt = { .body .last() .unwrap() - .end_location - .unwrap(); + .end(); ast::Stmt::new( location, end_location, @@ -384,8 +382,7 @@ MatchStatement: ast::Stmt = { .body .last() .unwrap() - .end_location - .unwrap(); + .end(); let mut subjects = subjects; subjects.insert(0, subject); ast::Stmt::new( @@ -803,8 +800,7 @@ IfStatement: ast::Stmt = { .or_else(|| s2.last().and_then(|last| last.4.last())) .or_else(|| body.last()) .unwrap() - .end_location - .unwrap(); + .end(); // handle elif: for i in s2.into_iter().rev() { let x = ast::Stmt::new( @@ -830,8 +826,7 @@ WhileStatement: ast::Stmt = { .last() .or_else(|| body.last()) .unwrap() - .end_location - .unwrap(); + .end(); ast::Stmt::new( location, end_location, @@ -851,8 +846,7 @@ ForStatement: ast::Stmt = { .last() .or_else(|| body.last()) .unwrap() - .end_location - .unwrap(); + .end(); let target = Box::new(set_context(target, ast::ExprContext::Store)); let iter = Box::new(iter); let type_comment = None; @@ -871,9 +865,9 @@ TryStatement: ast::Stmt = { let finalbody = finally.map(|s| s.2).unwrap_or_default(); let end_location = finalbody .last() - .and_then(|last| last.end_location) - .or_else(|| orelse.last().and_then(|last| last.end_location)) - .or_else(|| handlers.last().and_then(|last| last.end_location)) + .map(|last| last.end()) + .or_else(|| orelse.last().map(|last| last.end())) + .or_else(|| handlers.last().map(|last| last.end())) .unwrap(); ast::Stmt::new( location, @@ -892,8 +886,8 @@ TryStatement: ast::Stmt = { let end_location = finalbody .last() .or_else(|| orelse.last()) - .and_then(|last| last.end_location) - .or_else(|| handlers.last().and_then(|last| last.end_location)) + .map(|last| last.end()) + .or_else(|| handlers.last().map(|last| last.end())) .unwrap(); ast::Stmt::new( location, @@ -910,7 +904,7 @@ TryStatement: ast::Stmt = { let handlers = vec![]; let orelse = vec![]; let finalbody = finally.2; - let end_location = finalbody.last().unwrap().end_location.unwrap(); + let end_location = finalbody.last().unwrap().end(); ast::Stmt::new( location, end_location, @@ -926,7 +920,7 @@ TryStatement: ast::Stmt = { ExceptStarClause: ast::Excepthandler = { "except" "*" > ":" => { - let end_location = body.last().unwrap().end_location.unwrap(); + let end_location = body.last().unwrap().end(); ast::Excepthandler::new( location, end_location, @@ -938,7 +932,7 @@ ExceptStarClause: ast::Excepthandler = { ) }, "except" "*" "as" Identifier)> ":" => { - let end_location = body.last().unwrap().end_location.unwrap(); + let end_location = body.last().unwrap().end(); ast::Excepthandler::new( location, end_location, @@ -954,7 +948,7 @@ ExceptStarClause: ast::Excepthandler = { ExceptClause: ast::Excepthandler = { "except" ?> ":" => { - let end_location = body.last().unwrap().end_location.unwrap(); + let end_location = body.last().unwrap().end(); ast::Excepthandler::new( location, end_location, @@ -966,7 +960,7 @@ ExceptClause: ast::Excepthandler = { ) }, "except" "as" Identifier)> ":" => { - let end_location = body.last().unwrap().end_location.unwrap(); + let end_location = body.last().unwrap().end(); ast::Excepthandler::new( location, end_location, @@ -981,7 +975,7 @@ ExceptClause: ast::Excepthandler = { WithStatement: ast::Stmt = { "with" ":" => { - let end_location = body.last().unwrap().end_location.unwrap(); + let end_location = body.last().unwrap().end(); let type_comment = None; let node = if is_async.is_some() { ast::StmtKind::AsyncWith { items, body, type_comment } @@ -1022,7 +1016,7 @@ FuncDef: ast::Stmt = { "def" " Test<"all">)?> ":" => { let args = Box::new(args); let returns = r.map(|x| Box::new(x.1)); - let end_location = body.last().unwrap().end_location.unwrap(); + let end_location = body.last().unwrap().end(); let type_comment = None; let node = if is_async.is_some() { ast::StmtKind::AsyncFunctionDef { name, args, body, decorator_list, returns, type_comment } @@ -1197,7 +1191,7 @@ ClassDef: ast::Stmt = { Some((_, arg, _)) => (arg.args, arg.keywords), None => (vec![], vec![]), }; - let end_location = body.last().unwrap().end_location.unwrap(); + let end_location = body.last().unwrap().end(); ast::Stmt::new( location, end_location, @@ -1253,11 +1247,10 @@ NamedExpressionTest: ast::Expr = { NamedExpression: ast::Expr = { ":=" > => { - ast::Expr { + ast::Expr::new( location, - end_location: value.end_location, - custom: (), - node: ast::ExprKind::NamedExpr { + value.end(), + ast::ExprKind::NamedExpr { target: Box::new(ast::Expr::new( location, end_location, @@ -1265,7 +1258,7 @@ NamedExpression: ast::Expr = { )), value: Box::new(value), } - } + ) }, }; @@ -1564,7 +1557,7 @@ Atom: ast::Expr = { if matches!(mid.node, ast::ExprKind::Starred { .. }) { Err(LexicalError{ error: LexicalErrorType::OtherError("cannot use starred expression here".to_string()), - location: mid.location, + location: mid.start(), })? } Ok(mid) diff --git a/parser/src/function.rs b/parser/src/function.rs index 0f580e7..17b9882 100644 --- a/parser/src/function.rs +++ b/parser/src/function.rs @@ -35,12 +35,12 @@ pub(crate) fn validate_arguments( let mut all_arg_names = FxHashSet::with_hasher(Default::default()); for arg in all_args { - let arg_name = &arg.node.arg; + let arg_name = &arg.arg; // Check for duplicate arguments in the function definition. if !all_arg_names.insert(arg_name) { return Err(LexicalError { error: LexicalErrorType::DuplicateArgumentError(arg_name.to_string()), - location: arg.location, + location: arg.start(), }); } } @@ -64,7 +64,7 @@ pub(crate) fn parse_params( // have defaults. return Err(LexicalError { error: LexicalErrorType::DefaultArgumentError, - location: name.location, + location: name.start(), }); } Ok(()) @@ -126,14 +126,14 @@ pub(crate) fn parse_args(func_args: Vec) -> Result