Fix subtyping bug

This commit is contained in:
Shunsuke Shibayama 2022-10-21 20:04:14 +09:00
parent 978866b31a
commit c39973f536
7 changed files with 110 additions and 19 deletions

View file

@ -55,6 +55,14 @@ impl Field {
Field { vis, symbol } Field { vis, symbol }
} }
pub const fn private(symbol: Str) -> Self {
Field::new(Visibility::Private, symbol)
}
pub const fn public(symbol: Str) -> Self {
Field::new(Visibility::Public, symbol)
}
pub fn is_const(&self) -> bool { pub fn is_const(&self) -> bool {
self.symbol.starts_with(char::is_uppercase) self.symbol.starts_with(char::is_uppercase)
} }

View file

@ -784,6 +784,13 @@ impl Context {
(_, TyParam::FreeVar(fv), _) if fv.is_linked() => { (_, TyParam::FreeVar(fv), _) if fv.is_linked() => {
self.supertype_of_tp(lp, &fv.crack(), variance) self.supertype_of_tp(lp, &fv.crack(), variance)
} }
// _: Type :> T == true
(TyParam::Erased(t), TyParam::Type(_), _)
| (TyParam::Type(_), TyParam::Erased(t), _)
if t.as_ref() == &Type =>
{
true
}
(TyParam::Type(l), TyParam::Type(r), Variance::Contravariant) => self.subtype_of(l, r), (TyParam::Type(l), TyParam::Type(r), Variance::Contravariant) => self.subtype_of(l, r),
(TyParam::Type(l), TyParam::Type(r), Variance::Covariant) => { (TyParam::Type(l), TyParam::Type(r), Variance::Covariant) => {
// if matches!(r.as_ref(), &Type::Refinement(_)) { log!(info "{l}, {r}, {}", self.structural_supertype_of(l, r, bounds, Some(lhs_variance))); } // if matches!(r.as_ref(), &Type::Refinement(_)) { log!(info "{l}, {r}, {}", self.structural_supertype_of(l, r, bounds, Some(lhs_variance))); }

View file

@ -9,6 +9,8 @@ use crate::ty::ValueArgs;
use erg_common::astr::AtomicStr; use erg_common::astr::AtomicStr;
use erg_common::color::{RED, RESET, YELLOW}; use erg_common::color::{RED, RESET, YELLOW};
use erg_common::error::{ErrorCore, ErrorKind, Location}; use erg_common::error::{ErrorCore, ErrorKind, Location};
use erg_common::str::Str;
use erg_common::vis::Field;
/// Requirement: Type, Impl := Type -> ClassType /// Requirement: Type, Impl := Type -> ClassType
pub fn class_func(mut args: ValueArgs, ctx: &Context) -> EvalValueResult<ValueObj> { pub fn class_func(mut args: ValueArgs, ctx: &Context) -> EvalValueResult<ValueObj> {
@ -231,3 +233,27 @@ pub fn __dict_getitem__(mut args: ValueArgs, ctx: &Context) -> EvalValueResult<V
)) ))
} }
} }
pub fn __range_getitem__(mut args: ValueArgs, _ctx: &Context) -> EvalValueResult<ValueObj> {
let (_name, fields) = enum_unwrap!(
args.remove_left_or_key("Self").unwrap(),
ValueObj::DataClass { name, fields }
);
let index = enum_unwrap!(args.remove_left_or_key("Index").unwrap(), ValueObj::Nat);
let start = fields.get(&Field::private(Str::ever("start"))).unwrap();
let start = *enum_unwrap!(start, ValueObj::Nat);
let end = fields.get(&Field::private(Str::ever("end"))).unwrap();
let end = *enum_unwrap!(end, ValueObj::Nat);
// FIXME <= if inclusive
if start + index < end {
Ok(ValueObj::Nat(start + index))
} else {
Err(ErrorCore::new(
line!() as usize,
ErrorKind::IndexError,
Location::Unknown,
AtomicStr::from(format!("Index out of range: {}", index)),
None,
))
}
}

View file

