feat: support refinement class

This commit is contained in:
Shunsuke Shibayama 2023-10-31 02:23:50 +09:00
parent c72de02c2c
commit 678c02faf9
9 changed files with 94 additions and 4 deletions

View file

@ -536,6 +536,9 @@ impl Context {
(Mono(n), NamedTuple(_)) => &n[..] == "GenericNamedTuple" || &n[..] == "GenericTuple", (Mono(n), NamedTuple(_)) => &n[..] == "GenericNamedTuple" || &n[..] == "GenericTuple",
(Mono(n), Record(_)) => &n[..] == "Record", (Mono(n), Record(_)) => &n[..] == "Record",
(Type, Subr(subr)) => self.supertype_of(&Type, &subr.return_t), (Type, Subr(subr)) => self.supertype_of(&Type, &subr.return_t),
(Type, Poly { name, params }) if &name[..] == "Set" => {
self.convert_tp_into_value(params[0].clone()).is_ok()
}
(Type, Poly { name, params }) (Type, Poly { name, params })
if &name[..] == "Array" || &name[..] == "UnsizedArray" || &name[..] == "Set" => if &name[..] == "Array" || &name[..] == "UnsizedArray" || &name[..] == "Set" =>
{ {

View file

@ -20,8 +20,9 @@ use erg_parser::desugar::Desugarer;
use erg_parser::token::{Token, TokenKind}; use erg_parser::token::{Token, TokenKind};
use crate::ty::constructors::{ use crate::ty::constructors::{
array_t, bounded, dict_t, mono, mono_q, named_free_var, poly, proj, proj_call, ref_, ref_mut, array_t, bounded, closed_range, dict_t, mono, mono_q, named_free_var, poly, proj, proj_call,
refinement, set_t, subr_t, subtypeof, tp_enum, tuple_t, unknown_len_array_t, v_enum, ref_, ref_mut, refinement, set_t, subr_t, subtypeof, tp_enum, tuple_t, unknown_len_array_t,
v_enum,
}; };
use crate::ty::free::HasLevel; use crate::ty::free::HasLevel;
use crate::ty::typaram::{OpKind, TyParam}; use crate::ty::typaram::{OpKind, TyParam};
@ -93,6 +94,10 @@ fn op_to_name(op: OpKind) -> &'static str {
OpKind::BitXor => "__bitxor__", OpKind::BitXor => "__bitxor__",
OpKind::Shl => "__shl__", OpKind::Shl => "__shl__",
OpKind::Shr => "__shr__", OpKind::Shr => "__shr__",
OpKind::ClosedRange => "__rng__",
OpKind::LeftOpenRange => "__lorng__",
OpKind::RightOpenRange => "__rorng__",
OpKind::OpenRange => "__orng__",
} }
} }
@ -413,6 +418,10 @@ impl Context {
TokenKind::PrePlus => Ok(OpKind::Pos), TokenKind::PrePlus => Ok(OpKind::Pos),
TokenKind::PreMinus => Ok(OpKind::Neg), TokenKind::PreMinus => Ok(OpKind::Neg),
TokenKind::PreBitNot => Ok(OpKind::Invert), TokenKind::PreBitNot => Ok(OpKind::Invert),
TokenKind::Closed => Ok(OpKind::ClosedRange),
TokenKind::LeftOpen => Ok(OpKind::LeftOpenRange),
TokenKind::RightOpen => Ok(OpKind::RightOpenRange),
TokenKind::Open => Ok(OpKind::OpenRange),
_other => Err(EvalErrors::from(EvalError::not_const_expr( _other => Err(EvalErrors::from(EvalError::not_const_expr(
self.cfg.input.clone(), self.cfg.input.clone(),
line!() as usize, line!() as usize,
@ -1134,6 +1143,7 @@ impl Context {
line!(), line!(),
))), ))),
}, },
ClosedRange => Ok(ValueObj::range(lhs, rhs)),
_other => Err(EvalErrors::from(EvalError::unreachable( _other => Err(EvalErrors::from(EvalError::unreachable(
self.cfg.input.clone(), self.cfg.input.clone(),
fn_name!(), fn_name!(),
@ -2014,6 +2024,11 @@ impl Context {
Ok(dict_t(TyParam::Dict(dic))) Ok(dict_t(TyParam::Dict(dic)))
} }
ValueObj::Subr(subr) => subr.as_type(self).ok_or(ValueObj::Subr(subr)), ValueObj::Subr(subr) => subr.as_type(self).ok_or(ValueObj::Subr(subr)),
ValueObj::DataClass { name, fields } if &name == "Range" => {
let start = fields["start"].clone();
let end = fields["end"].clone();
Ok(closed_range(start.class(), start, end))
}
other => Err(other), other => Err(other),
} }
} }

View file

@ -50,6 +50,9 @@
__eq__ self, other: .SemVer = __eq__ self, other: .SemVer =
self.major == other.major and self.minor == other.minor and self.patch == other.patch and self.pre == other.pre self.major == other.major and self.minor == other.minor and self.patch == other.patch and self.pre == other.pre
.SemVerPrefix = Class { "~", "==", "<", ">", "<=", "<=", }
.SemVerSpec = Class { .prefix = .SemVerPrefix; .version = .SemVer }
if! __name__ == "__main__", do!: if! __name__ == "__main__", do!:
v = .SemVer.new(0, 0, 1) v = .SemVer.new(0, 0, 1)
assert v.minor == 0 assert v.minor == 0

View file

@ -133,6 +133,28 @@ pub fn singleton(ty: Type, tp: TyParam) -> Type {
#[inline] #[inline]
pub fn int_interval<P, PErr, Q, QErr>(op: IntervalOp, l: P, r: Q) -> Type pub fn int_interval<P, PErr, Q, QErr>(op: IntervalOp, l: P, r: Q) -> Type
where
P: TryInto<TyParam, Error = PErr>,
PErr: fmt::Debug,
Q: TryInto<TyParam, Error = QErr>,
QErr: fmt::Debug,
{
interval(op, Type::Int, l, r)
}
#[inline]
pub fn closed_range<P, PErr, Q, QErr>(t: Type, l: P, r: Q) -> Type
where
P: TryInto<TyParam, Error = PErr>,
PErr: fmt::Debug,
Q: TryInto<TyParam, Error = QErr>,
QErr: fmt::Debug,
{
interval(IntervalOp::Closed, t, l, r)
}
#[inline]
pub fn interval<P, PErr, Q, QErr>(op: IntervalOp, t: Type, l: P, r: Q) -> Type
where where
P: TryInto<TyParam, Error = PErr>, P: TryInto<TyParam, Error = PErr>,
PErr: fmt::Debug, PErr: fmt::Debug,
@ -161,7 +183,7 @@ where
Predicate::le(name.clone(), r), Predicate::le(name.clone(), r),
), ),
IntervalOp::Open if l == TyParam::value(NegInf) && r == TyParam::value(Inf) => { IntervalOp::Open if l == TyParam::value(NegInf) && r == TyParam::value(Inf) => {
return refinement(name, Type::Int, Predicate::TRUE) return refinement(name, t, Predicate::TRUE)
} }
// l<..<r => {I: classof(l) | I >= l+ε and I <= r-ε} // l<..<r => {I: classof(l) | I >= l+ε and I <= r-ε}
IntervalOp::Open => Predicate::and( IntervalOp::Open => Predicate::and(
@ -169,7 +191,7 @@ where
Predicate::le(name.clone(), TyParam::pred(r)), Predicate::le(name.clone(), TyParam::pred(r)),
), ),
}; };
refinement(name, Type::Int, pred) refinement(name, t, pred)
} }
pub fn iter(t: Type) -> Type { pub fn iter(t: Type) -> Type {

View file

@ -49,6 +49,10 @@ pub enum OpKind {
BitXor, BitXor,
Shl, Shl,
Shr, Shr,
ClosedRange,
LeftOpenRange,
RightOpenRange,
OpenRange,
} }
impl fmt::Display for OpKind { impl fmt::Display for OpKind {
@ -79,6 +83,10 @@ impl fmt::Display for OpKind {
Self::BitXor => write!(f, "^^"), Self::BitXor => write!(f, "^^"),
Self::Shl => write!(f, "<<"), Self::Shl => write!(f, "<<"),
Self::Shr => write!(f, ">>"), Self::Shr => write!(f, ">>"),
Self::ClosedRange => write!(f, ".."),
Self::LeftOpenRange => write!(f, "<.."),
Self::RightOpenRange => write!(f, "..<"),
Self::OpenRange => write!(f, "<..<"),
} }
} }
} }
@ -109,6 +117,10 @@ impl TryFrom<TokenKind> for OpKind {
TokenKind::BitXor => Ok(Self::BitXor), TokenKind::BitXor => Ok(Self::BitXor),
TokenKind::Shl => Ok(Self::Shl), TokenKind::Shl => Ok(Self::Shl),
TokenKind::Shr => Ok(Self::Shr), TokenKind::Shr => Ok(Self::Shr),
TokenKind::Closed => Ok(Self::ClosedRange),
TokenKind::LeftOpen => Ok(Self::LeftOpenRange),
TokenKind::RightOpen => Ok(Self::RightOpenRange),
TokenKind::Open => Ok(Self::OpenRange),
_ => Err(()), _ => Err(()),
} }
} }

