GH-98831: Implement basic cache effects (#99313)

This commit is contained in:
Guido van Rossum 2022-11-15 19:59:19 -08:00 committed by GitHub
parent 4636df9feb
commit e37744f289
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 201 additions and 127 deletions

View file

@ -56,11 +56,28 @@ class Block(Node):
tokens: list[lx.Token]
@dataclass
class Effect(Node):
pass
@dataclass
class StackEffect(Effect):
name: str
# TODO: type, condition
@dataclass
class CacheEffect(Effect):
name: str
size: int
@dataclass
class InstHeader(Node):
name: str
inputs: list[str]
outputs: list[str]
inputs: list[Effect]
outputs: list[Effect]
@dataclass
@ -69,16 +86,17 @@ class InstDef(Node):
block: Block
@property
def name(self):
def name(self) -> str:
return self.header.name
@property
def inputs(self):
def inputs(self) -> list[Effect]:
return self.header.inputs
@property
def outputs(self):
return self.header.outputs
def outputs(self) -> list[StackEffect]:
# This is always true
return [x for x in self.header.outputs if isinstance(x, StackEffect)]
@dataclass
@ -90,6 +108,7 @@ class Super(Node):
@dataclass
class Family(Node):
name: str
size: str # Variable giving the cache size in code units
members: list[str]
@ -123,18 +142,16 @@ class Parser(PLexer):
return InstHeader(name, [], [])
return None
def check_overlaps(self, inp: list[str], outp: list[str]):
def check_overlaps(self, inp: list[Effect], outp: list[Effect]):
for i, name in enumerate(inp):
try:
j = outp.index(name)
except ValueError:
continue
else:
if i != j:
raise self.make_syntax_error(
f"Input {name!r} at pos {i} repeated in output at different pos {j}")
for j, name2 in enumerate(outp):
if name == name2:
if i != j:
raise self.make_syntax_error(
f"Input {name!r} at pos {i} repeated in output at different pos {j}")
break
def stack_effect(self) -> tuple[list[str], list[str]]:
def stack_effect(self) -> tuple[list[Effect], list[Effect]]:
# '(' [inputs] '--' [outputs] ')'
if self.expect(lx.LPAREN):
inp = self.inputs() or []
@ -144,8 +161,8 @@ class Parser(PLexer):
return inp, outp
raise self.make_syntax_error("Expected stack effect")
def inputs(self) -> list[str] | None:
# input (, input)*
def inputs(self) -> list[Effect] | None:
# input (',' input)*
here = self.getpos()
if inp := self.input():
near = self.getpos()
@ -157,27 +174,25 @@ class Parser(PLexer):
self.setpos(here)
return None
def input(self) -> str | None:
# IDENTIFIER
@contextual
def input(self) -> Effect | None:
# IDENTIFIER '/' INTEGER (CacheEffect)
# IDENTIFIER (StackEffect)
if (tkn := self.expect(lx.IDENTIFIER)):
if self.expect(lx.LBRACKET):
if arg := self.expect(lx.IDENTIFIER):
if self.expect(lx.RBRACKET):
return f"{tkn.text}[{arg.text}]"
if self.expect(lx.TIMES):
if num := self.expect(lx.NUMBER):
if self.expect(lx.RBRACKET):
return f"{tkn.text}[{arg.text}*{num.text}]"
raise self.make_syntax_error("Expected argument in brackets", tkn)
if self.expect(lx.DIVIDE):
if num := self.expect(lx.NUMBER):
try:
size = int(num.text)
except ValueError:
raise self.make_syntax_error(
f"Expected integer, got {num.text!r}")
else:
return CacheEffect(tkn.text, size)
raise self.make_syntax_error("Expected integer")
else:
return StackEffect(tkn.text)
return tkn.text
if self.expect(lx.CONDOP):
while self.expect(lx.CONDOP):
pass
return "??"
return None
def outputs(self) -> list[str] | None:
def outputs(self) -> list[Effect] | None:
# output (, output)*
here = self.getpos()
if outp := self.output():
@ -190,8 +205,10 @@ class Parser(PLexer):
self.setpos(here)
return None
def output(self) -> str | None:
return self.input() # TODO: They're not quite the same.
@contextual
def output(self) -> Effect | None:
if (tkn := self.expect(lx.IDENTIFIER)):
return StackEffect(tkn.text)
@contextual
def super_def(self) -> Super | None:
@ -216,24 +233,35 @@ class Parser(PLexer):
@contextual
def family_def(self) -> Family | None:
if (tkn := self.expect(lx.IDENTIFIER)) and tkn.text == "family":
size = None
if self.expect(lx.LPAREN):
if (tkn := self.expect(lx.IDENTIFIER)):
if self.expect(lx.COMMA):
if not (size := self.expect(lx.IDENTIFIER)):
raise self.make_syntax_error(
"Expected identifier")
if self.expect(lx.RPAREN):
if self.expect(lx.EQUALS):
if not self.expect(lx.LBRACE):
raise self.make_syntax_error("Expected {")
if members := self.members():
if self.expect(lx.SEMI):
return Family(tkn.text, members)
if self.expect(lx.RBRACE) and self.expect(lx.SEMI):
return Family(tkn.text, size.text if size else "", members)
return None
def members(self) -> list[str] | None:
here = self.getpos()
if tkn := self.expect(lx.IDENTIFIER):
near = self.getpos()
if self.expect(lx.COMMA):
if rest := self.members():
return [tkn.text] + rest
self.setpos(near)
return [tkn.text]
members = [tkn.text]
while self.expect(lx.COMMA):
if tkn := self.expect(lx.IDENTIFIER):
members.append(tkn.text)
else:
break
peek = self.peek()
if not peek or peek.kind != lx.RBRACE:
raise self.make_syntax_error("Expected comma or right paren")
return members
self.setpos(here)
return None
@ -274,5 +302,5 @@ if __name__ == "__main__":
filename = None
src = "if (x) { x.foo; // comment\n}"
parser = Parser(src, filename)
x = parser.inst_def()
x = parser.inst_def() or parser.super_def() or parser.family_def()
print(x)