Fix end location of nodes containing body

This commit is contained in:
harupy 2022-12-11 12:35:28 +09:00
parent 2b91ffb3ae
commit 2d75aeb276
3 changed files with 9 additions and 6 deletions

View file

@ -492,9 +492,10 @@ WithItem: ast::Withitem = {
}; };
FuncDef: ast::Stmt = { FuncDef: ast::Stmt = {
<decorator_list:Decorator*> <location:@L> <is_async:"async"?> "def" <name:Identifier> <args:Parameters> <r:("->" Test)?> ":" <body:Suite> <end_location:@R> => { <decorator_list:Decorator*> <location:@L> <is_async:"async"?> "def" <name:Identifier> <args:Parameters> <r:("->" Test)?> ":" <body:Suite> => {
let args = Box::new(args); let args = Box::new(args);
let returns = r.map(|x| Box::new(x.1)); let returns = r.map(|x| Box::new(x.1));
let end_location = body.last().unwrap().end_location.unwrap();
let type_comment = None; let type_comment = None;
let node = if is_async.is_some() { let node = if is_async.is_some() {
ast::StmtKind::AsyncFunctionDef { name, args, body, decorator_list, returns, type_comment } ast::StmtKind::AsyncFunctionDef { name, args, body, decorator_list, returns, type_comment }
@ -646,15 +647,16 @@ KwargParameter<ArgType>: Option<Box<ast::Arg>> = {
}; };
ClassDef: ast::Stmt = { ClassDef: ast::Stmt = {
<decorator_list:Decorator*> <location:@L> "class" <name:Identifier> <a:("(" ArgumentList ")")?> ":" <body:Suite> <end_location:@R> => { <decorator_list:Decorator*> <location:@L> "class" <name:Identifier> <a:("(" ArgumentList ")")?> ":" <body:Suite> => {
let (bases, keywords) = match a { let (bases, keywords) = match a {
Some((_, arg, _)) => (arg.args, arg.keywords), Some((_, arg, _)) => (arg.args, arg.keywords),
None => (vec![], vec![]), None => (vec![], vec![]),
}; };
let end_location = body.last().unwrap().end_location;
ast::Stmt { ast::Stmt {
custom: (), custom: (),
location, location,
end_location: Some(end_location), end_location,
node: ast::StmtKind::ClassDef { node: ast::StmtKind::ClassDef {
name, name,
bases, bases,

View file

@ -173,7 +173,8 @@ class Foo(A, B):
def __init__(self): def __init__(self):
pass pass
def method_with_default(self, arg='default'): def method_with_default(self, arg='default'):
pass"; pass
";
insta::assert_debug_snapshot!(parse_program(source, "<test>").unwrap()); insta::assert_debug_snapshot!(parse_program(source, "<test>").unwrap());
} }

View file

@ -62,8 +62,8 @@ expression: "parse_program(source, \"<test>\").unwrap()"
}, },
end_location: Some( end_location: Some(
Location { Location {
row: 4, row: 3,
column: 1, column: 6,
}, },
), ),
custom: (), custom: (),