View file

@ -955,6 +955,7 @@ impl ValueObj {
ValueObj::Type(TypeObj::Generated(gen)) ValueObj::Type(TypeObj::Generated(gen))
} }
/// closed range (..)
pub fn range(start: Self, end: Self) -> Self { pub fn range(start: Self, end: Self) -> Self {
Self::DataClass { Self::DataClass {
name: "Range".into(), name: "Range".into(),

View file

@ -0,0 +1,7 @@
Suite = Class { "Heart", "Diamond", "Spade", "Club" }
_ = Suite.new "Invalid" # ERR
Month = Class 1..12
_ = Month.new 13 # ERR

View file

@ -0,0 +1,17 @@
Suite = Class { "Heart", "Diamond", "Spade", "Club" }
Suite.
is_heart self = self::base == "Heart"
h = Suite.new "Heart"
d = Suite.new "Diamond"
assert h.is_heart()
assert not d.is_heart()
Month = Class 1..12
Month.
is_first_half self = self::base <= 6
jan = Month.new 1
dec = Month.new 12
assert jan.is_first_half()
assert not dec.is_first_half()

View file

@ -327,6 +327,11 @@ fn exec_refinement() -> Result<(), ()> {
expect_success("tests/should_ok/refinement.er", 0) expect_success("tests/should_ok/refinement.er", 0)
} }
#[test]
fn exec_refinement_class() -> Result<(), ()> {
expect_success("tests/should_ok/refinement_class.er", 0)
}
#[test] #[test]
fn exec_return() -> Result<(), ()> { fn exec_return() -> Result<(), ()> {
expect_success("tests/should_ok/return.er", 0) expect_success("tests/should_ok/return.er", 0)
@ -603,6 +608,11 @@ fn exec_refinement_err() -> Result<(), ()> {
expect_failure("tests/should_err/refinement.er", 0, 8) expect_failure("tests/should_err/refinement.er", 0, 8)
} }
#[test]
fn exec_refinement_class_err() -> Result<(), ()> {
expect_failure("tests/should_err/refinement_class.er", 0, 2)
}
#[test] #[test]
fn exec_var_args_err() -> Result<(), ()> { fn exec_var_args_err() -> Result<(), ()> {
expect_failure("tests/should_err/var_args.er", 0, 3) expect_failure("tests/should_err/var_args.er", 0, 3)