fix: dict update! bug

This commit is contained in:
Shunsuke Shibayama 2023-10-14 21:38:12 +09:00
parent a8c1113df7
commit 0a24c0cb77
15 changed files with 97 additions and 23 deletions

View file

@ -2474,7 +2474,8 @@ impl PyCodeGenerator {
fn emit_call_method(&mut self, obj: Expr, method_name: Identifier, args: Args) {
log!(info "entered {}", fn_name!());
match &method_name.inspect()[..] {
"update!" => {
// mut value class `update!` can be optimized
"update!" if method_name.ref_t().self_t().is_some_and(|t| t.ref_mut_inner().is_some_and(|t| t.is_mut_value_class())) => {
if self.py_version.minor >= Some(11) {
return self.emit_call_update_311(obj, args);
} else {

View file

@ -1524,6 +1524,18 @@ impl Context {
}
}
fn _eval_tp_into_value(&self, tp: TyParam) -> EvalResult<ValueObj> {
self.eval_tp(tp).and_then(|tp| {
self.convert_tp_into_value(tp).map_err(|_| {
EvalErrors::from(EvalError::unreachable(
self.cfg.input.clone(),
fn_name!(),
line!(),
))
})
})
}
/// Evaluate `substituted`.
/// If the evaluation fails, return a harmless type (filled with `Failure`) and errors
pub(crate) fn eval_t_params(

View file

@ -2244,11 +2244,12 @@ impl Context {
vec![],
NoneType,
);
float_mut_mutable.register_builtin_erg_impl(
float_mut_mutable.register_builtin_py_impl(
PROC_UPDATE,
t,
Immutable,
Visibility::BUILTIN_PUBLIC,
Some(FUNC_UPDATE),
);
float_mut.register_trait(mono(MUT_FLOAT), float_mut_mutable);
/* Ratio! */
@ -2268,11 +2269,12 @@ impl Context {
vec![],
NoneType,
);
ratio_mut_mutable.register_builtin_erg_impl(
ratio_mut_mutable.register_builtin_py_impl(
PROC_UPDATE,
t,
Immutable,
Visibility::BUILTIN_PUBLIC,
Some(FUNC_UPDATE),
);
ratio_mut.register_trait(mono(MUT_RATIO), ratio_mut_mutable);
/* Int! */
@ -2308,11 +2310,12 @@ impl Context {
vec![],
NoneType,
);
int_mut_mutable.register_builtin_erg_impl(
int_mut_mutable.register_builtin_py_impl(
PROC_UPDATE,
t,
Immutable,
Visibility::BUILTIN_PUBLIC,
Some(FUNC_UPDATE),
);
int_mut.register_trait(mono(MUT_INT), int_mut_mutable);
let mut nat_mut = Self::builtin_mono_class(MUT_NAT, 2);
@ -2333,11 +2336,12 @@ impl Context {
vec![],
NoneType,
);
nat_mut_mutable.register_builtin_erg_impl(
nat_mut_mutable.register_builtin_py_impl(
PROC_UPDATE,
t,
Immutable,
Visibility::BUILTIN_PUBLIC,
Some(FUNC_UPDATE),
);
nat_mut.register_trait(mono(MUT_NAT), nat_mut_mutable);
/* Bool! */
@ -2358,11 +2362,12 @@ impl Context {
vec![],
NoneType,
);
bool_mut_mutable.register_builtin_erg_impl(
bool_mut_mutable.register_builtin_py_impl(
PROC_UPDATE,
t,
Immutable,
Visibility::BUILTIN_PUBLIC,
Some(FUNC_UPDATE),
);
bool_mut.register_trait(mono(MUT_BOOL), bool_mut_mutable);
let t = pr0_met(mono(MUT_BOOL), NoneType);
@ -2390,11 +2395,12 @@ impl Context {
vec![],
NoneType,
);
str_mut_mutable.register_builtin_erg_impl(
str_mut_mutable.register_builtin_py_impl(
PROC_UPDATE,
t,
Immutable,
Visibility::BUILTIN_PUBLIC,
Some(FUNC_UPDATE),
);
str_mut.register_trait(mono(MUT_STR), str_mut_mutable);
let t = pr_met(
@ -2626,11 +2632,12 @@ impl Context {
)
.quantify();
let mut array_mut_mutable = Self::builtin_methods(Some(mono(MUTABLE)), 2);
array_mut_mutable.register_builtin_erg_impl(
array_mut_mutable.register_builtin_py_impl(
PROC_UPDATE,
t,
Immutable,
Visibility::BUILTIN_PUBLIC,
Some(FUNC_UPDATE),
);
array_mut_.register_trait(array_mut_t.clone(), array_mut_mutable);
/* ByteArray! */
@ -2766,7 +2773,10 @@ impl Context {
poly(ITERABLE, vec![ty_tp(tuple_t(vec![K.clone(), V.clone()]))]),
)],
None,
vec![],
vec![kw(
KW_CONFLICT_RESOLVER,
func2(V.clone(), V.clone(), V.clone()),
)],
NoneType,
)
.quantify();

View file

@ -6,6 +6,7 @@ use erg_common::dict::Dict;
use erg_common::log;
use erg_common::{dict, set};
use crate::context::eval::UndoableLinkedList;
use crate::context::Context;
use crate::feature_error;
use crate::ty::constructors::{and, mono, tuple_t, v_enum};
@ -237,7 +238,8 @@ pub(crate) fn sub_vdict_get<'d>(
}
}
for (idx, kt, v) in matches.into_iter() {
match ctx.sub_unify(idx.typ(), kt.typ(), &(), None) {
let list = UndoableLinkedList::new();
match ctx.undoable_sub_unify(idx.typ(), kt.typ(), &(), &list, None) {
Ok(_) => {
return Some(v);
}
@ -269,7 +271,8 @@ pub(crate) fn sub_tpdict_get<'d>(
}
}
for (idx, kt, v) in matches.into_iter() {
match ctx.sub_unify(idx, kt, &(), None) {
let list = UndoableLinkedList::new();
match ctx.undoable_sub_unify(idx, kt, &(), &list, None) {
Ok(_) => {
return Some(v);
}

View file

@ -562,6 +562,7 @@ const KW_OFFSET: &str = "offset";
const KW_WHENCE: &str = "whence";
const KW_CHARS: &str = "chars";
const KW_OTHER: &str = "other";
const KW_CONFLICT_RESOLVER: &str = "conflict_resolver";
pub fn builtins_path() -> PathBuf {
erg_pystd_path().join("builtins.d.er")

View file

@ -1005,8 +1005,7 @@ impl<'c, 'l, 'u, L: Locational> Unifier<'c, 'l, 'u, L> {
if sup.contains_union(&new_sub) {
maybe_sup.link(&new_sub, self.undoable); // Bool <: ?T <: Bool or Y ==> ?T == Bool
} else {
let constr = Constraint::new_sandwiched(new_sub, mem::take(&mut sup));
maybe_sup.update_constraint(constr, self.undoable, true);
maybe_sup.update_tyvar(new_sub, mem::take(&mut sup), self.undoable, true);
}
}
// sub_unify(Nat, ?T(: Type)): (/* ?T(:> Nat) */)
@ -1072,8 +1071,7 @@ impl<'c, 'l, 'u, L: Locational> Unifier<'c, 'l, 'u, L> {
{
maybe_sub.link(&sub, self.undoable);
} else {
let constr = Constraint::new_sandwiched(sub, new_sup);
maybe_sub.update_constraint(constr, self.undoable, true);
maybe_sub.update_tyvar(sub, new_sup, self.undoable, true);
}
}
// sub_unify(?T(: Type), Int): (?T(<: Int))
@ -1570,6 +1568,7 @@ impl Context {
unifier.sub_unify_tp(maybe_sub, maybe_sup, variance, is_structural)
}
/// Use `undoable_sub_unify` to temporarily impose type constraints.
pub(crate) fn sub_unify(
&self,
maybe_sub: &Type,

View file

@ -9,24 +9,38 @@ dict = pyimport "Dict"
dic.insert!("b", 2)
assert dic == {"a": 1, "b": 2}
'''
insert!: |K, V|(self: .Dict!(K, V), key: K, value: V) => NoneType
insert!: |K, V|(self: RefMut(.Dict!(K, V)), key: K, value: V) => NoneType
'''erg
dic = !{"a": 1}
x = dic.remove!("a")
assert dic == {}
assert x == 1
'''
remove!: |K, V|(self: .Dict!(K, V), key: K) => V or NoneType
remove!: |K, V|(self: RefMut(.Dict!(K, V)), key: K) => V or NoneType
'''
Update the dictionary with the key-value pairs in `other`.
If `conflict_resolver` is specified, it will be called with the two values and store as a value when there is a conflict.
Otherwise, the value in `other` will be used.
'''
'''erg
dic = !{"a": 1}
dic.update!({"b": 2})
dic.update!([("c", 3)])
assert dic == {"a": 1, "b": 2, "c": 3}
dic.update!({"b": 3}, confilct_resolver := (x, y) -> x + y)
assert dic == {"a": 1, "b": 5, "c": 3}
'''
update!: |K, V|(self: RefMut(.Dict!(K, V)), other: Iterable([K, V]), confilct_resolver := (V, V) -> V) => NoneType
'''
Merge two dictionaries.
If `conflict_resolver` is specified, it will be called with the two values and store as a value when there is a conflict.
Otherwise, the value in `other` will be used.
'''
update!: |K, V|(self: .Dict!(K, V), other: Iterable([K, V])) => NoneType
'''erg
dic = !{"a": 1}
dic.merge!({"b": 2})
assert dic == {"a": 1, "b": 2}
dic.merge!({"b": 3}, confilct_resolver := (x, y) -> x + y)
assert dic == {"a": 1, "b": 5}
'''
merge!: |K, V|(self: .Dict!(K, V), other: .Dict!(K, V)) => NoneType
merge!: |K, V|(self: RefMut(.Dict!(K, V)), other: .Dict!(K, V), confilct_resolver := (V, V) -> V) => NoneType

View file

@ -71,6 +71,9 @@ class Array(list):
def __hash__(self):
return hash(tuple(self))
def update(self, f):
self = Array(f(self))
def type_check(self, t: type) -> bool:
if isinstance(t, list):
if len(t) < len(self):

View file

@ -53,5 +53,8 @@ class BoolMut(NatMut):
else:
return self.value != other.value
def update(self, f):
self.value = Bool(f(self.value))
def invert(self):
self.value = self.value.invert()

View file

@ -4,11 +4,20 @@ class Dict(dict):
def diff(self, other):
return Dict({k: v for k, v in self.items() if k not in other})
# other: Iterable
def extend(self, other):
self.update(other)
def update(self, other, conflict_resolver=None):
if conflict_resolver == None:
super().update(other)
elif isinstance(other, dict):
self.merge(other, conflict_resolver)
else:
for k, v in other:
if k in self:
self[k] = conflict_resolver(self[k], v)
else:
self[k] = v
# other: Dict
def merge(self, other):
self.update(other)
def merge(self, other, conflict_resolver=None):
self.update(other, conflict_resolver)
def insert(self, key, value):
self[key] = value
def remove(self, key):

View file

@ -145,3 +145,6 @@ class FloatMut: # inherits Float
def __neg__(self):
return FloatMut(-self.value)
def update(self, f):
self.value = Float(f(self.value))

View file

@ -153,6 +153,9 @@ class IntMut: # inherits Int
def __neg__(self):
return IntMut(-self.value)
def update(self, f):
self.value = Int(f(self.value))
def inc(self, i=1):
self.value = Int(self.value + i)

View file

@ -123,6 +123,9 @@ class NatMut(IntMut): # and Nat
def __pos__(self):
return self
def update(self, f):
self.value = Nat(f(self.value))
def try_new(i): # -> Result[Nat]
if i >= 0:
return NatMut(i)

View file

@ -76,6 +76,9 @@ class StrMut: # Inherits Str
else:
return self.value != other.value
def update(self, f):
self.value = Str(f(self.value))
def try_new(s: str):
if isinstance(s, str):
self = StrMut()

View file

@ -12,3 +12,10 @@ for! immut_dict.items(), ((k, v),) =>
print! k, v
_ = immut_dict.copy()
mut_dict = !{ "a": 1 }
mut_dict.update! [("a", 2)], (a, b) -> a + b
assert mut_dict == {"a": 3}
mut_dict.insert! "b", 4
i = mut_dict.remove! "a"
assert i == 3