Merge pull request #465 from erg-lang/closure2

fix closure codegen bug
This commit is contained in:
Shunsuke Shibayama 2023-10-24 09:33:28 +09:00 committed by GitHub
commit a2a26e4584
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
10 changed files with 162 additions and 36 deletions

View file

@ -52,11 +52,12 @@ macro_rules! impl_u8_enum {
$crate::impl_display_from_debug!($Enum);
impl From<u8> for $Enum {
fn from(byte: u8) -> Self {
impl TryFrom<u8> for $Enum {
type Error = ();
fn try_from(byte: u8) -> Result<Self, Self::Error> {
match byte {
$(v if v == $Enum::$Variant as u8 => $Enum::$Variant,)*
_ => todo!("unknown: {byte}"),
$($val => Ok($Enum::$Variant),)*
_ => Err(()),
}
}
}
@ -76,11 +77,12 @@ macro_rules! impl_u8_enum {
$crate::impl_display_from_debug!($Enum);
impl From<u8> for $Enum {
fn from(byte: u8) -> Self {
impl TryFrom<u8> for $Enum {
type Error = ();
fn try_from(byte: u8) -> Result<Self, Self::Error> {
match byte {
$($val => $Enum::$Variant,)*
_ => todo!("unknown opcode: {byte}"),
$($val => Ok($Enum::$Variant),)*
_ => Err(()),
}
}
}
@ -106,11 +108,12 @@ macro_rules! impl_u8_enum {
$crate::impl_display_from_debug!($Enum);
impl From<$size> for $Enum {
fn from(byte: $size) -> Self {
impl TryFrom<$size> for $Enum {
type Error = ();
fn try_from(byte: $size) -> Result<Self, Self::Error> {
match byte {
$($val => $Enum::$Variant,)*
_ => todo!("unknown opcode: {byte}"),
$($val => Ok($Enum::$Variant),)*
_ => Err(()),
}
}
}

View file

@ -601,6 +601,16 @@ impl PyCodeGenerator {
}
fn local_search(&self, name: &str, _acc_kind: AccessKind) -> Option<Name> {
if self.py_version.minor < Some(11) {
if let Some(idx) = self
.cur_block_codeobj()
.cellvars
.iter()
.position(|v| &**v == name)
{
return Some(Name::deref(idx));
}
}
if let Some(idx) = self
.cur_block_codeobj()
.names
@ -1398,10 +1408,11 @@ impl PyCodeGenerator {
sig.params.guards,
Some(name.clone()),
params,
sig.captured_names.clone(),
flags,
);
// code.flags += CodeObjFlags::Optimized as u32;
self.register_cellvars(&mut make_function_flag);
self.enclose_vars(&mut make_function_flag);
let n_decos = sig.decorators.len();
for deco in sig.decorators {
self.emit_expr(deco);
@ -1462,9 +1473,10 @@ impl PyCodeGenerator {
lambda.params.guards,
Some(format!("<lambda_{}>", lambda.id).into()),
params,
lambda.captured_names.clone(),
flags,
);
self.register_cellvars(&mut make_function_flag);
self.enclose_vars(&mut make_function_flag);
self.rewrite_captured_fast(&code);
self.emit_load_const(code);
if self.py_version.minor < Some(11) {
@ -1481,7 +1493,7 @@ impl PyCodeGenerator {
}
}
fn register_cellvars(&mut self, flag: &mut usize) {
fn enclose_vars(&mut self, flag: &mut usize) {
if !self.cur_block_codeobj().cellvars.is_empty() {
let cellvars_len = self.cur_block_codeobj().cellvars.len();
let cellvars = self.cur_block_codeobj().cellvars.clone();
@ -1494,8 +1506,6 @@ impl PyCodeGenerator {
.iter()
.position(|n| n == name)
.unwrap();
self.write_instr(Opcode311::MAKE_CELL);
self.write_arg(idx);
self.write_instr(Opcode311::LOAD_CLOSURE);
self.write_arg(idx);
} else {
@ -1537,19 +1547,33 @@ impl PyCodeGenerator {
let cellvars = self.cur_block_codeobj().cellvars.clone();
for cellvar in cellvars {
if code.freevars.iter().any(|n| n == &cellvar) {
let old_idx = self
.cur_block_codeobj()
.varnames
.iter()
.position(|n| n == &cellvar)
.unwrap();
let new_idx = self
.cur_block_codeobj()
.cellvars
.iter()
.position(|n| n == &cellvar)
.unwrap();
self.mut_cur_block().captured_vars.push(cellvar);
let mut op_idx = 0;
while let Some([op, _arg]) = self
while let Some([op, arg]) = self
.mut_cur_block_codeobj()
.code
.get_mut(op_idx..=op_idx + 1)
{
match Opcode310::try_from(*op) {
Ok(Opcode310::LOAD_FAST) => {
Ok(Opcode310::LOAD_FAST) if *arg == old_idx as u8 => {
*op = Opcode310::LOAD_DEREF as u8;
*arg = new_idx as u8;
}
Ok(Opcode310::STORE_FAST) => {
Ok(Opcode310::STORE_FAST) if *arg == old_idx as u8 => {
*op = Opcode310::STORE_DEREF as u8;
*arg = new_idx as u8;
}
_ => {}
}
@ -2991,7 +3015,7 @@ impl PyCodeGenerator {
/// Emits independent code blocks (e.g., linked other modules)
fn emit_code(&mut self, code: Block) {
let mut gen = self.inherit();
let code = gen.emit_block(code, vec![], None, vec![], 0);
let code = gen.emit_block(code, vec![], None, vec![], vec![], 0);
self.emit_load_const(code);
}
@ -3379,6 +3403,7 @@ impl PyCodeGenerator {
bounds,
params,
sig.t_spec_with_op().cloned(),
vec![],
);
let mut attrs = vec![];
match new_first_param.map(|pt| pt.typ()) {
@ -3464,6 +3489,7 @@ impl PyCodeGenerator {
bounds,
params,
sig.t_spec_with_op().cloned(),
vec![],
);
let arg = PosArg::new(Expr::Accessor(Accessor::private_with_line(
param_name, line,
@ -3481,6 +3507,7 @@ impl PyCodeGenerator {
bounds,
params,
sig.t_spec_with_op().cloned(),
vec![],
);
let call = class_new.call_expr(Args::empty());
let block = Block::new(vec![call]);
@ -3495,6 +3522,7 @@ impl PyCodeGenerator {
guards: Vec<GuardClause>,
opt_name: Option<Str>,
params: Vec<Str>,
captured_names: Vec<Identifier>,
flags: u32,
) -> CodeObj {
log!(info "entered {}", fn_name!());
@ -3527,6 +3555,14 @@ impl PyCodeGenerator {
} else {
0
};
let mut cells = vec![];
if self.py_version.minor >= Some(11) {
for captured in captured_names {
self.write_instr(Opcode311::MAKE_CELL);
cells.push((captured, self.lasti()));
self.write_arg(0);
}
}
let init_stack_len = self.stack_len();
for guard in guards {
if let GuardClause::Bind(bind) = guard {
@ -3576,6 +3612,18 @@ impl PyCodeGenerator {
debug_assert_eq!(code, Some(&(Opcode311::COPY_FREE_VARS as u8)));
self.edit_code(idx_copy_free_vars, CommonOpcode::NOP as usize);
}
for (cell, placeholder) in cells {
let name = escape_ident(cell);
let Some(idx) = self
.cur_block_codeobj()
.varnames
.iter()
.position(|v| v == &name)
else {
continue;
};
self.edit_code(placeholder, idx);
}
// end of flagging
let unit = self.units.pop().unwrap();
// increase lineno

View file

@ -1185,6 +1185,7 @@ impl Context {
TypeBoundSpecs::empty(),
params,
None,
vec![],
);
let sig = Signature::Subr(sig);
let call = Identifier::private("p!").call(Args::empty());

View file

@ -38,6 +38,7 @@ use erg_parser::token::Token;
use crate::context::instantiate::TyVarCache;
use crate::context::instantiate_spec::ConstTemplate;
use crate::error::{TyCheckError, TyCheckErrors};
use crate::hir::Identifier;
use crate::module::SharedModuleGraph;
use crate::module::{
SharedCompilerResource, SharedModuleCache, SharedModuleIndex, SharedPromises, SharedTraitImpls,
@ -543,6 +544,7 @@ pub struct Context {
pub(crate) higher_order_caller: Vec<Str>,
pub(crate) guards: Vec<GuardType>,
pub(crate) erg_to_py_names: Dict<Str, Str>,
pub(crate) captured_names: Vec<Identifier>,
pub(crate) level: usize,
}
@ -742,6 +744,7 @@ impl Context {
higher_order_caller: vec![],
guards: vec![],
erg_to_py_names: Dict::default(),
captured_names: vec![],
level,
}
}

View file

@ -346,8 +346,14 @@ impl ASTLowerer {
} else {
None
};
let sig =
hir::SubrSignature::new(decorators, ident, subr.bounds, params, ret_t_spec);
let sig = hir::SubrSignature::new(
decorators,
ident,
subr.bounds,
params,
ret_t_spec,
vec![],
);
Ok(hir::Signature::Subr(sig))
}
}
@ -522,6 +528,7 @@ impl ASTLowerer {
params,
lambda.op,
return_t_spec,
vec![],
body,
Type::Failure,
))

View file

@ -394,6 +394,14 @@ impl Args {
.find(|kw| kw.keyword.inspect() == keyword)
.map(|kw| &kw.expr)
}
pub fn iter(&self) -> impl Iterator<Item = &Expr> {
self.pos_args
.iter()
.map(|pos| &pos.expr)
.chain(self.var_args.iter().map(|var| &var.expr))
.chain(self.kw_args.iter().map(|kw| &kw.expr))
}
}
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
@ -1946,6 +1954,7 @@ pub struct SubrSignature {
pub bounds: TypeBoundSpecs,
pub params: Params,
pub return_t_spec: Option<TypeSpecWithOp>,
pub captured_names: Vec<Identifier>,
}
impl NestedDisplay for SubrSignature {
@ -1999,6 +2008,7 @@ impl SubrSignature {
bounds: TypeBoundSpecs,
params: Params,
return_t_spec: Option<TypeSpecWithOp>,
captured_names: Vec<Identifier>,
) -> Self {
Self {
decorators,
@ -2006,6 +2016,7 @@ impl SubrSignature {
bounds,
params,
return_t_spec,
captured_names,
}
}
@ -2023,6 +2034,7 @@ pub struct Lambda {
pub params: Params,
pub op: Token,
pub return_t_spec: Option<TypeSpec>,
pub captured_names: Vec<Identifier>,
pub body: Block,
pub id: usize,
pub t: Type,
@ -2056,6 +2068,7 @@ impl Lambda {
params: Params,
op: Token,
return_t_spec: Option<TypeSpec>,
captured_names: Vec<Identifier>,
body: Block,
t: Type,
) -> Self {
@ -2064,6 +2077,7 @@ impl Lambda {
params,
op,
return_t_spec,
captured_names,
body,
t,
}

View file

@ -915,6 +915,12 @@ impl ASTLowerer {
}
}
let ident = hir::Identifier::new(ident, __name__, vi);
if !ident.vi.is_toplevel()
&& ident.vi.def_namespace() != &self.module.context.name
&& ident.vi.kind.can_capture()
{
self.module.context.captured_names.push(ident.clone());
}
Ok(ident)
}
@ -1653,6 +1659,7 @@ impl ASTLowerer {
let default_param_tys = default_params
.map(|(name, vi)| ParamTy::kw(name.as_ref().unwrap().inspect().clone(), vi.t.clone()))
.collect();
let captured_names = mem::take(&mut self.module.context.captured_names);
if in_statement {
// For example, `i` in `for i in ...` is a parameter,
// but should be treated as a local variable in the later analysis, so move it to locals
@ -1702,6 +1709,7 @@ impl ASTLowerer {
params,
lambda.op,
return_t_spec,
captured_names,
body,
t,
))
@ -1936,8 +1944,14 @@ impl ASTLowerer {
} else {
None
};
let captured_names = mem::take(&mut self.module.context.captured_names);
let sig = hir::SubrSignature::new(
decorators, ident, sig.bounds, params, ret_t_spec,
decorators,
ident,
sig.bounds,
params,
ret_t_spec,
captured_names,
);
let body = hir::DefBody::new(body.op, block, body.id);
Ok(hir::Def::new(hir::Signature::Subr(sig), body))
@ -1964,8 +1978,14 @@ impl ASTLowerer {
} else {
None
};
let captured_names = mem::take(&mut self.module.context.captured_names);
let sig = hir::SubrSignature::new(
decorators, ident, sig.bounds, params, ret_t_spec,
decorators,
ident,
sig.bounds,
params,
ret_t_spec,
captured_names,
);
let block =
hir::Block::new(vec![hir::Expr::Dummy(hir::Dummy::new(vec![]))]);
@ -1994,8 +2014,15 @@ impl ASTLowerer {
} else {
None
};
let sig =
hir::SubrSignature::new(decorators, ident, sig.bounds, params, ret_t_spec);
let captured_names = mem::take(&mut self.module.context.captured_names);
let sig = hir::SubrSignature::new(
decorators,
ident,
sig.bounds,
params,
ret_t_spec,
captured_names,
);
let body = hir::DefBody::new(body.op, block, body.id);
Ok(hir::Def::new(hir::Signature::Subr(sig), body))
}

View file

@ -39,9 +39,9 @@ pub fn consts_into_bytes(consts: Vec<ValueObj>, python_ver: PythonVersion) -> Ve
pub fn jump_abs_addr(minor_ver: u8, op: u8, idx: usize, arg: usize) -> usize {
match minor_ver {
7..=9 => jump_abs_addr_309(Opcode309::from(op), idx, arg),
10 => jump_abs_addr_310(Opcode310::from(op), idx, arg),
11 => jump_abs_addr_311(Opcode311::from(op), idx, arg),
7..=9 => jump_abs_addr_309(Opcode309::try_from(op).unwrap(), idx, arg),
10 => jump_abs_addr_310(Opcode310::try_from(op).unwrap(), idx, arg),
11 => jump_abs_addr_311(Opcode311::try_from(op).unwrap(), idx, arg),
n => todo!("unsupported version: {n}"),
}
}
@ -601,7 +601,7 @@ impl CodeObj {
}
fn read_instr_308(&self, op: &u8, arg: usize, idx: usize, instrs: &mut String) {
let op308 = Opcode308::from(*op);
let op308 = Opcode308::try_from(*op).unwrap();
let s_op = op308.to_string();
write!(instrs, "{idx:>15} {s_op:<25}").unwrap();
if let Ok(op) = CommonOpcode::try_from(*op) {
@ -642,7 +642,7 @@ impl CodeObj {
}
fn read_instr_309(&self, op: &u8, arg: usize, idx: usize, instrs: &mut String) {
let op309 = Opcode309::from(*op);
let op309 = Opcode309::try_from(*op).unwrap();
let s_op = op309.to_string();
write!(instrs, "{idx:>15} {s_op:<25}").unwrap();
if let Ok(op) = CommonOpcode::try_from(*op) {
@ -683,7 +683,7 @@ impl CodeObj {
}
fn read_instr_310(&self, op: &u8, arg: usize, idx: usize, instrs: &mut String) {
let op310 = Opcode310::from(*op);
let op310 = Opcode310::try_from(*op).unwrap();
let s_op = op310.to_string();
write!(instrs, "{idx:>15} {s_op:<25}").unwrap();
if let Ok(op) = CommonOpcode::try_from(*op) {
@ -734,7 +734,7 @@ impl CodeObj {
}
fn read_instr_311(&self, op: &u8, arg: usize, idx: usize, instrs: &mut String) {
let op311 = Opcode311::from(*op);
let op311 = Opcode311::try_from(*op).unwrap();
let s_op = op311.to_string();
write!(instrs, "{idx:>15} {s_op:<26}").unwrap();
if let Ok(op) = CommonOpcode::try_from(*op) {
@ -770,7 +770,12 @@ impl CodeObj {
write!(instrs, "{arg} ({})", self.consts.get(arg).unwrap()).unwrap();
}
Opcode311::BINARY_OP => {
write!(instrs, "{arg} ({:?})", BinOpCode::from(arg as u8)).unwrap();
write!(
instrs,
"{arg} ({:?})",
BinOpCode::try_from(arg as u8).unwrap()
)
.unwrap();
}
_ => {}
}

View file

@ -83,6 +83,13 @@ impl VarKind {
matches!(self, Self::Defined(_))
}
pub const fn can_capture(&self) -> bool {
matches!(
self,
Self::Defined(_) | Self::Declared | Self::Parameter { .. }
)
}
pub const fn does_not_exist(&self) -> bool {
matches!(self, Self::DoesNotExist)
}
@ -446,7 +453,8 @@ impl VarInfo {
}
pub fn is_toplevel(&self) -> bool {
self.vis.def_namespace.split_with(&[".", "::"]).len() == 1
let ns = Str::rc(self.vis.def_namespace.trim_start_matches("./"));
ns.split_with(&[".", "::"]).len() == 1
}
pub fn is_fast_value(&self) -> bool {

View file

@ -2,3 +2,13 @@ func vers: Array(Int), version: Int =
all map(v -> v == version, vers)
assert func([1, 1], 1)
func2! version: Int =
arr = ![]
f!() =
arr.push! version
f!()
arr
arr = func2!(1)
assert arr[0] == 1