This commit is contained in:
Jeong YunWon 2023-05-15 20:17:41 +09:00
parent 269a9a98da
commit c89a2b2378
3 changed files with 943 additions and 928 deletions

View file

@ -10,7 +10,7 @@ import re
from argparse import ArgumentParser from argparse import ArgumentParser
from pathlib import Path from pathlib import Path
from typing import Optional, Dict from typing import Optional, Dict, Any
import asdl import asdl
@ -87,6 +87,7 @@ class TypeInfo:
is_simple: bool is_simple: bool
empty_field: bool empty_field: bool
children: set children: set
fields: Optional[Any]
boxed: bool boxed: bool
product: bool product: bool
has_expr: bool = False has_expr: bool = False
@ -99,6 +100,7 @@ class TypeInfo:
self.is_simple = False self.is_simple = False
self.empty_field = False self.empty_field = False
self.children = set() self.children = set()
self.fields = None
self.boxed = False self.boxed = False
self.product = False self.product = False
self.product_has_expr = False self.product_has_expr = False
@ -212,6 +214,7 @@ class FindUserDataTypesVisitor(asdl.VisitorBase):
t_info = TypeInfo(t.name) t_info = TypeInfo(t.name)
t_info.enum_name = name t_info.enum_name = name
t_info.empty_field = not t.fields t_info.empty_field = not t.fields
t_info.fields = t.fields
self.type_info[t.name] = t_info self.type_info[t.name] = t_info
self.add_children(t.name, t.fields) self.add_children(t.name, t.fields)
if len(sum.types) > 1: if len(sum.types) > 1:
@ -226,6 +229,7 @@ class FindUserDataTypesVisitor(asdl.VisitorBase):
def visitProduct(self, product, name): def visitProduct(self, product, name):
info = self.type_info[name] info = self.type_info[name]
info.fields = product.fields
if product.attributes: if product.attributes:
# attributes means located, which has the `range: R` field # attributes means located, which has the `range: R` field
info.has_user_data = True info.has_user_data = True
@ -982,6 +986,22 @@ class Pyo3StructVisitor(EmitVisitor):
def ref(self): def ref(self):
return "&" if self.borrow else "" return "&" if self.borrow else ""
def emit_class_cache(self, rust_name, fields):
self.emit(
textwrap.dedent(
f"""
impl {rust_name} {{
#[inline]
pub fn py_type_cell() -> &'static OnceCell<Py<PyAny>> {{
static PY_TYPE: OnceCell<Py<PyAny>> = OnceCell::new();
&PY_TYPE
}}
}}
"""
),
0,
)
def emit_class(self, name, rust_name, subclass, base="super::AST"): def emit_class(self, name, rust_name, subclass, base="super::AST"):
if subclass: if subclass:
subclass = ", subclass" subclass = ", subclass"
@ -995,18 +1015,10 @@ class Pyo3StructVisitor(EmitVisitor):
self.emit( self.emit(
textwrap.dedent( textwrap.dedent(
f""" f"""
#[pyclass(module="{self.module_name}", name="_{name}", extends={base}{subclass})] #[pyclass(module="{self.module_name}", name="_{name}", extends={base}, frozen{subclass})]
#[derive(Clone, Debug)] #[derive(Clone, Debug)]
pub struct {rust_name} {body}; pub struct {rust_name} {body};
impl {rust_name} {{
#[inline]
pub fn py_type_cell() -> &'static OnceCell<Py<PyAny>> {{
static PY_TYPE: OnceCell<Py<PyAny>> = OnceCell::new();
&PY_TYPE
}}
}}
impl From<{self.ref_def} crate::{self.namespace}::{rust_name}> for {rust_name} {{ impl From<{self.ref_def} crate::{self.namespace}::{rust_name}> for {rust_name} {{
fn from({"" if body else "_"}node: {self.ref_def} crate::{self.namespace}::{rust_name}) -> Self {{ fn from({"" if body else "_"}node: {self.ref_def} crate::{self.namespace}::{rust_name}) -> Self {{
{into} {into}
@ -1153,6 +1165,7 @@ class Pyo3StructVisitor(EmitVisitor):
def visitSum(self, sum, name, depth=0): def visitSum(self, sum, name, depth=0):
rust_name = rust_type_name(name) rust_name = rust_type_name(name)
self.emit_class(name, rust_name, True) self.emit_class(name, rust_name, True)
self.emit_class_cache(rust_name, [])
simple = is_simple(sum) simple = is_simple(sum)
@ -1193,6 +1206,7 @@ class Pyo3StructVisitor(EmitVisitor):
def visitProduct(self, product, name, depth=0): def visitProduct(self, product, name, depth=0):
rust_name = rust_type_name(name) rust_name = rust_type_name(name)
self.emit_class(name, rust_name, False) self.emit_class(name, rust_name, False)
self.emit_class_cache(rust_name, product.fields)
if self.borrow: if self.borrow:
self.emit_getter(product, rust_name) self.emit_getter(product, rust_name)
@ -1203,14 +1217,6 @@ class Pyo3StructVisitor(EmitVisitor):
#[pyclass(module="{self.module_name}", name="_{cons.name}", extends={parent})] #[pyclass(module="{self.module_name}", name="_{cons.name}", extends={parent})]
pub struct {parent}{cons.name}; pub struct {parent}{cons.name};
impl {parent}{cons.name} {{
#[inline]
pub fn py_type_cell() -> &'static OnceCell<Py<PyAny>> {{
static PY_TYPE: OnceCell<Py<PyAny>> = OnceCell::new();
&PY_TYPE
}}
}}
impl ToPyObject for {parent}{cons.name} {{ impl ToPyObject for {parent}{cons.name} {{
fn to_object(&self, py: Python) -> PyObject {{ fn to_object(&self, py: Python) -> PyObject {{
let initializer = PyClassInitializer::from(AST) let initializer = PyClassInitializer::from(AST)
@ -1222,10 +1228,15 @@ impl ToPyObject for {parent}{cons.name} {{
""", """,
depth, depth,
) )
self.emit_class_cache(f"{parent}{cons.name}", [])
else: else:
self.emit_class( self.emit_class(
cons.name, f"{parent}{cons.name}", subclass=False, base=parent cons.name,
f"{parent}{cons.name}",
subclass=False,
base=parent,
) )
self.emit_class_cache(f"{parent}{cons.name}", cons.fields)
if self.borrow: if self.borrow:
self.emit_getter(cons, f"{parent}{cons.name}") self.emit_getter(cons, f"{parent}{cons.name}")
@ -1244,21 +1255,21 @@ class Pyo3PymoduleVisitor(EmitVisitor):
def visitProduct(self, product, name, depth=0): def visitProduct(self, product, name, depth=0):
rust_name = rust_type_name(name) rust_name = rust_type_name(name)
self.emit_fields(name, rust_name, False, depth) self.emit_fields(name, rust_name, product.fields, False, depth)
def visitSum(self, sum, name, depth): def visitSum(self, sum, name, depth):
rust_name = rust_type_name(name) rust_name = rust_type_name(name)
simple = is_simple(sum) simple = is_simple(sum)
self.emit_fields(name, rust_name, True, depth) self.emit_fields(name, rust_name, [], True, depth)
for cons in sum.types: for cons in sum.types:
self.visit(cons, name, simple, depth) self.visit(cons, name, simple, depth)
def visitConstructor(self, cons, parent, simple, depth): def visitConstructor(self, cons, parent, simple, depth):
rust_name = rust_type_name(parent) + rust_type_name(cons.name) rust_name = rust_type_name(parent) + rust_type_name(cons.name)
self.emit_fields(cons.name, rust_name, simple, depth) self.emit_fields(cons.name, rust_name, cons.fields, simple, depth)
def emit_fields(self, name, rust_name, simple, depth): def emit_fields(self, name, rust_name, fields, simple, depth):
if simple: if simple:
call = ".call0().unwrap()" call = ".call0().unwrap()"
else: else:
@ -1413,7 +1424,10 @@ class StdlibTraitImplVisitor(EmitVisitor):
depth, depth,
) )
self.emit("};", depth + 3) self.emit("};", depth + 3)
self.emit("NodeAst.into_ref_with_type(vm, node_type.to_owned()).unwrap().into()", depth + 2) self.emit(
"NodeAst.into_ref_with_type(vm, node_type.to_owned()).unwrap().into()",
depth + 2,
)
else: else:
self.emit("match self {", depth + 2) self.emit("match self {", depth + 2)
for cons in sum.types: for cons in sum.types:
@ -1619,7 +1633,7 @@ def write_located_def(mod, type_info, f):
LocatedDefVisitor(f, type_info).visit(mod) LocatedDefVisitor(f, type_info).visit(mod)
def write_ast_pyo3(mod, type_info, namespace, f): def write_pyo3_ast(mod, type_info, namespace, f):
ToPyo3AstVisitor(namespace, f, type_info).visit(mod) ToPyo3AstVisitor(namespace, f, type_info).visit(mod)
@ -1636,6 +1650,7 @@ def write_pyo3_def(mod, type_info, namespace, borrow, f):
let ast_module = PyModule::import(py, "_ast")?; let ast_module = PyModule::import(py, "_ast")?;
""" """
) )
Pyo3PymoduleVisitor(namespace, f, type_info).visit(mod) Pyo3PymoduleVisitor(namespace, f, type_info).visit(mod)
f.write("Ok(())\n}") f.write("Ok(())\n}")
@ -1686,8 +1701,8 @@ def main(
("ranged", p(write_ranged_def, mod, type_info)), ("ranged", p(write_ranged_def, mod, type_info)),
("located", p(write_located_def, mod, type_info)), ("located", p(write_located_def, mod, type_info)),
("visitor", p(write_visitor_def, mod, type_info)), ("visitor", p(write_visitor_def, mod, type_info)),
("to_pyo3_located", p(write_ast_pyo3, mod, type_info, "located")), ("to_pyo3_located", p(write_pyo3_ast, mod, type_info, "located")),
("to_pyo3_ranged", p(write_ast_pyo3, mod, type_info, "ranged")), ("to_pyo3_ranged", p(write_pyo3_ast, mod, type_info, "ranged")),
# ("pyo3_located", p(write_pyo3_def, mod, type_info, "located", True)), # ("pyo3_located", p(write_pyo3_def, mod, type_info, "located", True)),
("pyo3_ranged", p(write_pyo3_def, mod, type_info, "ranged", True)), ("pyo3_ranged", p(write_pyo3_def, mod, type_info, "ranged", True)),
]: ]:

File diff suppressed because it is too large Load diff

View file

@ -1,7 +1,7 @@
use crate::Node; use crate::Node;
use num_complex::Complex64; use num_complex::Complex64;
use pyo3::prelude::*; use pyo3::{intern, prelude::*};
use pyo3::types::{PyBytes, PyList, PyTuple}; use pyo3::types::{PyBytes, PyList, PyString, PyTuple};
pub trait ToPyo3Ast { pub trait ToPyo3Ast {
fn to_pyo3_ast(&self, py: Python) -> PyResult<Py<PyAny>>; fn to_pyo3_ast(&self, py: Python) -> PyResult<Py<PyAny>>;