fix: torch type decls

This commit is contained in:
Shunsuke Shibayama 2024-02-03 20:51:58 +09:00
parent 9c9f8b7a0a
commit 3be5d75d05
13 changed files with 149 additions and 83 deletions

View file

@ -479,10 +479,10 @@ impl EffectError {
ErrorCore::new( ErrorCore::new(
sub, sub,
switch_lang!( switch_lang!(
"japanese" => "関数中で可変オブジェクトにアクセスすることは出来ません", "japanese" => format!("関数中で可変オブジェクト(: {})にアクセスすることは出来ません", expr.ref_t()),
"simplified_chinese" => "函数中不能访问可变对象", "simplified_chinese" => format!("函数中不能访问可变对象(: {})", expr.ref_t()),
"traditional_chinese" => "函數中不能訪問可變對象", "traditional_chinese" => format!("函數中不能訪問可變對象(: {})", expr.ref_t()),
"english" => "cannot access a mutable object in a function", "english" => format!("cannot access a mutable object (: {}) in a function", expr.ref_t()),
), ),
errno, errno,
HasEffect, HasEffect,

View file

@ -27,15 +27,23 @@
.Complex128 = 'complex128': ClassType .Complex128 = 'complex128': ClassType
.Size: ClassType .Size: ClassType
.Tensor: (T: Type, Shape: [Nat; _]) -> ClassType .Tensor!: (T: Type, Shape: [Nat; _]) -> ClassType
.Tensor(T, _) <: Output T .Tensor!(T, _) <: Output T
.Tensor(_, _). .Tensor!(_, _).
dtype: .DType dtype: .DType
shape: .Size shape: .Size
view: (|T, Old: [Nat; _], S: {A: [Nat; _] | A.prod() == Old.prod()}|( view: (|T, Old: [Nat; _], S: {A: [Nat; _] | A.prod() == Old.prod()}|(
self: .Tensor(T, Old), self: .Tensor!(T, Old),
shape: {S}, shape: {S},
) -> .Tensor(T, S)) \ ) -> .Tensor!(T, S)) \
and (|T|(self: .Tensor(T, _), shape: [Int; _]) -> .Tensor(T, _)) and (|T|(self: .Tensor!(T, _), shape: [Int; _]) -> .Tensor!(T, _))
backward!: |T, S: [Nat; _]|(
self: RefMut(.Tensor!(T, S)),
gradient := .Tensor!(T, S),
retain_graph := Bool,
create_graph := Bool,
) => NoneType
# TODO: S bound
item: |T|(self: Ref .Tensor!(T, _)) -> T
.relu: |T, S: [Nat; _]|(x: .Tensor(T, S)) -> .Tensor(T, S) .relu: |T, S: [Nat; _]|(x: .Tensor!(T, S)) -> .Tensor!(T, S)

View file

@ -1,4 +1,4 @@
{Tensor;} = pyimport "torch" {Tensor!;} = pyimport "torch"
.ReLU: ClassType .ReLU: ClassType
.ReLU. .ReLU.
@ -6,5 +6,5 @@
.ReLU|<: GenericCallable|. .ReLU|<: GenericCallable|.
__call__: |T, S: [Nat; _]|( __call__: |T, S: [Nat; _]|(
self: .ReLU, self: .ReLU,
input: Tensor(T, S), input: Tensor!(T, S),
) -> Tensor(T, S) ) -> Tensor!(T, S)

View file

@ -1,4 +1,4 @@
{Tensor;} = pyimport "torch" {Tensor!;} = pyimport "torch"
{Module;} = pyimport "torch/nn" {Module;} = pyimport "torch/nn"
.Sequential: ClassType .Sequential: ClassType
@ -7,5 +7,5 @@
.Sequential|<: GenericCallable|. .Sequential|<: GenericCallable|.
__call__: |T, S: [Nat; _]|( __call__: |T, S: [Nat; _]|(
self: .Sequential, self: .Sequential,
input: Tensor(T, S), input: Tensor!(T, S),
) -> Tensor(T, S) ) -> Tensor!(T, S)

View file

@ -1,4 +1,4 @@
{Device; DType; Tensor;} = pyimport "torch" {Device; DType; Tensor!;} = pyimport "torch"
{Module;} = pyimport "torch/nn" {Module;} = pyimport "torch/nn"
_ConvNd: ClassType _ConvNd: ClassType
@ -23,8 +23,8 @@ _ConvNd <: Module
.Conv1d|<: GenericCallable|. .Conv1d|<: GenericCallable|.
__call__: |T, S: [Nat; _]|( __call__: |T, S: [Nat; _]|(
self: .Conv1d, self: .Conv1d,
input: Tensor(T, S), input: Tensor!(T, S),
) -> Tensor(T, S) ) -> Tensor!(T, S)
.Conv2d: ClassType .Conv2d: ClassType
.Conv2d <: _ConvNd .Conv2d <: _ConvNd
@ -45,8 +45,8 @@ _ConvNd <: Module
.Conv2d|<: GenericCallable|. .Conv2d|<: GenericCallable|.
__call__: |T, S: [Nat; _]|( __call__: |T, S: [Nat; _]|(
self: .Conv2d, self: .Conv2d,
input: Tensor(T, S), input: Tensor!(T, S),
) -> Tensor(T, S) ) -> Tensor!(T, S)
.Conv3d: ClassType .Conv3d: ClassType
.Conv3d <: _ConvNd .Conv3d <: _ConvNd
@ -67,5 +67,5 @@ _ConvNd <: Module
.Conv3d|<: GenericCallable|. .Conv3d|<: GenericCallable|.
__call__: |T, S: [Nat; _]|( __call__: |T, S: [Nat; _]|(
self: .Conv3d, self: .Conv3d,
input: Tensor(T, S), input: Tensor!(T, S),
) -> Tensor(T, S) ) -> Tensor!(T, S)

View file

@ -1,4 +1,4 @@
{Tensor;} = pyimport "torch" {Tensor!;} = pyimport "torch"
.Flatten: ClassType .Flatten: ClassType
.Flatten. .Flatten.
@ -6,5 +6,5 @@
.Flatten|<: GenericCallable|. .Flatten|<: GenericCallable|.
__call__: |T, S: [Nat; _]|( __call__: |T, S: [Nat; _]|(
self: .Flatten, self: .Flatten,
input: Tensor(T, S), input: Tensor!(T, S),
) -> Tensor(T, S) ) -> Tensor!(T, S)

View file

@ -1,5 +1,5 @@
{Module;} = pyimport "torch/nn" {Module;} = pyimport "torch/nn"
{Device; DType; Tensor;} = pyimport "torch" {Device; DType; Tensor!;} = pyimport "torch"
.Linear: ClassType .Linear: ClassType
.Linear <: Module .Linear <: Module
@ -14,5 +14,5 @@
.Linear|<: GenericCallable|. .Linear|<: GenericCallable|.
__call__: |T, S: [Nat; _]|( __call__: |T, S: [Nat; _]|(
self: .Linear, self: .Linear,
input: Tensor(T, S), input: Tensor!(T, S),
) -> Tensor(T, S) ) -> Tensor!(T, S)

View file

@ -1,10 +1,27 @@
{Tensor;} = pyimport "torch" {Tensor!;} = pyimport "torch"
{Module;} = pyimport "torch/nn"
_Loss: ClassType
_Loss <: Module
_Loss.
reduction: Str
_WeightedLoss: ClassType
_WeightedLoss <: _Loss
.CrossEntropyLoss: ClassType .CrossEntropyLoss: ClassType
.CrossEntropyLoss <: _WeightedLoss
.CrossEntropyLoss. .CrossEntropyLoss.
__call__: () -> .CrossEntropyLoss __call__: () -> .CrossEntropyLoss
.CrossEntropyLoss|<: GenericCallable|. .CrossEntropyLoss|<: GenericCallable|.
__call__: |T|( __call__: |T|(
self: .CrossEntropyLoss, self: .CrossEntropyLoss,
input: Tensor(T, _), input: Tensor!(T, _),
) -> Tensor(T, []) target: Tensor!(T, _),
) -> Tensor!(T, [])
.CrossEntropyLoss.
forward: |T|(
self: .CrossEntropyLoss,
input: Tensor!(T, _),
target: Tensor!(T, _),
) -> Tensor!(T, [])

View file

@ -1,13 +1,18 @@
{.Parameter;} = pyimport "torch/nn/parameter" # {Tensor!;} = pyimport "torch"
# {.Tensor;} = pyimport "torch" {Parameter;} = pyimport "torch/nn/parameter"
.Module: ClassType .Module: ClassType
.Module <: InheritableType .Module <: InheritableType
.Module|<: GenericCallable|.
__call__: |M <: .Module|(
self: M,
input: Obj #Tensor!(T, _),
) -> Obj #Tensor!(T, _)
.Module. .Module.
parameters: (self: Ref(.Module), recurse := Bool) -> Iterator .Parameter parameters: (self: Ref(.Module), recurse := Bool) -> Iterator Parameter
named_parameters: (self: Ref(.Module), prefix := Str, recurse := Bool, remove_duplicate := Bool) -> Iterator((Str, .Parameter)) named_parameters: (self: Ref(.Module), prefix := Str, recurse := Bool, remove_duplicate := Bool) -> Iterator((Str, Parameter))
# buffers: (self: Ref(.Module), recurse := Bool) -> Iterator .Tensor # buffers: (self: Ref(.Module), recurse := Bool) -> Iterator .Tensor!
# named_buffers: (self: Ref(.Module), prefix := Str, recurse := Bool, remove_duplicate := Bool) -> Iterator((Str, .Tensor)) # named_buffers: (self: Ref(.Module), prefix := Str, recurse := Bool, remove_duplicate := Bool) -> Iterator((Str, .Tensor!))
children: (self: Ref(.Module)) -> Iterator .Module children: (self: Ref(.Module)) -> Iterator .Module
named_children: (self: Ref(.Module), prefix := Str) -> Iterator((Str, .Module)) named_children: (self: Ref(.Module), prefix := Str) -> Iterator((Str, .Module))
modules: (self: Ref(.Module)) -> Iterator .Module modules: (self: Ref(.Module)) -> Iterator .Module
@ -16,13 +21,13 @@
eval: |T <: .Module|(self: Ref(T)) -> T eval: |T <: .Module|(self: Ref(T)) -> T
zero_grad!: (self: RefMut(.Module), set_to_none := Bool) => NoneType zero_grad!: (self: RefMut(.Module), set_to_none := Bool) => NoneType
compile: (self: Ref(.Module), *args: Obj, **kwargs: Obj) -> .Module compile: (self: Ref(.Module), *args: Obj, **kwargs: Obj) -> .Module
# register_buffer!: (self: RefMut(.Module), name: Str, tensor := Tensor, persistent := Bool) => NoneType # register_buffer!: (self: RefMut(.Module), name: Str, tensor := Tensor!, persistent := Bool) => NoneType
register_parameter!: (self: RefMut(.Module), name: Str, param := .Parameter) => NoneType register_parameter!: (self: RefMut(.Module), name: Str, param := Parameter) => NoneType
add_module!: (self: RefMut(.Module), name: Str, module := .Module) => NoneType add_module!: (self: RefMut(.Module), name: Str, module := .Module) => NoneType
register_module!: (self: RefMut(.Module), name: Str, module := .Module) => NoneType register_module!: (self: RefMut(.Module), name: Str, module := .Module) => NoneType
get_submodule: (self: Ref(.Module), name: Str) -> .Module get_submodule: (self: Ref(.Module), name: Str) -> .Module
get_parameter: (self: Ref(.Module), name: Str) -> .Parameter get_parameter: (self: Ref(.Module), name: Str) -> Parameter
# get_buffer: (self: Ref(.Module), name: Str) -> .Tensor # get_buffer: (self: Ref(.Module), name: Str) -> .Tensor!
get_extra_state: (self: Ref(.Module)) -> Obj get_extra_state: (self: Ref(.Module)) -> Obj
set_extra_state!: (self: RefMut(.Module), state: Obj) => NoneType set_extra_state!: (self: RefMut(.Module), state: Obj) => NoneType
apply!: |T <: .Module|(self: T, fn: (module: RefMut(T)) => NoneType) => T apply!: |T <: .Module|(self: T, fn: (module: RefMut(T)) => NoneType) => T

View file

@ -1,4 +1,4 @@
{Tensor;} = pyimport "torch" {Tensor!;} = pyimport "torch"
{Module;} = pyimport "torch/nn" {Module;} = pyimport "torch/nn"
_MaxPoolNd: ClassType _MaxPoolNd: ClassType
@ -19,8 +19,8 @@ _MaxPoolNd <: Module
.MaxPool1d|<: GenericCallable|. .MaxPool1d|<: GenericCallable|.
__call__: |T, S: [Nat; _]|( __call__: |T, S: [Nat; _]|(
self: .MaxPool1d, self: .MaxPool1d,
input: Tensor(T, S), input: Tensor!(T, S),
) -> Tensor(T, S) ) -> Tensor!(T, S)
.MaxPool2d: ClassType .MaxPool2d: ClassType
.MaxPool2d <: _MaxPoolNd .MaxPool2d <: _MaxPoolNd
@ -36,8 +36,8 @@ _MaxPoolNd <: Module
.MaxPool2d|<: GenericCallable|. .MaxPool2d|<: GenericCallable|.
__call__: |T, S: [Nat; _]|( __call__: |T, S: [Nat; _]|(
self: .MaxPool2d, self: .MaxPool2d,
input: Tensor(T, S), input: Tensor!(T, S),
) -> Tensor(T, S) ) -> Tensor!(T, S)
.MaxPool3d: ClassType .MaxPool3d: ClassType
.MaxPool3d <: _MaxPoolNd .MaxPool3d <: _MaxPoolNd
@ -53,5 +53,5 @@ _MaxPoolNd <: Module
.MaxPool3d|<: GenericCallable|. .MaxPool3d|<: GenericCallable|.
__call__: |T, S: [Nat; _]|( __call__: |T, S: [Nat; _]|(
self: .MaxPool3d, self: .MaxPool3d,
input: Tensor(T, S), input: Tensor!(T, S),
) -> Tensor(T, S) ) -> Tensor!(T, S)

View file

@ -1,19 +1,21 @@
{Parameter;} = pyimport "torch/nn/parameter" {Parameter;} = pyimport "torch/nn/parameter"
.Optimizer: ClassType .Optimizer!: ClassType
.Optimizer <: InheritableType .Optimizer! <: InheritableType
.Optimizer. .Optimizer!.
__call__: (params: Iterable(Parameter)) -> .Optimizer __call__: (params: Iterable(Parameter)) -> .Optimizer!
zero_grad!: (self: RefMut .Optimizer!) => NoneType
step!: (self: RefMut .Optimizer!) => NoneType
.ASGD: ClassType .ASGD!: ClassType
.ASGD <: .Optimizer .ASGD! <: .Optimizer!
.Adadelta: ClassType .Adadelta!: ClassType
.Adadelta <: .Optimizer .Adadelta! <: .Optimizer!
.Adagrad: ClassType .Adagrad!: ClassType
.Adagrad <: .Optimizer .Adagrad! <: .Optimizer!
.Adam: ClassType .Adam!: ClassType
.Adam <: .Optimizer .Adam! <: .Optimizer!
.Adam. .Adam!.
__call__: ( __call__: (
params: Iterable(Parameter), params: Iterable(Parameter),
lr := Float, lr := Float,
@ -23,23 +25,23 @@
amsgrad := Bool, amsgrad := Bool,
foreach := Bool, foreach := Bool,
maximize := Bool, maximize := Bool,
) -> .Adam ) -> .Adam!
.AdamW: ClassType .AdamW!: ClassType
.AdamW <: .Optimizer .AdamW! <: .Optimizer!
.Adamax: ClassType .Adamax!: ClassType
.Adamax <: .Optimizer .Adamax! <: .Optimizer!
.LBFGS: ClassType .LBFGS!: ClassType
.LBFGS <: .Optimizer .LBFGS! <: .Optimizer!
.NAdam: ClassType .NAdam!: ClassType
.NAdam <: .Optimizer .NAdam! <: .Optimizer!
.RAdam: ClassType .RAdam!: ClassType
.RAdam <: .Optimizer .RAdam! <: .Optimizer!
.RMSprop: ClassType .RMSprop!: ClassType
.RMSprop <: .Optimizer .RMSprop! <: .Optimizer!
.Rprop: ClassType .Rprop!: ClassType
.Rprop <: .Optimizer .Rprop! <: .Optimizer!
.SGD: ClassType .SGD!: ClassType
.SGD <: .Optimizer .SGD! <: .Optimizer!
.SparseAdam: ClassType .SparseAdam!: ClassType
.SparseAdam <: .Optimizer .SparseAdam! <: .Optimizer!

View file

@ -3,7 +3,7 @@ dataset = pyimport "./dataset"
{Sampler;} = pyimport "./sampler" {Sampler;} = pyimport "./sampler"
.DataLoader: ClassType .DataLoader: ClassType
.DataLoader <: Iterable((torch.Tensor(_, _), torch.Tensor(_, _))) .DataLoader <: Iterable((torch.Tensor!(_, _), torch.Tensor!(_, _)))
.DataLoader. .DataLoader.
__call__: ( __call__: (
dataset: dataset.Dataset, dataset: dataset.Dataset,

View file

@ -1,6 +1,11 @@
datasets = pyimport "torchvision/datasets" datasets = pyimport "torchvision/datasets"
transforms = pyimport "torchvision/transforms" transforms = pyimport "torchvision/transforms"
data = pyimport "torch/utils/data" data = pyimport "torch/utils/data"
nn = pyimport "torch/nn"
torch = pyimport "torch"
_ = torch.manual_seed! 1
device = torch.device if torch.cuda.is_available!(), do "cuda", do "mps"
training_data = datasets.FashionMNIST( training_data = datasets.FashionMNIST(
root:="target/data", root:="target/data",
@ -12,3 +17,32 @@ train_dataloader = data.DataLoader(training_data, batch_size:=64)
for! train_dataloader, ((x, y),) => for! train_dataloader, ((x, y),) =>
print! x.shape, y.shape print! x.shape, y.shape
Net = Inherit nn.Module, Additional := {
.conv1 = nn.Conv2d;
.conv2 = nn.Conv2d;
.pool = nn.MaxPool2d;
.fc1 = nn.Linear;
.fc2 = nn.Linear;
}
Net.
@Override
new() = Net::__new__ {
conv1 = nn.Conv2d(1, 16, kernel_size:=3, stride:=1, padding:=1);
conv2 = nn.Conv2d(16, 32, kernel_size:=3, stride:=1, padding:=1);
pool = nn.MaxPool2d(kernel_size:=2, stride:=2);
fc1 = nn.Linear(32 * 7 * 7, 128);
fc2 = nn.Linear(128, 10)
}
forward! self, x =
x1 = self.pool torch.relu self.conv1 x
x2 = self.pool torch.relu self.conv2 x1
x3 = x2.view([-1, 32 * 7 * 7])
x4 = torch.relu self.fc1 x3
x5 = self.fc2 x4
x5
net = Net.new().to device
_ = nn.CrossEntropyLoss()
_ = torch.optim.Adam! net.parameters(), lr:=0.001