mirror of
https://github.com/erg-lang/erg.git
synced 2025-08-04 10:49:54 +00:00
feat: enhance torch type decls
This commit is contained in:
parent
2b91f9879c
commit
37a9b024be
12 changed files with 237 additions and 2 deletions
|
@ -6,6 +6,13 @@
|
|||
.utils = pyimport "./utils"
|
||||
|
||||
{.load!; .save!;} = pyimport "./serialization"
|
||||
{.manual_seed!;} = import "./random"
|
||||
|
||||
.Device = 'device': ClassType
|
||||
.device: (type: Str) => .Device
|
||||
|
||||
.DType = 'dtype': ClassType
|
||||
.dtype: (type: Str) => .DType
|
||||
|
||||
.UInt8 = 'uint8': ClassType
|
||||
.Int8 = 'int8': ClassType
|
||||
|
@ -19,7 +26,6 @@
|
|||
.Complex64 = 'complex64': ClassType
|
||||
.Complex128 = 'complex128': ClassType
|
||||
|
||||
.DType = 'dtype': ClassType
|
||||
.Size: ClassType
|
||||
.Tensor: (T: Type, Shape: [Nat; _]) -> ClassType
|
||||
.Tensor(T, _) <: Output T
|
||||
|
|
|
@ -2,9 +2,15 @@
|
|||
.parameter = pyimport "./parameter"
|
||||
|
||||
{
|
||||
.Conv1d;
|
||||
.Conv2d;
|
||||
.Conv3d;
|
||||
.CrossEntropyLoss;
|
||||
.Flatten;
|
||||
.Linear;
|
||||
.MaxPool1d;
|
||||
.MaxPool2d;
|
||||
.MaxPool3d;
|
||||
.Module;
|
||||
.ReLU;
|
||||
} = .modules
|
||||
|
|
|
@ -1,13 +1,17 @@
|
|||
.activation = pyimport "./activation"
|
||||
.container = pyimport "./container"
|
||||
.conv = pyimport "./conv"
|
||||
.flatten = pyimport "./flatten"
|
||||
.linear = pyimport "./linear"
|
||||
.loss = pyimport "./loss"
|
||||
.module = pyimport "./module"
|
||||
.pooling = pyimport "./pooling"
|
||||
|
||||
{.ReLU;} = .activation
|
||||
{.Sequential;} = .container
|
||||
{.Conv1d; .Conv2d; .Conv3d;} = .conv
|
||||
{.Flatten;} = .flatten
|
||||
{.Linear;} = .linear
|
||||
{.CrossEntropyLoss;} = .loss
|
||||
{.Module;} = .module
|
||||
{.MaxPool1d; .MaxPool2d; .MaxPool3d;} = .pooling
|
||||
|
|
59
crates/erg_compiler/lib/external/torch.d/nn.d/modules.d/conv.d.er
vendored
Normal file
59
crates/erg_compiler/lib/external/torch.d/nn.d/modules.d/conv.d.er
vendored
Normal file
|
@ -0,0 +1,59 @@
|
|||
{Device; DType;} = pyimport "torch"
|
||||
{Module;} = pyimport "torch/nn"
|
||||
|
||||
_ConvNd: ClassType
|
||||
_ConvNd <: Module
|
||||
|
||||
.Conv1d: ClassType
|
||||
.Conv1d <: _ConvNd
|
||||
.Conv1d <: GenericCallable
|
||||
.Conv1d.
|
||||
__call__: (
|
||||
in_channels: Nat,
|
||||
out_channels: Nat,
|
||||
kernel_size: Nat or [Nat; 1] or (Nat,),
|
||||
stride := Nat or [Nat; 1] or (Nat,),
|
||||
padding := Str or Nat or [Nat; 1] or (Nat,),
|
||||
dilation := Nat or [Nat; 1] or (Nat,),
|
||||
groups := Nat,
|
||||
bias := Bool,
|
||||
padding_mode := Str,
|
||||
device := Device or Str or Nat,
|
||||
dtype := DType or Str,
|
||||
) -> .Conv1d
|
||||
|
||||
.Conv2d: ClassType
|
||||
.Conv2d <: _ConvNd
|
||||
.Conv2d <: GenericCallable
|
||||
.Conv2d.
|
||||
__call__: (
|
||||
in_channels: Nat,
|
||||
out_channels: Nat,
|
||||
kernel_size: Nat or [Nat; 2] or (Nat, Nat),
|
||||
stride := Nat or [Nat; 2] or (Nat, Nat),
|
||||
padding := Str or Nat or [Nat; 2] or (Nat, Nat),
|
||||
dilation := Nat or [Nat; 2] or (Nat, Nat),
|
||||
groups := Nat,
|
||||
bias := Bool,
|
||||
padding_mode := Str,
|
||||
device := Device or Str or Nat,
|
||||
dtype := DType or Str,
|
||||
) -> .Conv2d
|
||||
|
||||
.Conv3d: ClassType
|
||||
.Conv3d <: _ConvNd
|
||||
.Conv3d <: GenericCallable
|
||||
.Conv3d.
|
||||
__call__: (
|
||||
in_channels: Nat,
|
||||
out_channels: Nat,
|
||||
kernel_size: Nat or [Nat; 3] or (Nat, Nat, Nat),
|
||||
stride := Nat or [Nat; 3] or (Nat, Nat, Nat),
|
||||
padding := Str or Nat or [Nat; 3] or (Nat, Nat, Nat),
|
||||
dilation := Nat or [Nat; 3] or (Nat, Nat, Nat),
|
||||
groups := Nat,
|
||||
bias := Bool,
|
||||
padding_mode := Str,
|
||||
device := Device or Str or Nat,
|
||||
dtype := DType or Str,
|
||||
) -> .Conv3d
|
|
@ -1 +1,14 @@
|
|||
{Module;} = pyimport "torch/nn"
|
||||
{Device; DType;} = pyimport "torch"
|
||||
|
||||
.Linear: ClassType
|
||||
.Linear <: Module
|
||||
.Linear <: GenericCallable
|
||||
.Linear.
|
||||
__call__: (
|
||||
in_features: Nat,
|
||||
out_features: Nat,
|
||||
bias := Bool,
|
||||
device := Device or Str or Nat,
|
||||
dtype := DType or Str,
|
||||
) -> .Linear
|
||||
|
|
43
crates/erg_compiler/lib/external/torch.d/nn.d/modules.d/pooling.d.er
vendored
Normal file
43
crates/erg_compiler/lib/external/torch.d/nn.d/modules.d/pooling.d.er
vendored
Normal file
|
@ -0,0 +1,43 @@
|
|||
{Module;} = pyimport "torch/nn"
|
||||
|
||||
_MaxPoolNd: ClassType
|
||||
_MaxPoolNd <: Module
|
||||
|
||||
.MaxPool1d: ClassType
|
||||
.MaxPool1d <: _MaxPoolNd
|
||||
.MaxPool1d <: GenericCallable
|
||||
.MaxPool1d.
|
||||
__call__: (
|
||||
kernel_size: Nat or [Nat; 1] or (Nat,),
|
||||
stride := Nat or [Nat; 1] or (Nat,),
|
||||
padding := Str or Nat or [Nat; 1] or (Nat,),
|
||||
dilation := Nat or [Nat; 1] or (Nat,),
|
||||
return_indices := Bool,
|
||||
ceil_mode := Bool,
|
||||
) -> .MaxPool1d
|
||||
|
||||
.MaxPool2d: ClassType
|
||||
.MaxPool2d <: _MaxPoolNd
|
||||
.MaxPool2d <: GenericCallable
|
||||
.MaxPool2d.
|
||||
__call__: (
|
||||
kernel_size: Nat or [Nat; 2] or (Nat, Nat),
|
||||
stride := Nat or [Nat; 2] or (Nat, Nat),
|
||||
padding := Str or Nat or [Nat; 2] or (Nat, Nat),
|
||||
dilation := Nat or [Nat; 2] or (Nat, Nat),
|
||||
return_indices := Bool,
|
||||
ceil_mode := Bool,
|
||||
) -> .MaxPool2d
|
||||
|
||||
.MaxPool3d: ClassType
|
||||
.MaxPool3d <: _MaxPoolNd
|
||||
.MaxPool3d <: GenericCallable
|
||||
.MaxPool3d.
|
||||
__call__: (
|
||||
kernel_size: Nat or [Nat; 3] or (Nat, Nat, Nat),
|
||||
stride := Nat or [Nat; 3] or (Nat, Nat, Nat),
|
||||
padding := Str or Nat or [Nat; 3] or (Nat, Nat, Nat),
|
||||
dilation := Nat or [Nat; 3] or (Nat, Nat, Nat),
|
||||
return_indices := Bool,
|
||||
ceil_mode := Bool,
|
||||
) -> .MaxPool3d
|
1
crates/erg_compiler/lib/external/torch.d/random.d.er
vendored
Normal file
1
crates/erg_compiler/lib/external/torch.d/random.d.er
vendored
Normal file
|
@ -0,0 +1 @@
|
|||
.manual_seed!: (seed: Int) => Obj
|
|
@ -1,2 +1,9 @@
|
|||
{.DataLoader;} = pyimport "./dataloader"
|
||||
{.Dataset;} = pyimport "./dataset"
|
||||
{
|
||||
.Sampler;
|
||||
.SequentialSampler;
|
||||
.RandomSampler;
|
||||
.SubsetRandomSampler;
|
||||
.WeightedRandomSampler;
|
||||
} = pyimport "./sampler"
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
torch = pyimport "torch"
|
||||
dataset = pyimport "./dataset"
|
||||
{Sampler;} = pyimport "./sampler"
|
||||
|
||||
.DataLoader: ClassType
|
||||
.DataLoader <: Iterable((torch.Tensor(_, _), torch.Tensor(_, _)))
|
||||
|
@ -8,4 +9,17 @@ dataset = pyimport "./dataset"
|
|||
dataset: dataset.Dataset,
|
||||
batch_size := Nat,
|
||||
shuffle := Bool,
|
||||
sampler := Sampler,
|
||||
batch_sampler := Sampler,
|
||||
num_workers := Nat,
|
||||
collate_fn := Obj,
|
||||
pin_memory := Bool,
|
||||
drop_last := Bool,
|
||||
timeout := Float,
|
||||
worker_init_fn := Obj,
|
||||
multiprocessing_context := Obj,
|
||||
generator := Obj,
|
||||
prefetch_factor := Nat,
|
||||
persistent_workers := Bool,
|
||||
pin_memory_device := Str,
|
||||
) -> .DataLoader
|
||||
|
|
13
crates/erg_compiler/lib/external/torch.d/utils.d/data.d/sampler.d.er
vendored
Normal file
13
crates/erg_compiler/lib/external/torch.d/utils.d/data.d/sampler.d.er
vendored
Normal file
|
@ -0,0 +1,13 @@
|
|||
.Sampler: ClassType
|
||||
|
||||
.RandomSampler: ClassType
|
||||
.RandomSampler <: .Sampler
|
||||
|
||||
.SequentialSampler: ClassType
|
||||
.SequentialSampler <: .Sampler
|
||||
|
||||
.SubsetRandomSampler: ClassType
|
||||
.SubsetRandomSampler <: .Sampler
|
||||
|
||||
.WeightedRandomSampler: ClassType
|
||||
.WeightedRandomSampler <: .Sampler
|
|
@ -1 +1,11 @@
|
|||
{.ToTensor;} = pyimport "./transforms"
|
||||
{
|
||||
.Compose;
|
||||
.CenterCrop;
|
||||
.GrayScale;
|
||||
.Normalize;
|
||||
.RandomCrop;
|
||||
.RandomHorizontalFlip;
|
||||
.RandomResizedCrop;
|
||||
.Resize;
|
||||
.ToTensor;
|
||||
} = pyimport "./transforms"
|
||||
|
|
|
@ -1,10 +1,69 @@
|
|||
Transform: ClassType
|
||||
|
||||
.Compose: ClassType
|
||||
.Compose <: GenericCallable
|
||||
.Compose.
|
||||
__call__: (transforms: [Transform; _]) -> .Compose
|
||||
|
||||
.CenterCrop: ClassType
|
||||
.CenterCrop <: Transform
|
||||
.CenterCrop <: GenericCallable
|
||||
.CenterCrop.
|
||||
__call__: (size: Nat or (Nat, Nat) or [Nat; 2]) -> .CenterCrop
|
||||
|
||||
.GrayScale: ClassType
|
||||
.GrayScale <: Transform
|
||||
.GrayScale <: GenericCallable
|
||||
.GrayScale.
|
||||
__call__: () -> .GrayScale
|
||||
|
||||
.Normalize: ClassType
|
||||
.Normalize <: Transform
|
||||
.Normalize <: GenericCallable
|
||||
.Normalize.
|
||||
__call__: (
|
||||
mean: [Float; _],
|
||||
std: [Float; _],
|
||||
inplace := Bool,
|
||||
) -> .Normalize
|
||||
|
||||
.RandomCrop: ClassType
|
||||
.RandomCrop <: Transform
|
||||
.RandomCrop <: GenericCallable
|
||||
.RandomCrop.
|
||||
__call__: (
|
||||
size: Nat or (Nat, Nat) or [Nat; 2],
|
||||
padding := Nat or (Nat, Nat) or (Nat, Nat, Nat, Nat),
|
||||
pad_if_needed:=Bool,
|
||||
fill := Nat or (Nat, Nat, Nat),
|
||||
padding_mode := Str,
|
||||
) -> .RandomCrop
|
||||
|
||||
.RandomHorizontalFlip: ClassType
|
||||
.RandomHorizontalFlip <: Transform
|
||||
.RandomHorizontalFlip <: GenericCallable
|
||||
.RandomHorizontalFlip.
|
||||
__call__: () -> .RandomHorizontalFlip
|
||||
|
||||
.RandomResizedCrop: ClassType
|
||||
.RandomResizedCrop <: Transform
|
||||
.RandomResizedCrop <: GenericCallable
|
||||
.RandomResizedCrop.
|
||||
__call__: (size: Nat or (Nat, Nat) or [Nat; 2]) -> .RandomResizedCrop
|
||||
|
||||
.Resize: ClassType
|
||||
.Resize <: Transform
|
||||
.Resize <: GenericCallable
|
||||
.Resize.
|
||||
__call__: (
|
||||
size: Nat or (Nat, Nat) or [Nat; 2],
|
||||
interpolation := Str,
|
||||
max_size := Nat,
|
||||
antialias := Bool,
|
||||
) -> .Resize
|
||||
|
||||
.ToTensor: ClassType
|
||||
.ToTensor <: Transform
|
||||
.ToTensor <: GenericCallable
|
||||
.ToTensor.
|
||||
__call__: () -> .ToTensor
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue