feat: enhance torch type decls

This commit is contained in:
Shunsuke Shibayama 2024-02-01 16:01:55 +09:00
parent 2b91f9879c
commit 37a9b024be
12 changed files with 237 additions and 2 deletions

View file

@ -6,6 +6,13 @@
.utils = pyimport "./utils"
{.load!; .save!;} = pyimport "./serialization"
{.manual_seed!;} = import "./random"
.Device = 'device': ClassType
.device: (type: Str) => .Device
.DType = 'dtype': ClassType
.dtype: (type: Str) => .DType
.UInt8 = 'uint8': ClassType
.Int8 = 'int8': ClassType
@ -19,7 +26,6 @@
.Complex64 = 'complex64': ClassType
.Complex128 = 'complex128': ClassType
.DType = 'dtype': ClassType
.Size: ClassType
.Tensor: (T: Type, Shape: [Nat; _]) -> ClassType
.Tensor(T, _) <: Output T

View file

@ -2,9 +2,15 @@
.parameter = pyimport "./parameter"
{
.Conv1d;
.Conv2d;
.Conv3d;
.CrossEntropyLoss;
.Flatten;
.Linear;
.MaxPool1d;
.MaxPool2d;
.MaxPool3d;
.Module;
.ReLU;
} = .modules

View file

@ -1,13 +1,17 @@
.activation = pyimport "./activation"
.container = pyimport "./container"
.conv = pyimport "./conv"
.flatten = pyimport "./flatten"
.linear = pyimport "./linear"
.loss = pyimport "./loss"
.module = pyimport "./module"
.pooling = pyimport "./pooling"
{.ReLU;} = .activation
{.Sequential;} = .container
{.Conv1d; .Conv2d; .Conv3d;} = .conv
{.Flatten;} = .flatten
{.Linear;} = .linear
{.CrossEntropyLoss;} = .loss
{.Module;} = .module
{.MaxPool1d; .MaxPool2d; .MaxPool3d;} = .pooling

View file

@ -0,0 +1,59 @@
{Device; DType;} = pyimport "torch"
{Module;} = pyimport "torch/nn"
_ConvNd: ClassType
_ConvNd <: Module
.Conv1d: ClassType
.Conv1d <: _ConvNd
.Conv1d <: GenericCallable
.Conv1d.
__call__: (
in_channels: Nat,
out_channels: Nat,
kernel_size: Nat or [Nat; 1] or (Nat,),
stride := Nat or [Nat; 1] or (Nat,),
padding := Str or Nat or [Nat; 1] or (Nat,),
dilation := Nat or [Nat; 1] or (Nat,),
groups := Nat,
bias := Bool,
padding_mode := Str,
device := Device or Str or Nat,
dtype := DType or Str,
) -> .Conv1d
.Conv2d: ClassType
.Conv2d <: _ConvNd
.Conv2d <: GenericCallable
.Conv2d.
__call__: (
in_channels: Nat,
out_channels: Nat,
kernel_size: Nat or [Nat; 2] or (Nat, Nat),
stride := Nat or [Nat; 2] or (Nat, Nat),
padding := Str or Nat or [Nat; 2] or (Nat, Nat),
dilation := Nat or [Nat; 2] or (Nat, Nat),
groups := Nat,
bias := Bool,
padding_mode := Str,
device := Device or Str or Nat,
dtype := DType or Str,
) -> .Conv2d
.Conv3d: ClassType
.Conv3d <: _ConvNd
.Conv3d <: GenericCallable
.Conv3d.
__call__: (
in_channels: Nat,
out_channels: Nat,
kernel_size: Nat or [Nat; 3] or (Nat, Nat, Nat),
stride := Nat or [Nat; 3] or (Nat, Nat, Nat),
padding := Str or Nat or [Nat; 3] or (Nat, Nat, Nat),
dilation := Nat or [Nat; 3] or (Nat, Nat, Nat),
groups := Nat,
bias := Bool,
padding_mode := Str,
device := Device or Str or Nat,
dtype := DType or Str,
) -> .Conv3d

View file

@ -1 +1,14 @@
{Module;} = pyimport "torch/nn"
{Device; DType;} = pyimport "torch"
.Linear: ClassType
.Linear <: Module
.Linear <: GenericCallable
.Linear.
__call__: (
in_features: Nat,
out_features: Nat,
bias := Bool,
device := Device or Str or Nat,
dtype := DType or Str,
) -> .Linear

View file

