diff --git a/compiler/erg_common/opcode311.rs b/compiler/erg_common/opcode311.rs index 1ce9e1d4..a1ff8e81 100644 --- a/compiler/erg_common/opcode311.rs +++ b/compiler/erg_common/opcode311.rs @@ -102,6 +102,8 @@ impl_u8_enum! {Opcode311; PRECALL = 166, CALL = 171, KW_NAMES = 172, + POP_JUMP_BACKWARD_IF_FALSE = 175, + POP_JUMP_BACKWARD_IF_TRUE = 176, // Erg-specific opcodes (must have a unary `ERG_`) // Define in descending order from 219, 255 ERG_POP_NTH = 196, diff --git a/compiler/erg_compiler/codegen.rs b/compiler/erg_compiler/codegen.rs index 16c279aa..7d8fac3c 100644 --- a/compiler/erg_compiler/codegen.rs +++ b/compiler/erg_compiler/codegen.rs @@ -1446,6 +1446,14 @@ impl PyCodeGenerator { self.write_arg(name.idx); } + fn emit_not_instr(&mut self, mut args: Args) { + log!(info "entered {}", fn_name!()); + let expr = args.remove_left_or_key("b").unwrap(); + self.emit_expr(expr); + self.write_instr(UNARY_NOT); + self.write_arg(0); + } + fn emit_discard_instr(&mut self, mut args: Args) { log!(info "entered {}", fn_name!()); while let Some(arg) = args.try_remove(0) { @@ -1524,12 +1532,15 @@ impl PyCodeGenerator { // cannot detect where to jump to at this moment, so put as 0 self.write_arg(0); let lambda = enum_unwrap!(args.remove(0), Expr::Lambda); + // If there is nothing on the stack at the start, init_stack_len == 2 (an iterator and iterator value) let init_stack_len = self.stack_len(); let params = self.gen_param_names(&lambda.params); + // store the iterator value, stack_len == 1 or 2 in the end self.emit_frameless_block(lambda.body, params); - if self.stack_len() >= init_stack_len { + if self.stack_len() > init_stack_len - 1 { self.emit_pop_top(); } + debug_assert_eq!(self.stack_len(), init_stack_len - 1); // the iterator is remained match self.py_version.minor { Some(11) => { self.write_instr(Opcode311::JUMP_BACKWARD); @@ -1567,16 +1578,22 @@ impl PyCodeGenerator { self.emit_pop_top(); } self.emit_expr(cond); - self.write_instr(Opcode310::POP_JUMP_IF_TRUE); - let arg = if self.py_version.minor >= Some(10) { - (idx_while + 2) / 2 + let arg = if self.py_version.minor >= Some(11) { + let arg = self.lasti() - (idx_while + 2); + self.write_instr(Opcode311::POP_JUMP_BACKWARD_IF_TRUE); + arg / 2 + 1 } else { - idx_while + 2 + self.write_instr(Opcode310::POP_JUMP_IF_TRUE); + if self.py_version.minor >= Some(10) { + (idx_while + 2) / 2 + } else { + idx_while + 2 + } }; self.write_arg(arg); self.stack_dec(); let idx_end = if self.py_version.minor >= Some(11) { - self.lasti() - idx_while + self.lasti() - idx_while - 1 } else { self.lasti() }; @@ -1863,6 +1880,7 @@ impl PyCodeGenerator { match &local.inspect()[..] { "assert" => self.emit_assert_instr(args), "Del" => self.emit_del_instr(args), + "not" => self.emit_not_instr(args), "discard" => self.emit_discard_instr(args), "for" | "for!" => self.emit_for_instr(args), "while!" => self.emit_while_instr(args), diff --git a/compiler/erg_compiler/context/initialize/mod.rs b/compiler/erg_compiler/context/initialize/mod.rs index 96d28c29..8cd19d63 100644 --- a/compiler/erg_compiler/context/initialize/mod.rs +++ b/compiler/erg_compiler/context/initialize/mod.rs @@ -1740,6 +1740,9 @@ impl Context { ], NoneType, ); + // e.g. not(b: Bool!): Bool! + let B = mono_q("B", subtypeof(Bool)); + let t_not = nd_func(vec![kw("b", B.clone())], None, B).quantify(); let t_oct = nd_func(vec![kw("x", Int)], None, Str); let t_ord = nd_func(vec![kw("c", Str)], None, Nat); let t_panic = nd_func(vec![kw("err_message", Str)], None, Never); @@ -1792,6 +1795,7 @@ impl Context { ); self.register_builtin_py_impl("len", t_len, Immutable, Private, Some("len")); self.register_builtin_py_impl("log", t_log, Immutable, Private, Some("print")); + self.register_builtin_py_impl("not", t_not, Immutable, Private, None); // `not` is not a function in Python self.register_builtin_py_impl("oct", t_oct, Immutable, Private, Some("oct")); self.register_builtin_py_impl("ord", t_ord, Immutable, Private, Some("ord")); self.register_builtin_py_impl("panic", t_panic, Immutable, Private, Some("quit")); @@ -1934,7 +1938,7 @@ impl Context { let t_locals = proc(vec![], None, vec![], dict! { Str => Obj }.into()); let t_while = nd_proc( vec![ - kw("cond", mono("Bool!")), + kw("cond", Bool), // not Bool! type because `cond` may be the result of evaluation of a mutable object's method returns Bool. kw("p", nd_proc(vec![], None, NoneType)), ], None, diff --git a/compiler/erg_compiler/ty/codeobj.rs b/compiler/erg_compiler/ty/codeobj.rs index efc2210a..15646efa 100644 --- a/compiler/erg_compiler/ty/codeobj.rs +++ b/compiler/erg_compiler/ty/codeobj.rs @@ -601,6 +601,9 @@ impl CodeObj { Opcode311::POP_JUMP_FORWARD_IF_FALSE | Opcode311::POP_JUMP_FORWARD_IF_TRUE => { write!(instrs, "{arg} (to {})", idx + *arg as usize * 2 + 2).unwrap(); } + Opcode311::POP_JUMP_BACKWARD_IF_FALSE | Opcode311::POP_JUMP_BACKWARD_IF_TRUE => { + write!(instrs, "{arg} (to {})", idx - *arg as usize * 2 + 2).unwrap(); + } Opcode311::JUMP_BACKWARD => { write!(instrs, "{arg} (to {})", idx - *arg as usize * 2 + 2).unwrap(); } diff --git a/compiler/erg_compiler/ty/mod.rs b/compiler/erg_compiler/ty/mod.rs index 0ba1f39e..af1cb6bd 100644 --- a/compiler/erg_compiler/ty/mod.rs +++ b/compiler/erg_compiler/ty/mod.rs @@ -2105,7 +2105,7 @@ impl Type { // At least in situations where this function is needed, self cannot be Quantified. Self::Quantified(quant) => { if quant.return_t().unwrap().is_generalized() { - todo!("quantified return type") + todo!("quantified return type (recursive function type inference)") } quant.return_t() } diff --git a/tests/control.er b/tests/control.er new file mode 100644 index 00000000..fbcebb5b --- /dev/null +++ b/tests/control.er @@ -0,0 +1,25 @@ +cond = True +s = if cond: + do "then block" + do "else block" +assert s == "then block" + +if! cond: + do!: + print! "then block!" + do!: + print! "else block!" + +a = 1 +_ = match a: + (i: Int) -> i + _ -> panic "unknown object" + +for! 0..<10, i => + print! "i = {i}" + +counter = !10 +print! counter +while! not(counter == 0), do!: + print! "counter = {counter}" + counter.update!(i -> i - 1) diff --git a/tests/test.rs b/tests/test.rs index 2a779a5d..5af796e2 100644 --- a/tests/test.rs +++ b/tests/test.rs @@ -29,6 +29,11 @@ fn exec_class() -> Result<(), ()> { expect_success("examples/class.er") } +#[test] +fn exec_control() -> Result<(), ()> { + expect_success("tests/control.er") +} + #[test] fn exec_dict() -> Result<(), ()> { expect_success("examples/dict.er")