@ -1534,6 +1534,7 @@ impl Context {
// range.register_superclass(Obj, &obj); // range.register_superclass(Obj, &obj);
range.register_superclass(Type, &type_); range.register_superclass(Type, &type_);
range.register_marker_trait(poly("Output", vec![ty_tp(mono_q("T"))])); range.register_marker_trait(poly("Output", vec![ty_tp(mono_q("T"))]));
range.register_marker_trait(poly("Seq", vec![ty_tp(mono_q("T"))]));
let mut range_eq = Self::builtin_methods(Some(poly("Eq", vec![ty_tp(range_t.clone())])), 2); let mut range_eq = Self::builtin_methods(Some(poly("Eq", vec![ty_tp(range_t.clone())])), 2);
range_eq.register_builtin_impl( range_eq.register_builtin_impl(
"__eq__", "__eq__",
@ -1551,6 +1552,15 @@ impl Context {
Public, Public,
); );
range.register_trait(range_t.clone(), range_iterable); range.register_trait(range_t.clone(), range_iterable);
let range_getitem_t = fn1_kw_met(range_t.clone(), anon(mono_q("T")), mono_q("T"));
let range_getitem_t = quant(range_getitem_t, set! { static_instance("T", Type) });
let get_item = ValueObj::Subr(ConstSubr::Builtin(BuiltinConstSubr::new(
"__getitem__",
__range_getitem__,
range_getitem_t,
None,
)));
range.register_builtin_const("__getitem__", Public, get_item);
/* Proc */ /* Proc */
let mut proc = Self::builtin_mono_class("Proc", 2); let mut proc = Self::builtin_mono_class("Proc", 2);
proc.register_superclass(Obj, &obj); proc.register_superclass(Obj, &obj);

View file

@ -1104,6 +1104,21 @@ impl Context {
self.sub_unify_tp(lhs, lhs2, variance, loc, allow_divergence)?; self.sub_unify_tp(lhs, lhs2, variance, loc, allow_divergence)?;
self.sub_unify_tp(rhs, rhs2, variance, loc, allow_divergence) self.sub_unify_tp(rhs, rhs2, variance, loc, allow_divergence)
} }
(l, TyParam::Erased(t)) => {
let sub_t = self.get_tp_t(l)?;
if self.subtype_of(&sub_t, t) {
Ok(())
} else {
Err(TyCheckErrors::from(TyCheckError::subtyping_error(
self.cfg.input.clone(),
line!() as usize,
&sub_t,
t,
loc,
self.caused_by(),
)))
}
}
(l, r) => panic!("type-parameter unification failed:\nl:{l}\nr: {r}"), (l, r) => panic!("type-parameter unification failed:\nl:{l}\nr: {r}"),
} }
} }
@ -1123,6 +1138,10 @@ impl Context {
*l.borrow_mut() = r.clone(); *l.borrow_mut() = r.clone();
Ok(()) Ok(())
} }
/*(TyParam::Value(ValueObj::Mut(l)), TyParam::Erased(_)) => {
*l.borrow_mut() = after.clone();
Ok(())
}*/
(TyParam::Type(l), TyParam::Type(r)) => self.reunify(l, r, loc), (TyParam::Type(l), TyParam::Type(r)) => self.reunify(l, r, loc),
(TyParam::UnaryOp { op: lop, val: lval }, TyParam::UnaryOp { op: rop, val: rval }) (TyParam::UnaryOp { op: lop, val: lval }, TyParam::UnaryOp { op: rop, val: rval })
if lop == rop => if lop == rop =>

View file

