to_pyo3_ast to return &'py

This commit is contained in:
Jeong YunWon 2023-05-22 23:55:39 +09:00
parent e1f02fced7
commit b81273e9bc
3 changed files with 1289 additions and 1288 deletions

View file

@ -300,10 +300,13 @@ class StructVisitor(EmitVisitor):
def visitModule(self, mod):
self.emit_attrs(0)
self.emit("""
self.emit(
"""
#[derive(is_macro::Is)]
pub enum Ast<R=TextRange> {
""", 0)
""",
0,
)
for dfn in mod.dfns:
rust_name = rust_type_name(dfn.name)
generics = "" if self.type_info[dfn.name].is_simple else "<R>"
@ -315,23 +318,29 @@ class StructVisitor(EmitVisitor):
# "ast_" prefix to everywhere seems less useful.
self.emit('#[is(name = "module")]', 1)
self.emit(f"{rust_name}({rust_name}{generics}),", 1)
self.emit("""
}
impl<R> Node for Ast<R> {
const NAME: &'static str = "AST";
const FIELD_NAMES: &'static [&'static str] = &[];
}
""", 0)
self.emit(
"""
}
impl<R> Node for Ast<R> {
const NAME: &'static str = "AST";
const FIELD_NAMES: &'static [&'static str] = &[];
}
""",
0,
)
for dfn in mod.dfns:
rust_name = rust_type_name(dfn.name)
generics = "" if self.type_info[dfn.name].is_simple else "<R>"
self.emit(f"""
self.emit(
f"""
impl<R> From<{rust_name}{generics}> for Ast<R> {{
fn from(node: {rust_name}{generics}) -> Self {{
Ast::{rust_name}(node)
}}
}}
""", 0)
""",
0,
)
for dfn in mod.dfns:
self.visit(dfn)
@ -663,9 +672,7 @@ class FoldImplVisitor(EmitVisitor):
cons_type_name = f"{enum_name}{cons.name}"
self.emit(
f"impl<T, U> Foldable<T, U> for {cons_type_name}{apply_t} {{", depth
)
self.emit(f"impl<T, U> Foldable<T, U> for {cons_type_name}{apply_t} {{", depth)
self.emit(f"type Mapped = {cons_type_name}{apply_u};", depth + 1)
self.emit(
"fn fold<F: Fold<T, TargetU = U> + ?Sized>(self, folder: &mut F) -> Result<Self::Mapped, F::Error> {",
@ -1097,7 +1104,7 @@ class ToPyo3AstVisitor(EmitVisitor):
f"""
impl ToPyo3Ast for crate::generic::{rust_name}{self.generics} {{
#[inline]
fn to_pyo3_ast(&self, {"_" if simple else ""}py: Python) -> PyResult<Py<PyAny>> {{
fn to_pyo3_ast<'py>(&self, {"_" if simple else ""}py: Python<'py>) -> PyResult<&'py PyAny> {{
let instance = match &self {{
""",
0,
@ -1130,7 +1137,7 @@ class ToPyo3AstVisitor(EmitVisitor):
f"""
impl ToPyo3Ast for crate::{name}{self.generics} {{
#[inline]
fn to_pyo3_ast(&self, py: Python) -> PyResult<Py<PyAny>> {{
fn to_pyo3_ast<'py>(&self, py: Python<'py>) -> PyResult<&'py PyAny> {{
let cache = Self::py_type_cache().get().unwrap();
""",
0,
@ -1144,10 +1151,37 @@ class ToPyo3AstVisitor(EmitVisitor):
1,
)
self.emit(
"let instance = cache.0.call1(py, (",
"""
let instance = Py::<PyAny>::as_ref(&cache.0, py).call1((
""",
1,
)
for field in cons.fields:
if field.type == "constant":
self.emit(
f"{rust_field(field.name)}.to_object(py),",
3,
)
continue
if field.type == "int":
if field.name == "level":
assert field.opt
self.emit(
f"{rust_field(field.name)}.map_or_else(|| py.None(), |level| level.to_u32().to_object(py)),",
3,
)
continue
if field.name in (
"lineno",
"col_offset",
"end_lineno",
"end_col_offset",
):
self.emit(
f"{rust_field(field.name)}.to_u32().to_object(py),",
3,
)
continue
self.emit(
f"{rust_field(field.name)}.to_pyo3_ast(py)?,",
3,
@ -1158,7 +1192,7 @@ class ToPyo3AstVisitor(EmitVisitor):
)
else:
self.emit(
"let instance = cache.0.call0(py)?;",
"let instance = Py::<PyAny>::as_ref(&cache.0, py).call0()?;",
1,
)
self.emit(
@ -1168,12 +1202,12 @@ class ToPyo3AstVisitor(EmitVisitor):
if type.value.attributes and self.namespace == "located":
self.emit(
"""
let cache = ast_key_cache().get().unwrap();
instance.setattr(py, cache.lineno.as_ref(py), _range.start.row.get())?;
instance.setattr(py, cache.col_offset.as_ref(py), _range.start.column.get())?;
let cache = ast_cache();
instance.setattr(cache.lineno.as_ref(py), _range.start.row.get())?;
instance.setattr(cache.col_offset.as_ref(py), _range.start.column.get())?;
if let Some(end) = _range.end {
instance.setattr(py, cache.end_lineno.as_ref(py), end.row.get())?;
instance.setattr(py, cache.end_col_offset.as_ref(py), end.column.get())?;
instance.setattr(cache.end_lineno.as_ref(py), end.row.get())?;
instance.setattr(cache.end_col_offset.as_ref(py), end.column.get())?;
}
""",
1,
@ -1858,7 +1892,7 @@ def write_to_pyo3_simple(type_info, f):
f"""
impl ToPyo3Ast for crate::generic::{rust_name} {{
#[inline]
fn to_pyo3_ast(&self, _py: Python) -> PyResult<Py<PyAny>> {{
fn to_pyo3_ast<'py>(&self, py: Python<'py>) -> PyResult<&'py PyAny> {{
let cell = match &self {{
""",
)
@ -1869,7 +1903,7 @@ def write_to_pyo3_simple(type_info, f):
f.write(
"""
};
Ok(cell.get().unwrap().1.clone())
Ok(Py::<PyAny>::as_ref(&cell.get().unwrap().1, py))
}
}
""",