diff --git a/crates/erg_compiler/context/compare.rs b/crates/erg_compiler/context/compare.rs index 5188323f..45377aef 100644 --- a/crates/erg_compiler/context/compare.rs +++ b/crates/erg_compiler/context/compare.rs @@ -383,9 +383,9 @@ impl Context { (Some((_, l_sup)), Some((r_sub, _))) => self.supertype_of(&l_sup, &r_sub), _ => { if lfv.is_linked() { - self.supertype_of(&lfv.crack(), rhs) + self.supertype_of(lfv.unsafe_crack(), rhs) } else if rfv.is_linked() { - self.supertype_of(lhs, &rfv.crack()) + self.supertype_of(lhs, rfv.unsafe_crack()) } else { false } @@ -1418,7 +1418,8 @@ impl Context { Not(t) => *t.clone(), Refinement(r) => Type::Refinement(r.clone().invert()), Guard(guard) => Type::Guard(GuardType::new( - guard.var.clone(), + guard.namespace.clone(), + guard.target.clone(), self.complement(&guard.to), )), Or(l, r) => self.intersection(&self.complement(l), &self.complement(r)), diff --git a/crates/erg_compiler/context/generalize.rs b/crates/erg_compiler/context/generalize.rs index a7f0597b..f06ac46b 100644 --- a/crates/erg_compiler/context/generalize.rs +++ b/crates/erg_compiler/context/generalize.rs @@ -1092,9 +1092,9 @@ impl Context { .unbound_name() .map_or(false, |name| !qnames.contains(&name)) { - let t = mem::take(acc.ref_mut_t()); + let t = mem::take(acc.ref_mut_t().unwrap()); let mut dereferencer = Dereferencer::simple(self, qnames, acc); - *acc.ref_mut_t() = dereferencer.deref_tyvar(t)?; + *acc.ref_mut_t().unwrap() = dereferencer.deref_tyvar(t)?; } if let hir::Accessor::Attr(attr) = acc { self.resolve_expr_t(&mut attr.obj, qnames)?; @@ -1181,10 +1181,10 @@ impl Context { let mut dereferencer = Dereferencer::simple(self, qnames, record); record.t = dereferencer.deref_tyvar(t)?; for attr in record.attrs.iter_mut() { - let t = mem::take(attr.sig.ref_mut_t()); + let t = mem::take(attr.sig.ref_mut_t().unwrap()); let mut dereferencer = Dereferencer::simple(self, qnames, &attr.sig); let t = dereferencer.deref_tyvar(t)?; - *attr.sig.ref_mut_t() = t; + *attr.sig.ref_mut_t().unwrap() = t; for chunk in attr.body.block.iter_mut() { self.resolve_expr_t(chunk, qnames)?; } @@ -1232,9 +1232,9 @@ impl Context { } else { qnames.clone() }; - let t = mem::take(def.sig.ref_mut_t()); + let t = mem::take(def.sig.ref_mut_t().unwrap()); let mut dereferencer = Dereferencer::simple(self, &qnames, &def.sig); - *def.sig.ref_mut_t() = dereferencer.deref_tyvar(t)?; + *def.sig.ref_mut_t().unwrap() = dereferencer.deref_tyvar(t)?; if let Some(params) = def.sig.params_mut() { self.resolve_params_t(params, &qnames)?; } diff --git a/crates/erg_compiler/context/inquire.rs b/crates/erg_compiler/context/inquire.rs index 619edff4..90e7ceb9 100644 --- a/crates/erg_compiler/context/inquire.rs +++ b/crates/erg_compiler/context/inquire.rs @@ -2873,6 +2873,14 @@ impl Context { } } + pub(crate) fn rec_get_guards(&self) -> Vec<&GuardType> { + if let Some(outer) = self.get_outer() { + [self.guards.iter().collect(), outer.rec_get_guards()].concat() + } else { + self.guards.iter().collect() + } + } + // TODO: `Override` decorator should also be used /// e.g. /// ```erg @@ -3203,9 +3211,9 @@ impl Context { return Err(TyCheckErrors::from(TyCheckError::invalid_type_cast_error( self.cfg.input.clone(), line!() as usize, - guard.var.loc(), + guard.target.loc(), self.caused_by(), - &guard.var.to_string(), + &guard.target.to_string(), base, &guard.to, None, @@ -3229,9 +3237,9 @@ impl Context { Err(TyCheckErrors::from(TyCheckError::invalid_type_cast_error( self.cfg.input.clone(), line!() as usize, - guard.var.loc(), + guard.target.loc(), self.caused_by(), - &guard.var.to_string(), + &guard.target.to_string(), base, &guard.to, None, diff --git a/crates/erg_compiler/context/register.rs b/crates/erg_compiler/context/register.rs index dae7a10f..a083c972 100644 --- a/crates/erg_compiler/context/register.rs +++ b/crates/erg_compiler/context/register.rs @@ -35,7 +35,7 @@ use crate::ty::free::{Constraint, HasLevel}; use crate::ty::typaram::TyParam; use crate::ty::value::{GenTypeObj, TypeObj, ValueObj}; use crate::ty::{ - Field, GuardType, HasType, ParamTy, SubrType, Type, Variable, Visibility, VisibilityModifier, + CastTarget, Field, GuardType, HasType, ParamTy, SubrType, Type, Visibility, VisibilityModifier, }; use crate::build_hir::HIRBuilder; @@ -2272,49 +2272,61 @@ impl Context { } } + pub(crate) fn get_casted_type(&self, expr: &ast::Expr) -> Option { + for guard in self.rec_get_guards() { + if !self.name.starts_with(&guard.namespace[..]) { + continue; + } + if let CastTarget::Expr(target) = &guard.target { + if expr == target.as_ref() { + return Some(*guard.to.clone()); + } + } + } + None + } + pub(crate) fn cast( &mut self, guard: GuardType, overwritten: &mut Vec<(VarName, VarInfo)>, ) -> TyCheckResult<()> { - if let Variable::Var { - namespace, name, .. - } = &guard.var - { - if !self.name.starts_with(&namespace[..]) { - return Ok(()); - } - let vi = if let Some((name, vi)) = self.locals.remove_entry(name) { - overwritten.push((name, vi.clone())); - vi - } else if let Some((n, vi)) = self.get_var_kv(name) { - overwritten.push((n.clone(), vi.clone())); - vi.clone() - } else { - VarInfo::nd_parameter( - *guard.to.clone(), - self.absolutize(().loc()), - self.name.clone(), - ) - }; - match self.recover_typarams(&vi.t, &guard) { - Ok(t) => { - self.locals - .insert(VarName::from_str(name.clone()), VarInfo { t, ..vi }); + match &guard.target { + CastTarget::Var { name, .. } => { + if !self.name.starts_with(&guard.namespace[..]) { + return Ok(()); } - Err(errs) => { - self.locals.insert(VarName::from_str(name.clone()), vi); - return Err(errs); + let vi = if let Some((name, vi)) = self.locals.remove_entry(name) { + overwritten.push((name, vi.clone())); + vi + } else if let Some((n, vi)) = self.get_var_kv(name) { + overwritten.push((n.clone(), vi.clone())); + vi.clone() + } else { + VarInfo::nd_parameter( + *guard.to.clone(), + self.absolutize(().loc()), + self.name.clone(), + ) + }; + match self.recover_typarams(&vi.t, &guard) { + Ok(t) => { + self.locals + .insert(VarName::from_str(name.clone()), VarInfo { t, ..vi }); + } + Err(errs) => { + self.locals.insert(VarName::from_str(name.clone()), vi); + return Err(errs); + } } } - } /* else { - return Err(TyCheckErrors::from(TyCheckError::feature_error( - self.cfg.input.clone(), - guard.var.loc(), - &format!("casting {}", guard.var), - self.caused_by(), - ))); - } */ + CastTarget::Param { .. } => { + // TODO: + } + CastTarget::Expr(_) => { + self.guards.push(guard); + } + } Ok(()) } diff --git a/crates/erg_compiler/hir.rs b/crates/erg_compiler/hir.rs index f8458757..5f62c451 100644 --- a/crates/erg_compiler/hir.rs +++ b/crates/erg_compiler/hir.rs @@ -413,8 +413,8 @@ impl HasType for Identifier { &self.vi.t } #[inline] - fn ref_mut_t(&mut self) -> &mut Type { - &mut self.vi.t + fn ref_mut_t(&mut self) -> Option<&mut Type> { + Some(&mut self.vi.t) } #[inline] fn signature_t(&self) -> Option<&Type> { @@ -545,7 +545,7 @@ impl HasType for Attribute { self.ident.ref_t() } #[inline] - fn ref_mut_t(&mut self) -> &mut Type { + fn ref_mut_t(&mut self) -> Option<&mut Type> { self.ident.ref_mut_t() } #[inline] @@ -1221,8 +1221,8 @@ impl HasType for BinOp { fn ref_t(&self) -> &Type { self.info.t.return_t().unwrap() } - fn ref_mut_t(&mut self) -> &mut Type { - self.info.t.mut_return_t().unwrap() + fn ref_mut_t(&mut self) -> Option<&mut Type> { + self.info.t.mut_return_t() } #[inline] fn lhs_t(&self) -> &Type { @@ -1268,8 +1268,8 @@ impl HasType for UnaryOp { fn ref_t(&self) -> &Type { self.info.t.return_t().unwrap() } - fn ref_mut_t(&mut self) -> &mut Type { - self.info.t.mut_return_t().unwrap() + fn ref_mut_t(&mut self) -> Option<&mut Type> { + self.info.t.mut_return_t() } #[inline] fn lhs_t(&self) -> &Type { @@ -1357,11 +1357,11 @@ impl HasType for Call { } } #[inline] - fn ref_mut_t(&mut self) -> &mut Type { + fn ref_mut_t(&mut self) -> Option<&mut Type> { if let Some(attr) = self.attr_name.as_mut() { - attr.ref_mut_t().mut_return_t().unwrap() + attr.ref_mut_t()?.mut_return_t() } else { - self.obj.ref_mut_t().mut_return_t().unwrap() + self.obj.ref_mut_t()?.mut_return_t() } } #[inline] @@ -1391,12 +1391,12 @@ impl HasType for Call { #[inline] fn signature_mut_t(&mut self) -> Option<&mut Type> { if let Some(attr) = self.attr_name.as_mut() { - Some(attr.ref_mut_t()) + attr.ref_mut_t() } else { if let Expr::Call(call) = self.obj.as_ref() { call.return_t()?; } - Some(self.obj.ref_mut_t()) + self.obj.ref_mut_t() } } } @@ -1466,8 +1466,8 @@ impl HasType for Block { .unwrap_or(Type::FAILURE) } #[inline] - fn ref_mut_t(&mut self) -> &mut Type { - self.last_mut().unwrap().ref_mut_t() + fn ref_mut_t(&mut self) -> Option<&mut Type> { + self.last_mut()?.ref_mut_t() } #[inline] fn t(&self) -> Type { @@ -1517,8 +1517,8 @@ impl HasType for Dummy { Type::FAILURE } #[inline] - fn ref_mut_t(&mut self) -> &mut Type { - todo!() + fn ref_mut_t(&mut self) -> Option<&mut Type> { + None } #[inline] fn t(&self) -> Type { @@ -1583,7 +1583,7 @@ impl HasType for VarSignature { self.ident.ref_t() } #[inline] - fn ref_mut_t(&mut self) -> &mut Type { + fn ref_mut_t(&mut self) -> Option<&mut Type> { self.ident.ref_mut_t() } #[inline] @@ -1874,7 +1874,7 @@ impl HasType for SubrSignature { self.ident.ref_t() } #[inline] - fn ref_mut_t(&mut self) -> &mut Type { + fn ref_mut_t(&mut self) -> Option<&mut Type> { self.ident.ref_mut_t() } #[inline] @@ -2105,8 +2105,8 @@ impl HasType for Def { Type::NONE } #[inline] - fn ref_mut_t(&mut self) -> &mut Type { - todo!() + fn ref_mut_t(&mut self) -> Option<&mut Type> { + None } #[inline] fn signature_t(&self) -> Option<&Type> { @@ -2188,8 +2188,8 @@ impl HasType for Methods { Type::NONE } #[inline] - fn ref_mut_t(&mut self) -> &mut Type { - todo!() + fn ref_mut_t(&mut self) -> Option<&mut Type> { + None } #[inline] fn signature_t(&self) -> Option<&Type> { @@ -2242,8 +2242,8 @@ impl HasType for ClassDef { Type::NONE } #[inline] - fn ref_mut_t(&mut self) -> &mut Type { - todo!() + fn ref_mut_t(&mut self) -> Option<&mut Type> { + None } #[inline] fn signature_t(&self) -> Option<&Type> { @@ -2312,8 +2312,8 @@ impl HasType for PatchDef { Type::NONE } #[inline] - fn ref_mut_t(&mut self) -> &mut Type { - todo!() + fn ref_mut_t(&mut self) -> Option<&mut Type> { + None } #[inline] fn signature_t(&self) -> Option<&Type> { @@ -2368,8 +2368,8 @@ impl HasType for ReDef { Type::NONE } #[inline] - fn ref_mut_t(&mut self) -> &mut Type { - todo!() + fn ref_mut_t(&mut self) -> Option<&mut Type> { + None } #[inline] fn signature_t(&self) -> Option<&Type> { @@ -2455,9 +2455,9 @@ impl HasType for TypeAscription { } } #[inline] - fn ref_mut_t(&mut self) -> &mut Type { + fn ref_mut_t(&mut self) -> Option<&mut Type> { if self.spec.kind().is_force_cast() { - &mut self.spec.spec_t + Some(&mut self.spec.spec_t) } else { self.expr.ref_mut_t() } diff --git a/crates/erg_compiler/lower.rs b/crates/erg_compiler/lower.rs index 6ea09360..4b1ecee6 100644 --- a/crates/erg_compiler/lower.rs +++ b/crates/erg_compiler/lower.rs @@ -31,7 +31,7 @@ use crate::ty::constructors::{ use crate::ty::free::Constraint; use crate::ty::typaram::TyParam; use crate::ty::value::{GenTypeObj, TypeObj, ValueObj}; -use crate::ty::{GuardType, HasType, ParamTy, Predicate, Type, Variable, VisibilityModifier}; +use crate::ty::{CastTarget, GuardType, HasType, ParamTy, Predicate, Type, VisibilityModifier}; use crate::context::{ ClassDefType, Context, ContextKind, ContextProvider, ControlKind, ModuleContext, @@ -50,26 +50,13 @@ use crate::{feature_error, unreachable_error}; use VisibilityModifier::*; -pub fn acc_to_variable(namespace: Str, acc: &ast::Accessor) -> Option { - match acc { - ast::Accessor::Ident(ident) => Some(Variable::Var { - namespace, +pub fn expr_to_cast_target(expr: &ast::Expr) -> CastTarget { + match expr { + ast::Expr::Accessor(ast::Accessor::Ident(ident)) => CastTarget::Var { name: ident.inspect().clone(), loc: ident.loc(), - }), - ast::Accessor::Attr(attr) => Some(Variable::attr( - expr_to_variable(namespace, &attr.obj)?, - attr.ident.inspect().clone(), - attr.loc(), - )), - _ => None, - } -} - -pub fn expr_to_variable(namespace: Str, expr: &ast::Expr) -> Option { - match expr { - ast::Expr::Accessor(acc) => acc_to_variable(namespace, acc), - _ => None, + }, + _ => CastTarget::expr(expr.clone()), } } @@ -746,28 +733,29 @@ impl ASTLowerer { } fn get_guard_type(&self, op: &Token, lhs: &ast::Expr, rhs: &ast::Expr) -> Option { - let var = if op.kind == TokenKind::ContainsOp { - expr_to_variable(self.module.context.name.clone(), rhs)? + let target = if op.kind == TokenKind::ContainsOp { + expr_to_cast_target(rhs) } else { - expr_to_variable(self.module.context.name.clone(), lhs)? + expr_to_cast_target(lhs) }; + let namespace = self.module.context.name.clone(); match op.kind { // l in T -> T contains l TokenKind::ContainsOp => { let to = self.module.context.expr_to_type(lhs.clone())?; - Some(guard(var, to)) + Some(guard(namespace, target, to)) } TokenKind::Symbol if &op.content[..] == "isinstance" => { let to = self.module.context.expr_to_type(rhs.clone())?; - Some(guard(var, to)) + Some(guard(namespace, target, to)) } TokenKind::IsOp | TokenKind::DblEq => { let value = self.module.context.expr_to_value(rhs.clone())?; - Some(guard(var, v_enum(set! { value }))) + Some(guard(namespace, target, v_enum(set! { value }))) } TokenKind::IsNotOp | TokenKind::NotEq => { let value = self.module.context.expr_to_value(rhs.clone())?; - let ty = guard(var, v_enum(set! { value })); + let ty = guard(namespace, target, v_enum(set! { value })); Some(self.module.context.complement(&ty)) } TokenKind::Gre => { @@ -776,7 +764,7 @@ impl ASTLowerer { let varname = self.fresh_gen.fresh_varname(); let pred = Predicate::gt(varname.clone(), TyParam::value(value)); let refine = refinement(varname, t, pred); - Some(guard(var, refine)) + Some(guard(namespace, target, refine)) } TokenKind::GreEq => { let value = self.module.context.expr_to_value(rhs.clone())?; @@ -784,7 +772,7 @@ impl ASTLowerer { let varname = self.fresh_gen.fresh_varname(); let pred = Predicate::ge(varname.clone(), TyParam::value(value)); let refine = refinement(varname, t, pred); - Some(guard(var, refine)) + Some(guard(namespace, target, refine)) } TokenKind::Less => { let value = self.module.context.expr_to_value(rhs.clone())?; @@ -792,7 +780,7 @@ impl ASTLowerer { let varname = self.fresh_gen.fresh_varname(); let pred = Predicate::lt(varname.clone(), TyParam::value(value)); let refine = refinement(varname, t, pred); - Some(guard(var, refine)) + Some(guard(namespace, target, refine)) } TokenKind::LessEq => { let value = self.module.context.expr_to_value(rhs.clone())?; @@ -800,7 +788,7 @@ impl ASTLowerer { let varname = self.fresh_gen.fresh_varname(); let pred = Predicate::le(varname.clone(), TyParam::value(value)); let refine = refinement(varname, t, pred); - Some(guard(var, refine)) + Some(guard(namespace, target, refine)) } _ => None, } @@ -930,7 +918,8 @@ impl ASTLowerer { } 1 if kind.is_if() => { let guard = GuardType::new( - guard.var.clone(), + guard.namespace.clone(), + guard.target.clone(), self.module.context.complement(&guard.to), ); self.module.context.guards.push(guard); @@ -1015,10 +1004,10 @@ impl ASTLowerer { } else { if let hir::Expr::Call(call) = &obj { if call.return_t().is_some() { - *obj.ref_mut_t() = vi.t; + *obj.ref_mut_t().unwrap() = vi.t; } } else { - *obj.ref_mut_t() = vi.t; + *obj.ref_mut_t().unwrap() = vi.t; } None }; @@ -1964,7 +1953,7 @@ impl ASTLowerer { ) { Err(err) => self.errs.push(err), Ok(_) => { - *attr.ref_mut_t() = derefined.clone(); + *attr.ref_mut_t().unwrap() = derefined.clone(); if let hir::Accessor::Ident(ident) = &attr { if let Some(vi) = self .module @@ -2426,27 +2415,36 @@ impl ASTLowerer { // so turn off type checking (check=false) fn lower_expr(&mut self, expr: ast::Expr) -> LowerResult { log!(info "entered {}", fn_name!()); - match expr { - ast::Expr::Literal(lit) => Ok(hir::Expr::Lit(self.lower_literal(lit)?)), - ast::Expr::Array(arr) => Ok(hir::Expr::Array(self.lower_array(arr)?)), - ast::Expr::Tuple(tup) => Ok(hir::Expr::Tuple(self.lower_tuple(tup)?)), - ast::Expr::Record(rec) => Ok(hir::Expr::Record(self.lower_record(rec)?)), - ast::Expr::Set(set) => Ok(hir::Expr::Set(self.lower_set(set)?)), - ast::Expr::Dict(dict) => Ok(hir::Expr::Dict(self.lower_dict(dict)?)), - ast::Expr::Accessor(acc) => Ok(hir::Expr::Accessor(self.lower_acc(acc)?)), - ast::Expr::BinOp(bin) => Ok(hir::Expr::BinOp(self.lower_bin(bin))), - ast::Expr::UnaryOp(unary) => Ok(hir::Expr::UnaryOp(self.lower_unary(unary))), - ast::Expr::Call(call) => Ok(hir::Expr::Call(self.lower_call(call)?)), - ast::Expr::DataPack(pack) => Ok(hir::Expr::Call(self.lower_pack(pack)?)), - ast::Expr::Lambda(lambda) => Ok(hir::Expr::Lambda(self.lower_lambda(lambda)?)), - ast::Expr::TypeAscription(tasc) => Ok(hir::Expr::TypeAsc(self.lower_type_asc(tasc)?)), + let casted = self.module.context.get_casted_type(&expr); + let mut expr = match expr { + ast::Expr::Literal(lit) => hir::Expr::Lit(self.lower_literal(lit)?), + ast::Expr::Array(arr) => hir::Expr::Array(self.lower_array(arr)?), + ast::Expr::Tuple(tup) => hir::Expr::Tuple(self.lower_tuple(tup)?), + ast::Expr::Record(rec) => hir::Expr::Record(self.lower_record(rec)?), + ast::Expr::Set(set) => hir::Expr::Set(self.lower_set(set)?), + ast::Expr::Dict(dict) => hir::Expr::Dict(self.lower_dict(dict)?), + ast::Expr::Accessor(acc) => hir::Expr::Accessor(self.lower_acc(acc)?), + ast::Expr::BinOp(bin) => hir::Expr::BinOp(self.lower_bin(bin)), + ast::Expr::UnaryOp(unary) => hir::Expr::UnaryOp(self.lower_unary(unary)), + ast::Expr::Call(call) => hir::Expr::Call(self.lower_call(call)?), + ast::Expr::DataPack(pack) => hir::Expr::Call(self.lower_pack(pack)?), + ast::Expr::Lambda(lambda) => hir::Expr::Lambda(self.lower_lambda(lambda)?), + ast::Expr::TypeAscription(tasc) => hir::Expr::TypeAsc(self.lower_type_asc(tasc)?), // Checking is also performed for expressions in Dummy. However, it has no meaning in code generation - ast::Expr::Dummy(dummy) => Ok(hir::Expr::Dummy(self.lower_dummy(dummy)?)), + ast::Expr::Dummy(dummy) => hir::Expr::Dummy(self.lower_dummy(dummy)?), other => { log!(err "unreachable: {other}"); - unreachable_error!(LowerErrors, LowerError, self.module.context) + return unreachable_error!(LowerErrors, LowerError, self.module.context); + } + }; + if let Some(casted) = casted { + if self.module.context.subtype_of(&casted, expr.ref_t()) { + if let Some(ref_mut_t) = expr.ref_mut_t() { + *ref_mut_t = casted; + } } } + Ok(expr) } /// The meaning of TypeAscription changes between chunk and expr. diff --git a/crates/erg_compiler/ty/codeobj.rs b/crates/erg_compiler/ty/codeobj.rs index 156bae77..71124c2e 100644 --- a/crates/erg_compiler/ty/codeobj.rs +++ b/crates/erg_compiler/ty/codeobj.rs @@ -210,8 +210,8 @@ impl HasType for CodeObj { fn ref_t(&self) -> &Type { &Type::Code } - fn ref_mut_t(&mut self) -> &mut Type { - todo!() + fn ref_mut_t(&mut self) -> Option<&mut Type> { + None } fn signature_t(&self) -> Option<&Type> { None diff --git a/crates/erg_compiler/ty/constructors.rs b/crates/erg_compiler/ty/constructors.rs index c5ddb766..ee132193 100644 --- a/crates/erg_compiler/ty/constructors.rs +++ b/crates/erg_compiler/ty/constructors.rs @@ -474,8 +474,8 @@ pub fn not(ty: Type) -> Type { Type::Not(Box::new(ty)) } -pub fn guard(var: Variable, to: Type) -> Type { - Type::Guard(GuardType::new(var, to)) +pub fn guard(namespace: Str, target: CastTarget, to: Type) -> Type { + Type::Guard(GuardType::new(namespace, target, to)) } pub fn bounded(sub: Type, sup: Type) -> Type { diff --git a/crates/erg_compiler/ty/mod.rs b/crates/erg_compiler/ty/mod.rs index f48fd5a3..0bdf1afd 100644 --- a/crates/erg_compiler/ty/mod.rs +++ b/crates/erg_compiler/ty/mod.rs @@ -31,6 +31,7 @@ use erg_common::set::Set; use erg_common::traits::{LimitedDisplay, Locational, StructuralEq}; use erg_common::{enum_unwrap, fmt_option, ref_addr_eq, set, Str}; +use erg_parser::ast::Expr; use erg_parser::token::TokenKind; pub use const_subr::*; @@ -59,8 +60,7 @@ pub trait HasType { // 関数呼び出しの場合、.ref_t()は戻り値を返し、signature_t()は関数全体の型を返す fn signature_t(&self) -> Option<&Type>; // 最後にHIR全体の型変数を消すために使う - /// `x.ref_mut_t()` may panic, in which case `x` is `Call` and `x.ref_t() == Type::Failure`. - fn ref_mut_t(&mut self) -> &mut Type; + fn ref_mut_t(&mut self) -> Option<&mut Type>; fn signature_mut_t(&mut self) -> Option<&mut Type>; #[inline] fn t(&self) -> Type { @@ -89,8 +89,8 @@ macro_rules! impl_t { &self.t } #[inline] - fn ref_mut_t(&mut self) -> &mut Type { - &mut self.t + fn ref_mut_t(&mut self) -> Option<&mut Type> { + Some(&mut self.t) } #[inline] fn signature_t(&self) -> Option<&Type> { @@ -109,7 +109,7 @@ macro_rules! impl_t { &self.$attr.ref_t() } #[inline] - fn ref_mut_t(&mut self) -> &mut Type { + fn ref_mut_t(&mut self) -> Option<&mut Type> { self.$attr.ref_mut_t() } #[inline] @@ -133,7 +133,7 @@ macro_rules! impl_t_for_enum { $($Enum::$Variant(v) => v.ref_t(),)* } } - fn ref_mut_t(&mut self) -> &mut Type { + fn ref_mut_t(&mut self) -> Option<&mut Type> { match self { $($Enum::$Variant(v) => v.ref_mut_t(),)* } @@ -752,80 +752,74 @@ impl ArgsOwnership { } #[derive(Debug, Clone, PartialEq, Eq, Hash)] -pub enum Variable { +pub enum CastTarget { Param { nth: usize, name: Str, loc: Location, }, Var { - namespace: Str, name: Str, loc: Location, }, - Attr { - receiver: Box, - attr: Str, - loc: Location, - }, + // NOTE: `Expr(Expr)` causes a bad memory access error + Expr(Box), } -impl fmt::Display for Variable { +impl fmt::Display for CastTarget { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { match self { Self::Param { nth, name, .. } => write!(f, "{name}#{nth}"), Self::Var { name, .. } => write!(f, "{name}"), - Self::Attr { receiver, attr, .. } => write!(f, "{receiver}.{attr}"), + Self::Expr(expr) => write!(f, "{expr}"), } } } -impl Locational for Variable { +impl Locational for CastTarget { fn loc(&self) -> Location { match self { Self::Param { loc, .. } => *loc, Self::Var { loc, .. } => *loc, - Self::Attr { loc, .. } => *loc, + Self::Expr(expr) => expr.loc(), } } } -impl Variable { +impl CastTarget { pub const fn param(nth: usize, name: Str, loc: Location) -> Self { Self::Param { nth, name, loc } } - pub fn attr(receiver: Variable, attr: Str, loc: Location) -> Self { - Self::Attr { - receiver: Box::new(receiver), - attr, - loc, - } + pub fn expr(expr: Expr) -> Self { + Self::Expr(Box::new(expr)) } } #[derive(Debug, Clone, PartialEq, Eq, Hash)] pub struct GuardType { - pub var: Variable, + pub namespace: Str, + pub target: CastTarget, pub to: Box, } impl fmt::Display for GuardType { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - write!(f, "{{{} in {}}}", self.var, self.to) + write!(f, "{{{} in {}}}", self.target, self.to) } } impl StructuralEq for GuardType { fn structural_eq(&self, other: &Self) -> bool { - self.var == other.var && self.to.structural_eq(&other.to) + self.target == other.target && self.to.structural_eq(&other.to) } } impl GuardType { - pub fn new(var: Variable, to: Type) -> Self { + pub fn new(namespace: Str, target: CastTarget, to: Type) -> Self { Self { - var, + namespace, + target, to: Box::new(to), } } @@ -1359,8 +1353,8 @@ impl HasType for Type { self } #[inline] - fn ref_mut_t(&mut self) -> &mut Type { - self + fn ref_mut_t(&mut self) -> Option<&mut Type> { + Some(self) } fn inner_ts(&self) -> Vec { match self { @@ -3107,9 +3101,11 @@ impl Type { proj_call(lhs, attr_name.clone(), args) } Self::Structural(ty) => ty.derefine().structuralize(), - Self::Guard(guard) => { - Self::Guard(GuardType::new(guard.var.clone(), guard.to.derefine())) - } + Self::Guard(guard) => Self::Guard(GuardType::new( + guard.namespace.clone(), + guard.target.clone(), + guard.to.derefine(), + )), Self::Bounded { sub, sup } => Self::Bounded { sub: Box::new(sub.derefine()), sup: Box::new(sup.derefine()), @@ -3253,7 +3249,8 @@ impl Type { } Self::Structural(ty) => ty._replace(target, to).structuralize(), Self::Guard(guard) => Self::Guard(GuardType::new( - guard.var.clone(), + guard.namespace, + guard.target.clone(), guard.to._replace(target, to), )), Self::Bounded { sub, sup } => Self::Bounded { @@ -3315,7 +3312,11 @@ impl Type { Self::Or(l, r) => l.normalize() | r.normalize(), Self::Not(ty) => !ty.normalize(), Self::Structural(ty) => ty.normalize().structuralize(), - Self::Guard(guard) => Self::Guard(GuardType::new(guard.var, guard.to.normalize())), + Self::Guard(guard) => Self::Guard(GuardType::new( + guard.namespace, + guard.target, + guard.to.normalize(), + )), Self::Bounded { sub, sup } => Self::Bounded { sub: Box::new(sub.normalize()), sup: Box::new(sup.normalize()), diff --git a/crates/erg_compiler/ty/value.rs b/crates/erg_compiler/ty/value.rs index 1619814b..f12706ef 100644 --- a/crates/erg_compiler/ty/value.rs +++ b/crates/erg_compiler/ty/value.rs @@ -880,8 +880,8 @@ impl HasType for ValueObj { fn ref_t(&self) -> &Type { panic!("cannot get reference of the const") } - fn ref_mut_t(&mut self) -> &mut Type { - panic!("cannot get mutable reference of the const") + fn ref_mut_t(&mut self) -> Option<&mut Type> { + None } /// その要素だけの集合型を返す、クラスが欲しい場合は.classで #[inline] diff --git a/crates/erg_compiler/varinfo.rs b/crates/erg_compiler/varinfo.rs index 817e44a0..115dd0a7 100644 --- a/crates/erg_compiler/varinfo.rs +++ b/crates/erg_compiler/varinfo.rs @@ -216,8 +216,8 @@ impl HasType for VarInfo { &self.t } #[inline] - fn ref_mut_t(&mut self) -> &mut Type { - &mut self.t + fn ref_mut_t(&mut self) -> Option<&mut Type> { + Some(&mut self.t) } #[inline] fn signature_t(&self) -> Option<&Type> { diff --git a/tests/should_ok/assert_cast.er b/tests/should_ok/assert_cast.er index ed063269..ee8a3f9f 100644 --- a/tests/should_ok/assert_cast.er +++ b/tests/should_ok/assert_cast.er @@ -10,3 +10,8 @@ j = json.loads "{ \"a\": [1] }" assert j in {Str: Obj} assert j["a"] in Array(Int) assert j["a"] notin Array(Str) +_: Array(Int) = j["a"] + +.f dic: {Str: Str or Array(Str)} = + assert dic["key"] in Str # Required to pass the check on the next line + assert dic["key"] in {"a", "b", "c"}