diff --git a/compiler/erg_common/dict.rs b/compiler/erg_common/dict.rs index 4aaddcc2..233da03f 100644 --- a/compiler/erg_common/dict.rs +++ b/compiler/erg_common/dict.rs @@ -175,8 +175,8 @@ impl Dict { } #[inline] - pub fn insert(&mut self, k: K, v: V) { - self.dict.insert(k, v); + pub fn insert(&mut self, k: K, v: V) -> Option { + self.dict.insert(k, v) } #[inline] diff --git a/compiler/erg_compiler/context/initialize/mod.rs b/compiler/erg_compiler/context/initialize/mod.rs index ebfd0a88..c61477b9 100644 --- a/compiler/erg_compiler/context/initialize/mod.rs +++ b/compiler/erg_compiler/context/initialize/mod.rs @@ -7,6 +7,7 @@ pub mod py_mods; use std::path::PathBuf; use erg_common::config::ErgConfig; +use erg_common::dict; // use erg_common::error::Location; use erg_common::vis::Visibility; use erg_common::Str; @@ -470,7 +471,12 @@ impl Context { obj.register_builtin_impl("__sizeof__", fn0_met(Obj, Nat), Const, Public); obj.register_builtin_impl("__repr__", fn0_met(Obj, Str), Immutable, Public); obj.register_builtin_impl("__str__", fn0_met(Obj, Str), Immutable, Public); - obj.register_builtin_impl("__dict__", fn0_met(Obj, dict(Str, Obj)), Immutable, Public); + obj.register_builtin_impl( + "__dict__", + fn0_met(Obj, dict! {Str => Obj}.into()), + Immutable, + Public, + ); obj.register_builtin_impl( "__bytes__", fn0_met(Obj, builtin_mono("Bytes")), @@ -907,7 +913,7 @@ impl Context { Self::builtin_poly_class("Set", vec![PS::t_nd("T"), PS::named_nd("N", Nat)], 10); let n = mono_q_tp("N"); let m = mono_q_tp("M"); - let set_t = set(mono_q("T"), n.clone()); + let set_t = set_t(mono_q("T"), n.clone()); set_.register_superclass(Obj, &obj); set_.register_marker_trait(builtin_poly("Output", vec![ty_tp(mono_q("T"))])); let t = fn_met( diff --git a/compiler/erg_compiler/hir.rs b/compiler/erg_compiler/hir.rs index d83c85fe..99195096 100644 --- a/compiler/erg_compiler/hir.rs +++ b/compiler/erg_compiler/hir.rs @@ -1,6 +1,7 @@ /// defines High-level Intermediate Representation use std::fmt; +use erg_common::dict::Dict as HashMap; use erg_common::error::Location; use erg_common::traits::{Locational, NestedDisplay, Stream}; use erg_common::vis::{Field, Visibility}; @@ -14,7 +15,7 @@ use erg_common::{ use erg_parser::ast::{fmt_lines, DefId, DefKind, Params, TypeSpec, VarName}; use erg_parser::token::{Token, TokenKind}; -use erg_type::constructors::{array, set, tuple}; +use erg_type::constructors::{array, dict_t, set_t, tuple}; use erg_type::typaram::TyParam; use erg_type::value::{TypeKind, ValueObj}; use erg_type::{impl_t, impl_t_for_enum, HasType, Type}; @@ -734,11 +735,16 @@ impl_display_from_nested!(NormalDict); impl_locational!(NormalDict, l_brace, r_brace); impl NormalDict { - pub const fn new(l_brace: Token, r_brace: Token, t: Type, kvs: Vec) -> Self { + pub fn new( + l_brace: Token, + r_brace: Token, + kv_ts: HashMap, + kvs: Vec, + ) -> Self { Self { l_brace, r_brace, - t, + t: dict_t(TyParam::Dict(kv_ts)), kvs, } } @@ -801,7 +807,7 @@ impl_t!(NormalSet); impl NormalSet { pub fn new(l_brace: Token, r_brace: Token, elem_t: Type, elems: Args) -> Self { - let t = set(elem_t, TyParam::value(elems.len())); + let t = set_t(elem_t, TyParam::value(elems.len())); Self { l_brace, r_brace, diff --git a/compiler/erg_compiler/lower.rs b/compiler/erg_compiler/lower.rs index 0ca580cf..1946c607 100644 --- a/compiler/erg_compiler/lower.rs +++ b/compiler/erg_compiler/lower.rs @@ -4,6 +4,7 @@ use erg_common::astr::AtomicStr; use erg_common::config::ErgConfig; +use erg_common::dict; use erg_common::error::{Location, MultiErrorDisplay}; use erg_common::set; use erg_common::set::Set; @@ -18,7 +19,8 @@ use erg_parser::token::{Token, TokenKind}; use erg_parser::Parser; use erg_type::constructors::{ - array, array_mut, builtin_mono, builtin_poly, free_var, func, mono, proc, quant, set, set_mut, + array, array_mut, builtin_mono, builtin_poly, free_var, func, mono, proc, quant, set_mut, + set_t, ty_tp, }; use erg_type::free::Constraint; use erg_type::typaram::TyParam; @@ -362,7 +364,6 @@ impl ASTLowerer { let mut union = Type::Never; let mut new_set = vec![]; for elem in elems { - // TODO: Check if the object's type implements Eq let elem = self.lower_expr(elem.expr)?; union = self.ctx.union(&union, elem.ref_t()); if union.is_intersection_type() { @@ -448,7 +449,7 @@ impl ASTLowerer { } else if self.ctx.subtype_of(&elem.t(), &Type::Type) { builtin_poly("SetType", vec![TyParam::t(elem.t()), TyParam::Value(v)]) } else { - set(elem.t(), TyParam::Value(v)) + set_t(elem.t(), TyParam::Value(v)) } } Ok(v @ ValueObj::Mut(_)) if v.class() == builtin_mono("Nat!") => { @@ -469,7 +470,7 @@ impl ASTLowerer { vec![TyParam::t(elem.t()), TyParam::erased(Type::Nat)], ) } else { - set(elem.t(), TyParam::erased(Type::Nat)) + set_t(elem.t(), TyParam::erased(Type::Nat)) } } } @@ -484,8 +485,81 @@ impl ASTLowerer { } } - fn lower_normal_dict(&mut self, _dict: ast::NormalDict) -> LowerResult { - todo!() + fn lower_normal_dict(&mut self, dict: ast::NormalDict) -> LowerResult { + log!(info "enter {}({dict})", fn_name!()); + let mut union = dict! {}; + let mut new_kvs = vec![]; + for kv in dict.kvs { + let loc = kv.loc(); + let key = self.lower_expr(kv.key)?; + let value = self.lower_expr(kv.value)?; + if union.insert(key.t(), value.t()).is_none() { + return Err(LowerErrors::from(LowerError::syntax_error( + self.cfg.input.clone(), + line!() as usize, + loc, + AtomicStr::arc(&self.ctx.name[..]), + switch_lang!( + "japanese" => "集合の要素は全て同じ型である必要があります", + "simplified_chinese" => "集合元素必须全部是相同类型", + "traditional_chinese" => "集合元素必須全部是相同類型", + "english" => "all elements of a set must be of the same type", + ), + Some( + switch_lang!( + "japanese" => "Int or Strなど明示的に型を指定してください", + "simplified_chinese" => "明确指定类型,例如:Int or Str", + "traditional_chinese" => "明確指定類型,例如:Int or Str", + "english" => "please specify the type explicitly, e.g. Int or Str", + ) + .into(), + ), + ))); + } + new_kvs.push(hir::KeyValue::new(key, value)); + } + /*let sup = builtin_poly("Eq", vec![TyParam::t(elem_t.clone())]); + let loc = Location::concat(&dict.l_brace, &dict.r_brace); + // check if elem_t is Eq + if let Err(errs) = self.ctx.sub_unify(&elem_t, &sup, loc, None) { + self.errs.extend(errs.into_iter()); + }*/ + let kv_ts = if union.is_empty() { + dict! { + ty_tp(free_var(self.ctx.level, Constraint::new_type_of(Type::Type))) => + ty_tp(free_var(self.ctx.level, Constraint::new_type_of(Type::Type))) + } + } else { + union + .into_iter() + .map(|(k, v)| (TyParam::t(k), TyParam::t(v))) + .collect() + }; + // TODO: lint + /* + if is_duplicated { + self.warns.push(LowerWarning::syntax_error( + self.cfg.input.clone(), + line!() as usize, + normal_set.loc(), + AtomicStr::arc(&self.ctx.name[..]), + switch_lang!( + "japanese" => "要素が重複しています", + "simplified_chinese" => "元素重复", + "traditional_chinese" => "元素重複", + "english" => "Elements are duplicated", + ), + None, + )); + } + Ok(normal_set) + */ + Ok(hir::NormalDict::new( + dict.l_brace, + dict.r_brace, + kv_ts, + new_kvs, + )) } fn lower_acc(&mut self, acc: ast::Accessor) -> LowerResult { diff --git a/compiler/erg_parser/ast.rs b/compiler/erg_parser/ast.rs index d73e2489..e3222509 100644 --- a/compiler/erg_parser/ast.rs +++ b/compiler/erg_parser/ast.rs @@ -613,8 +613,8 @@ impl KeyValue { #[derive(Clone, Debug, PartialEq, Eq, Hash)] pub struct NormalDict { - pub(crate) l_brace: Token, - pub(crate) r_brace: Token, + pub l_brace: Token, + pub r_brace: Token, pub kvs: Vec, } diff --git a/compiler/erg_type/constructors.rs b/compiler/erg_type/constructors.rs index 32da9479..31ce967e 100644 --- a/compiler/erg_type/constructors.rs +++ b/compiler/erg_type/constructors.rs @@ -35,16 +35,13 @@ pub fn array_mut(elem_t: Type, len: TyParam) -> Type { builtin_poly("Array!", vec![TyParam::t(elem_t), len]) } -pub fn dict(k_t: Type, v_t: Type) -> Type { - builtin_poly("Dict", vec![TyParam::t(k_t), TyParam::t(v_t)]) -} - +// FIXME pub fn tuple(args: Vec) -> Type { let name = format!("Tuple{}", args.len()); builtin_poly(name, args.into_iter().map(TyParam::t).collect()) } -pub fn set(elem_t: Type, len: TyParam) -> Type { +pub fn set_t(elem_t: Type, len: TyParam) -> Type { builtin_poly("Set", vec![TyParam::t(elem_t), len]) } @@ -52,6 +49,10 @@ pub fn set_mut(elem_t: Type, len: TyParam) -> Type { builtin_poly("Set!", vec![TyParam::t(elem_t), len]) } +pub fn dict_t(dict: TyParam) -> Type { + builtin_poly("Dict", vec![dict]) +} + #[inline] pub fn range(t: Type) -> Type { builtin_poly("Range", vec![TyParam::t(t)]) diff --git a/compiler/erg_type/deserialize.rs b/compiler/erg_type/deserialize.rs index 54c45ec3..2d89077a 100644 --- a/compiler/erg_type/deserialize.rs +++ b/compiler/erg_type/deserialize.rs @@ -5,6 +5,7 @@ use std::string::FromUtf8Error; use erg_common::astr::AtomicStr; use erg_common::cache::CacheSet; use erg_common::config::{ErgConfig, Input}; +use erg_common::dict::Dict; use erg_common::error::{ErrorCore, ErrorKind, Location}; use erg_common::serialize::DataTypePrefix; use erg_common::{fn_name, switch_lang}; @@ -105,7 +106,7 @@ pub type DeserializeResult = Result; pub struct Deserializer { str_cache: CacheSet, arr_cache: CacheSet<[ValueObj]>, - dict_cache: CacheSet<[(ValueObj, ValueObj)]>, + _dict_cache: CacheSet>, } impl Deserializer { @@ -113,7 +114,7 @@ impl Deserializer { Self { str_cache: CacheSet::new(), arr_cache: CacheSet::new(), - dict_cache: CacheSet::new(), + _dict_cache: CacheSet::new(), } } @@ -146,11 +147,6 @@ impl Deserializer { ValueObj::Array(self.arr_cache.get(arr)) } - /// TODO: 使わない? - pub fn get_cached_dict(&mut self, dict: &[(ValueObj, ValueObj)]) -> ValueObj { - ValueObj::Dict(self.dict_cache.get(dict)) - } - pub fn vec_to_bytes(vector: Vec) -> [u8; LEN] { let mut arr = [0u8; LEN]; for (arr_elem, vec_elem) in arr.iter_mut().zip(vector.iter()) { diff --git a/compiler/erg_type/lib.rs b/compiler/erg_type/lib.rs index 65dbd745..d343bd88 100644 --- a/compiler/erg_type/lib.rs +++ b/compiler/erg_type/lib.rs @@ -14,6 +14,7 @@ use std::fmt; use std::ops::{Range, RangeInclusive}; use std::path::PathBuf; +use constructors::dict_t; use erg_common::dict::Dict; use erg_common::set::Set; use erg_common::traits::LimitedDisplay; @@ -1446,6 +1447,16 @@ impl From> for Type { } } +impl From> for Type { + fn from(d: Dict) -> Self { + let d = d + .into_iter() + .map(|(k, v)| (TyParam::t(k), TyParam::t(v))) + .collect(); + dict_t(TyParam::Dict(d)) + } +} + fn get_t_from_tp(tp: &TyParam) -> Option { match tp { TyParam::FreeVar(fv) if fv.is_linked() => get_t_from_tp(&fv.crack()), diff --git a/compiler/erg_type/typaram.rs b/compiler/erg_type/typaram.rs index c3bb7470..6859a124 100644 --- a/compiler/erg_type/typaram.rs +++ b/compiler/erg_type/typaram.rs @@ -2,6 +2,7 @@ use std::cmp::Ordering; use std::fmt; use std::ops::{Add, Div, Mul, Neg, Range, RangeInclusive, Sub}; +use erg_common::dict::Dict; use erg_common::traits::LimitedDisplay; use crate::constructors::int_interval; @@ -127,6 +128,7 @@ pub enum TyParam { Array(Vec), Set(Vec), Tuple(Vec), + Dict(Dict), Mono(Str), MonoProj { obj: Box, @@ -300,6 +302,7 @@ impl LimitedDisplay for TyParam { } write!(f, "}}") } + Self::Dict(dict) => write!(f, "{dict}"), Self::Tuple(tuple) => { write!(f, "(")?; for (i, t) in tuple.iter().enumerate() { diff --git a/compiler/erg_type/value.rs b/compiler/erg_type/value.rs index 4328d4ef..10d8cc18 100644 --- a/compiler/erg_type/value.rs +++ b/compiler/erg_type/value.rs @@ -12,13 +12,14 @@ use erg_common::dict::Dict; use erg_common::error::ErrorCore; use erg_common::serialize::*; use erg_common::set; +use erg_common::set::Set; use erg_common::shared::Shared; use erg_common::vis::Field; use erg_common::{dict, fmt_iter, impl_display_from_debug, switch_lang}; use erg_common::{RcArray, Str}; use crate::codeobj::CodeObj; -use crate::constructors::{array, builtin_mono, builtin_poly, refinement, set as const_set, tuple}; +use crate::constructors::{array, builtin_mono, builtin_poly, refinement, set_t, tuple}; use crate::free::fresh_varname; use crate::typaram::TyParam; use crate::{ConstSubr, HasType, Predicate, Type}; @@ -124,8 +125,8 @@ pub enum ValueObj { Str(Str), Bool(bool), Array(Rc<[ValueObj]>), - Set(Rc<[ValueObj]>), - Dict(Rc<[(ValueObj, ValueObj)]>), + Set(Set), + Dict(Dict), Tuple(Rc<[ValueObj]>), Record(Dict), Code(Box), @@ -508,7 +509,7 @@ impl ValueObj { ), Self::Dict(_dict) => todo!(), Self::Tuple(tup) => tuple(tup.iter().map(|v| v.class()).collect()), - Self::Set(st) => const_set(st.iter().next().unwrap().class(), TyParam::value(st.len())), + Self::Set(st) => set_t(st.iter().next().unwrap().class(), TyParam::value(st.len())), Self::Code(_) => Type::Code, Self::Record(rec) => { Type::Record(rec.iter().map(|(k, v)| (k.clone(), v.class())).collect())