GH-131498: Cases generator: Parse down to C statement level. (GH-131948)

* Parse down to statement level in the cases generator

* Add handling for #if macros, treating them much like normal ifs.
This commit is contained in:
Mark Shannon 2025-04-02 16:31:59 +01:00 committed by GitHub
parent 6e91d1f9aa
commit ad053d8d6a
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
16 changed files with 795 additions and 959 deletions

View file

@ -1,10 +1,12 @@
"""Parser for bytecodes.inst."""
from dataclasses import dataclass, field
from typing import NamedTuple, Callable, TypeVar, Literal, cast
from typing import NamedTuple, Callable, TypeVar, Literal, cast, Iterator
from io import StringIO
import lexer as lx
from plexer import PLexer
from cwriter import CWriter
P = TypeVar("P", bound="Parser")
@ -66,12 +68,181 @@ class Node:
assert context is not None
return context.owner.tokens[context.begin]
# Statements
Visitor = Callable[["Stmt"], None]
class Stmt:
def __repr__(self) -> str:
io = StringIO()
out = CWriter(io, 0, False)
self.print(out)
return io.getvalue()
def print(self, out:CWriter) -> None:
raise NotImplementedError
def accept(self, visitor: Visitor) -> None:
raise NotImplementedError
def tokens(self) -> Iterator[lx.Token]:
raise NotImplementedError
@dataclass
class Block(Node):
# This just holds a context which has the list of tokens.
pass
class IfStmt(Stmt):
if_: lx.Token
condition: list[lx.Token]
body: Stmt
else_: lx.Token | None
else_body: Stmt | None
def print(self, out:CWriter) -> None:
out.emit(self.if_)
for tkn in self.condition:
out.emit(tkn)
self.body.print(out)
if self.else_ is not None:
out.emit(self.else_)
self.body.print(out)
if self.else_body is not None:
self.else_body.print(out)
def accept(self, visitor: Visitor) -> None:
visitor(self)
self.body.accept(visitor)
if self.else_body is not None:
self.else_body.accept(visitor)
def tokens(self) -> Iterator[lx.Token]:
yield self.if_
yield from self.condition
yield from self.body.tokens()
if self.else_ is not None:
yield self.else_
if self.else_body is not None:
yield from self.else_body.tokens()
@dataclass
class ForStmt(Stmt):
for_: lx.Token
header: list[lx.Token]
body: Stmt
def print(self, out:CWriter) -> None:
out.emit(self.for_)
for tkn in self.header:
out.emit(tkn)
self.body.print(out)
def accept(self, visitor: Visitor) -> None:
visitor(self)
self.body.accept(visitor)
def tokens(self) -> Iterator[lx.Token]:
yield self.for_
yield from self.header
yield from self.body.tokens()
@dataclass
class WhileStmt(Stmt):
while_: lx.Token
condition: list[lx.Token]
body: Stmt
def print(self, out:CWriter) -> None:
out.emit(self.while_)
for tkn in self.condition:
out.emit(tkn)
self.body.print(out)
def accept(self, visitor: Visitor) -> None:
visitor(self)
self.body.accept(visitor)
def tokens(self) -> Iterator[lx.Token]:
yield self.while_
yield from self.condition
yield from self.body.tokens()
@dataclass
class MacroIfStmt(Stmt):
condition: lx.Token
body: list[Stmt]
else_: lx.Token | None
else_body: list[Stmt] | None
endif: lx.Token
def print(self, out:CWriter) -> None:
out.emit(self.condition)
for stmt in self.body:
stmt.print(out)
if self.else_body is not None:
out.emit("#else\n")
for stmt in self.else_body:
stmt.print(out)
def accept(self, visitor: Visitor) -> None:
visitor(self)
for stmt in self.body:
stmt.accept(visitor)
if self.else_body is not None:
for stmt in self.else_body:
stmt.accept(visitor)
def tokens(self) -> Iterator[lx.Token]:
yield self.condition
for stmt in self.body:
yield from stmt.tokens()
if self.else_body is not None:
for stmt in self.else_body:
yield from stmt.tokens()
@dataclass
class BlockStmt(Stmt):
open: lx.Token
body: list[Stmt]
close: lx.Token
def print(self, out:CWriter) -> None:
out.emit(self.open)
for stmt in self.body:
stmt.print(out)
out.start_line()
out.emit(self.close)
def accept(self, visitor: Visitor) -> None:
visitor(self)
for stmt in self.body:
stmt.accept(visitor)
def tokens(self) -> Iterator[lx.Token]:
yield self.open
for stmt in self.body:
yield from stmt.tokens()
yield self.close
@dataclass
class SimpleStmt(Stmt):
contents: list[lx.Token]
def print(self, out:CWriter) -> None:
for tkn in self.contents:
out.emit(tkn)
def tokens(self) -> Iterator[lx.Token]:
yield from self.contents
def accept(self, visitor: Visitor) -> None:
visitor(self)
__hash__ = object.__hash__
@dataclass
class StackEffect(Node):
@ -124,7 +295,7 @@ class InstDef(Node):
name: str
inputs: list[InputEffect]
outputs: list[OutputEffect]
block: Block
block: BlockStmt
@dataclass
@ -153,7 +324,7 @@ class Pseudo(Node):
class LabelDef(Node):
name: str
spilled: bool
block: Block
block: BlockStmt
AstNode = InstDef | Macro | Pseudo | Family | LabelDef
@ -183,23 +354,22 @@ class Parser(PLexer):
if self.expect(lx.LPAREN):
if tkn := self.expect(lx.IDENTIFIER):
if self.expect(lx.RPAREN):
if block := self.block():
return LabelDef(tkn.text, spilled, block)
block = self.block()
return LabelDef(tkn.text, spilled, block)
return None
@contextual
def inst_def(self) -> InstDef | None:
if hdr := self.inst_header():
if block := self.block():
return InstDef(
hdr.annotations,
hdr.kind,
hdr.name,
hdr.inputs,
hdr.outputs,
block,
)
raise self.make_syntax_error("Expected block")
block = self.block()
return InstDef(
hdr.annotations,
hdr.kind,
hdr.name,
hdr.inputs,
hdr.outputs,
block,
)
return None
@contextual
@ -473,28 +643,85 @@ class Parser(PLexer):
self.setpos(here)
return None
@contextual
def block(self) -> Block | None:
if self.c_blob():
return Block()
return None
def block(self) -> BlockStmt:
open = self.require(lx.LBRACE)
stmts: list[Stmt] = []
while not (close := self.expect(lx.RBRACE)):
stmts.append(self.stmt())
return BlockStmt(open, stmts, close)
def c_blob(self) -> list[lx.Token]:
tokens: list[lx.Token] = []
level = 0
while tkn := self.next(raw=True):
tokens.append(tkn)
if tkn.kind in (lx.LBRACE, lx.LPAREN, lx.LBRACKET):
level += 1
elif tkn.kind in (lx.RBRACE, lx.RPAREN, lx.RBRACKET):
level -= 1
if level <= 0:
break
return tokens
def stmt(self) -> Stmt:
if tkn := self.expect(lx.IF):
return self.if_stmt(tkn)
elif self.expect(lx.LBRACE):
self.backup()
return self.block()
elif tkn := self.expect(lx.FOR):
return self.for_stmt(tkn)
elif tkn := self.expect(lx.WHILE):
return self.while_stmt(tkn)
elif tkn := self.expect(lx.CMACRO_IF):
return self.macro_if(tkn)
elif tkn := self.expect(lx.CMACRO_ELSE):
msg = "Unexpected #else"
raise self.make_syntax_error(msg)
elif tkn := self.expect(lx.CMACRO_ENDIF):
msg = "Unexpected #endif"
raise self.make_syntax_error(msg)
elif tkn := self.expect(lx.CMACRO_OTHER):
return SimpleStmt([tkn])
elif tkn := self.expect(lx.SWITCH):
msg = "switch statements are not supported due to their complex flow control. Sorry."
raise self.make_syntax_error(msg)
tokens = self.consume_to(lx.SEMI)
return SimpleStmt(tokens)
def if_stmt(self, if_: lx.Token) -> IfStmt:
lparen = self.require(lx.LPAREN)
condition = [lparen] + self.consume_to(lx.RPAREN)
body = self.block()
else_body: Stmt | None = None
else_: lx.Token | None = None
if else_ := self.expect(lx.ELSE):
if inner := self.expect(lx.IF):
else_body = self.if_stmt(inner)
else:
else_body = self.block()
return IfStmt(if_, condition, body, else_, else_body)
def macro_if(self, cond: lx.Token) -> MacroIfStmt:
else_ = None
body: list[Stmt] = []
else_body: list[Stmt] | None = None
part = body
while True:
if tkn := self.expect(lx.CMACRO_ENDIF):
return MacroIfStmt(cond, body, else_, else_body, tkn)
elif tkn := self.expect(lx.CMACRO_ELSE):
if part is else_body:
raise self.make_syntax_error("Multiple #else")
else_ = tkn
else_body = []
part = else_body
else:
part.append(self.stmt())
def for_stmt(self, for_: lx.Token) -> ForStmt:
lparen = self.require(lx.LPAREN)
header = [lparen] + self.consume_to(lx.RPAREN)
body = self.block()
return ForStmt(for_, header, body)
def while_stmt(self, while_: lx.Token) -> WhileStmt:
lparen = self.require(lx.LPAREN)
cond = [lparen] + self.consume_to(lx.RPAREN)
body = self.block()
return WhileStmt(while_, cond, body)
if __name__ == "__main__":
import sys
import pprint
if sys.argv[1:]:
filename = sys.argv[1]
@ -512,5 +739,5 @@ if __name__ == "__main__":
filename = "<default>"
src = "if (x) { x.foo; // comment\n}"
parser = Parser(src, filename)
x = parser.definition()
print(x)
while node := parser.definition():
pprint.pprint(node)