feat: add torch type declaration

This commit is contained in:
Shunsuke Shibayama 2023-12-09 12:05:07 +09:00
parent 480c9e4f83
commit 1aa61cd6a6
24 changed files with 84 additions and 0 deletions

View file

@ -0,0 +1 @@
{.load!; .save!;} = pyimport "./serialization"

View file

@ -0,0 +1 @@
{.Module;} = pyimport "./modules/module"

View file

@ -0,0 +1 @@
{.Module;} = pyimport "./module"

View file

@ -0,0 +1,3 @@
.CrossEntropyLoss: ClassType
.CrossEntropyLoss.
__call__: () -> .CrossEntropyLoss

View file

@ -0,0 +1 @@
.Module: ClassType

View file

View file

@ -0,0 +1,2 @@
.load!: (f: PathLike) => NoneType
.save!: (obj: Obj, f: PathLike) => NoneType

View file

@ -0,0 +1,2 @@
{.DataLoader;} = pyimport "./dataloader"
{.Dataset;} = pyimport "./dataset"

View file

@ -0,0 +1,9 @@
dataset = pyimport "./dataset"
.DataLoader: ClassType
.DataLoader.
__call__: (
dataset: dataset.Dataset,
batch_size := Nat,
shuffle := Bool,
) -> .DataLoader

View file

@ -0,0 +1 @@
.Dataset: ClassType

View file

@ -0,0 +1,28 @@
.Optimizer: ClassType
.ASGD: ClassType
.ASGD <: .Optimizer
.Adadelta: ClassType
.Adadelta <: .Optimizer
.Adagrad: ClassType
.Adagrad <: .Optimizer
.Adam: ClassType
.Adam <: .Optimizer
.AdamW: ClassType
.AdamW <: .Optimizer
.Adamax: ClassType
.Adamax <: .Optimizer
.LBFGS: ClassType
.LBFGS <: .Optimizer
.NAdam: ClassType
.NAdam <: .Optimizer
.RAdam: ClassType
.RAdam <: .Optimizer
.RMSprop: ClassType
.RMSprop <: .Optimizer
.Rprop: ClassType
.Rprop <: .Optimizer
.SGD: ClassType
.SGD <: .Optimizer
.SparseAdam: ClassType
.SparseAdam <: .Optimizer

View file

@ -0,0 +1,23 @@
vision = pyimport "./vision"
.MNIST: ClassType
.MNIST <: vision.VisionDataset
.MNIST.
__call__: (
root: Str,
train := Bool,
download := Bool,
transform := GenericCallable,
target_transform := GenericCallable,
) -> .MNIST
.FashionMNIST: ClassType
.FashionMNIST <: .MNIST
.FashionMNIST.
__call__: (
root: Str,
train := Bool,
download := Bool,
transform := GenericCallable,
target_transform := GenericCallable,
) -> .FashionMNIST

View file

@ -0,0 +1,9 @@
.VisionDataset: ClassType
.VisionDataset.
__call__: (
root: Str,
train := Bool,
download := Bool,
transform := GenericCallable,
target_transform := GenericCallable,
) -> .VisionDataset

View file

@ -0,0 +1,3 @@
.ToTensor: ClassType
.ToTensor.
__call__: () -> .ToTensor