@ -26,7 +26,8 @@ def in_operator(x, y):
return True return True
# TODO: trait check # TODO: trait check
return False return False
elif (type(y) == list or type(y) == set) and type(y[0]) == type: elif (type(y) == list or type(y) == set) \
and (type(y[0]) == type or issubclass(type(y[0]), Range)):
# FIXME: # FIXME:
type_check = in_operator(x[0], y[0]) type_check = in_operator(x[0], y[0])
len_check = len(x) == len(y) len_check = len(x) == len(y)
@ -82,9 +83,26 @@ class Range:
def __contains__(self, item): def __contains__(self, item):
pass pass
def __getitem__(self, item): def __getitem__(self, item):
pass res = self.start + item
if res in self:
return res
else:
raise IndexError("Index out of range")
def __len__(self): def __len__(self):
pass if self.start in self:
if self.end in self:
# len(1..4) == 4
return self.end - self.start + 1
else:
# len(1..<4) == 3
return self.end - self.start
else:
if self.end in self:
# len(1<..4) == 3
return self.end - self.start
else:
# len(1<..<4) == 2
return self.end - self.start - 2
def __iter__(self): def __iter__(self):
return RangeIterator(rng=self) return RangeIterator(rng=self)
@ -96,37 +114,21 @@ Iterable.register(Range)
class LeftOpenRange(Range): class LeftOpenRange(Range):
def __contains__(self, item): def __contains__(self, item):
return self.start < item <= self.end return self.start < item <= self.end
def __getitem__(self, item):
return NotImplemented
def __len__(self):
return NotImplemented
# represents `start..<end` # represents `start..<end`
class RightOpenRange(Range): class RightOpenRange(Range):
def __contains__(self, item): def __contains__(self, item):
return self.start <= item < self.end return self.start <= item < self.end
def __getitem__(self, item):
return NotImplemented
def __len__(self):
return NotImplemented
# represents `start<..<end` # represents `start<..<end`
class OpenRange(Range): class OpenRange(Range):
def __contains__(self, item): def __contains__(self, item):
return self.start < item < self.end return self.start < item < self.end
def __getitem__(self, item):
return NotImplemented
def __len__(self):
return NotImplemented
# represents `start..end` # represents `start..end`
class ClosedRange(Range): class ClosedRange(Range):
def __contains__(self, item): def __contains__(self, item):
return self.start <= item <= self.end return self.start <= item <= self.end
def __getitem__(self, item):
return NotImplemented
def __len__(self):
return NotImplemented
class RangeIterator: class RangeIterator:
def __init__(self, rng): def __init__(self, rng):

View file

@ -129,6 +129,10 @@ pub enum ValueObj {
Dict(Dict<ValueObj, ValueObj>), Dict(Dict<ValueObj, ValueObj>),
Tuple(Rc<[ValueObj]>), Tuple(Rc<[ValueObj]>),
Record(Dict<Field, ValueObj>), Record(Dict<Field, ValueObj>),
DataClass {
name: Str,
fields: Dict<Field, ValueObj>,
},
Code(Box<CodeObj>), Code(Box<CodeObj>),
Subr(ConstSubr), Subr(ConstSubr),
Type(TypeObj), Type(TypeObj),
@ -199,6 +203,16 @@ impl fmt::Debug for ValueObj {
} }
write!(f, "}}") write!(f, "}}")
} }
Self::DataClass { name, fields } => {
write!(f, "{name} {{")?;
for (i, (k, v)) in fields.iter().enumerate() {
if i != 0 {
write!(f, "; ")?;
}
write!(f, "{k} = {v}")?;
}
write!(f, "}}")
}
Self::Subr(subr) => write!(f, "{subr}"), Self::Subr(subr) => write!(f, "{subr}"),
Self::Type(t) => write!(f, "{t}"), Self::Type(t) => write!(f, "{t}"),
Self::None => write!(f, "None"), Self::None => write!(f, "None"),
@ -247,6 +261,10 @@ impl Hash for ValueObj {
Self::Set(st) => st.hash(state), Self::Set(st) => st.hash(state),
Self::Code(code) => code.hash(state), Self::Code(code) => code.hash(state),
Self::Record(rec) => rec.hash(state), Self::Record(rec) => rec.hash(state),
Self::DataClass { name, fields } => {
name.hash(state);
fields.hash(state);
}
Self::Subr(subr) => subr.hash(state), Self::Subr(subr) => subr.hash(state),
Self::Type(t) => t.hash(state), Self::Type(t) => t.hash(state),
Self::None => { Self::None => {
@ -538,6 +556,7 @@ impl ValueObj {
Self::Record(rec) => { Self::Record(rec) => {
Type::Record(rec.iter().map(|(k, v)| (k.clone(), v.class())).collect()) Type::Record(rec.iter().map(|(k, v)| (k.clone(), v.class())).collect())
} }
Self::DataClass { name, .. } => Type::Mono(name.clone()),
Self::Subr(subr) => subr.sig_t().clone(), Self::Subr(subr) => subr.sig_t().clone(),
Self::Type(t_obj) => match t_obj { Self::Type(t_obj) => match t_obj {
// TODO: builtin // TODO: builtin