@ -0,0 +1,43 @@
{Module;} = pyimport "torch/nn"
_MaxPoolNd: ClassType
_MaxPoolNd <: Module
.MaxPool1d: ClassType
.MaxPool1d <: _MaxPoolNd
.MaxPool1d <: GenericCallable
.MaxPool1d.
__call__: (
kernel_size: Nat or [Nat; 1] or (Nat,),
stride := Nat or [Nat; 1] or (Nat,),
padding := Str or Nat or [Nat; 1] or (Nat,),
dilation := Nat or [Nat; 1] or (Nat,),
return_indices := Bool,
ceil_mode := Bool,
) -> .MaxPool1d
.MaxPool2d: ClassType
.MaxPool2d <: _MaxPoolNd
.MaxPool2d <: GenericCallable
.MaxPool2d.
__call__: (
kernel_size: Nat or [Nat; 2] or (Nat, Nat),
stride := Nat or [Nat; 2] or (Nat, Nat),
padding := Str or Nat or [Nat; 2] or (Nat, Nat),
dilation := Nat or [Nat; 2] or (Nat, Nat),
return_indices := Bool,
ceil_mode := Bool,
) -> .MaxPool2d
.MaxPool3d: ClassType
.MaxPool3d <: _MaxPoolNd
.MaxPool3d <: GenericCallable
.MaxPool3d.
__call__: (
kernel_size: Nat or [Nat; 3] or (Nat, Nat, Nat),
stride := Nat or [Nat; 3] or (Nat, Nat, Nat),
padding := Str or Nat or [Nat; 3] or (Nat, Nat, Nat),
dilation := Nat or [Nat; 3] or (Nat, Nat, Nat),
return_indices := Bool,
ceil_mode := Bool,
) -> .MaxPool3d

View file

@ -0,0 +1 @@
.manual_seed!: (seed: Int) => Obj

View file

@ -1,2 +1,9 @@
{.DataLoader;} = pyimport "./dataloader"
{.Dataset;} = pyimport "./dataset"
{
.Sampler;
.SequentialSampler;
.RandomSampler;
.SubsetRandomSampler;
.WeightedRandomSampler;
} = pyimport "./sampler"

View file

@ -1,5 +1,6 @@
torch = pyimport "torch"
dataset = pyimport "./dataset"
{Sampler;} = pyimport "./sampler"
.DataLoader: ClassType
.DataLoader <: Iterable((torch.Tensor(_, _), torch.Tensor(_, _)))
@ -8,4 +9,17 @@ dataset = pyimport "./dataset"
dataset: dataset.Dataset,
batch_size := Nat,
shuffle := Bool,
sampler := Sampler,
batch_sampler := Sampler,
num_workers := Nat,
collate_fn := Obj,
pin_memory := Bool,
drop_last := Bool,
timeout := Float,
worker_init_fn := Obj,
multiprocessing_context := Obj,
generator := Obj,
prefetch_factor := Nat,
persistent_workers := Bool,
pin_memory_device := Str,
) -> .DataLoader

View file

@ -0,0 +1,13 @@
.Sampler: ClassType
.RandomSampler: ClassType
.RandomSampler <: .Sampler
.SequentialSampler: ClassType
.SequentialSampler <: .Sampler
.SubsetRandomSampler: ClassType
.SubsetRandomSampler <: .Sampler
.WeightedRandomSampler: ClassType
.WeightedRandomSampler <: .Sampler

View file

@ -1 +1,11 @@
{.ToTensor;} = pyimport "./transforms"
{
.Compose;
.CenterCrop;
.GrayScale;
.Normalize;
.RandomCrop;
.RandomHorizontalFlip;
.RandomResizedCrop;
.Resize;
.ToTensor;
} = pyimport "./transforms"

View file

@ -1,10 +1,69 @@
Transform: ClassType
.Compose: ClassType
.Compose <: GenericCallable
.Compose.
__call__: (transforms: [Transform; _]) -> .Compose
.CenterCrop: ClassType
.CenterCrop <: Transform
.CenterCrop <: GenericCallable
.CenterCrop.
__call__: (size: Nat or (Nat, Nat) or [Nat; 2]) -> .CenterCrop
.GrayScale: ClassType
.GrayScale <: Transform
.GrayScale <: GenericCallable
.GrayScale.
__call__: () -> .GrayScale
.Normalize: ClassType
.Normalize <: Transform
.Normalize <: GenericCallable
.Normalize.
__call__: (
mean: [Float; _],
std: [Float; _],
inplace := Bool,
) -> .Normalize
.RandomCrop: ClassType
.RandomCrop <: Transform
.RandomCrop <: GenericCallable
.RandomCrop.
__call__: (
size: Nat or (Nat, Nat) or [Nat; 2],
padding := Nat or (Nat, Nat) or (Nat, Nat, Nat, Nat),
pad_if_needed:=Bool,
fill := Nat or (Nat, Nat, Nat),
padding_mode := Str,
) -> .RandomCrop
.RandomHorizontalFlip: ClassType
.RandomHorizontalFlip <: Transform
.RandomHorizontalFlip <: GenericCallable
.RandomHorizontalFlip.
__call__: () -> .RandomHorizontalFlip
.RandomResizedCrop: ClassType
.RandomResizedCrop <: Transform
.RandomResizedCrop <: GenericCallable
.RandomResizedCrop.
__call__: (size: Nat or (Nat, Nat) or [Nat; 2]) -> .RandomResizedCrop
.Resize: ClassType
.Resize <: Transform
.Resize <: GenericCallable
.Resize.
__call__: (
size: Nat or (Nat, Nat) or [Nat; 2],
interpolation := Str,
max_size := Nat,
antialias := Bool,
) -> .Resize
.ToTensor: ClassType
.ToTensor <: Transform
.ToTensor <: GenericCallable
.ToTensor.
__call__: () -> .ToTensor