rustpython_ast + pyo3 (#25)

This commit is contained in:
Jeong, YunWon 2023-05-16 18:06:54 +09:00 committed by GitHub
parent 53de75efc3
commit 611dcc2e9b
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
6 changed files with 3658 additions and 0 deletions

View file

@ -937,6 +937,111 @@ class LocatedDefVisitor(EmitVisitor):
)
class ToPyo3AstVisitor(EmitVisitor):
"""Visitor to generate type-defs for AST."""
def __init__(self, namespace, *args, **kw):
super().__init__(*args, **kw)
self.namespace = namespace
@property
def generics(self):
if self.namespace == "ranged":
return "<TextRange>"
elif self.namespace == "located":
return "<SourceRange>"
else:
assert False, self.namespace
def visitModule(self, mod):
for dfn in mod.dfns:
self.visit(dfn)
def visitType(self, type, depth=0):
self.visit(type.value, type.name, depth)
def visitProduct(self, product, name, depth=0):
rust_name = rust_type_name(name)
self.emit_to_pyo3_with_fields(product, rust_name)
def visitSum(self, sum, name, depth=0):
rust_name = rust_type_name(name)
simple = is_simple(sum)
if is_simple(sum):
return
self.emit(
f"""
impl ToPyo3Ast for crate::generic::{rust_name}{self.generics} {{
#[inline]
fn to_pyo3_ast(&self, {"_" if simple else ""}py: Python) -> PyResult<Py<PyAny>> {{
let instance = match &self {{
""",
0,
)
for cons in sum.types:
self.emit(
f"""crate::{rust_name}::{cons.name}(cons) => cons.to_pyo3_ast(py)?,""",
depth,
)
self.emit(
"""
};
Ok(instance)
}
}
""",
0,
)
for cons in sum.types:
self.visit(cons, rust_name, depth)
def visitConstructor(self, cons, parent, depth):
self.emit_to_pyo3_with_fields(cons, f"{parent}{cons.name}")
def emit_to_pyo3_with_fields(self, cons, name):
if cons.fields:
self.emit(
f"""
impl ToPyo3Ast for crate::{name}{self.generics} {{
#[inline]
fn to_pyo3_ast(&self, py: Python) -> PyResult<Py<PyAny>> {{
let cache = Self::py_type_cache().get().unwrap();
let instance = cache.0.call1(py, (
""",
0,
)
for field in cons.fields:
self.emit(
f"self.{rust_field(field.name)}.to_pyo3_ast(py)?,",
3,
)
self.emit(
"""
))?;
Ok(instance)
}
}
""",
0,
)
else:
self.emit(
f"""
impl ToPyo3Ast for crate::{name}{self.generics} {{
#[inline]
fn to_pyo3_ast(&self, py: Python) -> PyResult<Py<PyAny>> {{
let cache = Self::py_type_cache().get().unwrap();
let instance = cache.0.call0(py)?;
Ok(instance)
}}
}}
""",
0,
)
class StdlibClassDefVisitor(EmitVisitor):
def visitModule(self, mod):
for dfn in mod.dfns:
@ -1271,6 +1376,82 @@ def write_located_def(mod, type_info, f):
LocatedDefVisitor(f, type_info).visit(mod)
def write_pyo3_node(type_info, f):
def write(info: TypeInfo):
rust_name = info.rust_sum_name
if info.is_simple:
generics = ""
else:
generics = "<R>"
f.write(
textwrap.dedent(
f"""
impl{generics} Pyo3Node for crate::generic::{rust_name}{generics} {{
#[inline]
fn py_type_cache() -> &'static OnceCell<(Py<PyAny>, Py<PyAny>)> {{
static PY_TYPE: OnceCell<(Py<PyAny>, Py<PyAny>)> = OnceCell::new();
&PY_TYPE
}}
}}
"""
),
)
for info in type_info.values():
write(info)
def write_to_pyo3(mod, type_info, f):
write_pyo3_node(type_info, f)
write_to_pyo3_simple(type_info, f)
for namespace in ("ranged", "located"):
ToPyo3AstVisitor(namespace, f, type_info).visit(mod)
f.write(
"""
pub fn init(py: Python) -> PyResult<()> {
let ast_module = PyModule::import(py, "_ast")?;
"""
)
for info in type_info.values():
rust_name = info.rust_sum_name
f.write(f"cache_py_type::<crate::generic::{rust_name}>(ast_module)?;\n")
f.write("Ok(())\n}")
def write_to_pyo3_simple(type_info, f):
for type_info in type_info.values():
if not type_info.is_sum:
continue
if not type_info.is_simple:
continue
rust_name = type_info.rust_sum_name
f.write(
f"""
impl ToPyo3Ast for crate::generic::{rust_name} {{
#[inline]
fn to_pyo3_ast(&self, _py: Python) -> PyResult<Py<PyAny>> {{
let cell = match &self {{
""",
)
for cons in type_info.type.value.types:
f.write(
f"""crate::{rust_name}::{cons.name} => crate::{rust_name}{cons.name}::py_type_cache(),""",
)
f.write(
"""
};
Ok(cell.get().unwrap().1.clone())
}
}
""",
)
def write_ast_mod(mod, type_info, f):
f.write(
textwrap.dedent(
@ -1316,6 +1497,7 @@ def main(
("ranged", p(write_ranged_def, mod, type_info)),
("located", p(write_located_def, mod, type_info)),
("visitor", p(write_visitor_def, mod, type_info)),
("to_pyo3", p(write_to_pyo3, mod, type_info)),
]:
with (ast_dir / f"{filename}.rs").open("w") as f:
f.write(auto_gen_msg)