WIP: enhance torch type decls

This commit is contained in:
Shunsuke Shibayama 2023-12-15 21:59:59 +09:00
parent 905a03d578
commit 89b26b3e8c
19 changed files with 81 additions and 3 deletions

View file

@ -1,6 +1,9 @@
.backends = pyimport "./backends"
.cuda = pyimport "./cuda"
.nn = pyimport "./nn"
.optim = pyimport "./optim"
.serialization = pyimport "./serialization"
.util = pyimport "./util"
.utils = pyimport "./utils"
{.load!; .save!;} = pyimport "./serialization"
@ -19,6 +22,7 @@
.DType = 'dtype': ClassType
.Size: ClassType
.Tensor: (T: Type, Shape: [Nat; _]) -> ClassType
.Tensor.
.Tensor(T, _) <: Output T
.Tensor(_, _).
dtype: .DType
shape: .Size

View file

@ -0,0 +1 @@
.mps = pyimport "./mps"

View file

@ -0,0 +1 @@
.is_available!: () => Bool

View file

@ -0,0 +1 @@
.is_available!: () => Bool

View file

@ -1 +1,6 @@
.modules = pyimport "./modules"
.parameter = pyimport "./parameter"
{.CrossEntropyLoss; .Flatten; .Linear; .Module; .ReLU;} = .modules
{.Parameter;} = .parameter
{.Module;} = pyimport "./modules/module"

View file

@ -1 +1,13 @@
{.Module;} = pyimport "./module"
.activation = pyimport "./activation"
.container = pyimport "./container"
.flatten = pyimport "./flatten"
.linear = pyimport "./linear"
.loss = pyimport "./loss"
.module = pyimport "./module"
{.ReLU;} = .activation
{.Sequential;} = .container
{.Flatten;} = .flatten
{.Linear;} = .linear
{.CrossEntropyLoss;} = .loss
{.Module;} = .module

View file

@ -0,0 +1,3 @@
.ReLU: ClassType
.ReLU.
__call__: () -> .ReLU

View file

@ -0,0 +1,3 @@
.Sequential: ClassType
.Sequential.
__call__: (*args: .Module) -> .Sequential

View file

@ -0,0 +1 @@
.Flatten: ClassType

View file

@ -0,0 +1 @@
.Linear: ClassType

View file

@ -1 +1,5 @@
{.Parameter;} = pyimport "torch/nn/parameter"
.Module: ClassType
.Module.
parameters: (self: Ref(.Module)) -> Iterator .Parameter

View file

@ -0,0 +1 @@
.Parameter: ClassType

View file

@ -0,0 +1 @@
.data = pyimport "./data"

View file

@ -1,5 +1,7 @@
.cifar = pyimport "./cifar"
.mnist = pyimport "./mnist"
.utils = pyimport "./utils"
.vision = pyimport "./vision"
{.CIFAR10; .CIFAR100;} = .cifar
{.MNIST; .FashionMNIST;} = .mnist

View file

@ -0,0 +1,23 @@
vision = pyimport "./vision"
.CIFAR10: ClassType
.CIFAR10 <: vision.VisionDataset
.CIFAR10.
__call__: (
root: Str,
train := Bool,
download := Bool,
transform := GenericCallable,
target_transform := GenericCallable,
) -> .CIFAR10
.CIFAR100: ClassType
.CIFAR100 <: .CIFAR10
.CIFAR100.
__call__: (
root: Str,
train := Bool,
download := Bool,
transform := GenericCallable,
target_transform := GenericCallable,
) -> .CIFAR100

View file

@ -0,0 +1,3 @@
.resnet = pyimport "./resnet"
{.ResNet; .resnet18;} = .resnet

View file

@ -0,0 +1,6 @@
{.Module;} = pyimport "torch/nn"
.ResNet: ClassType
.ResNet <: .Module
.resnet18: () -> .ResNet

View file

@ -1,3 +1,9 @@
.Compose: ClassType
.Normalize: ClassType
.RandomCrop: ClassType
.RandomHorizontalFlip: ClassType
.ToTensor: ClassType
.ToTensor <: GenericCallable
.ToTensor.