fix: external type specification bug

This commit is contained in:
Shunsuke Shibayama 2023-12-13 23:33:06 +09:00
parent 83cd92bb48
commit 905a03d578
10 changed files with 153 additions and 64 deletions

View file

@ -3,3 +3,22 @@
.util = pyimport "./util"
{.load!; .save!;} = pyimport "./serialization"
.UInt8 = 'uint8': ClassType
.Int8 = 'int8': ClassType
.Int16 = 'int16': ClassType
.Int32 = 'int32': ClassType
.Int64 = 'int64': ClassType
.Float16 = 'float16': ClassType
.Float32 = 'float32': ClassType
.Float64 = 'float64': ClassType
.Complex32 = 'complex32': ClassType
.Complex64 = 'complex64': ClassType
.Complex128 = 'complex128': ClassType
.DType = 'dtype': ClassType
.Size: ClassType
.Tensor: (T: Type, Shape: [Nat; _]) -> ClassType
.Tensor.
dtype: .DType
shape: .Size

View file

@ -1,6 +1,8 @@
torch = pyimport "torch"
dataset = pyimport "./dataset"
.DataLoader: ClassType
.DataLoader <: Iterable((torch.Tensor(_, _), torch.Tensor(_, _)))
.DataLoader.
__call__: (
dataset: dataset.Dataset,

View file

@ -1,4 +1,7 @@
dataset = pyimport "torch/utils/data/dataset"
.VisionDataset: ClassType
.VisionDataset <: dataset.Dataset
.VisionDataset.
__call__: (
root: Str,

View file

@ -20,33 +20,18 @@ class Float(float):
def __add__(self, other):
return then__(float.__add__(self, other), Float)
def __radd__(self, other):
return then__(float.__add__(float(other), self), Float)
def __sub__(self, other):
return then__(float.__sub__(self, other), Float)
def __rsub__(self, other):
return then__(float.__sub__(float(other), self), Float)
def __mul__(self, other):
return then__(float.__mul__(self, other), Float)
def __rmul__(self, other):
return then__(float.__mul__(float(other), self), Float)
def __div__(self, other):
return then__(float.__div__(self, other), Float)
def __rdiv__(self, other):
return then__(float.__div__(float(other), self), Float)
def __floordiv__(self, other):
return then__(float.__floordiv__(self, other), Float)
def __rfloordiv__(self, other):
return then__(float.__floordiv__(float(other), self), Float)
def __pow__(self, other):
return then__(float.__pow__(self, other), Float)

View file

@ -27,33 +27,18 @@ class Int(int):
def __add__(self, other):
return then__(int.__add__(self, other), Int)
def __radd__(self, other):
return then__(int.__add__(other, self), Int)
def __sub__(self, other):
return then__(int.__sub__(self, other), Int)
def __rsub__(self, other):
return then__(int.__sub__(other, self), Int)
def __mul__(self, other):
return then__(int.__mul__(self, other), Int)
def __rmul__(self, other):
return then__(int.__mul__(other, self), Int)
def __div__(self, other):
return then__(int.__div__(self, other), Int)
def __rdiv__(self, other):
return then__(int.__div__(other, self), Int)
def __floordiv__(self, other):
return then__(int.__floordiv__(self, other), Int)
def __rfloordiv__(self, other):
return then__(int.__floordiv__(other, self), Int)
def __pow__(self, other):
return then__(int.__pow__(self, other), Int)

View file

@ -30,9 +30,6 @@ class Str(str):
def __add__(self, other):
return then__(str.__add__(self, other), Str)
def __radd__(self, other):
return then__(str.__add__(other, self), Str)
def __mul__(self, other):
return then__(str.__mul__(self, other), Str)