diff --git a/crates/erg_compiler/codegen.rs b/crates/erg_compiler/codegen.rs index 4cc8836f..cfc59d17 100644 --- a/crates/erg_compiler/codegen.rs +++ b/crates/erg_compiler/codegen.rs @@ -320,6 +320,14 @@ impl PyCodeGenerator { self.units.last().unwrap() } + fn captured_vars(&self) -> Vec<&str> { + let mut caps = vec![]; + for unit in self.units.iter() { + caps.extend(unit.captured_vars.iter().map(|s| &**s)); + } + caps + } + #[inline] fn mut_cur_block(&mut self) -> &mut PyCodeGenUnit { self.units.last_mut().unwrap() @@ -638,23 +646,27 @@ impl PyCodeGenerator { } } if let Some(idx) = self - .cur_block_codeobj() - .names - .iter() - .position(|n| &**n == name) - { - Some(Name::local(idx)) - } else if let Some(idx) = self .cur_block_codeobj() .varnames .iter() .position(|v| &**v == name) { - if self.cur_block().captured_vars.contains(&Str::rc(name)) { + if self.captured_vars().contains(&name) { Some(Name::deref(idx)) } else { Some(Name::fast(idx)) } + } else if let Some(idx) = self + .cur_block_codeobj() + .names + .iter() + .position(|n| &**n == name) + { + if self.captured_vars().contains(&name) { + None + } else { + Some(Name::local(idx)) + } } else { self.cur_block_codeobj() .freevars @@ -3706,6 +3718,9 @@ impl PyCodeGenerator { let mut cells = vec![]; if self.py_version.minor >= Some(11) { for captured in captured_names { + self.mut_cur_block() + .captured_vars + .push(captured.inspect().clone()); self.write_instr(Opcode311::MAKE_CELL); cells.push((captured, self.lasti())); self.write_arg(0); diff --git a/tests/should_ok/closure.er b/tests/should_ok/closure.er index ce4dadf0..95d033f5 100644 --- a/tests/should_ok/closure.er +++ b/tests/should_ok/closure.er @@ -37,3 +37,19 @@ Versions!. vs = Versions!.new() _ = vs.insert! "foo", SemVer.from_str "1.0.0" _ = vs.insert! "foo", SemVer.from_str "1.0.1" + +Triple = Class { .version = SemVer; } +Triple. + new version = Triple { .version; } +.Version! = Class Dict! { Str: Array!(Triple) } +.Version!. + new!() = .Version! !{ "a" : ![Triple.new(SemVer.from_str("0.1.0"))] } + insert!(ref! self, name: Str, version: SemVer) = + if! all(map((triple) -> not(triple.version.compatible_with(version)), self::base[name])), do!: + self::base[name].push!(Triple.new(version)) + +f!() = + vers = .Version!.new!() + vers.insert!("a", SemVer.from_str("0.2.0")) + +f!()