mirror of
https://github.com/erg-lang/erg.git
synced 2025-08-04 10:49:54 +00:00
WIP: enhance torch
type decls
This commit is contained in:
parent
905a03d578
commit
89b26b3e8c
19 changed files with 81 additions and 3 deletions
|
@ -1,6 +1,9 @@
|
|||
.backends = pyimport "./backends"
|
||||
.cuda = pyimport "./cuda"
|
||||
.nn = pyimport "./nn"
|
||||
.optim = pyimport "./optim"
|
||||
.serialization = pyimport "./serialization"
|
||||
.util = pyimport "./util"
|
||||
.utils = pyimport "./utils"
|
||||
|
||||
{.load!; .save!;} = pyimport "./serialization"
|
||||
|
||||
|
@ -19,6 +22,7 @@
|
|||
.DType = 'dtype': ClassType
|
||||
.Size: ClassType
|
||||
.Tensor: (T: Type, Shape: [Nat; _]) -> ClassType
|
||||
.Tensor.
|
||||
.Tensor(T, _) <: Output T
|
||||
.Tensor(_, _).
|
||||
dtype: .DType
|
||||
shape: .Size
|
||||
|
|
1
crates/erg_compiler/lib/external/torch.d/backends.d/__init__.d.er
vendored
Normal file
1
crates/erg_compiler/lib/external/torch.d/backends.d/__init__.d.er
vendored
Normal file
|
@ -0,0 +1 @@
|
|||
.mps = pyimport "./mps"
|
1
crates/erg_compiler/lib/external/torch.d/backends.d/mps.d/__init__.d.er
vendored
Normal file
1
crates/erg_compiler/lib/external/torch.d/backends.d/mps.d/__init__.d.er
vendored
Normal file
|
@ -0,0 +1 @@
|
|||
.is_available!: () => Bool
|
1
crates/erg_compiler/lib/external/torch.d/cuda.d/__init__.d.er
vendored
Normal file
1
crates/erg_compiler/lib/external/torch.d/cuda.d/__init__.d.er
vendored
Normal file
|
@ -0,0 +1 @@
|
|||
.is_available!: () => Bool
|
|
@ -1 +1,6 @@
|
|||
.modules = pyimport "./modules"
|
||||
.parameter = pyimport "./parameter"
|
||||
|
||||
{.CrossEntropyLoss; .Flatten; .Linear; .Module; .ReLU;} = .modules
|
||||
{.Parameter;} = .parameter
|
||||
{.Module;} = pyimport "./modules/module"
|
||||
|
|
|
@ -1 +1,13 @@
|
|||
{.Module;} = pyimport "./module"
|
||||
.activation = pyimport "./activation"
|
||||
.container = pyimport "./container"
|
||||
.flatten = pyimport "./flatten"
|
||||
.linear = pyimport "./linear"
|
||||
.loss = pyimport "./loss"
|
||||
.module = pyimport "./module"
|
||||
|
||||
{.ReLU;} = .activation
|
||||
{.Sequential;} = .container
|
||||
{.Flatten;} = .flatten
|
||||
{.Linear;} = .linear
|
||||
{.CrossEntropyLoss;} = .loss
|
||||
{.Module;} = .module
|
||||
|
|
|
@ -0,0 +1,3 @@
|
|||
.ReLU: ClassType
|
||||
.ReLU.
|
||||
__call__: () -> .ReLU
|
|
@ -0,0 +1,3 @@
|
|||
.Sequential: ClassType
|
||||
.Sequential.
|
||||
__call__: (*args: .Module) -> .Sequential
|
|
@ -0,0 +1 @@
|
|||
.Flatten: ClassType
|
|
@ -0,0 +1 @@
|
|||
.Linear: ClassType
|
|
@ -1 +1,5 @@
|
|||
{.Parameter;} = pyimport "torch/nn/parameter"
|
||||
|
||||
.Module: ClassType
|
||||
.Module.
|
||||
parameters: (self: Ref(.Module)) -> Iterator .Parameter
|
||||
|
|
1
crates/erg_compiler/lib/external/torch.d/nn.d/parameter.d.er
vendored
Normal file
1
crates/erg_compiler/lib/external/torch.d/nn.d/parameter.d.er
vendored
Normal file
|
@ -0,0 +1 @@
|
|||
.Parameter: ClassType
|
|
@ -0,0 +1 @@
|
|||
.data = pyimport "./data"
|
|
@ -1,5 +1,7 @@
|
|||
.cifar = pyimport "./cifar"
|
||||
.mnist = pyimport "./mnist"
|
||||
.utils = pyimport "./utils"
|
||||
.vision = pyimport "./vision"
|
||||
|
||||
{.CIFAR10; .CIFAR100;} = .cifar
|
||||
{.MNIST; .FashionMNIST;} = .mnist
|
||||
|
|
23
crates/erg_compiler/lib/external/torchvision.d/datasets.d/cifar.d.er
vendored
Normal file
23
crates/erg_compiler/lib/external/torchvision.d/datasets.d/cifar.d.er
vendored
Normal file
|
@ -0,0 +1,23 @@
|
|||
vision = pyimport "./vision"
|
||||
|
||||
.CIFAR10: ClassType
|
||||
.CIFAR10 <: vision.VisionDataset
|
||||
.CIFAR10.
|
||||
__call__: (
|
||||
root: Str,
|
||||
train := Bool,
|
||||
download := Bool,
|
||||
transform := GenericCallable,
|
||||
target_transform := GenericCallable,
|
||||
) -> .CIFAR10
|
||||
|
||||
.CIFAR100: ClassType
|
||||
.CIFAR100 <: .CIFAR10
|
||||
.CIFAR100.
|
||||
__call__: (
|
||||
root: Str,
|
||||
train := Bool,
|
||||
download := Bool,
|
||||
transform := GenericCallable,
|
||||
target_transform := GenericCallable,
|
||||
) -> .CIFAR100
|
3
crates/erg_compiler/lib/external/torchvision.d/models.d/__init__.d.er
vendored
Normal file
3
crates/erg_compiler/lib/external/torchvision.d/models.d/__init__.d.er
vendored
Normal file
|
@ -0,0 +1,3 @@
|
|||
.resnet = pyimport "./resnet"
|
||||
|
||||
{.ResNet; .resnet18;} = .resnet
|
6
crates/erg_compiler/lib/external/torchvision.d/models.d/resnet.d.er
vendored
Normal file
6
crates/erg_compiler/lib/external/torchvision.d/models.d/resnet.d.er
vendored
Normal file
|
@ -0,0 +1,6 @@
|
|||
{.Module;} = pyimport "torch/nn"
|
||||
|
||||
.ResNet: ClassType
|
||||
.ResNet <: .Module
|
||||
|
||||
.resnet18: () -> .ResNet
|
|
@ -1,3 +1,9 @@
|
|||
.Compose: ClassType
|
||||
.Normalize: ClassType
|
||||
|
||||
.RandomCrop: ClassType
|
||||
.RandomHorizontalFlip: ClassType
|
||||
|
||||
.ToTensor: ClassType
|
||||
.ToTensor <: GenericCallable
|
||||
.ToTensor.
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue