feat: add HasScalarType

This commit is contained in:
Shunsuke Shibayama 2024-02-09 18:37:33 +09:00
parent d232acc3e4
commit e5c8f177ba
7 changed files with 136 additions and 4 deletions

View file

@ -1730,6 +1730,15 @@ impl Context {
),
)
.unwrap();
array_
.register_marker_trait(
self,
poly(
HAS_SCALAR_TYPE,
vec![ty_tp(arr_t.clone()).proj_call(FUNC_SCALAR_TYPE.into(), vec![])],
),
)
.unwrap();
let mut array_sized = Self::builtin_methods(Some(mono(SIZED)), 2);
array_sized.register_builtin_erg_impl(
FUNDAMENTAL_LEN,
@ -1760,6 +1769,19 @@ impl Context {
None,
)));
array_.register_builtin_const(FUNC_SHAPE, Visibility::BUILTIN_PUBLIC, None, shape);
let array_scalar_type_t = fn0_met(Type, Type).quantify();
let array_scalar_type = ValueObj::Subr(ConstSubr::Builtin(BuiltinConstSubr::new(
FUNC_SCALAR_TYPE,
array_scalar_type,
array_scalar_type_t,
None,
)));
array_.register_builtin_const(
FUNC_SCALAR_TYPE,
Visibility::BUILTIN_PUBLIC,
None,
array_scalar_type,
);
let mut array_eq = Self::builtin_methods(Some(mono(EQ)), 2);
array_eq.register_builtin_erg_impl(
OP_EQ,

View file

@ -535,6 +535,86 @@ pub(crate) fn array_shape(mut args: ValueArgs, ctx: &Context) -> EvalValueResult
Ok(arr)
}
fn _array_scalar_type(mut typ: Type, ctx: &Context) -> Result<Type, String> {
loop {
if matches!(&typ.qual_name()[..], "Array" | "Array!" | "UnsizedArray") {
let tp = typ.typarams().remove(0);
match ctx.convert_tp_into_type(tp) {
Ok(typ_) => {
typ = typ_;
}
Err(err) => {
return Err(format!("Cannot convert {err} into type"));
}
}
} else {
return Ok(typ);
}
}
}
pub(crate) fn array_scalar_type(mut args: ValueArgs, ctx: &Context) -> EvalValueResult<TyParam> {
let slf = args
.remove_left_or_key("Self")
.ok_or_else(|| not_passed("Self"))?;
let Ok(slf) = ctx.convert_value_into_type(slf.clone()) else {
return Err(type_mismatch("Type", slf, "Self"));
};
let res = _array_scalar_type(slf, ctx).unwrap();
Ok(TyParam::t(res))
}
fn _scalar_type(mut value: ValueObj, _ctx: &Context) -> Result<Type, String> {
loop {
match value {
ValueObj::Array(a) => match a.first() {
Some(elem) => {
value = elem.clone();
}
None => {
return Ok(Type::Never);
}
},
ValueObj::Set(s) => match s.iter().next() {
Some(elem) => {
value = elem.clone();
}
None => {
return Ok(Type::Never);
}
},
ValueObj::Tuple(t) => match t.first() {
Some(elem) => {
value = elem.clone();
}
None => {
return Ok(Type::Never);
}
},
ValueObj::UnsizedArray(a) => {
value = *a.clone();
}
other => {
return Ok(other.class());
}
}
}
}
/// ```erg
/// [1, 2].scalar_type() == Nat
/// [[1, 2], [3, 4], [5, 6]].scalar_type() == Nat
/// ```
#[allow(unused)]
pub(crate) fn scalar_type(mut args: ValueArgs, ctx: &Context) -> EvalValueResult<TyParam> {
let arr = args
.remove_left_or_key("Self")
.ok_or_else(|| not_passed("Self"))?;
let res = _scalar_type(arr, ctx).unwrap();
let arr = TyParam::t(res);
Ok(arr)
}
fn _array_sum(arr: ValueObj, _ctx: &Context) -> Result<ValueObj, String> {
match arr {
ValueObj::Array(a) => {

View file

@ -100,6 +100,7 @@ const INDEXABLE: &str = "Indexable";
const MAPPING: &str = "Mapping";
const MUTABLE_MAPPING: &str = "Mapping!";
const HAS_SHAPE: &str = "HasShape";
const HAS_SCALAR_TYPE: &str = "HasScalarType";
const EQ: &str = "Eq";
const IRREGULAR_EQ: &str = "IrregularEq";
const HASH: &str = "Hash";
@ -277,6 +278,7 @@ const SYMMETRIC_DIFFERENCE: &str = "symmetric_difference";
const MEMORYVIEW: &str = "MemoryView";
const FUNC_UNION: &str = "union";
const FUNC_SHAPE: &str = "shape";
const FUNC_SCALAR_TYPE: &str = "scalar_type";
const FUNC_AS_DICT: &str = "as_dict";
const FUNC_AS_RECORD: &str = "as_record";
const FUNC_INC: &str = "inc";

View file

@ -379,6 +379,11 @@ impl Context {
let S = mono_q_tp(TY_S, instanceof(unknown_len_array_t(Nat)));
let params = vec![PS::named_nd("S", unknown_len_array_t(Nat))];
let has_shape = Self::builtin_poly_trait(HAS_SHAPE, params.clone(), 2);
/* HasScalarType */
let Ty = mono_q_tp(TY_T, instanceof(Type));
let params = vec![PS::t(TY_T, false, WithDefault)];
let mut has_scalar_type = Self::builtin_poly_trait(HAS_SCALAR_TYPE, params.clone(), 2);
has_scalar_type.register_superclass(poly(OUTPUT, vec![Ty.clone()]), &output);
/* Num */
let R = mono_q(TY_R, instanceof(Type));
let params = vec![PS::t(TY_R, false, WithDefault)];
@ -585,6 +590,13 @@ impl Context {
Const,
None,
);
self.register_builtin_type(
poly(HAS_SCALAR_TYPE, vec![Ty]),
has_scalar_type,
vis.clone(),
Const,
None,
);
self.register_builtin_type(poly(ADD, ty_params.clone()), add, vis.clone(), Const, None);
self.register_builtin_type(poly(SUB, ty_params.clone()), sub, vis.clone(), Const, None);
self.register_builtin_type(poly(MUL, ty_params.clone()), mul, vis.clone(), Const, None);

View file

@ -69,7 +69,7 @@
.all: |T <: Num|(object: .NDArray(T),) -> Bool
.any: |T <: Num|(object: .NDArray(T),) -> Bool
.arange: |T <: Num|(start: T, stop := T, step := T) -> .NDArray(T)
.array: |T, S: [Nat; _]|(object: Iterable(T) and HasShape(S),) -> .NDArray(T, S)
.array: |T, S: [Nat; _]|(object: HasScalarType(T) and HasShape(S),) -> .NDArray(T, S)
.linspace: |T <: Num|(start: T, stop: T, num := Nat, endpoint := Bool, retstep := Bool, dtype := Type, axis := Nat) -> .NDArray(T)
.max: |T <: Num|(object: .NDArray(T),) -> T
.mean: |T <: Num|(object: .NDArray(T),) -> T

View file

@ -29,7 +29,11 @@ np = pyimport "numpy"
.Complex64 = 'complex64': ClassType
.Complex128 = 'complex128': ClassType
.Size: ClassType
.Size: (S: [Nat; _]) -> ClassType
.Size(S).
__call__: (size: {S}) -> .Size(S)
.Size(S)|<: Eq|.
__eq__: (self: .Size(S), other: .Size(S)) -> Bool
.Tensor!: (T: Type, Shape: [Nat; _]) -> ClassType
.Tensor!(T, _) <: Output T
.Tensor!(T, S)|<: IrregularEq|.
@ -39,9 +43,9 @@ np = pyimport "numpy"
__getitem__: (self: .Tensor!(T, S), index: Nat or [Nat; _]) -> .Tensor!(T, _)
.Tensor!(T, S).
data: .Tensor!(T, S)
shape: .Size(S)
.Tensor!(_, _).
dtype: .DType
shape: .Size
clone: |T, S: [Nat; _]|(self: .Tensor!(T, S)) -> .Tensor!(T, S)
cpu: |T, S: [Nat; _]|(self: .Tensor!(T, S)) -> .Tensor!(T, S)
detach: |T, S: [Nat; _]|(self: .Tensor!(T, S)) -> .Tensor!(T, S)
@ -79,5 +83,5 @@ np = pyimport "numpy"
and (|T|(input: .Tensor!(T, _)) -> .Tensor!(T, _))
.min: (|T|(input: .Tensor!(T, _), dim: Nat, keepdim := Bool) -> (.Tensor!(T, _)), .Tensor!(T, _)) \
and (|T|(input: .Tensor!(T, _)) -> .Tensor!(T, _))
.tensor: (|T, S: [T; _]|(data: {S}, dtype := .DType, device := .Device) -> .Tensor!(T, S)) \
.tensor: (|T, S: [Nat; _]|(data: HasScalarType(T) and HasShape(S), dtype := .DType, device := .Device) -> .Tensor!(T, S)) \
and (|T|(data: [T; _], dtype := .DType, device := .Device) -> .Tensor!(T, _))

View file

@ -1023,6 +1023,18 @@ impl ValueObj {
}
}
pub const fn is_container(&self) -> bool {
matches!(
self,
Self::Array(_)
| Self::UnsizedArray(_)
| Self::Set(_)
| Self::Dict(_)
| Self::Tuple(_)
| Self::Record(_)
)
}
pub fn from_str(t: Type, mut content: Str) -> Option<Self> {
match t {
Type::Int => content.replace('_', "").parse::<i32>().ok().map(Self::Int),