diff --git a/.cargo/config.toml b/.cargo/config.toml index 2c53e2b6..67b82bf3 100644 --- a/.cargo/config.toml +++ b/.cargo/config.toml @@ -83,5 +83,5 @@ dinst = "install --path . --features debug --features els" ntest = "nextest run" # +nightly -# you must specify the --target option +# you must specify the --target option (e.g. x86_64-unknown-linux-gnu, aarch64-apple-darwin, x86_64-pc-windows-msvc) drc_r = "r -Zbuild-std -Zbuild-std-features=core/debug_refcell" diff --git a/crates/els/util.rs b/crates/els/util.rs index e01a2c89..a47595e9 100644 --- a/crates/els/util.rs +++ b/crates/els/util.rs @@ -5,7 +5,7 @@ use std::str::FromStr; use erg_common::consts::CASE_SENSITIVE; use erg_common::normalize_path; -use erg_common::traits::{DequeStream, Locational}; +use erg_common::traits::{DequeStream, Immutable, Locational}; use erg_compiler::erg_parser::token::{Token, TokenStream}; @@ -18,6 +18,8 @@ use crate::server::ELSResult; #[derive(Debug, Clone, PartialEq, Eq, Hash)] pub struct NormalizedUrl(Url); +impl Immutable for NormalizedUrl {} + impl fmt::Display for NormalizedUrl { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { write!(f, "{}", self.0) diff --git a/crates/erg_common/cache.rs b/crates/erg_common/cache.rs index 7597026e..a0b0d4f3 100644 --- a/crates/erg_common/cache.rs +++ b/crates/erg_common/cache.rs @@ -5,6 +5,7 @@ use std::sync::Arc; use crate::dict::Dict; use crate::set::Set; use crate::shared::Shared; +use crate::traits::Immutable; use crate::{ArcArray, Str}; #[derive(Debug)] @@ -45,7 +46,7 @@ impl CacheSet { } } -impl CacheSet<[T]> { +impl CacheSet<[T]> { pub fn get(&self, q: &[T]) -> Arc<[T]> { if let Some(cached) = self.0.borrow().get(q) { return cached.clone(); @@ -56,7 +57,18 @@ impl CacheSet<[T]> { } } -impl CacheSet { +impl CacheSet<[T]> { + pub fn linear_get(&self, q: &[T]) -> Arc<[T]> { + if let Some(cached) = self.0.borrow().linear_get(q) { + return cached.clone(); + } // &self.0 is dropped + let s = ArcArray::from(q); + self.0.borrow_mut().insert(s.clone()); + s + } +} + +impl CacheSet { pub fn get(&self, q: &Q) -> Arc where Arc: Borrow, @@ -71,6 +83,21 @@ impl CacheSet { } } +impl CacheSet { + pub fn linear_get(&self, q: &Q) -> Arc + where + Arc: Borrow, + Q: ?Sized + Eq + ToOwned, + { + if let Some(cached) = self.0.borrow().linear_get(q) { + return cached.clone(); + } // &self.0 is dropped + let s = Arc::from(q.to_owned()); + self.0.borrow_mut().insert(s.clone()); + s + } +} + #[derive(Debug, Clone)] pub struct CacheDict(Shared>>); @@ -86,15 +113,26 @@ impl CacheDict { } } -impl CacheDict { +impl CacheDict { pub fn get(&self, k: &Q) -> Option> where K: Borrow, { self.0.borrow().get(k).cloned() } +} +impl CacheDict { pub fn insert(&self, k: K, v: V) { self.0.borrow_mut().insert(k, Arc::new(v)); } } + +impl CacheDict { + pub fn linear_get(&self, k: &Q) -> Option> + where + K: Borrow, + { + self.0.borrow().linear_get(k).cloned() + } +} diff --git a/crates/erg_common/dict.rs b/crates/erg_common/dict.rs index 4e8e1008..9f46fd13 100644 --- a/crates/erg_common/dict.rs +++ b/crates/erg_common/dict.rs @@ -9,6 +9,7 @@ use std::ops::{Index, IndexMut}; use crate::fxhash::FxHashMap; use crate::get_hash; +use crate::traits::Immutable; #[macro_export] macro_rules! dict { @@ -25,13 +26,13 @@ pub struct Dict { dict: FxHashMap, } -impl PartialEq for Dict { +impl PartialEq for Dict { fn eq(&self, other: &Dict) -> bool { self.dict == other.dict } } -impl Eq for Dict {} +impl Eq for Dict {} impl Hash for Dict { fn hash(&self, state: &mut H) { @@ -82,7 +83,7 @@ impl From> for Dict { } } -impl Index<&Q> for Dict +impl Index<&Q> for Dict where K: Borrow, Q: Hash + Eq, @@ -94,7 +95,7 @@ where } } -impl IndexMut<&Q> for Dict +impl IndexMut<&Q> for Dict where K: Borrow, Q: Hash + Eq, @@ -195,6 +196,15 @@ impl Dict { { self.dict.retain(f); } + + pub fn get_by(&self, k: &K, cmp: impl Fn(&K, &K) -> bool) -> Option<&V> { + for (k_, v) in self.dict.iter() { + if cmp(k, k_) { + return Some(v); + } + } + None + } } impl IntoIterator for Dict { @@ -215,7 +225,47 @@ impl<'a, K, V> IntoIterator for &'a Dict { } } -impl Dict { +impl Dict { + /// K: interior-mutable + pub fn linear_get(&self, key: &Q) -> Option<&V> + where + K: Borrow, + Q: Eq + ?Sized, + { + self.dict + .iter() + .find(|(k, _)| (*k).borrow() == key) + .map(|(_, v)| v) + } + + pub fn linear_get_mut(&mut self, key: &Q) -> Option<&mut V> + where + K: Borrow, + Q: Eq + ?Sized, + { + self.dict + .iter_mut() + .find(|(k, _)| (*k).borrow() == key) + .map(|(_, v)| v) + } +} + +impl Dict { + /// K: interior-mutable + pub fn linear_eq(&self, other: &Self) -> bool { + if self.len() != other.len() { + return false; + } + for (k, v) in self.iter() { + if other.linear_get(k) != Some(v) { + return false; + } + } + true + } +} + +impl Dict { #[inline] pub fn get(&self, k: &Q) -> Option<&V> where @@ -225,15 +275,6 @@ impl Dict { self.dict.get(k) } - pub fn get_by(&self, k: &K, cmp: impl Fn(&K, &K) -> bool) -> Option<&V> { - for (k_, v) in self.dict.iter() { - if cmp(k, k_) { - return Some(v); - } - } - None - } - #[inline] pub fn get_mut(&mut self, k: &Q) -> Option<&mut V> where @@ -260,11 +301,6 @@ impl Dict { self.dict.contains_key(k) } - #[inline] - pub fn insert(&mut self, k: K, v: V) -> Option { - self.dict.insert(k, v) - } - #[inline] pub fn remove(&mut self, k: &Q) -> Option where @@ -281,6 +317,13 @@ impl Dict { { self.dict.remove_entry(k) } +} + +impl Dict { + #[inline] + pub fn insert(&mut self, k: K, v: V) -> Option { + self.dict.insert(k, v) + } /// NOTE: This method does not consider pairing with values and keys. That is, a value may be paired with a different key (can be considered equal). /// If you need to consider the pairing of the keys and values, use `guaranteed_extend` instead. diff --git a/crates/erg_common/pathutil.rs b/crates/erg_common/pathutil.rs index 2f4f3f0f..f713de3f 100644 --- a/crates/erg_common/pathutil.rs +++ b/crates/erg_common/pathutil.rs @@ -6,6 +6,7 @@ use std::path::{Component, Path, PathBuf}; use crate::consts::PYTHON_MODE; use crate::env::erg_pkgs_path; +use crate::traits::Immutable; use crate::{normalize_path, Str}; /// Guaranteed equivalence path. @@ -16,6 +17,8 @@ use crate::{normalize_path, Str}; #[derive(Debug, Clone, PartialEq, Eq, Hash, Default, PartialOrd, Ord)] pub struct NormalizedPathBuf(PathBuf); +impl Immutable for NormalizedPathBuf {} + impl fmt::Display for NormalizedPathBuf { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { write!(f, "{}", self.display()) diff --git a/crates/erg_common/set.rs b/crates/erg_common/set.rs index 3500afa1..c925cf1e 100644 --- a/crates/erg_common/set.rs +++ b/crates/erg_common/set.rs @@ -5,6 +5,7 @@ use std::hash::{Hash, Hasher}; use std::iter::FromIterator; use crate::fxhash::FxHashSet; +use crate::traits::Immutable; use crate::{debug_fmt_iter, fmt_iter, get_hash}; #[cfg(feature = "pylib")] @@ -46,13 +47,13 @@ where } } -impl PartialEq for Set { +impl PartialEq for Set { fn eq(&self, other: &Set) -> bool { self.elems.eq(&other.elems) } } -impl Eq for Set {} +impl Eq for Set {} impl Default for Set { fn default() -> Self { @@ -108,9 +109,15 @@ impl Set { elems: FxHashSet::default(), } } -} -impl Set { + pub fn get_by(&self, value: &Q, cmp: impl Fn(&Q, &Q) -> bool) -> Option<&T> + where + T: Borrow, + Q: ?Sized, + { + self.elems.iter().find(|&v| cmp(v.borrow(), value)) + } + pub fn with_capacity(capacity: usize) -> Self { Self { elems: FxHashSet::with_capacity_and_hasher(capacity, Default::default()), @@ -163,7 +170,50 @@ impl<'a, T> IntoIterator for &'a Set { } } -impl Set { +impl Set { + pub fn linear_get(&self, value: &Q) -> Option<&T> + where + T: Borrow, + Q: ?Sized + Eq, + { + self.elems.iter().find(|x| (*x).borrow() == value) + } + + pub fn linear_contains(&self, value: &Q) -> bool + where + T: Borrow, + Q: ?Sized + Eq, + { + self.elems.iter().any(|x| (*x).borrow() == value) + } + + pub fn linear_eq(&self, other: &Set) -> bool { + self.len() == other.len() && self.iter().all(|x| other.linear_contains(x)) + } + + pub fn linear_remove(&mut self, value: &Q) -> bool + where + T: Borrow, + Q: ?Sized + Eq, + { + let mut found = false; + self.elems.retain(|x| { + let eq = (*x).borrow() == value; + if eq { + found = true; + } + !eq + }); + found + } + + pub fn linear_exclude(mut self, other: &T) -> Set { + self.linear_remove(other); + self + } +} + +impl Set { #[inline] pub fn get(&self, value: &Q) -> Option<&T> where @@ -173,14 +223,6 @@ impl Set { self.elems.get(value) } - pub fn get_by(&self, value: &Q, cmp: impl Fn(&Q, &Q) -> bool) -> Option<&T> - where - T: Borrow, - Q: ?Sized + Hash + Eq, - { - self.elems.iter().find(|&v| cmp(v.borrow(), value)) - } - #[inline] pub fn contains(&self, value: &Q) -> bool where @@ -190,12 +232,6 @@ impl Set { self.elems.contains(value) } - /// newly inserted: true, already present: false - #[inline] - pub fn insert(&mut self, value: T) -> bool { - self.elems.insert(value) - } - #[inline] pub fn remove(&mut self, value: &Q) -> bool where @@ -205,6 +241,42 @@ impl Set { self.elems.remove(value) } + pub fn exclude(mut self, other: &T) -> Set { + self.remove(other); + self + } +} + +impl Set { + /// ``` + /// # use erg_common::set; + /// # use erg_common::set::Set; + /// assert_eq!(Set::multi_intersection([set!{1, 3}, set!{1, 2}].into_iter()), set!{1}); + /// assert_eq!(Set::multi_intersection([set!{1, 3}, set!{1, 2}, set!{2}].into_iter()), set!{1, 2}); + /// assert_eq!(Set::multi_intersection([set!{1, 3}, set!{1, 2}, set!{2, 3}].into_iter()), set!{1, 2, 3}); + /// ``` + pub fn multi_intersection(mut i: I) -> Set + where + I: Iterator> + Clone, + { + let mut res = set! {}; + while let Some(s) = i.next() { + res = res.union_from_iter( + s.into_iter() + .filter(|x| i.clone().any(|set| set.contains(x))), + ); + } + res + } +} + +impl Set { + /// newly inserted: true, already present: false + #[inline] + pub fn insert(&mut self, value: T) -> bool { + self.elems.insert(value) + } + #[inline] pub fn extend>(&mut self, iter: I) { self.elems.extend(iter); @@ -295,27 +367,6 @@ impl Set { self.intersection(&iter.collect()) } - /// ``` - /// # use erg_common::set; - /// # use erg_common::set::Set; - /// assert_eq!(Set::multi_intersection([set!{1, 3}, set!{1, 2}].into_iter()), set!{1}); - /// assert_eq!(Set::multi_intersection([set!{1, 3}, set!{1, 2}, set!{2}].into_iter()), set!{1, 2}); - /// assert_eq!(Set::multi_intersection([set!{1, 3}, set!{1, 2}, set!{2, 3}].into_iter()), set!{1, 2, 3}); - /// ``` - pub fn multi_intersection(mut i: I) -> Set - where - I: Iterator> + Clone, - { - let mut res = set! {}; - while let Some(s) = i.next() { - res = res.union_from_iter( - s.into_iter() - .filter(|x| i.clone().any(|set| set.contains(x))), - ); - } - res - } - pub fn difference(&self, other: &Set) -> Set { let u = self.elems.difference(&other.elems); Self { @@ -331,11 +382,6 @@ impl Set { self.insert(other); self } - - pub fn exclude(mut self, other: &T) -> Set { - self.remove(other); - self - } } impl Set { diff --git a/crates/erg_common/traits.rs b/crates/erg_common/traits.rs index 4f07cfe0..7c0c84bf 100644 --- a/crates/erg_common/traits.rs +++ b/crates/erg_common/traits.rs @@ -1371,3 +1371,42 @@ impl OptionalTranspose for Option> { pub trait New { fn new(cfg: ErgConfig) -> Self; } + +/// Indicates that the type has no interior mutability +// TODO: auto trait +pub trait Immutable {} + +impl Immutable for () {} +impl Immutable for bool {} +impl Immutable for char {} +impl Immutable for u8 {} +impl Immutable for u16 {} +impl Immutable for u32 {} +impl Immutable for u64 {} +impl Immutable for u128 {} +impl Immutable for usize {} +impl Immutable for i8 {} +impl Immutable for i16 {} +impl Immutable for i32 {} +impl Immutable for i64 {} +impl Immutable for i128 {} +impl Immutable for isize {} +impl Immutable for f32 {} +impl Immutable for f64 {} +impl Immutable for str {} +impl Immutable for String {} +impl Immutable for crate::Str {} +impl Immutable for std::path::PathBuf {} +impl Immutable for std::path::Path {} +impl Immutable for std::ffi::OsString {} +impl Immutable for std::ffi::OsStr {} +impl Immutable for std::time::Duration {} +impl Immutable for std::time::SystemTime {} +impl Immutable for std::time::Instant {} +impl Immutable for &T {} +impl Immutable for Option {} +impl Immutable for Vec {} +impl Immutable for [T] {} +impl Immutable for Box {} +impl Immutable for std::rc::Rc {} +impl Immutable for std::sync::Arc {} diff --git a/crates/erg_common/tsort.rs b/crates/erg_common/tsort.rs index 8b114d44..78c3efbd 100644 --- a/crates/erg_common/tsort.rs +++ b/crates/erg_common/tsort.rs @@ -1,6 +1,7 @@ //! Topological sort use crate::dict::Dict; use crate::set::Set; +use crate::traits::Immutable; use std::fmt::Debug; use std::hash::Hash; @@ -48,13 +49,13 @@ impl TopoSortError { } #[derive(Debug, Clone, PartialEq, Eq)] -pub struct Node { +pub struct Node { pub id: T, pub data: U, pub depends_on: Set, } -impl Node { +impl Node { pub const fn new(id: T, data: U, depends_on: Set) -> Self { Node { id, @@ -93,12 +94,12 @@ fn _reorder_by_idx(mut v: Vec, idx: Vec) -> Vec { v } -fn reorder_by_key(mut g: Graph, idx: Vec) -> Graph { +fn reorder_by_key(mut g: Graph, idx: Vec) -> Graph { g.sort_by_key(|node| idx.iter().position(|k| k == &node.id).unwrap()); g } -fn dfs( +fn dfs( g: &Graph, v: T, used: &mut Set, @@ -125,7 +126,7 @@ fn dfs( /// perform topological sort on a graph #[allow(clippy::result_unit_err)] -pub fn tsort( +pub fn tsort( g: Graph, ) -> Result, TopoSortError> { let n = g.len(); diff --git a/crates/erg_compiler/context/compare.rs b/crates/erg_compiler/context/compare.rs index 03d7017d..8acac4f0 100644 --- a/crates/erg_compiler/context/compare.rs +++ b/crates/erg_compiler/context/compare.rs @@ -7,7 +7,7 @@ use erg_common::dict::Dict; use erg_common::set::Set; use erg_common::style::colors::DEBUG_ERROR; use erg_common::traits::StructuralEq; -use erg_common::{assume_unreachable, log, set_recursion_limit}; +use erg_common::{assume_unreachable, log}; use erg_common::{Str, Triple}; use crate::context::eval::UndoableLinkedList; @@ -126,7 +126,6 @@ impl Context { /// lhs :> rhs ? pub(crate) fn supertype_of(&self, lhs: &Type, rhs: &Type) -> bool { - set_recursion_limit!(false, 128); let res = match Self::cheap_supertype_of(lhs, rhs) { (Absolutely, judge) => judge, (Maybe, judge) => { @@ -1040,7 +1039,7 @@ impl Context { } for (sub_k, sub_v) in sub_d.iter() { if let Some(sup_v) = sup_d - .get(sub_k) + .linear_get(sub_k) .or_else(|| sub_tpdict_get(sup_d, sub_k, self)) { if !self.supertype_of_tp(sup_v, sub_v, variance) { @@ -1794,7 +1793,7 @@ impl Context { if self.subtype_of(&t, elem) { return intersection.clone(); } else if self.supertype_of(&t, elem) { - return constructors::ands(ands.exclude(&t).include(elem.clone())); + return constructors::ands(ands.linear_exclude(&t).include(elem.clone())); } } and(intersection.clone(), elem.clone()) @@ -2024,7 +2023,7 @@ impl Context { "or" => self.is_sub_pred_of(existing, pred), _ => unreachable!(), }) { - reduced.remove(old); + reduced.linear_remove(old); } // insert if necessary if reduced.iter().all(|existing| match mode { diff --git a/crates/erg_compiler/context/eval.rs b/crates/erg_compiler/context/eval.rs index 5e5ab3bc..9b3f5454 100644 --- a/crates/erg_compiler/context/eval.rs +++ b/crates/erg_compiler/context/eval.rs @@ -9,7 +9,7 @@ use erg_common::log; use erg_common::set::Set; use erg_common::shared::Shared; use erg_common::traits::{Locational, Stream}; -use erg_common::{dict, fmt_vec, fn_name, option_enum_unwrap, set, set_recursion_limit, Triple}; +use erg_common::{dict, fmt_vec, fn_name, option_enum_unwrap, set, Triple}; use erg_common::{ArcArray, Str}; use OpKind::*; @@ -2068,7 +2068,6 @@ impl Context { level: usize, t_loc: &impl Locational, ) -> Failable { - set_recursion_limit!(Ok(Failure), 128); let mut errs = EvalErrors::empty(); match substituted { Type::FreeVar(fv) if fv.is_linked() => { @@ -3943,8 +3942,8 @@ impl Context { (TyParam::Erased(l), TyParam::Erased(r)) => l == r, (TyParam::List(l), TyParam::List(r)) => l == r, (TyParam::Tuple(l), TyParam::Tuple(r)) => l == r, - (TyParam::Set(l), TyParam::Set(r)) => l == r, // FIXME: - (TyParam::Dict(l), TyParam::Dict(r)) => l == r, + (TyParam::Set(l), TyParam::Set(r)) => l.linear_eq(r), + (TyParam::Dict(l), TyParam::Dict(r)) => l.linear_eq(r), (TyParam::Lambda(l), TyParam::Lambda(r)) => l == r, (TyParam::FreeVar { .. }, TyParam::FreeVar { .. }) => true, (TyParam::Mono(l), TyParam::Mono(r)) => { diff --git a/crates/erg_compiler/context/generalize.rs b/crates/erg_compiler/context/generalize.rs index 661ecdc1..59da39b9 100644 --- a/crates/erg_compiler/context/generalize.rs +++ b/crates/erg_compiler/context/generalize.rs @@ -100,19 +100,19 @@ impl Generalizer { let nd_params = lambda .nd_params .into_iter() - .map(|pt| pt.map_type(|t| self.generalize_t(t, uninit))) + .map(|pt| pt.map_type(&mut |t| self.generalize_t(t, uninit))) .collect::>(); let var_params = lambda .var_params - .map(|pt| pt.map_type(|t| self.generalize_t(t, uninit))); + .map(|pt| pt.map_type(&mut |t| self.generalize_t(t, uninit))); let d_params = lambda .d_params .into_iter() - .map(|pt| pt.map_type(|t| self.generalize_t(t, uninit))) + .map(|pt| pt.map_type(&mut |t| self.generalize_t(t, uninit))) .collect::>(); let kw_var_params = lambda .kw_var_params - .map(|pt| pt.map_type(|t| self.generalize_t(t, uninit))); + .map(|pt| pt.map_type(&mut |t| self.generalize_t(t, uninit))); let body = lambda .body .into_iter() diff --git a/crates/erg_compiler/context/initialize/const_func.rs b/crates/erg_compiler/context/initialize/const_func.rs index 6093b5db..e98178c4 100644 --- a/crates/erg_compiler/context/initialize/const_func.rs +++ b/crates/erg_compiler/context/initialize/const_func.rs @@ -342,7 +342,10 @@ pub(crate) fn __dict_getitem__(mut args: ValueArgs, ctx: &Context) -> EvalValueR let index = args .remove_left_or_key("Index") .ok_or_else(|| not_passed("Index"))?; - if let Some(v) = slf.get(&index).or_else(|| sub_vdict_get(&slf, &index, ctx)) { + if let Some(v) = slf + .linear_get(&index) + .or_else(|| sub_vdict_get(&slf, &index, ctx)) + { Ok(v.clone().into()) } else if let Some(v) = sub_vdict_get(&homogenize_dict(&slf, ctx), &index, ctx).cloned() { Ok(v.into()) diff --git a/crates/erg_compiler/context/inquire.rs b/crates/erg_compiler/context/inquire.rs index 00e4f988..9c2ffcd9 100644 --- a/crates/erg_compiler/context/inquire.rs +++ b/crates/erg_compiler/context/inquire.rs @@ -1235,11 +1235,11 @@ impl Context { let non_default_params = subr_t .non_default_params .iter() - .map(|pt| pt.clone().map_type(|t| self.readable_type(t))); + .map(|pt| pt.clone().map_type(&mut |t| self.readable_type(t))); let default_params = subr_t .default_params .iter() - .map(|pt| pt.clone().map_type(|t| self.readable_type(t))); + .map(|pt| pt.clone().map_type(&mut |t| self.readable_type(t))); Err(TyCheckError::overload_error( self.cfg.input.clone(), line!() as usize, diff --git a/crates/erg_compiler/context/register.rs b/crates/erg_compiler/context/register.rs index f5194d2e..edde9f02 100644 --- a/crates/erg_compiler/context/register.rs +++ b/crates/erg_compiler/context/register.rs @@ -596,7 +596,7 @@ impl Context { } if let Some(var_params) = &mut params.var_params { if let Some(pt) = &subr_t.var_params { - let pt = pt.clone().map_type(unknown_len_list_t); + let pt = pt.clone().map_type(&mut unknown_len_list_t); if let Err(es) = self.assign_param(var_params, Some(&pt), tmp_tv_cache, ParamKind::VarParams) { @@ -620,7 +620,7 @@ impl Context { } if let Some(kw_var_params) = &mut params.kw_var_params { if let Some(pt) = &subr_t.var_params { - let pt = pt.clone().map_type(str_dict_t); + let pt = pt.clone().map_type(&mut str_dict_t); if let Err(es) = self.assign_param( kw_var_params, Some(&pt), diff --git a/crates/erg_compiler/context/unify.rs b/crates/erg_compiler/context/unify.rs index 362cd12f..a9cdef89 100644 --- a/crates/erg_compiler/context/unify.rs +++ b/crates/erg_compiler/context/unify.rs @@ -302,7 +302,7 @@ impl<'c, 'l, 'u, L: Locational> Unifier<'c, 'l, 'u, L> { return Ok(()); } for (sub_k, sub_v) in sub.iter() { - if let Some(sup_v) = sup.get(sub_k) { + if let Some(sup_v) = sup.linear_get(sub_k) { self.sub_unify_value(sub_v, sup_v)?; } else { log!(err "{sup} does not have key {sub_k}"); @@ -628,7 +628,7 @@ impl<'c, 'l, 'u, L: Locational> Unifier<'c, 'l, 'u, L> { (TyParam::Dict(sub), TyParam::Dict(sup)) => { for (sub_k, sub_v) in sub.iter() { if let Some(sup_v) = sup - .get(sub_k) + .linear_get(sub_k) .or_else(|| sub_tpdict_get(sup, sub_k, self.ctx)) { // self.sub_unify_tp(sub_k, sup_k, _variance, loc, allow_divergence)?; diff --git a/crates/erg_compiler/hir.rs b/crates/erg_compiler/hir.rs index 14564516..5311cf36 100644 --- a/crates/erg_compiler/hir.rs +++ b/crates/erg_compiler/hir.rs @@ -2098,7 +2098,8 @@ impl Params { pub type Decorator = Expr; -#[derive(Debug, Clone, PartialEq, Eq, Hash)] +#[allow(clippy::derived_hash_with_manual_eq)] +#[derive(Debug, Clone, Hash)] pub struct SubrSignature { pub decorators: HashSet, pub ident: Identifier, @@ -2108,6 +2109,19 @@ pub struct SubrSignature { pub captured_names: Vec, } +impl PartialEq for SubrSignature { + fn eq(&self, other: &Self) -> bool { + self.ident == other.ident + && self.bounds == other.bounds + && self.params == other.params + && self.return_t_spec == other.return_t_spec + && self.captured_names == other.captured_names + && self.decorators.linear_eq(&other.decorators) + } +} + +impl Eq for SubrSignature {} + impl NestedDisplay for SubrSignature { fn fmt_nest(&self, f: &mut fmt::Formatter<'_>, _level: usize) -> fmt::Result { write!( diff --git a/crates/erg_compiler/lower.rs b/crates/erg_compiler/lower.rs index c4af5727..4d2d7160 100644 --- a/crates/erg_compiler/lower.rs +++ b/crates/erg_compiler/lower.rs @@ -874,7 +874,7 @@ impl GenericASTLowerer { }; if let Some(popped_val_t) = union.insert(key.t(), value.t()) { if PYTHON_MODE { - if let Some(val_t) = union.get_mut(key.ref_t()) { + if let Some(val_t) = union.linear_get_mut(key.ref_t()) { *val_t = self.module.context.union(&mem::take(val_t), &popped_val_t); } } else { diff --git a/crates/erg_compiler/module/cache.rs b/crates/erg_compiler/module/cache.rs index b69110a0..14bce069 100644 --- a/crates/erg_compiler/module/cache.rs +++ b/crates/erg_compiler/module/cache.rs @@ -443,6 +443,6 @@ impl SharedGeneralizationCache { } pub fn get(&self, key: &FreeTyVar) -> Option { - self.0.borrow().get(key).cloned() + self.0.borrow().linear_get(key).cloned() } } diff --git a/crates/erg_compiler/ty/deserialize.rs b/crates/erg_compiler/ty/deserialize.rs index c5a3d6bb..5834237d 100644 --- a/crates/erg_compiler/ty/deserialize.rs +++ b/crates/erg_compiler/ty/deserialize.rs @@ -140,7 +140,7 @@ impl Deserializer { } fn get_cached_arr(&mut self, arr: &[ValueObj]) -> ValueObj { - ValueObj::List(self.arr_cache.get(arr)) + ValueObj::List(self.arr_cache.linear_get(arr)) } pub fn vec_to_bytes(vector: Vec) -> [u8; LEN] { diff --git a/crates/erg_compiler/ty/free.rs b/crates/erg_compiler/ty/free.rs index 9a80e801..f93b3ff3 100644 --- a/crates/erg_compiler/ty/free.rs +++ b/crates/erg_compiler/ty/free.rs @@ -781,12 +781,12 @@ impl Free { } /// interior-mut - pub fn do_avoiding_recursion O>(&self, f: F) -> O { + pub fn do_avoiding_recursion(&self, f: impl FnOnce() -> O) -> O { self._do_avoiding_recursion(None, f) } /// interior-mut - pub fn do_avoiding_recursion_with O>(&self, placeholder: &Type, f: F) -> O { + pub fn do_avoiding_recursion_with(&self, placeholder: &Type, f: impl FnOnce() -> O) -> O { self._do_avoiding_recursion(Some(placeholder), f) } @@ -868,8 +868,10 @@ impl bool { if let (Some((l, r)), Some((l2, r2))) = (self.get_subsup(), other.get_subsup()) { self.dummy_link(); + other.dummy_link(); let res = l.structural_eq(&l2) && r.structural_eq(&r2); self.undo(); + other.undo(); res } else if let (Some(l), Some(r)) = (self.get_type(), other.get_type()) { l.structural_eq(&r) diff --git a/crates/erg_compiler/ty/mod.rs b/crates/erg_compiler/ty/mod.rs index adce5e1b..b45d65ba 100644 --- a/crates/erg_compiler/ty/mod.rs +++ b/crates/erg_compiler/ty/mod.rs @@ -288,7 +288,7 @@ impl ParamTy { } } - pub fn map_type(self, f: impl FnOnce(Type) -> Type) -> Self { + pub fn map_type(self, f: &mut impl FnMut(Type) -> Type) -> Self { match self { Self::Pos(ty) => Self::Pos(f(ty)), Self::Kw { name, ty } => Self::Kw { name, ty: f(ty) }, @@ -300,7 +300,7 @@ impl ParamTy { } } - pub fn map_default_type(self, f: impl FnOnce(Type) -> Type) -> Self { + pub fn map_default_type(self, f: &mut impl FnMut(Type) -> Type) -> Self { match self { Self::KwWithDefault { name, ty, default } => Self::KwWithDefault { name, @@ -581,7 +581,7 @@ impl SubrType { || self.return_t.contains_tp(target) } - pub fn map(self, f: impl Fn(Type) -> Type + Copy) -> Self { + pub fn map(self, f: &mut impl FnMut(Type) -> Type) -> Self { Self::new( self.kind, self.non_default_params @@ -859,25 +859,25 @@ impl SubrType { let non_default_params = self .non_default_params .iter() - .map(|pt| pt.clone().map_type(|t| t.derefine())) + .map(|pt| pt.clone().map_type(&mut |t| t.derefine())) .collect(); let var_params = self .var_params .as_ref() - .map(|pt| pt.clone().map_type(|t| t.derefine())); + .map(|pt| pt.clone().map_type(&mut |t| t.derefine())); let default_params = self .default_params .iter() .map(|pt| { pt.clone() - .map_type(|t| t.derefine()) - .map_default_type(|t| t.derefine()) + .map_type(&mut |t| t.derefine()) + .map_default_type(&mut |t| t.derefine()) }) .collect(); let kw_var_params = self .kw_var_params .as_ref() - .map(|pt| pt.clone().map_type(|t| t.derefine())); + .map(|pt| pt.clone().map_type(&mut |t| t.derefine())); Self::new( self.kind, non_default_params, @@ -1016,13 +1016,34 @@ impl SubrType { } } -#[derive(Debug, Clone, PartialEq, Eq, Hash)] +#[derive(Debug, Clone, Hash)] pub enum RefineKind { Interval { min: TyParam, max: TyParam }, // e.g. {I: Int | I >= 2; I <= 10} 2..10 Enum(Set), // e.g. {I: Int | I == 1 or I == 2} {1, 2} Complex, } +impl PartialEq for RefineKind { + fn eq(&self, other: &Self) -> bool { + match (self, other) { + ( + Self::Interval { + min: lmin, + max: lmax, + }, + Self::Interval { + min: rmin, + max: rmax, + }, + ) => lmin == rmin && lmax == rmax, + (Self::Enum(lset), Self::Enum(rset)) => lset.linear_eq(rset), + (Self::Complex, Self::Complex) => true, + _ => false, + } + } +} +impl Eq for RefineKind {} + /// e.g. /// ```erg /// {I: Int | I >= 0} @@ -1483,8 +1504,8 @@ impl PartialEq for Type { (Self::NamedTuple(lhs), Self::NamedTuple(rhs)) => lhs == rhs, (Self::Refinement(l), Self::Refinement(r)) => l == r, (Self::Quantified(l), Self::Quantified(r)) => l == r, - (Self::And(_, _), Self::And(_, _)) => self.ands() == other.ands(), - (Self::Or(_, _), Self::Or(_, _)) => self.ors() == other.ors(), + (Self::And(_, _), Self::And(_, _)) => self.ands().linear_eq(&other.ands()), + (Self::Or(_, _), Self::Or(_, _)) => self.ors().linear_eq(&other.ors()), (Self::Not(l), Self::Not(r)) => l == r, ( Self::Poly { @@ -4160,7 +4181,7 @@ impl Type { } Self::NamedTuple(r) } - Self::Subr(subr) => Self::Subr(subr.map(|t| t.eliminate_recursion(target))), + Self::Subr(subr) => Self::Subr(subr.map(&mut |t| t.eliminate_recursion(target))), Self::Callable { param_ts, return_t } => { let param_ts = param_ts .into_iter() @@ -4228,20 +4249,20 @@ impl Type { .iter() .map(|pt| { pt.clone() - .map_type(|t| t.replace(&Self::Failure, &Self::Obj)) + .map_type(&mut |t| t.replace(&Self::Failure, &Self::Obj)) }) .collect(); let var_params = subr.var_params.as_ref().map(|pt| { pt.clone() - .map_type(|t| t.replace(&Self::Failure, &Self::Obj)) + .map_type(&mut |t| t.replace(&Self::Failure, &Self::Obj)) }); let default_params = subr .default_params .iter() .map(|pt| { pt.clone() - .map_type(|t| t.replace(&Self::Failure, &Self::Obj)) - .map_default_type(|t| { + .map_type(&mut |t| t.replace(&Self::Failure, &Self::Obj)) + .map_default_type(&mut |t| { let typ = pt.typ().clone().replace(&Self::Failure, &Self::Obj); t.replace(&Self::Failure, &typ) & typ }) @@ -4249,8 +4270,8 @@ impl Type { .collect(); let kw_var_params = subr.kw_var_params.as_ref().map(|pt| { pt.clone() - .map_type(|t| t.replace(&Self::Failure, &Self::Obj)) - .map_default_type(|t| { + .map_type(&mut |t| t.replace(&Self::Failure, &Self::Obj)) + .map_default_type(&mut |t| { let typ = pt.typ().clone().replace(&Self::Failure, &Self::Obj); t.replace(&Self::Failure, &typ) & typ }) @@ -4285,97 +4306,91 @@ impl Type { } } - /// Unlike `replace`, this does not make a look-up table. - fn _replace(mut self, target: &Type, to: &Type) -> Type { - if self.structural_eq(target) { - self = to.clone(); - } + fn map(self, f: &mut impl FnMut(Type) -> Type) -> Type { match self { - Self::FreeVar(fv) if fv.is_linked() => fv.unwrap_linked()._replace(target, to), + Self::FreeVar(fv) if fv.is_linked() => fv.unwrap_linked().map(f), Self::FreeVar(fv) => { let fv_clone = fv.deep_clone(); if let Some((sub, sup)) = fv_clone.get_subsup() { fv.dummy_link(); fv_clone.dummy_link(); - let sub = sub._replace(target, to); - let sup = sup._replace(target, to); + let sub = sub.map(f); + let sup = sup.map(f); fv.undo(); fv_clone.undo(); fv_clone.update_constraint(Constraint::new_sandwiched(sub, sup), true); } else if let Some(ty) = fv_clone.get_type() { - fv_clone - .update_constraint(Constraint::new_type_of(ty._replace(target, to)), true); + fv_clone.update_constraint(Constraint::new_type_of(ty.map(f)), true); } Self::FreeVar(fv_clone) } Self::Refinement(mut refine) => { - refine.t = Box::new(refine.t._replace(target, to)); - refine.pred = Box::new(refine.pred._replace_t(target, to)); + refine.t = Box::new(refine.t.map(f)); + refine.pred = Box::new(refine.pred.map_t(f)); Self::Refinement(refine) } Self::Record(mut rec) => { for v in rec.values_mut() { - *v = std::mem::take(v)._replace(target, to); + *v = std::mem::take(v).map(f); } Self::Record(rec) } Self::NamedTuple(mut r) => { for (_, v) in r.iter_mut() { - *v = std::mem::take(v)._replace(target, to); + *v = std::mem::take(v).map(f); } Self::NamedTuple(r) } - Self::Subr(subr) => Self::Subr(subr._replace(target, to)), + Self::Subr(subr) => Self::Subr(subr.map(f)), Self::Callable { param_ts, return_t } => { - let param_ts = param_ts - .into_iter() - .map(|t| t._replace(target, to)) - .collect(); - let return_t = Box::new(return_t._replace(target, to)); + let param_ts = param_ts.into_iter().map(|t| t.map(f)).collect(); + let return_t = Box::new(return_t.map(f)); Self::Callable { param_ts, return_t } } - Self::Quantified(quant) => quant._replace(target, to).quantify(), + Self::Quantified(quant) => quant.map(f).quantify(), Self::Poly { name, params } => { - let params = params - .into_iter() - .map(|tp| tp.replace_t(target, to)) - .collect(); + let params = params.into_iter().map(|tp| tp.map_t(f)).collect(); Self::Poly { name, params } } - Self::Ref(t) => Self::Ref(Box::new(t._replace(target, to))), + Self::Ref(t) => Self::Ref(Box::new(t.map(f))), Self::RefMut { before, after } => Self::RefMut { - before: Box::new(before._replace(target, to)), - after: after.map(|t| Box::new(t._replace(target, to))), + before: Box::new(before.map(f)), + after: after.map(|t| Box::new(t.map(f))), }, - Self::And(l, r) => l._replace(target, to) & r._replace(target, to), - Self::Or(l, r) => l._replace(target, to) | r._replace(target, to), - Self::Not(ty) => !ty._replace(target, to), - Self::Proj { lhs, rhs } => lhs._replace(target, to).proj(rhs), + Self::And(l, r) => l.map(f) & r.map(f), + Self::Or(l, r) => l.map(f) | r.map(f), + Self::Not(ty) => !ty.map(f), + Self::Proj { lhs, rhs } => lhs.map(f).proj(rhs), Self::ProjCall { lhs, attr_name, args, } => { - let args = args - .into_iter() - .map(|tp| tp.replace_t(target, to)) - .collect(); - proj_call(lhs.replace_t(target, to), attr_name, args) + let args = args.into_iter().map(|tp| tp.map_t(f)).collect(); + proj_call(lhs.map_t(f), attr_name, args) } - Self::Structural(ty) => ty._replace(target, to).structuralize(), + Self::Structural(ty) => ty.map(f).structuralize(), Self::Guard(guard) => Self::Guard(GuardType::new( guard.namespace, guard.target.clone(), - guard.to._replace(target, to), + guard.to.map(f), )), Self::Bounded { sub, sup } => Self::Bounded { - sub: Box::new(sub._replace(target, to)), - sup: Box::new(sup._replace(target, to)), + sub: Box::new(sub.map(f)), + sup: Box::new(sup.map(f)), }, mono_type_pattern!() => self, } } + /// Unlike `replace`, this does not make a look-up table. + fn _replace(mut self, target: &Type, to: &Type) -> Type { + if self.structural_eq(target) { + self = to.clone(); + } + self.map(&mut |t| t._replace(target, to)) + } + fn _replace_tp(self, target: &TyParam, to: &TyParam) -> Type { match self { Self::FreeVar(fv) if fv.is_linked() => fv.unwrap_linked()._replace_tp(target, to), diff --git a/crates/erg_compiler/ty/typaram.rs b/crates/erg_compiler/ty/typaram.rs index efdf1d79..146db392 100644 --- a/crates/erg_compiler/ty/typaram.rs +++ b/crates/erg_compiler/ty/typaram.rs @@ -298,7 +298,7 @@ impl PartialEq for TyParam { (Self::List(l), Self::List(r)) => l == r, (Self::UnsizedList(l), Self::UnsizedList(r)) => l == r, (Self::Tuple(l), Self::Tuple(r)) => l == r, - (Self::Dict(l), Self::Dict(r)) => l == r, + (Self::Dict(l), Self::Dict(r)) => l.linear_eq(r), (Self::Record(l), Self::Record(r)) => l == r, ( Self::DataClass { @@ -310,7 +310,7 @@ impl PartialEq for TyParam { fields: rfs, }, ) => ln == rn && lfs == rfs, - (Self::Set(l), Self::Set(r)) => l == r, + (Self::Set(l), Self::Set(r)) => l.linear_eq(r), (Self::Lambda(l), Self::Lambda(r)) => l == r, (Self::Mono(l), Self::Mono(r)) => l == r, ( diff --git a/crates/erg_compiler/ty/value.rs b/crates/erg_compiler/ty/value.rs index d2df2b51..e95d55d9 100644 --- a/crates/erg_compiler/ty/value.rs +++ b/crates/erg_compiler/ty/value.rs @@ -641,7 +641,7 @@ impl Rem for Float { /// 値オブジェクト /// コンパイル時評価ができ、シリアライズも可能(Typeなどはシリアライズ不可) -#[derive(Clone, PartialEq, Default, Hash)] +#[derive(Clone, Default, Hash)] pub enum ValueObj { Int(i32), Nat(u64), @@ -970,6 +970,44 @@ impl LimitedDisplay for ValueObj { } } +impl PartialEq for ValueObj { + fn eq(&self, other: &Self) -> bool { + match (self, other) { + (Self::Int(i1), Self::Int(i2)) => i1 == i2, + (Self::Nat(n1), Self::Nat(n2)) => n1 == n2, + (Self::Float(f1), Self::Float(f2)) => f1 == f2, + (Self::Str(s1), Self::Str(s2)) => s1 == s2, + (Self::Bool(b1), Self::Bool(b2)) => b1 == b2, + (Self::List(l1), Self::List(l2)) => l1 == l2, + (Self::UnsizedList(l1), Self::UnsizedList(l2)) => l1 == l2, + (Self::Set(s1), Self::Set(s2)) => s1.linear_eq(s2), + (Self::Dict(d1), Self::Dict(d2)) => d1.linear_eq(d2), + (Self::Tuple(t1), Self::Tuple(t2)) => t1 == t2, + (Self::Record(r1), Self::Record(r2)) => r1 == r2, + ( + Self::DataClass { + name: n1, + fields: f1, + }, + Self::DataClass { + name: n2, + fields: f2, + }, + ) => n1 == n2 && f1 == f2, + (Self::Code(c1), Self::Code(c2)) => c1 == c2, + (Self::Subr(s1), Self::Subr(s2)) => s1 == s2, + (Self::Type(t1), Self::Type(t2)) => t1 == t2, + (Self::None, Self::None) + | (Self::Ellipsis, Self::Ellipsis) + | (Self::NotImplemented, Self::NotImplemented) + | (Self::NegInf, Self::NegInf) + | (Self::Inf, Self::Inf) + | (Self::Failure, Self::Failure) => true, + _ => false, + } + } +} + impl Eq for ValueObj {} impl Neg for ValueObj { diff --git a/crates/erg_compiler/ty/vis.rs b/crates/erg_compiler/ty/vis.rs index 498703a8..18ff69b9 100644 --- a/crates/erg_compiler/ty/vis.rs +++ b/crates/erg_compiler/ty/vis.rs @@ -4,6 +4,7 @@ use std::fmt; #[allow(unused_imports)] use erg_common::log; use erg_common::set::Set; +use erg_common::traits::Immutable; use erg_common::{switch_lang, Str}; use erg_parser::ast::AccessModifier; @@ -170,6 +171,8 @@ pub struct Field { pub symbol: Str, } +impl Immutable for Field {} + impl PartialEq for Field { fn eq(&self, other: &Self) -> bool { self.symbol == other.symbol diff --git a/crates/erg_compiler/varinfo.rs b/crates/erg_compiler/varinfo.rs index 5e64419f..343c3cc2 100644 --- a/crates/erg_compiler/varinfo.rs +++ b/crates/erg_compiler/varinfo.rs @@ -4,6 +4,7 @@ use std::path::Path; use erg_common::error::Location; use erg_common::pathutil::NormalizedPathBuf; use erg_common::set::Set; +use erg_common::traits::Immutable; use erg_common::{switch_lang, Str}; use erg_parser::ast::DefId; @@ -141,6 +142,8 @@ pub struct AbsLocation { pub loc: Location, } +impl Immutable for AbsLocation {} + impl fmt::Display for AbsLocation { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { if let Some(module) = &self.module { diff --git a/crates/erg_parser/ast.rs b/crates/erg_parser/ast.rs index 8bc4fc3f..93b22dea 100644 --- a/crates/erg_parser/ast.rs +++ b/crates/erg_parser/ast.rs @@ -8,7 +8,7 @@ use erg_common::error::Location; use erg_common::io::Input; use erg_common::set::Set as HashSet; // use erg_common::dict::Dict as HashMap; -use erg_common::traits::{Locational, NestedDisplay, Stream}; +use erg_common::traits::{Immutable, Locational, NestedDisplay, Stream}; use erg_common::{ fmt_option, fmt_vec, impl_display_for_enum, impl_display_from_nested, impl_displayable_stream_for_wrapper, impl_from_trait_for_enum, impl_locational, @@ -3806,6 +3806,8 @@ impl Locational for TypeBoundSpecs { #[derive(Debug, Clone, PartialEq, Eq, Hash)] pub struct Decorator(pub Expr); +impl Immutable for Decorator {} + impl Decorator { pub const fn new(expr: Expr) -> Self { Self(expr) @@ -3825,6 +3827,8 @@ impl Decorator { #[derive(Debug, Clone, Eq)] pub struct VarName(Token); +impl Immutable for VarName {} + impl PartialEq for VarName { fn eq(&self, other: &Self) -> bool { self.0 == other.0