Add experimental pyo3-wrapper feature (#41)

* Fix pyo3 unit type value error

* Add experimental pyo3-wrapper feature

* location support
This commit is contained in:
Jeong, YunWon 2023-05-16 23:45:31 +09:00 committed by GitHub
parent ff17f6e178
commit e820928f11
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
9 changed files with 11108 additions and 445 deletions

View file

@ -11,7 +11,7 @@ include = ["LICENSE", "Cargo.toml", "src/**/*.rs"]
[workspace]
resolver = "2"
members = [
"ast", "core", "format", "literal", "parser",
"ast", "core", "format", "literal", "parser", "parser-pyo3",
"ruff_text_size", "ruff_source_location",
]
@ -19,6 +19,8 @@ members = [
rustpython-ast = { path = "ast", default-features = false }
rustpython-parser-core = { path = "core", features = [] }
rustpython-literal = { path = "literal" }
rustpython-format = { path = "format" }
rustpython-parser = { path = "parser" }
ahash = "0.7.6"
anyhow = "1.0.45"

View file

@ -17,6 +17,10 @@ visitor = []
all-nodes-with-ranges = []
pyo3 = ["dep:pyo3", "num-complex", "once_cell"]
# This feature is experimental
# It reimplements AST types, but currently both slower than python AST types and limited to use in other API
pyo3-wrapper = ["pyo3"]
[dependencies]
rustpython-parser-core = { workspace = true }
rustpython-literal = { workspace = true, optional = true }

View file

@ -42,6 +42,7 @@ RUST_KEYWORDS = {
"yield",
"in",
"mod",
"type",
}
@ -212,7 +213,7 @@ class EmitVisitor(asdl.VisitorBase, TypeInfoMixin):
def emit(self, line, depth):
if line:
line = (" " * TABSIZE * depth) + line
line = (" " * TABSIZE * depth) + textwrap.dedent(line)
self.file.write(line + "\n")
@ -277,10 +278,9 @@ class FindUserDataTypesVisitor(asdl.VisitorBase):
def rust_field(field_name):
if field_name == "type":
return "type_"
else:
return field_name
if field_name in RUST_KEYWORDS:
field_name += "_"
return field_name
class StructVisitor(EmitVisitor):
@ -954,19 +954,24 @@ class ToPyo3AstVisitor(EmitVisitor):
else:
assert False, self.namespace
@property
def location(self):
# lineno, col_offset
pass
def visitModule(self, mod):
for dfn in mod.dfns:
self.visit(dfn)
def visitType(self, type, depth=0):
self.visit(type.value, type.name, depth)
def visitType(self, type):
self.visit(type.value, type)
def visitProduct(self, product, name, depth=0):
rust_name = rust_type_name(name)
self.emit_to_pyo3_with_fields(product, rust_name)
def visitProduct(self, product, type):
rust_name = rust_type_name(type.name)
self.emit_to_pyo3_with_fields(product, type, rust_name)
def visitSum(self, sum, name, depth=0):
rust_name = rust_type_name(name)
def visitSum(self, sum, type):
rust_name = rust_type_name(type.name)
simple = is_simple(sum)
if is_simple(sum):
return
@ -983,7 +988,7 @@ class ToPyo3AstVisitor(EmitVisitor):
for cons in sum.types:
self.emit(
f"""crate::{rust_name}::{cons.name}(cons) => cons.to_pyo3_ast(py)?,""",
depth,
1,
)
self.emit(
"""
@ -996,52 +1001,372 @@ class ToPyo3AstVisitor(EmitVisitor):
)
for cons in sum.types:
self.visit(cons, rust_name, depth)
self.visit(cons, type)
def visitConstructor(self, cons, parent, depth):
self.emit_to_pyo3_with_fields(cons, f"{parent}{cons.name}")
def visitConstructor(self, cons, type):
parent = rust_type_name(type.name)
self.emit_to_pyo3_with_fields(cons, type, f"{parent}{cons.name}")
def emit_to_pyo3_with_fields(self, cons, name):
if cons.fields:
self.emit(
f"""
impl ToPyo3Ast for crate::{name}{self.generics} {{
#[inline]
def emit_to_pyo3_with_fields(self, cons, type, name):
type_info = self.type_info[type.name]
self.emit(
f"""
impl ToPyo3Ast for crate::{name}{self.generics} {{
#[inline]
fn to_pyo3_ast(&self, py: Python) -> PyResult<Py<PyAny>> {{
let cache = Self::py_type_cache().get().unwrap();
let instance = cache.0.call1(py, (
""",
0,
""",
0,
)
if cons.fields:
field_names = ", ".join(rust_field(f.name) for f in cons.fields)
if not type_info.is_simple:
field_names += ", range: _range"
self.emit(
f"let Self {{ {field_names} }} = self;",
1,
)
self.emit(
"let instance = cache.0.call1(py, (",
1,
)
for field in cons.fields:
self.emit(
f"self.{rust_field(field.name)}.to_pyo3_ast(py)?,",
f"{rust_field(field.name)}.to_pyo3_ast(py)?,",
3,
)
self.emit(
"""
))?;
Ok(instance)
}
}
""",
0,
)
else:
self.emit(
"let instance = cache.0.call0(py)?;",
1,
)
self.emit(
"let Self { range: _range } = self;",
1,
)
if type.value.attributes and self.namespace == "located":
self.emit(
"""
let cache = ast_key_cache().get().unwrap();
instance.setattr(py, cache.lineno.as_ref(py), _range.start.row.get())?;
instance.setattr(py, cache.col_offset.as_ref(py), _range.start.column.get())?;
if let Some(end) = _range.end {
instance.setattr(py, cache.end_lineno.as_ref(py), end.row.get())?;
instance.setattr(py, cache.end_col_offset.as_ref(py), end.column.get())?;
}
""",
1,
)
self.emit(
"""
Ok(instance)
}
}
""",
0,
)
class Pyo3StructVisitor(EmitVisitor):
"""Visitor to generate type-defs for AST."""
def __init__(self, namespace, *args, **kw):
self.namespace = namespace
self.borrow = True
super().__init__(*args, **kw)
@property
def generics(self):
if self.namespace == "ranged":
return "<TextRange>"
elif self.namespace == "located":
return "<SourceRange>"
else:
assert False, self.namespace
@property
def module_name(self):
name = f"rustpython_ast.{self.namespace}"
return name
@property
def ref_def(self):
return "&'static " if self.borrow else ""
@property
def ref(self):
return "&" if self.borrow else ""
def emit_class(self, name, rust_name, simple, base="super::AST"):
info = self.type_info[name]
if simple:
generics = ""
else:
generics = self.generics
if info.is_sum:
subclass = ", subclass"
body = ""
into = f"{rust_name}"
else:
subclass = ""
body = f"(pub {self.ref_def} crate::{rust_name}{generics})"
into = f"{rust_name}(node)"
self.emit(
textwrap.dedent(
f"""
impl ToPyo3Ast for crate::{name}{self.generics} {{
#[inline]
fn to_pyo3_ast(&self, py: Python) -> PyResult<Py<PyAny>> {{
let cache = Self::py_type_cache().get().unwrap();
let instance = cache.0.call0(py)?;
Ok(instance)
#[pyclass(module="{self.module_name}", name="_{name}", extends={base}, frozen{subclass})]
#[derive(Clone, Debug)]
pub struct {rust_name} {body};
impl From<{self.ref_def} crate::{rust_name}{generics}> for {rust_name} {{
fn from({"" if body else "_"}node: {self.ref_def} crate::{rust_name}{generics}) -> Self {{
{into}
}}
}}
"""
),
0,
)
if subclass:
self.emit(
textwrap.dedent(
f"""
#[pymethods]
impl {rust_name} {{
#[new]
fn new() -> PyClassInitializer<Self> {{
PyClassInitializer::from(AST)
.add_subclass(Self)
}}
}}
impl ToPyObject for {rust_name} {{
fn to_object(&self, py: Python) -> PyObject {{
let initializer = Self::new();
Py::new(py, initializer).unwrap().into_py(py)
}}
}}
"""
),
0,
)
else:
if base != "super::AST":
add_subclass = f".add_subclass({base})"
else:
add_subclass = ""
self.emit(
textwrap.dedent(
f"""
impl ToPyObject for {rust_name} {{
fn to_object(&self, py: Python) -> PyObject {{
let initializer = PyClassInitializer::from(AST)
{add_subclass}
.add_subclass(self.clone());
Py::new(py, initializer).unwrap().into_py(py)
}}
}}
"""
),
0,
)
if not subclass:
self.emit_wrapper(rust_name)
def emit_getter(self, owner, type_name):
self.emit(
textwrap.dedent(
f"""
#[pymethods]
impl {type_name} {{
"""
),
0,
)
for field in owner.fields:
self.emit(
textwrap.dedent(
f"""
#[getter]
#[inline]
fn get_{field.name}(&self, py: Python) -> PyResult<PyObject> {{
self.0.{rust_field(field.name)}.to_pyo3_wrapper(py)
}}
"""
),
3,
)
self.emit(
textwrap.dedent(
"""
}
"""
),
0,
)
def emit_getattr(self, owner, type_name):
self.emit(
textwrap.dedent(
f"""
#[pymethods]
impl {type_name} {{
fn __getattr__(&self, py: Python, key: &str) -> PyResult<PyObject> {{
let object: Py<PyAny> = match key {{
"""
),
0,
)
for field in owner.fields:
self.emit(
f'"{field.name}" => self.0.{rust_field(field.name)}.to_pyo3_wrapper(py)?,',
3,
)
self.emit(
textwrap.dedent(
"""
_ => todo!(),
};
Ok(object)
}
}
"""
),
0,
)
def emit_wrapper(self, rust_name):
self.emit(
f"""
impl ToPyo3Wrapper for crate::{rust_name}{self.generics} {{
#[inline]
fn to_pyo3_wrapper(&'static self, py: Python) -> PyResult<Py<PyAny>> {{
Ok({rust_name}(self).to_object(py))
}}
}}
""",
0,
)
def visitModule(self, mod):
for dfn in mod.dfns:
self.visit(dfn)
def visitType(self, type, depth=0):
self.visit(type.value, type, depth)
def visitSum(self, sum, type, depth=0):
rust_name = rust_type_name(type.name)
simple = is_simple(sum)
self.emit_class(type.name, rust_name, simple)
if not simple:
self.emit(
f"""
impl ToPyo3Wrapper for crate::{rust_name}{self.generics} {{
#[inline]
fn to_pyo3_wrapper(&'static self, py: Python) -> PyResult<Py<PyAny>> {{
match &self {{
""",
0,
)
for cons in sum.types:
self.emit(f"Self::{cons.name}(cons) => cons.to_pyo3_wrapper(py),", 3)
self.emit(
"""
}
}
}
""",
0,
)
for cons in sum.types:
self.visit(cons, rust_name, simple, depth + 1)
def visitProduct(self, product, type, depth=0):
rust_name = rust_type_name(type.name)
self.emit_class(type.name, rust_name, False)
if self.borrow:
self.emit_getter(product, rust_name)
def visitConstructor(self, cons, parent, simple, depth):
if simple:
self.emit(
f"""
#[pyclass(module="{self.module_name}", name="_{cons.name}", extends={parent})]
pub struct {parent}{cons.name};
impl ToPyObject for {parent}{cons.name} {{
fn to_object(&self, py: Python) -> PyObject {{
let initializer = PyClassInitializer::from(AST)
.add_subclass({parent})
.add_subclass(Self);
Py::new(py, initializer).unwrap().into_py(py)
}}
}}
""",
depth,
)
else:
self.emit_class(
cons.name,
f"{parent}{cons.name}",
simple=False,
base=parent,
)
if self.borrow:
self.emit_getter(cons, f"{parent}{cons.name}")
class Pyo3PymoduleVisitor(EmitVisitor):
def __init__(self, namespace, *args, **kw):
self.namespace = namespace
super().__init__(*args, **kw)
def visitModule(self, mod):
for dfn in mod.dfns:
self.visit(dfn)
def visitType(self, type, depth=0):
self.visit(type.value, type.name, depth)
def visitProduct(self, product, name, depth=0):
rust_name = rust_type_name(name)
self.emit_fields(name, rust_name, False)
def visitSum(self, sum, name, depth):
rust_name = rust_type_name(name)
simple = is_simple(sum)
self.emit_fields(name, rust_name, True)
for cons in sum.types:
self.visit(cons, name, simple, depth)
def visitConstructor(self, cons, parent, simple, depth):
rust_name = rust_type_name(parent) + rust_type_name(cons.name)
self.emit_fields(cons.name, rust_name, simple)
def emit_fields(self, name, rust_name, simple):
self.emit(
f"super::init_type::<{rust_name}, crate::generic::{rust_name}>(py, m)?;", 1
)
class StdlibClassDefVisitor(EmitVisitor):
def visitModule(self, mod):
@ -1412,7 +1737,7 @@ def write_to_pyo3(mod, type_info, f):
f.write(
"""
pub fn init(py: Python) -> PyResult<()> {
fn init_types(py: Python) -> PyResult<()> {
let ast_module = PyModule::import(py, "_ast")?;
"""
)
@ -1453,6 +1778,58 @@ def write_to_pyo3_simple(type_info, f):
)
def write_pyo3_wrapper(mod, type_info, namespace, f):
Pyo3StructVisitor(namespace, f, type_info).visit(mod)
if namespace == "located":
for type_info in type_info.values():
if not type_info.is_simple or not type_info.is_sum:
continue
rust_name = type_info.rust_sum_name
f.write(
f"""
impl ToPyo3Wrapper for crate::generic::{rust_name} {{
#[inline]
fn to_pyo3_wrapper(&self, py: Python) -> PyResult<Py<PyAny>> {{
match &self {{
""",
)
for cons in type_info.type.value.types:
f.write(
f"Self::{cons.name} => Ok({rust_name}{cons.name}.to_object(py)),",
)
f.write(
"""
}
}
}
""",
)
for cons in type_info.type.value.types:
f.write(
f"""
impl ToPyo3Wrapper for crate::generic::{rust_name}{cons.name} {{
#[inline]
fn to_pyo3_wrapper(&self, py: Python) -> PyResult<Py<PyAny>> {{
Ok({rust_name}{cons.name}.to_object(py))
}}
}}
"""
)
f.write(
"""
pub fn add_to_module(py: Python, m: &PyModule) -> PyResult<()> {
super::init_module(py, m)?;
"""
)
Pyo3PymoduleVisitor(namespace, f, type_info).visit(mod)
f.write("Ok(())\n}")
def write_ast_mod(mod, type_info, f):
f.write(
textwrap.dedent(
@ -1499,6 +1876,8 @@ def main(
("located", p(write_located_def, mod, type_info)),
("visitor", p(write_visitor_def, mod, type_info)),
("to_pyo3", p(write_to_pyo3, mod, type_info)),
("pyo3_wrapper_located", p(write_pyo3_wrapper, mod, type_info, "located")),
("pyo3_wrapper_ranged", p(write_pyo3_wrapper, mod, type_info, "ranged")),
]:
with (ast_dir / f"{filename}.rs").open("w") as f:
f.write(auto_gen_msg)

File diff suppressed because it is too large Load diff

File diff suppressed because it is too large Load diff

File diff suppressed because it is too large Load diff

View file

@ -50,3 +50,5 @@ pub use optimizer::ConstantOptimizer;
#[cfg(feature = "pyo3")]
pub mod pyo3;
#[cfg(feature = "pyo3-wrapper")]
pub mod pyo3_wrapper;

View file

@ -1,8 +1,10 @@
use crate::{source_code::SourceRange, text_size::TextRange, ConversionFlag, Node};
use num_complex::Complex64;
use once_cell::sync::OnceCell;
use pyo3::prelude::*;
use pyo3::types::{PyBytes, PyList, PyTuple};
use pyo3::{
prelude::*,
types::{PyBytes, PyList, PyString, PyTuple},
};
pub trait Pyo3Node {
fn py_type_cache() -> &'static OnceCell<(Py<PyAny>, Py<PyAny>)> {
@ -113,11 +115,39 @@ impl AST {
}
fn cache_py_type<N: Pyo3Node + Node>(ast_module: &PyAny) -> PyResult<()> {
let class = ast_module.getattr(N::NAME).unwrap();
let base = class.getattr("__new__").unwrap();
let class = ast_module.getattr(N::NAME)?;
let base = if std::mem::size_of::<N>() == 0 {
class.call0()?
} else {
class.getattr("__new__")?
};
N::py_type_cache().get_or_init(|| (class.into(), base.into()));
Ok(())
}
struct AstKeyCache {
lineno: Py<PyString>,
col_offset: Py<PyString>,
end_lineno: Py<PyString>,
end_col_offset: Py<PyString>,
}
fn ast_key_cache() -> &'static OnceCell<AstKeyCache> {
{
static PY_TYPE: OnceCell<AstKeyCache> = OnceCell::new();
&PY_TYPE
}
}
pub fn init(py: Python) -> PyResult<()> {
ast_key_cache().get_or_init(|| AstKeyCache {
lineno: pyo3::intern!(py, "lineno").into_py(py),
col_offset: pyo3::intern!(py, "col_offset").into_py(py),
end_lineno: pyo3::intern!(py, "end_lineno").into_py(py),
end_col_offset: pyo3::intern!(py, "end_col_offset").into_py(py),
});
init_types(py)
}
include!("gen/to_pyo3.rs");

128
ast/src/pyo3_wrapper.rs Normal file
View file

@ -0,0 +1,128 @@
use crate::pyo3::{Pyo3Node, AST};
use crate::{source_code::SourceRange, text_size::TextRange, ConversionFlag, Node};
use num_complex::Complex64;
use pyo3::prelude::*;
use pyo3::types::{PyBytes, PyList, PyTuple};
pub trait ToPyo3Wrapper {
fn to_pyo3_wrapper(&'static self, py: Python) -> PyResult<Py<PyAny>>;
}
impl<T: ToPyo3Wrapper> ToPyo3Wrapper for Box<T> {
#[inline]
fn to_pyo3_wrapper(&'static self, py: Python) -> PyResult<Py<PyAny>> {
(**self).to_pyo3_wrapper(py)
}
}
impl<T: ToPyo3Wrapper> ToPyo3Wrapper for Option<T> {
#[inline]
fn to_pyo3_wrapper(&'static self, py: Python) -> PyResult<Py<PyAny>> {
match self {
Some(ast) => ast.to_pyo3_wrapper(py),
None => Ok(py.None()),
}
}
}
impl ToPyo3Wrapper for crate::Identifier {
#[inline]
fn to_pyo3_wrapper(&'static self, py: Python) -> PyResult<Py<PyAny>> {
Ok(self.as_str().to_object(py))
}
}
impl ToPyo3Wrapper for crate::String {
#[inline]
fn to_pyo3_wrapper(&'static self, py: Python) -> PyResult<Py<PyAny>> {
Ok(self.as_str().to_object(py))
}
}
impl ToPyo3Wrapper for crate::Int {
#[inline]
fn to_pyo3_wrapper(&'static self, py: Python) -> PyResult<Py<PyAny>> {
Ok((self.to_u32()).to_object(py))
}
}
impl ToPyo3Wrapper for bool {
#[inline]
fn to_pyo3_wrapper(&'static self, py: Python) -> PyResult<Py<PyAny>> {
Ok((*self as u32).to_object(py))
}
}
impl ToPyo3Wrapper for ConversionFlag {
#[inline]
fn to_pyo3_wrapper(&'static self, py: Python) -> PyResult<Py<PyAny>> {
Ok((*self as i8).to_object(py))
}
}
impl ToPyo3Wrapper for crate::Constant {
fn to_pyo3_wrapper(&'static self, py: Python) -> PyResult<Py<PyAny>> {
let value = match self {
crate::Constant::None => py.None(),
crate::Constant::Bool(bool) => bool.to_object(py),
crate::Constant::Str(string) => string.to_object(py),
crate::Constant::Bytes(bytes) => PyBytes::new(py, bytes).into(),
crate::Constant::Int(int) => int.to_object(py),
crate::Constant::Tuple(elts) => {
let elts: PyResult<Vec<_>> = elts.iter().map(|c| c.to_pyo3_wrapper(py)).collect();
PyTuple::new(py, elts?).into()
}
crate::Constant::Float(f64) => f64.to_object(py),
crate::Constant::Complex { real, imag } => Complex64::new(*real, *imag).to_object(py),
crate::Constant::Ellipsis => py.Ellipsis(),
};
Ok(value)
}
}
impl<T: ToPyo3Wrapper> ToPyo3Wrapper for Vec<T> {
fn to_pyo3_wrapper(&'static self, py: Python) -> PyResult<Py<PyAny>> {
let list = PyList::empty(py);
for item in self {
let py_item = item.to_pyo3_wrapper(py)?;
list.append(py_item)?;
}
Ok(list.into())
}
}
pub mod located {
use super::*;
pub use crate::pyo3::AST;
include!("gen/pyo3_wrapper_located.rs");
}
pub mod ranged {
use super::*;
pub use crate::pyo3::AST;
include!("gen/pyo3_wrapper_ranged.rs");
}
fn init_type<P: pyo3::PyClass, N: Pyo3Node + Node>(py: Python, m: &PyModule) -> PyResult<()> {
m.add_class::<P>()?;
let node = m.getattr(P::NAME)?;
if P::NAME != N::NAME {
// TODO: no idea how to escape rust keyword on #[pyclass]
m.setattr(P::NAME, node)?;
}
let names: Vec<&'static str> = N::FIELD_NAMES.to_vec();
let fields = PyTuple::new(py, names);
node.setattr("_fields", fields)?;
Ok(())
}
/// A Python module implemented in Rust.
fn init_module(py: Python, m: &PyModule) -> PyResult<()> {
m.add_class::<AST>()?;
let ast = m.getattr("AST")?;
let fields = PyTuple::empty(py);
ast.setattr("_fields", fields)?;
Ok(())
}