mirror of
https://github.com/python/cpython.git
synced 2025-08-04 08:59:19 +00:00
GH-131498: Cases generator: Parse down to C statement level. (GH-131948)
* Parse down to statement level in the cases generator * Add handling for #if macros, treating them much like normal ifs.
This commit is contained in:
parent
6e91d1f9aa
commit
ad053d8d6a
16 changed files with 795 additions and 959 deletions
|
@ -1,10 +1,12 @@
|
|||
"""Parser for bytecodes.inst."""
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import NamedTuple, Callable, TypeVar, Literal, cast
|
||||
from typing import NamedTuple, Callable, TypeVar, Literal, cast, Iterator
|
||||
from io import StringIO
|
||||
|
||||
import lexer as lx
|
||||
from plexer import PLexer
|
||||
from cwriter import CWriter
|
||||
|
||||
|
||||
P = TypeVar("P", bound="Parser")
|
||||
|
@ -66,12 +68,181 @@ class Node:
|
|||
assert context is not None
|
||||
return context.owner.tokens[context.begin]
|
||||
|
||||
# Statements
|
||||
|
||||
Visitor = Callable[["Stmt"], None]
|
||||
|
||||
class Stmt:
|
||||
|
||||
def __repr__(self) -> str:
|
||||
io = StringIO()
|
||||
out = CWriter(io, 0, False)
|
||||
self.print(out)
|
||||
return io.getvalue()
|
||||
|
||||
def print(self, out:CWriter) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
def accept(self, visitor: Visitor) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
def tokens(self) -> Iterator[lx.Token]:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
@dataclass
|
||||
class Block(Node):
|
||||
# This just holds a context which has the list of tokens.
|
||||
pass
|
||||
class IfStmt(Stmt):
|
||||
if_: lx.Token
|
||||
condition: list[lx.Token]
|
||||
body: Stmt
|
||||
else_: lx.Token | None
|
||||
else_body: Stmt | None
|
||||
|
||||
def print(self, out:CWriter) -> None:
|
||||
out.emit(self.if_)
|
||||
for tkn in self.condition:
|
||||
out.emit(tkn)
|
||||
self.body.print(out)
|
||||
if self.else_ is not None:
|
||||
out.emit(self.else_)
|
||||
self.body.print(out)
|
||||
if self.else_body is not None:
|
||||
self.else_body.print(out)
|
||||
|
||||
def accept(self, visitor: Visitor) -> None:
|
||||
visitor(self)
|
||||
self.body.accept(visitor)
|
||||
if self.else_body is not None:
|
||||
self.else_body.accept(visitor)
|
||||
|
||||
def tokens(self) -> Iterator[lx.Token]:
|
||||
yield self.if_
|
||||
yield from self.condition
|
||||
yield from self.body.tokens()
|
||||
if self.else_ is not None:
|
||||
yield self.else_
|
||||
if self.else_body is not None:
|
||||
yield from self.else_body.tokens()
|
||||
|
||||
|
||||
@dataclass
|
||||
class ForStmt(Stmt):
|
||||
for_: lx.Token
|
||||
header: list[lx.Token]
|
||||
body: Stmt
|
||||
|
||||
def print(self, out:CWriter) -> None:
|
||||
out.emit(self.for_)
|
||||
for tkn in self.header:
|
||||
out.emit(tkn)
|
||||
self.body.print(out)
|
||||
|
||||
def accept(self, visitor: Visitor) -> None:
|
||||
visitor(self)
|
||||
self.body.accept(visitor)
|
||||
|
||||
def tokens(self) -> Iterator[lx.Token]:
|
||||
yield self.for_
|
||||
yield from self.header
|
||||
yield from self.body.tokens()
|
||||
|
||||
|
||||
@dataclass
|
||||
class WhileStmt(Stmt):
|
||||
while_: lx.Token
|
||||
condition: list[lx.Token]
|
||||
body: Stmt
|
||||
|
||||
def print(self, out:CWriter) -> None:
|
||||
out.emit(self.while_)
|
||||
for tkn in self.condition:
|
||||
out.emit(tkn)
|
||||
self.body.print(out)
|
||||
|
||||
def accept(self, visitor: Visitor) -> None:
|
||||
visitor(self)
|
||||
self.body.accept(visitor)
|
||||
|
||||
def tokens(self) -> Iterator[lx.Token]:
|
||||
yield self.while_
|
||||
yield from self.condition
|
||||
yield from self.body.tokens()
|
||||
|
||||
|
||||
@dataclass
|
||||
class MacroIfStmt(Stmt):
|
||||
condition: lx.Token
|
||||
body: list[Stmt]
|
||||
else_: lx.Token | None
|
||||
else_body: list[Stmt] | None
|
||||
endif: lx.Token
|
||||
|
||||
def print(self, out:CWriter) -> None:
|
||||
out.emit(self.condition)
|
||||
for stmt in self.body:
|
||||
stmt.print(out)
|
||||
if self.else_body is not None:
|
||||
out.emit("#else\n")
|
||||
for stmt in self.else_body:
|
||||
stmt.print(out)
|
||||
|
||||
def accept(self, visitor: Visitor) -> None:
|
||||
visitor(self)
|
||||
for stmt in self.body:
|
||||
stmt.accept(visitor)
|
||||
if self.else_body is not None:
|
||||
for stmt in self.else_body:
|
||||
stmt.accept(visitor)
|
||||
|
||||
def tokens(self) -> Iterator[lx.Token]:
|
||||
yield self.condition
|
||||
for stmt in self.body:
|
||||
yield from stmt.tokens()
|
||||
if self.else_body is not None:
|
||||
for stmt in self.else_body:
|
||||
yield from stmt.tokens()
|
||||
|
||||
|
||||
@dataclass
|
||||
class BlockStmt(Stmt):
|
||||
open: lx.Token
|
||||
body: list[Stmt]
|
||||
close: lx.Token
|
||||
|
||||
def print(self, out:CWriter) -> None:
|
||||
out.emit(self.open)
|
||||
for stmt in self.body:
|
||||
stmt.print(out)
|
||||
out.start_line()
|
||||
out.emit(self.close)
|
||||
|
||||
def accept(self, visitor: Visitor) -> None:
|
||||
visitor(self)
|
||||
for stmt in self.body:
|
||||
stmt.accept(visitor)
|
||||
|
||||
def tokens(self) -> Iterator[lx.Token]:
|
||||
yield self.open
|
||||
for stmt in self.body:
|
||||
yield from stmt.tokens()
|
||||
yield self.close
|
||||
|
||||
|
||||
@dataclass
|
||||
class SimpleStmt(Stmt):
|
||||
contents: list[lx.Token]
|
||||
|
||||
def print(self, out:CWriter) -> None:
|
||||
for tkn in self.contents:
|
||||
out.emit(tkn)
|
||||
|
||||
def tokens(self) -> Iterator[lx.Token]:
|
||||
yield from self.contents
|
||||
|
||||
def accept(self, visitor: Visitor) -> None:
|
||||
visitor(self)
|
||||
|
||||
__hash__ = object.__hash__
|
||||
|
||||
@dataclass
|
||||
class StackEffect(Node):
|
||||
|
@ -124,7 +295,7 @@ class InstDef(Node):
|
|||
name: str
|
||||
inputs: list[InputEffect]
|
||||
outputs: list[OutputEffect]
|
||||
block: Block
|
||||
block: BlockStmt
|
||||
|
||||
|
||||
@dataclass
|
||||
|
@ -153,7 +324,7 @@ class Pseudo(Node):
|
|||
class LabelDef(Node):
|
||||
name: str
|
||||
spilled: bool
|
||||
block: Block
|
||||
block: BlockStmt
|
||||
|
||||
|
||||
AstNode = InstDef | Macro | Pseudo | Family | LabelDef
|
||||
|
@ -183,23 +354,22 @@ class Parser(PLexer):
|
|||
if self.expect(lx.LPAREN):
|
||||
if tkn := self.expect(lx.IDENTIFIER):
|
||||
if self.expect(lx.RPAREN):
|
||||
if block := self.block():
|
||||
return LabelDef(tkn.text, spilled, block)
|
||||
block = self.block()
|
||||
return LabelDef(tkn.text, spilled, block)
|
||||
return None
|
||||
|
||||
@contextual
|
||||
def inst_def(self) -> InstDef | None:
|
||||
if hdr := self.inst_header():
|
||||
if block := self.block():
|
||||
return InstDef(
|
||||
hdr.annotations,
|
||||
hdr.kind,
|
||||
hdr.name,
|
||||
hdr.inputs,
|
||||
hdr.outputs,
|
||||
block,
|
||||
)
|
||||
raise self.make_syntax_error("Expected block")
|
||||
block = self.block()
|
||||
return InstDef(
|
||||
hdr.annotations,
|
||||
hdr.kind,
|
||||
hdr.name,
|
||||
hdr.inputs,
|
||||
hdr.outputs,
|
||||
block,
|
||||
)
|
||||
return None
|
||||
|
||||
@contextual
|
||||
|
@ -473,28 +643,85 @@ class Parser(PLexer):
|
|||
self.setpos(here)
|
||||
return None
|
||||
|
||||
@contextual
|
||||
def block(self) -> Block | None:
|
||||
if self.c_blob():
|
||||
return Block()
|
||||
return None
|
||||
def block(self) -> BlockStmt:
|
||||
open = self.require(lx.LBRACE)
|
||||
stmts: list[Stmt] = []
|
||||
while not (close := self.expect(lx.RBRACE)):
|
||||
stmts.append(self.stmt())
|
||||
return BlockStmt(open, stmts, close)
|
||||
|
||||
def c_blob(self) -> list[lx.Token]:
|
||||
tokens: list[lx.Token] = []
|
||||
level = 0
|
||||
while tkn := self.next(raw=True):
|
||||
tokens.append(tkn)
|
||||
if tkn.kind in (lx.LBRACE, lx.LPAREN, lx.LBRACKET):
|
||||
level += 1
|
||||
elif tkn.kind in (lx.RBRACE, lx.RPAREN, lx.RBRACKET):
|
||||
level -= 1
|
||||
if level <= 0:
|
||||
break
|
||||
return tokens
|
||||
def stmt(self) -> Stmt:
|
||||
if tkn := self.expect(lx.IF):
|
||||
return self.if_stmt(tkn)
|
||||
elif self.expect(lx.LBRACE):
|
||||
self.backup()
|
||||
return self.block()
|
||||
elif tkn := self.expect(lx.FOR):
|
||||
return self.for_stmt(tkn)
|
||||
elif tkn := self.expect(lx.WHILE):
|
||||
return self.while_stmt(tkn)
|
||||
elif tkn := self.expect(lx.CMACRO_IF):
|
||||
return self.macro_if(tkn)
|
||||
elif tkn := self.expect(lx.CMACRO_ELSE):
|
||||
msg = "Unexpected #else"
|
||||
raise self.make_syntax_error(msg)
|
||||
elif tkn := self.expect(lx.CMACRO_ENDIF):
|
||||
msg = "Unexpected #endif"
|
||||
raise self.make_syntax_error(msg)
|
||||
elif tkn := self.expect(lx.CMACRO_OTHER):
|
||||
return SimpleStmt([tkn])
|
||||
elif tkn := self.expect(lx.SWITCH):
|
||||
msg = "switch statements are not supported due to their complex flow control. Sorry."
|
||||
raise self.make_syntax_error(msg)
|
||||
tokens = self.consume_to(lx.SEMI)
|
||||
return SimpleStmt(tokens)
|
||||
|
||||
def if_stmt(self, if_: lx.Token) -> IfStmt:
|
||||
lparen = self.require(lx.LPAREN)
|
||||
condition = [lparen] + self.consume_to(lx.RPAREN)
|
||||
body = self.block()
|
||||
else_body: Stmt | None = None
|
||||
else_: lx.Token | None = None
|
||||
if else_ := self.expect(lx.ELSE):
|
||||
if inner := self.expect(lx.IF):
|
||||
else_body = self.if_stmt(inner)
|
||||
else:
|
||||
else_body = self.block()
|
||||
return IfStmt(if_, condition, body, else_, else_body)
|
||||
|
||||
def macro_if(self, cond: lx.Token) -> MacroIfStmt:
|
||||
else_ = None
|
||||
body: list[Stmt] = []
|
||||
else_body: list[Stmt] | None = None
|
||||
part = body
|
||||
while True:
|
||||
if tkn := self.expect(lx.CMACRO_ENDIF):
|
||||
return MacroIfStmt(cond, body, else_, else_body, tkn)
|
||||
elif tkn := self.expect(lx.CMACRO_ELSE):
|
||||
if part is else_body:
|
||||
raise self.make_syntax_error("Multiple #else")
|
||||
else_ = tkn
|
||||
else_body = []
|
||||
part = else_body
|
||||
else:
|
||||
part.append(self.stmt())
|
||||
|
||||
def for_stmt(self, for_: lx.Token) -> ForStmt:
|
||||
lparen = self.require(lx.LPAREN)
|
||||
header = [lparen] + self.consume_to(lx.RPAREN)
|
||||
body = self.block()
|
||||
return ForStmt(for_, header, body)
|
||||
|
||||
def while_stmt(self, while_: lx.Token) -> WhileStmt:
|
||||
lparen = self.require(lx.LPAREN)
|
||||
cond = [lparen] + self.consume_to(lx.RPAREN)
|
||||
body = self.block()
|
||||
return WhileStmt(while_, cond, body)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import sys
|
||||
import pprint
|
||||
|
||||
if sys.argv[1:]:
|
||||
filename = sys.argv[1]
|
||||
|
@ -512,5 +739,5 @@ if __name__ == "__main__":
|
|||
filename = "<default>"
|
||||
src = "if (x) { x.foo; // comment\n}"
|
||||
parser = Parser(src, filename)
|
||||
x = parser.definition()
|
||||
print(x)
|
||||
while node := parser.definition():
|
||||
pprint.pprint(node)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue