fix: mutable object class bug

This commit is contained in:
Shunsuke Shibayama 2024-02-08 03:04:33 +09:00
parent fce88717b0
commit 3eec9ed590
7 changed files with 135 additions and 117 deletions

View file

@ -1,7 +1,7 @@
from _erg_nat import Nat from _erg_nat import Nat
from _erg_nat import NatMut from _erg_nat import NatMut
from _erg_result import Error from _erg_result import Error
from _erg_type import MutType
class Bool(Nat): class Bool(Nat):
def try_new(b: bool): # -> Result[Nat] def try_new(b: bool): # -> Result[Nat]
@ -42,16 +42,16 @@ class BoolMut(NatMut):
return self.value.__hash__() return self.value.__hash__()
def __eq__(self, other): def __eq__(self, other):
if isinstance(other, bool): if isinstance(other, MutType):
return self.value == other
else:
return self.value == other.value return self.value == other.value
else:
return self.value == other
def __ne__(self, other): def __ne__(self, other):
if isinstance(other, bool): if isinstance(other, MutType):
return self.value != other
else:
return self.value != other.value return self.value != other.value
else:
return self.value != other
def update(self, f): def update(self, f):
self.value = Bool(f(self.value)) self.value = Bool(f(self.value))

View file

@ -1,6 +1,6 @@
from _erg_result import Error from _erg_result import Error
from _erg_control import then__ from _erg_control import then__
from _erg_type import MutType
class Float(float): class Float(float):
EPSILON = 2.220446049250313e-16 EPSILON = 2.220446049250313e-16
@ -32,6 +32,9 @@ class Float(float):
def __floordiv__(self, other): def __floordiv__(self, other):
return then__(float.__floordiv__(self, other), Float) return then__(float.__floordiv__(self, other), Float)
def __truediv__(self, other):
return then__(float.__truediv__(self, other), Float)
def __pow__(self, other): def __pow__(self, other):
return then__(float.__pow__(self, other), Float) return then__(float.__pow__(self, other), Float)
@ -47,7 +50,7 @@ class Float(float):
def nearly_eq(self, other, epsilon=EPSILON): def nearly_eq(self, other, epsilon=EPSILON):
return abs(self - other) < epsilon return abs(self - other) < epsilon
class FloatMut: # inherits Float class FloatMut(MutType): # inherits Float
value: Float value: Float
EPSILON = 2.220446049250313e-16 EPSILON = 2.220446049250313e-16
@ -64,71 +67,80 @@ class FloatMut: # inherits Float
def __deref__(self): def __deref__(self):
return self.value return self.value
def __float__(self):
return self.value.__float__()
def __eq__(self, other): def __eq__(self, other):
if isinstance(other, Float): if isinstance(other, MutType):
return self.value == other
else:
return self.value == other.value return self.value == other.value
else:
return self.value == other
def __ne__(self, other): def __ne__(self, other):
if isinstance(other, Float): if isinstance(other, MutType):
return self.value != other
else:
return self.value != other.value return self.value != other.value
else:
return self.value != other
def __le__(self, other): def __le__(self, other):
if isinstance(other, Float): if isinstance(other, MutType):
return self.value <= other
else:
return self.value <= other.value return self.value <= other.value
else:
return self.value <= other
def __ge__(self, other): def __ge__(self, other):
if isinstance(other, Float): if isinstance(other, MutType):
return self.value >= other
else:
return self.value >= other.value return self.value >= other.value
else:
return self.value >= other
def __lt__(self, other): def __lt__(self, other):
if isinstance(other, Float): if isinstance(other, MutType):
return self.value < other
else:
return self.value < other.value return self.value < other.value
else:
return self.value < other
def __gt__(self, other): def __gt__(self, other):
if isinstance(other, Float): if isinstance(other, MutType):
return self.value > other
else:
return self.value > other.value return self.value > other.value
else:
return self.value > other
def __add__(self, other): def __add__(self, other):
if isinstance(other, Float): if isinstance(other, MutType):
return FloatMut(self.value + other)
else:
return FloatMut(self.value + other.value) return FloatMut(self.value + other.value)
else:
return FloatMut(self.value + other)
def __sub__(self, other): def __sub__(self, other):
if isinstance(other, Float): if isinstance(other, MutType):
return FloatMut(self.value - other)
else:
return FloatMut(self.value - other.value) return FloatMut(self.value - other.value)
else:
return FloatMut(self.value - other)
def __mul__(self, other): def __mul__(self, other):
if isinstance(other, Float): if isinstance(other, MutType):
return FloatMut(self.value * other)
else:
return FloatMut(self.value * other.value) return FloatMut(self.value * other.value)
else:
return FloatMut(self.value * other)
def __floordiv__(self, other): def __floordiv__(self, other):
if isinstance(other, Float): if isinstance(other, MutType):
return FloatMut(self.value // other)
else:
return FloatMut(self.value // other.value) return FloatMut(self.value // other.value)
else:
return FloatMut(self.value // other)
def __truediv__(self, other):
if isinstance(other, MutType):
return FloatMut(self.value / other.value)
else:
return FloatMut(self.value / other)
def __pow__(self, other): def __pow__(self, other):
if isinstance(other, Float): if isinstance(other, MutType):
return FloatMut(self.value**other)
else:
return FloatMut(self.value**other.value) return FloatMut(self.value**other.value)
else:
return FloatMut(self.value**other)
def __pos__(self): def __pos__(self):
return self return self

View file

@ -1,6 +1,6 @@
from _erg_result import Error from _erg_result import Error
from _erg_control import then__ from _erg_control import then__
from _erg_type import MutType
class Int(int): class Int(int):
def try_new(i): # -> Result[Nat] def try_new(i): # -> Result[Nat]
@ -52,7 +52,7 @@ class Int(int):
return then__(int.__neg__(self), Int) return then__(int.__neg__(self), Int)
class IntMut: # inherits Int class IntMut(MutType): # inherits Int
value: Int value: Int
def __init__(self, i): def __init__(self, i):
@ -67,70 +67,70 @@ class IntMut: # inherits Int
return self.value.__hash__() return self.value.__hash__()
def __eq__(self, other): def __eq__(self, other):
if isinstance(other, Int): if isinstance(other, MutType):
return self.value == other
else:
return self.value == other.value return self.value == other.value
else:
return self.value == other
def __ne__(self, other): def __ne__(self, other):
if isinstance(other, Int): if isinstance(other, MutType):
return self.value != other
else:
return self.value != other.value return self.value != other.value
else:
return self.value != other
def __le__(self, other): def __le__(self, other):
if isinstance(other, Int): if isinstance(other, MutType):
return self.value <= other
else:
return self.value <= other.value return self.value <= other.value
else:
return self.value <= other
def __ge__(self, other): def __ge__(self, other):
if isinstance(other, Int): if isinstance(other, MutType):
return self.value >= other
else:
return self.value >= other.value return self.value >= other.value
else:
return self.value >= other
def __lt__(self, other): def __lt__(self, other):
if isinstance(other, Int): if isinstance(other, MutType):
return self.value < other
else:
return self.value < other.value return self.value < other.value
else:
return self.value < other
def __gt__(self, other): def __gt__(self, other):
if isinstance(other, Int): if isinstance(other, MutType):
return self.value > other
else:
return self.value > other.value return self.value > other.value
else:
return self.value > other
def __add__(self, other): def __add__(self, other):
if isinstance(other, Int): if isinstance(other, MutType):
return IntMut(self.value + other)
else:
return IntMut(self.value + other.value) return IntMut(self.value + other.value)
else:
return IntMut(self.value + other)
def __sub__(self, other): def __sub__(self, other):
if isinstance(other, Int): if isinstance(other, MutType):
return IntMut(self.value - other)
else:
return IntMut(self.value - other.value) return IntMut(self.value - other.value)
else:
return IntMut(self.value - other)
def __mul__(self, other): def __mul__(self, other):
if isinstance(other, Int): if isinstance(other, MutType):
return IntMut(self.value * other)
else:
return IntMut(self.value * other.value) return IntMut(self.value * other.value)
else:
return IntMut(self.value * other)
def __floordiv__(self, other): def __floordiv__(self, other):
if isinstance(other, Int): if isinstance(other, MutType):
return IntMut(self.value // other)
else:
return IntMut(self.value // other.value) return IntMut(self.value // other.value)
else:
return IntMut(self.value // other)
def __pow__(self, other): def __pow__(self, other):
if isinstance(other, Int): if isinstance(other, MutType):
return IntMut(self.value**other)
else:
return IntMut(self.value**other.value) return IntMut(self.value**other.value)
else:
return IntMut(self.value**other)
def __pos__(self): def __pos__(self):
return self return self

View file

@ -2,6 +2,7 @@ from _erg_result import Error
from _erg_int import Int from _erg_int import Int
from _erg_int import IntMut # don't unify with the above line from _erg_int import IntMut # don't unify with the above line
from _erg_control import then__ from _erg_control import then__
from _erg_type import MutType
class Nat(Int): class Nat(Int):
def __init__(self, i): def __init__(self, i):
@ -55,70 +56,70 @@ class NatMut(IntMut): # and Nat
return self.value.__hash__() return self.value.__hash__()
def __eq__(self, other): def __eq__(self, other):
if isinstance(other, int): if isinstance(other, MutType):
return self.value == other
else:
return self.value == other.value return self.value == other.value
else:
return self.value == other
def __ne__(self, other): def __ne__(self, other):
if isinstance(other, int): if isinstance(other, MutType):
return self.value != other
else:
return self.value != other.value return self.value != other.value
else:
return self.value != other
def __le__(self, other): def __le__(self, other):
if isinstance(other, int): if isinstance(other, MutType):
return self.value <= other
else:
return self.value <= other.value return self.value <= other.value
else:
return self.value <= other
def __ge__(self, other): def __ge__(self, other):
if isinstance(other, int): if isinstance(other, MutType):
return self.value >= other
else:
return self.value >= other.value return self.value >= other.value
else:
return self.value >= other
def __lt__(self, other): def __lt__(self, other):
if isinstance(other, int): if isinstance(other, MutType):
return self.value < other
else:
return self.value < other.value return self.value < other.value
else:
return self.value < other
def __gt__(self, other): def __gt__(self, other):
if isinstance(other, int): if isinstance(other, MutType):
return self.value > other
else:
return self.value > other.value return self.value > other.value
else:
return self.value > other
def __add__(self, other): def __add__(self, other):
if isinstance(other, Nat): if isinstance(other, MutType):
return NatMut(self.value + other)
else:
return NatMut(self.value + other.value) return NatMut(self.value + other.value)
else:
return NatMut(self.value + other)
def __radd__(self, other): def __radd__(self, other):
if isinstance(other, Nat): if isinstance(other, MutType):
return Nat(other + self.value)
else:
return Nat(other.value + self.value) return Nat(other.value + self.value)
else:
return Nat(other + self.value)
def __mul__(self, other): def __mul__(self, other):
if isinstance(other, Nat): if isinstance(other, MutType):
return NatMut(self.value * other)
else:
return NatMut(self.value * other.value) return NatMut(self.value * other.value)
else:
return NatMut(self.value * other)
def __rmul__(self, other): def __rmul__(self, other):
if isinstance(other, Nat): if isinstance(other, MutType):
return Nat(other * self.value)
else:
return Nat(other.value * self.value) return Nat(other.value * self.value)
else:
return Nat(other * self.value)
def __pow__(self, other): def __pow__(self, other):
if isinstance(other, Nat): if isinstance(other, MutType):
return NatMut(self.value**other)
else:
return NatMut(self.value**other.value) return NatMut(self.value**other.value)
else:
return NatMut(self.value**other)
def __pos__(self): def __pos__(self):
return self return self

View file

@ -1,6 +1,7 @@
from _erg_result import Error from _erg_result import Error
from _erg_int import Int from _erg_int import Int
from _erg_control import then__ from _erg_control import then__
from _erg_type import MutType
class Str(str): class Str(str):
def __instancecheck__(cls, obj): def __instancecheck__(cls, obj):
@ -46,7 +47,7 @@ class Str(str):
return str.__getitem__(self, index_or_slice) return str.__getitem__(self, index_or_slice)
class StrMut: # Inherits Str class StrMut(MutType): # Inherits Str
value: Str value: Str
def __init__(self, s: str): def __init__(self, s: str):
@ -62,16 +63,16 @@ class StrMut: # Inherits Str
return self.value.__hash__() return self.value.__hash__()
def __eq__(self, other): def __eq__(self, other):
if isinstance(other, Str): if isinstance(other, MutType):
return self.value == other
else:
return self.value == other.value return self.value == other.value
else:
return self.value == other
def __ne__(self, other): def __ne__(self, other):
if isinstance(other, Str): if isinstance(other, MutType):
return self.value != other
else:
return self.value != other.value return self.value != other.value
else:
return self.value != other
def update(self, f): def update(self, f):
self.value = Str(f(self.value)) self.value = Str(f(self.value))

View file

@ -50,3 +50,6 @@ def _isinstance(obj, classinfo) -> bool:
return isinstance(obj, classinfo) return isinstance(obj, classinfo)
except: except:
return False return False
class MutType:
value: object

View file

@ -421,6 +421,7 @@ impl PyScriptGenerator {
.replace("from _erg_type import is_type", "") .replace("from _erg_type import is_type", "")
.replace("from _erg_type import _isinstance", "") .replace("from _erg_type import _isinstance", "")
.replace("from _erg_type import UnionType", "") .replace("from _erg_type import UnionType", "")
.replace("from _erg_type import MutType", "")
} }
fn load_namedtuple_if_not(&mut self) { fn load_namedtuple_if_not(&mut self) {