fix: Context::union/intersection

This commit is contained in:
Shunsuke Shibayama 2023-04-10 11:51:45 +09:00
parent 3ed863cef6
commit 7c8b8a66a1
4 changed files with 80 additions and 46 deletions

View file

@ -48,7 +48,8 @@ impl Context {
pub(crate) fn eq_tp(&self, lhs: &TyParam, rhs: &TyParam) -> bool { pub(crate) fn eq_tp(&self, lhs: &TyParam, rhs: &TyParam) -> bool {
match (lhs, rhs) { match (lhs, rhs) {
(TyParam::Type(lhs), TyParam::Type(rhs)) => return self.same_type_of(lhs, rhs), (TyParam::Type(lhs), TyParam::Type(rhs))
| (TyParam::Erased(lhs), TyParam::Erased(rhs)) => return self.same_type_of(lhs, rhs),
(TyParam::Mono(l), TyParam::Mono(r)) => { (TyParam::Mono(l), TyParam::Mono(r)) => {
if let (Some(l), Some(r)) = (self.rec_get_const_obj(l), self.rec_get_const_obj(r)) { if let (Some(l), Some(r)) = (self.rec_get_const_obj(l), self.rec_get_const_obj(r)) {
return l == r; return l == r;
@ -502,7 +503,7 @@ impl Context {
(Type, Poly { name, params }) | (Poly { name, params }, Type) (Type, Poly { name, params }) | (Poly { name, params }, Type)
if &name[..] == "Array" || &name[..] == "Set" => if &name[..] == "Array" || &name[..] == "Set" =>
{ {
let elem_t = self.convert_tp_into_ty(params[0].clone()).unwrap(); let elem_t = self.convert_tp_into_type(params[0].clone()).unwrap();
self.supertype_of(&Type, &elem_t) self.supertype_of(&Type, &elem_t)
} }
(Type, Poly { name, params }) | (Poly { name, params }, Type) (Type, Poly { name, params }) | (Poly { name, params }, Type)
@ -515,7 +516,7 @@ impl Context {
return self.supertype_of(&Type, &arr_t); return self.supertype_of(&Type, &arr_t);
} else if let Ok(tps) = Vec::try_from(params[0].clone()) { } else if let Ok(tps) = Vec::try_from(params[0].clone()) {
for tp in tps { for tp in tps {
let Ok(t) = self.convert_tp_into_ty(tp) else { let Ok(t) = self.convert_tp_into_type(tp) else {
return false; return false;
}; };
if !self.supertype_of(&Type, &t) { if !self.supertype_of(&Type, &t) {
@ -539,10 +540,10 @@ impl Context {
return false; return false;
}; };
for (k, v) in dict.into_iter() { for (k, v) in dict.into_iter() {
let Ok(k) = self.convert_tp_into_ty(k) else { let Ok(k) = self.convert_tp_into_type(k) else {
return false; return false;
}; };
let Ok(v) = self.convert_tp_into_ty(v) else { let Ok(v) = self.convert_tp_into_type(v) else {
return false; return false;
}; };
if !self.supertype_of(&Type, &k) || !self.supertype_of(&Type, &v) { if !self.supertype_of(&Type, &k) || !self.supertype_of(&Type, &v) {
@ -682,8 +683,8 @@ impl Context {
} }
// [Int; 2] :> [Int; 3] // [Int; 2] :> [Int; 3]
if &ln[..] == "Array" || &ln[..] == "Set" { if &ln[..] == "Array" || &ln[..] == "Set" {
let lt = self.convert_tp_into_ty(lparams[0].clone()).unwrap(); let lt = self.convert_tp_into_type(lparams[0].clone()).unwrap();
let rt = self.convert_tp_into_ty(rparams[0].clone()).unwrap(); let rt = self.convert_tp_into_type(rparams[0].clone()).unwrap();
let llen = lparams[1].clone(); let llen = lparams[1].clone();
let rlen = rparams[1].clone(); let rlen = rparams[1].clone();
self.supertype_of(&lt, &rt) self.supertype_of(&lt, &rt)
@ -975,16 +976,6 @@ impl Context {
if lhs == rhs { if lhs == rhs {
return lhs.clone(); return lhs.clone();
} }
// `?T or ?U` will not be unified
// `Set!(?T, 3) or Set(?T, 3)` wii be unified to Set(?T, 3)
if !lhs.is_unbound_var() && !rhs.is_unbound_var() {
match (self.supertype_of(lhs, rhs), self.subtype_of(lhs, rhs)) {
(true, true) => return lhs.clone(), // lhs = rhs
(true, false) => return lhs.clone(), // lhs :> rhs
(false, true) => return rhs.clone(),
(false, false) => {}
}
}
match (lhs, rhs) { match (lhs, rhs) {
(FreeVar(fv), other) | (other, FreeVar(fv)) if fv.is_linked() => { (FreeVar(fv), other) | (other, FreeVar(fv)) if fv.is_linked() => {
self.union(&fv.crack(), other) self.union(&fv.crack(), other)
@ -1007,21 +998,45 @@ impl Context {
let mut unified_params = vec![]; let mut unified_params = vec![];
for (lp, rp) in lps.iter().zip(rps.iter()) { for (lp, rp) in lps.iter().zip(rps.iter()) {
match (lp, rp) { match (lp, rp) {
(TyParam::Value(ValueObj::Type(l)), TyParam::Value(ValueObj::Type(r))) => {
unified_params.push(TyParam::t(self.union(l.typ(), r.typ())));
}
(TyParam::Value(ValueObj::Type(l)), TyParam::Type(r)) => {
unified_params.push(TyParam::t(self.union(l.typ(), r)));
}
(TyParam::Type(l), TyParam::Value(ValueObj::Type(r))) => {
unified_params.push(TyParam::t(self.union(l, r.typ())));
}
(TyParam::Type(l), TyParam::Type(r)) => { (TyParam::Type(l), TyParam::Type(r)) => {
unified_params.push(TyParam::t(self.union(l, r))) unified_params.push(TyParam::t(self.union(l, r)));
} }
(_, _) => { (_, _) => {
if self.eq_tp(lp, rp) { if self.eq_tp(lp, rp) {
unified_params.push(lp.clone()); unified_params.push(lp.clone());
} else { } else {
return or(lhs.clone(), rhs.clone()); return self.simple_union(lhs, rhs);
} }
} }
} }
} }
poly(ln, unified_params) poly(ln, unified_params)
} }
(l, r) => or(l.clone(), r.clone()), _ => self.simple_union(lhs, rhs),
}
}
fn simple_union(&self, lhs: &Type, rhs: &Type) -> Type {
// `?T or ?U` will not be unified
// `Set!(?T(<: Int), 3) or Set(?U(<: Nat), 3)` wii be unified to Set(?T, 3)
if lhs.is_unbound_var() || rhs.is_unbound_var() {
or(lhs.clone(), rhs.clone())
} else {
match (self.supertype_of(lhs, rhs), self.subtype_of(lhs, rhs)) {
(true, true) => lhs.clone(), // lhs = rhs
(true, false) => lhs.clone(), // lhs :> rhs
(false, true) => rhs.clone(),
(false, false) => or(lhs.clone(), rhs.clone()),
}
} }
} }
@ -1040,15 +1055,6 @@ impl Context {
if lhs == rhs { if lhs == rhs {
return lhs.clone(); return lhs.clone();
} }
// ?T and ?U will not be unified
if !lhs.is_unbound_var() && !rhs.is_unbound_var() {
match (self.supertype_of(lhs, rhs), self.subtype_of(lhs, rhs)) {
(true, true) => return lhs.clone(), // lhs = rhs
(true, false) => return rhs.clone(), // lhs :> rhs
(false, true) => return lhs.clone(),
(false, false) => {}
}
}
match (lhs, rhs) { match (lhs, rhs) {
(FreeVar(fv), other) | (other, FreeVar(fv)) if fv.is_linked() => { (FreeVar(fv), other) | (other, FreeVar(fv)) if fv.is_linked() => {
self.intersection(&fv.crack(), other) self.intersection(&fv.crack(), other)
@ -1071,10 +1077,29 @@ impl Context {
Type::Record(rec2) => Type::Record(rec.clone().diff(rec2)), Type::Record(rec2) => Type::Record(rec.clone().diff(rec2)),
_ => Type::Never, _ => Type::Never,
}, },
(l, r) if self.is_trait(l) && self.is_trait(r) => and(l.clone(), r.clone()),
(_, Not(r)) => self.diff(lhs, r), (_, Not(r)) => self.diff(lhs, r),
(Not(l), _) => self.diff(rhs, l), (Not(l), _) => self.diff(rhs, l),
(_l, _r) => Type::Never, _ => self.simple_intersection(lhs, rhs),
}
}
fn simple_intersection(&self, lhs: &Type, rhs: &Type) -> Type {
// ?T and ?U will not be unified
if lhs.is_unbound_var() || rhs.is_unbound_var() {
and(lhs.clone(), rhs.clone())
} else {
match (self.supertype_of(lhs, rhs), self.subtype_of(lhs, rhs)) {
(true, true) => lhs.clone(), // lhs = rhs
(true, false) => rhs.clone(), // lhs :> rhs
(false, true) => lhs.clone(),
(false, false) => {
if self.is_trait(lhs) && self.is_trait(rhs) {
and(lhs.clone(), rhs.clone())
} else {
Type::Never
}
}
}
} }
} }

View file

@ -1233,23 +1233,23 @@ impl Context {
} }
} }
pub(crate) fn convert_tp_into_ty(&self, tp: TyParam) -> Result<Type, ()> { pub(crate) fn convert_tp_into_type(&self, tp: TyParam) -> Result<Type, TyParam> {
match tp { match tp {
TyParam::Array(tps) => { TyParam::Array(tps) => {
let len = tps.len(); let len = tps.len();
let mut t = Type::Never; let mut t = Type::Never;
for elem_tp in tps { for elem_tp in tps {
let elem_t = self.convert_tp_into_ty(elem_tp)?; let elem_t = self.convert_tp_into_type(elem_tp)?;
// not union // not union
t = self.union(&t, &elem_t); t = self.union(&t, &elem_t);
} }
Ok(array_t(t, TyParam::value(len))) Ok(array_t(t, TyParam::value(len)))
} }
TyParam::FreeVar(fv) if fv.is_linked() => self.convert_tp_into_ty(fv.crack().clone()), TyParam::FreeVar(fv) if fv.is_linked() => self.convert_tp_into_type(fv.crack().clone()),
TyParam::Type(t) => Ok(t.as_ref().clone()), TyParam::Type(t) => Ok(t.as_ref().clone()),
TyParam::Value(v) => Type::try_from(v).or(Err(())), TyParam::Value(v) => Type::try_from(v).map_err(TyParam::Value),
// TODO: Dict, Set // TODO: Dict, Set
_ => Err(()), other => Err(other),
} }
} }
@ -1259,8 +1259,8 @@ impl Context {
let dict = Dict::try_from(params[0].clone())?; let dict = Dict::try_from(params[0].clone())?;
let mut new_dict = dict! {}; let mut new_dict = dict! {};
for (k, v) in dict.into_iter() { for (k, v) in dict.into_iter() {
let k = self.convert_tp_into_ty(k)?; let k = self.convert_tp_into_type(k).map_err(|_| ())?;
let v = self.convert_tp_into_ty(v)?; let v = self.convert_tp_into_type(v).map_err(|_| ())?;
new_dict.insert(k, v); new_dict.insert(k, v);
} }
Ok(new_dict) Ok(new_dict)
@ -1272,7 +1272,9 @@ impl Context {
fn convert_type_to_array(&self, ty: Type) -> Result<Vec<ValueObj>, ()> { fn convert_type_to_array(&self, ty: Type) -> Result<Vec<ValueObj>, ()> {
match ty { match ty {
Type::Poly { name, params } if &name[..] == "Array" || &name[..] == "Array!" => { Type::Poly { name, params } if &name[..] == "Array" || &name[..] == "Array!" => {
let t = self.convert_tp_into_ty(params[0].clone())?; let t = self
.convert_tp_into_type(params[0].clone())
.map_err(|_| ())?;
let len = enum_unwrap!(params[1], TyParam::Value:(ValueObj::Nat:(_))); let len = enum_unwrap!(params[1], TyParam::Value:(ValueObj::Nat:(_)));
Ok(vec![ValueObj::builtin_type(t); len as usize]) Ok(vec![ValueObj::builtin_type(t); len as usize])
} }

View file

@ -2874,7 +2874,7 @@ impl Context {
params params
.iter() .iter()
.map(|tp| { .map(|tp| {
if let Ok(t) = self.convert_tp_into_ty(tp.clone()) { if let Ok(t) = self.convert_tp_into_type(tp.clone()) {
TyParam::t(self.meta_type(&t)) TyParam::t(self.meta_type(&t))
} else { } else {
tp.clone() tp.clone()

View file

@ -3,6 +3,7 @@ use std::mem;
use std::option::Option; use std::option::Option;
use erg_common::fresh::fresh_varname; use erg_common::fresh::fresh_varname;
use erg_common::style::Stylize;
use erg_common::traits::Locational; use erg_common::traits::Locational;
use erg_common::Str; use erg_common::Str;
#[allow(unused_imports)] #[allow(unused_imports)]
@ -379,7 +380,7 @@ impl Context {
} }
} }
(l, TyParam::Type(r)) => { (l, TyParam::Type(r)) => {
let l = self.convert_tp_into_ty(l.clone()).map_err(|_| { let l = self.convert_tp_into_type(l.clone()).map_err(|_| {
TyCheckError::tp_to_type_error( TyCheckError::tp_to_type_error(
self.cfg.input.clone(), self.cfg.input.clone(),
line!() as usize, line!() as usize,
@ -392,7 +393,7 @@ impl Context {
Ok(()) Ok(())
} }
(TyParam::Type(l), r) => { (TyParam::Type(l), r) => {
let r = self.convert_tp_into_ty(r.clone()).map_err(|_| { let r = self.convert_tp_into_type(r.clone()).map_err(|_| {
TyCheckError::tp_to_type_error( TyCheckError::tp_to_type_error(
self.cfg.input.clone(), self.cfg.input.clone(),
line!() as usize, line!() as usize,
@ -848,11 +849,17 @@ impl Context {
{ {
let (l, r) = new_sub.union_pair().unwrap(); let (l, r) = new_sub.union_pair().unwrap();
if self.unify(&l, &r).is_none() { if self.unify(&l, &r).is_none() {
let maybe_sub_ = maybe_sub
.to_string()
.with_color(erg_common::style::Color::Yellow);
let new_sub = new_sub
.to_string()
.with_color(erg_common::style::Color::Yellow);
let hint = switch_lang!( let hint = switch_lang!(
"japanese" => format!("{maybe_sub}から{new_sub}への暗黙の型拡大はデフォルトでは禁止されています。明示的に型指定してください"), "japanese" => format!("{maybe_sub_}から{new_sub}への暗黙の型拡大はデフォルトでは禁止されています。明示的に型指定してください"),
"simplified_chinese" => format!("隐式扩展{maybe_sub}{new_sub}被默认禁止。请明确指定类型。"), "simplified_chinese" => format!("隐式扩展{maybe_sub_}{new_sub}被默认禁止。请明确指定类型。"),
"traditional_chinese" => format!("隱式擴展{maybe_sub}{new_sub}被默認禁止。請明確指定類型。"), "traditional_chinese" => format!("隱式擴展{maybe_sub_}{new_sub}被默認禁止。請明確指定類型。"),
"english" => format!("Implicitly widening {maybe_sub} to {new_sub} is prohibited by default. Consider specifying the type explicitly."), "english" => format!("Implicitly widening {maybe_sub_} to {new_sub} is prohibited by default. Consider specifying the type explicitly."),
); );
return Err(TyCheckErrors::from(TyCheckError::type_mismatch_error( return Err(TyCheckErrors::from(TyCheckError::type_mismatch_error(
self.cfg.input.clone(), self.cfg.input.clone(),