mirror of
https://github.com/erg-lang/erg.git
synced 2025-10-03 05:54:33 +00:00
fix: torch type decls
This commit is contained in:
parent
9c9f8b7a0a
commit
3be5d75d05
13 changed files with 149 additions and 83 deletions
|
@ -479,10 +479,10 @@ impl EffectError {
|
||||||
ErrorCore::new(
|
ErrorCore::new(
|
||||||
sub,
|
sub,
|
||||||
switch_lang!(
|
switch_lang!(
|
||||||
"japanese" => "関数中で可変オブジェクトにアクセスすることは出来ません",
|
"japanese" => format!("関数中で可変オブジェクト(: {})にアクセスすることは出来ません", expr.ref_t()),
|
||||||
"simplified_chinese" => "函数中不能访问可变对象",
|
"simplified_chinese" => format!("函数中不能访问可变对象(: {})", expr.ref_t()),
|
||||||
"traditional_chinese" => "函數中不能訪問可變對象",
|
"traditional_chinese" => format!("函數中不能訪問可變對象(: {})", expr.ref_t()),
|
||||||
"english" => "cannot access a mutable object in a function",
|
"english" => format!("cannot access a mutable object (: {}) in a function", expr.ref_t()),
|
||||||
),
|
),
|
||||||
errno,
|
errno,
|
||||||
HasEffect,
|
HasEffect,
|
||||||
|
|
|
@ -27,15 +27,23 @@
|
||||||
.Complex128 = 'complex128': ClassType
|
.Complex128 = 'complex128': ClassType
|
||||||
|
|
||||||
.Size: ClassType
|
.Size: ClassType
|
||||||
.Tensor: (T: Type, Shape: [Nat; _]) -> ClassType
|
.Tensor!: (T: Type, Shape: [Nat; _]) -> ClassType
|
||||||
.Tensor(T, _) <: Output T
|
.Tensor!(T, _) <: Output T
|
||||||
.Tensor(_, _).
|
.Tensor!(_, _).
|
||||||
dtype: .DType
|
dtype: .DType
|
||||||
shape: .Size
|
shape: .Size
|
||||||
view: (|T, Old: [Nat; _], S: {A: [Nat; _] | A.prod() == Old.prod()}|(
|
view: (|T, Old: [Nat; _], S: {A: [Nat; _] | A.prod() == Old.prod()}|(
|
||||||
self: .Tensor(T, Old),
|
self: .Tensor!(T, Old),
|
||||||
shape: {S},
|
shape: {S},
|
||||||
) -> .Tensor(T, S)) \
|
) -> .Tensor!(T, S)) \
|
||||||
and (|T|(self: .Tensor(T, _), shape: [Int; _]) -> .Tensor(T, _))
|
and (|T|(self: .Tensor!(T, _), shape: [Int; _]) -> .Tensor!(T, _))
|
||||||
|
backward!: |T, S: [Nat; _]|(
|
||||||
|
self: RefMut(.Tensor!(T, S)),
|
||||||
|
gradient := .Tensor!(T, S),
|
||||||
|
retain_graph := Bool,
|
||||||
|
create_graph := Bool,
|
||||||
|
) => NoneType
|
||||||
|
# TODO: S bound
|
||||||
|
item: |T|(self: Ref .Tensor!(T, _)) -> T
|
||||||
|
|
||||||
.relu: |T, S: [Nat; _]|(x: .Tensor(T, S)) -> .Tensor(T, S)
|
.relu: |T, S: [Nat; _]|(x: .Tensor!(T, S)) -> .Tensor!(T, S)
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
{Tensor;} = pyimport "torch"
|
{Tensor!;} = pyimport "torch"
|
||||||
|
|
||||||
.ReLU: ClassType
|
.ReLU: ClassType
|
||||||
.ReLU.
|
.ReLU.
|
||||||
|
@ -6,5 +6,5 @@
|
||||||
.ReLU|<: GenericCallable|.
|
.ReLU|<: GenericCallable|.
|
||||||
__call__: |T, S: [Nat; _]|(
|
__call__: |T, S: [Nat; _]|(
|
||||||
self: .ReLU,
|
self: .ReLU,
|
||||||
input: Tensor(T, S),
|
input: Tensor!(T, S),
|
||||||
) -> Tensor(T, S)
|
) -> Tensor!(T, S)
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
{Tensor;} = pyimport "torch"
|
{Tensor!;} = pyimport "torch"
|
||||||
{Module;} = pyimport "torch/nn"
|
{Module;} = pyimport "torch/nn"
|
||||||
|
|
||||||
.Sequential: ClassType
|
.Sequential: ClassType
|
||||||
|
@ -7,5 +7,5 @@
|
||||||
.Sequential|<: GenericCallable|.
|
.Sequential|<: GenericCallable|.
|
||||||
__call__: |T, S: [Nat; _]|(
|
__call__: |T, S: [Nat; _]|(
|
||||||
self: .Sequential,
|
self: .Sequential,
|
||||||
input: Tensor(T, S),
|
input: Tensor!(T, S),
|
||||||
) -> Tensor(T, S)
|
) -> Tensor!(T, S)
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
{Device; DType; Tensor;} = pyimport "torch"
|
{Device; DType; Tensor!;} = pyimport "torch"
|
||||||
{Module;} = pyimport "torch/nn"
|
{Module;} = pyimport "torch/nn"
|
||||||
|
|
||||||
_ConvNd: ClassType
|
_ConvNd: ClassType
|
||||||
|
@ -23,8 +23,8 @@ _ConvNd <: Module
|
||||||
.Conv1d|<: GenericCallable|.
|
.Conv1d|<: GenericCallable|.
|
||||||
__call__: |T, S: [Nat; _]|(
|
__call__: |T, S: [Nat; _]|(
|
||||||
self: .Conv1d,
|
self: .Conv1d,
|
||||||
input: Tensor(T, S),
|
input: Tensor!(T, S),
|
||||||
) -> Tensor(T, S)
|
) -> Tensor!(T, S)
|
||||||
|
|
||||||
.Conv2d: ClassType
|
.Conv2d: ClassType
|
||||||
.Conv2d <: _ConvNd
|
.Conv2d <: _ConvNd
|
||||||
|
@ -45,8 +45,8 @@ _ConvNd <: Module
|
||||||
.Conv2d|<: GenericCallable|.
|
.Conv2d|<: GenericCallable|.
|
||||||
__call__: |T, S: [Nat; _]|(
|
__call__: |T, S: [Nat; _]|(
|
||||||
self: .Conv2d,
|
self: .Conv2d,
|
||||||
input: Tensor(T, S),
|
input: Tensor!(T, S),
|
||||||
) -> Tensor(T, S)
|
) -> Tensor!(T, S)
|
||||||
|
|
||||||
.Conv3d: ClassType
|
.Conv3d: ClassType
|
||||||
.Conv3d <: _ConvNd
|
.Conv3d <: _ConvNd
|
||||||
|
@ -67,5 +67,5 @@ _ConvNd <: Module
|
||||||
.Conv3d|<: GenericCallable|.
|
.Conv3d|<: GenericCallable|.
|
||||||
__call__: |T, S: [Nat; _]|(
|
__call__: |T, S: [Nat; _]|(
|
||||||
self: .Conv3d,
|
self: .Conv3d,
|
||||||
input: Tensor(T, S),
|
input: Tensor!(T, S),
|
||||||
) -> Tensor(T, S)
|
) -> Tensor!(T, S)
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
{Tensor;} = pyimport "torch"
|
{Tensor!;} = pyimport "torch"
|
||||||
|
|
||||||
.Flatten: ClassType
|
.Flatten: ClassType
|
||||||
.Flatten.
|
.Flatten.
|
||||||
|
@ -6,5 +6,5 @@
|
||||||
.Flatten|<: GenericCallable|.
|
.Flatten|<: GenericCallable|.
|
||||||
__call__: |T, S: [Nat; _]|(
|
__call__: |T, S: [Nat; _]|(
|
||||||
self: .Flatten,
|
self: .Flatten,
|
||||||
input: Tensor(T, S),
|
input: Tensor!(T, S),
|
||||||
) -> Tensor(T, S)
|
) -> Tensor!(T, S)
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
{Module;} = pyimport "torch/nn"
|
{Module;} = pyimport "torch/nn"
|
||||||
{Device; DType; Tensor;} = pyimport "torch"
|
{Device; DType; Tensor!;} = pyimport "torch"
|
||||||
|
|
||||||
.Linear: ClassType
|
.Linear: ClassType
|
||||||
.Linear <: Module
|
.Linear <: Module
|
||||||
|
@ -14,5 +14,5 @@
|
||||||
.Linear|<: GenericCallable|.
|
.Linear|<: GenericCallable|.
|
||||||
__call__: |T, S: [Nat; _]|(
|
__call__: |T, S: [Nat; _]|(
|
||||||
self: .Linear,
|
self: .Linear,
|
||||||
input: Tensor(T, S),
|
input: Tensor!(T, S),
|
||||||
) -> Tensor(T, S)
|
) -> Tensor!(T, S)
|
||||||
|
|
|
@ -1,10 +1,27 @@
|
||||||
{Tensor;} = pyimport "torch"
|
{Tensor!;} = pyimport "torch"
|
||||||
|
{Module;} = pyimport "torch/nn"
|
||||||
|
|
||||||
|
_Loss: ClassType
|
||||||
|
_Loss <: Module
|
||||||
|
_Loss.
|
||||||
|
reduction: Str
|
||||||
|
|
||||||
|
_WeightedLoss: ClassType
|
||||||
|
_WeightedLoss <: _Loss
|
||||||
|
|
||||||
.CrossEntropyLoss: ClassType
|
.CrossEntropyLoss: ClassType
|
||||||
|
.CrossEntropyLoss <: _WeightedLoss
|
||||||
.CrossEntropyLoss.
|
.CrossEntropyLoss.
|
||||||
__call__: () -> .CrossEntropyLoss
|
__call__: () -> .CrossEntropyLoss
|
||||||
.CrossEntropyLoss|<: GenericCallable|.
|
.CrossEntropyLoss|<: GenericCallable|.
|
||||||
__call__: |T|(
|
__call__: |T|(
|
||||||
self: .CrossEntropyLoss,
|
self: .CrossEntropyLoss,
|
||||||
input: Tensor(T, _),
|
input: Tensor!(T, _),
|
||||||
) -> Tensor(T, [])
|
target: Tensor!(T, _),
|
||||||
|
) -> Tensor!(T, [])
|
||||||
|
.CrossEntropyLoss.
|
||||||
|
forward: |T|(
|
||||||
|
self: .CrossEntropyLoss,
|
||||||
|
input: Tensor!(T, _),
|
||||||
|
target: Tensor!(T, _),
|
||||||
|
) -> Tensor!(T, [])
|
||||||
|
|
|
@ -1,13 +1,18 @@
|
||||||
{.Parameter;} = pyimport "torch/nn/parameter"
|
# {Tensor!;} = pyimport "torch"
|
||||||
# {.Tensor;} = pyimport "torch"
|
{Parameter;} = pyimport "torch/nn/parameter"
|
||||||
|
|
||||||
.Module: ClassType
|
.Module: ClassType
|
||||||
.Module <: InheritableType
|
.Module <: InheritableType
|
||||||
|
.Module|<: GenericCallable|.
|
||||||
|
__call__: |M <: .Module|(
|
||||||
|
self: M,
|
||||||
|
input: Obj #Tensor!(T, _),
|
||||||
|
) -> Obj #Tensor!(T, _)
|
||||||
.Module.
|
.Module.
|
||||||
parameters: (self: Ref(.Module), recurse := Bool) -> Iterator .Parameter
|
parameters: (self: Ref(.Module), recurse := Bool) -> Iterator Parameter
|
||||||
named_parameters: (self: Ref(.Module), prefix := Str, recurse := Bool, remove_duplicate := Bool) -> Iterator((Str, .Parameter))
|
named_parameters: (self: Ref(.Module), prefix := Str, recurse := Bool, remove_duplicate := Bool) -> Iterator((Str, Parameter))
|
||||||
# buffers: (self: Ref(.Module), recurse := Bool) -> Iterator .Tensor
|
# buffers: (self: Ref(.Module), recurse := Bool) -> Iterator .Tensor!
|
||||||
# named_buffers: (self: Ref(.Module), prefix := Str, recurse := Bool, remove_duplicate := Bool) -> Iterator((Str, .Tensor))
|
# named_buffers: (self: Ref(.Module), prefix := Str, recurse := Bool, remove_duplicate := Bool) -> Iterator((Str, .Tensor!))
|
||||||
children: (self: Ref(.Module)) -> Iterator .Module
|
children: (self: Ref(.Module)) -> Iterator .Module
|
||||||
named_children: (self: Ref(.Module), prefix := Str) -> Iterator((Str, .Module))
|
named_children: (self: Ref(.Module), prefix := Str) -> Iterator((Str, .Module))
|
||||||
modules: (self: Ref(.Module)) -> Iterator .Module
|
modules: (self: Ref(.Module)) -> Iterator .Module
|
||||||
|
@ -16,13 +21,13 @@
|
||||||
eval: |T <: .Module|(self: Ref(T)) -> T
|
eval: |T <: .Module|(self: Ref(T)) -> T
|
||||||
zero_grad!: (self: RefMut(.Module), set_to_none := Bool) => NoneType
|
zero_grad!: (self: RefMut(.Module), set_to_none := Bool) => NoneType
|
||||||
compile: (self: Ref(.Module), *args: Obj, **kwargs: Obj) -> .Module
|
compile: (self: Ref(.Module), *args: Obj, **kwargs: Obj) -> .Module
|
||||||
# register_buffer!: (self: RefMut(.Module), name: Str, tensor := Tensor, persistent := Bool) => NoneType
|
# register_buffer!: (self: RefMut(.Module), name: Str, tensor := Tensor!, persistent := Bool) => NoneType
|
||||||
register_parameter!: (self: RefMut(.Module), name: Str, param := .Parameter) => NoneType
|
register_parameter!: (self: RefMut(.Module), name: Str, param := Parameter) => NoneType
|
||||||
add_module!: (self: RefMut(.Module), name: Str, module := .Module) => NoneType
|
add_module!: (self: RefMut(.Module), name: Str, module := .Module) => NoneType
|
||||||
register_module!: (self: RefMut(.Module), name: Str, module := .Module) => NoneType
|
register_module!: (self: RefMut(.Module), name: Str, module := .Module) => NoneType
|
||||||
get_submodule: (self: Ref(.Module), name: Str) -> .Module
|
get_submodule: (self: Ref(.Module), name: Str) -> .Module
|
||||||
get_parameter: (self: Ref(.Module), name: Str) -> .Parameter
|
get_parameter: (self: Ref(.Module), name: Str) -> Parameter
|
||||||
# get_buffer: (self: Ref(.Module), name: Str) -> .Tensor
|
# get_buffer: (self: Ref(.Module), name: Str) -> .Tensor!
|
||||||
get_extra_state: (self: Ref(.Module)) -> Obj
|
get_extra_state: (self: Ref(.Module)) -> Obj
|
||||||
set_extra_state!: (self: RefMut(.Module), state: Obj) => NoneType
|
set_extra_state!: (self: RefMut(.Module), state: Obj) => NoneType
|
||||||
apply!: |T <: .Module|(self: T, fn: (module: RefMut(T)) => NoneType) => T
|
apply!: |T <: .Module|(self: T, fn: (module: RefMut(T)) => NoneType) => T
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
{Tensor;} = pyimport "torch"
|
{Tensor!;} = pyimport "torch"
|
||||||
{Module;} = pyimport "torch/nn"
|
{Module;} = pyimport "torch/nn"
|
||||||
|
|
||||||
_MaxPoolNd: ClassType
|
_MaxPoolNd: ClassType
|
||||||
|
@ -19,8 +19,8 @@ _MaxPoolNd <: Module
|
||||||
.MaxPool1d|<: GenericCallable|.
|
.MaxPool1d|<: GenericCallable|.
|
||||||
__call__: |T, S: [Nat; _]|(
|
__call__: |T, S: [Nat; _]|(
|
||||||
self: .MaxPool1d,
|
self: .MaxPool1d,
|
||||||
input: Tensor(T, S),
|
input: Tensor!(T, S),
|
||||||
) -> Tensor(T, S)
|
) -> Tensor!(T, S)
|
||||||
|
|
||||||
.MaxPool2d: ClassType
|
.MaxPool2d: ClassType
|
||||||
.MaxPool2d <: _MaxPoolNd
|
.MaxPool2d <: _MaxPoolNd
|
||||||
|
@ -36,8 +36,8 @@ _MaxPoolNd <: Module
|
||||||
.MaxPool2d|<: GenericCallable|.
|
.MaxPool2d|<: GenericCallable|.
|
||||||
__call__: |T, S: [Nat; _]|(
|
__call__: |T, S: [Nat; _]|(
|
||||||
self: .MaxPool2d,
|
self: .MaxPool2d,
|
||||||
input: Tensor(T, S),
|
input: Tensor!(T, S),
|
||||||
) -> Tensor(T, S)
|
) -> Tensor!(T, S)
|
||||||
|
|
||||||
.MaxPool3d: ClassType
|
.MaxPool3d: ClassType
|
||||||
.MaxPool3d <: _MaxPoolNd
|
.MaxPool3d <: _MaxPoolNd
|
||||||
|
@ -53,5 +53,5 @@ _MaxPoolNd <: Module
|
||||||
.MaxPool3d|<: GenericCallable|.
|
.MaxPool3d|<: GenericCallable|.
|
||||||
__call__: |T, S: [Nat; _]|(
|
__call__: |T, S: [Nat; _]|(
|
||||||
self: .MaxPool3d,
|
self: .MaxPool3d,
|
||||||
input: Tensor(T, S),
|
input: Tensor!(T, S),
|
||||||
) -> Tensor(T, S)
|
) -> Tensor!(T, S)
|
||||||
|
|
|
@ -1,19 +1,21 @@
|
||||||
{Parameter;} = pyimport "torch/nn/parameter"
|
{Parameter;} = pyimport "torch/nn/parameter"
|
||||||
|
|
||||||
.Optimizer: ClassType
|
.Optimizer!: ClassType
|
||||||
.Optimizer <: InheritableType
|
.Optimizer! <: InheritableType
|
||||||
.Optimizer.
|
.Optimizer!.
|
||||||
__call__: (params: Iterable(Parameter)) -> .Optimizer
|
__call__: (params: Iterable(Parameter)) -> .Optimizer!
|
||||||
|
zero_grad!: (self: RefMut .Optimizer!) => NoneType
|
||||||
|
step!: (self: RefMut .Optimizer!) => NoneType
|
||||||
|
|
||||||
.ASGD: ClassType
|
.ASGD!: ClassType
|
||||||
.ASGD <: .Optimizer
|
.ASGD! <: .Optimizer!
|
||||||
.Adadelta: ClassType
|
.Adadelta!: ClassType
|
||||||
.Adadelta <: .Optimizer
|
.Adadelta! <: .Optimizer!
|
||||||
.Adagrad: ClassType
|
.Adagrad!: ClassType
|
||||||
.Adagrad <: .Optimizer
|
.Adagrad! <: .Optimizer!
|
||||||
.Adam: ClassType
|
.Adam!: ClassType
|
||||||
.Adam <: .Optimizer
|
.Adam! <: .Optimizer!
|
||||||
.Adam.
|
.Adam!.
|
||||||
__call__: (
|
__call__: (
|
||||||
params: Iterable(Parameter),
|
params: Iterable(Parameter),
|
||||||
lr := Float,
|
lr := Float,
|
||||||
|
@ -23,23 +25,23 @@
|
||||||
amsgrad := Bool,
|
amsgrad := Bool,
|
||||||
foreach := Bool,
|
foreach := Bool,
|
||||||
maximize := Bool,
|
maximize := Bool,
|
||||||
) -> .Adam
|
) -> .Adam!
|
||||||
|
|
||||||
.AdamW: ClassType
|
.AdamW!: ClassType
|
||||||
.AdamW <: .Optimizer
|
.AdamW! <: .Optimizer!
|
||||||
.Adamax: ClassType
|
.Adamax!: ClassType
|
||||||
.Adamax <: .Optimizer
|
.Adamax! <: .Optimizer!
|
||||||
.LBFGS: ClassType
|
.LBFGS!: ClassType
|
||||||
.LBFGS <: .Optimizer
|
.LBFGS! <: .Optimizer!
|
||||||
.NAdam: ClassType
|
.NAdam!: ClassType
|
||||||
.NAdam <: .Optimizer
|
.NAdam! <: .Optimizer!
|
||||||
.RAdam: ClassType
|
.RAdam!: ClassType
|
||||||
.RAdam <: .Optimizer
|
.RAdam! <: .Optimizer!
|
||||||
.RMSprop: ClassType
|
.RMSprop!: ClassType
|
||||||
.RMSprop <: .Optimizer
|
.RMSprop! <: .Optimizer!
|
||||||
.Rprop: ClassType
|
.Rprop!: ClassType
|
||||||
.Rprop <: .Optimizer
|
.Rprop! <: .Optimizer!
|
||||||
.SGD: ClassType
|
.SGD!: ClassType
|
||||||
.SGD <: .Optimizer
|
.SGD! <: .Optimizer!
|
||||||
.SparseAdam: ClassType
|
.SparseAdam!: ClassType
|
||||||
.SparseAdam <: .Optimizer
|
.SparseAdam! <: .Optimizer!
|
||||||
|
|
|
@ -3,7 +3,7 @@ dataset = pyimport "./dataset"
|
||||||
{Sampler;} = pyimport "./sampler"
|
{Sampler;} = pyimport "./sampler"
|
||||||
|
|
||||||
.DataLoader: ClassType
|
.DataLoader: ClassType
|
||||||
.DataLoader <: Iterable((torch.Tensor(_, _), torch.Tensor(_, _)))
|
.DataLoader <: Iterable((torch.Tensor!(_, _), torch.Tensor!(_, _)))
|
||||||
.DataLoader.
|
.DataLoader.
|
||||||
__call__: (
|
__call__: (
|
||||||
dataset: dataset.Dataset,
|
dataset: dataset.Dataset,
|
||||||
|
|
|
@ -1,6 +1,11 @@
|
||||||
datasets = pyimport "torchvision/datasets"
|
datasets = pyimport "torchvision/datasets"
|
||||||
transforms = pyimport "torchvision/transforms"
|
transforms = pyimport "torchvision/transforms"
|
||||||
data = pyimport "torch/utils/data"
|
data = pyimport "torch/utils/data"
|
||||||
|
nn = pyimport "torch/nn"
|
||||||
|
torch = pyimport "torch"
|
||||||
|
|
||||||
|
_ = torch.manual_seed! 1
|
||||||
|
device = torch.device if torch.cuda.is_available!(), do "cuda", do "mps"
|
||||||
|
|
||||||
training_data = datasets.FashionMNIST(
|
training_data = datasets.FashionMNIST(
|
||||||
root:="target/data",
|
root:="target/data",
|
||||||
|
@ -12,3 +17,32 @@ train_dataloader = data.DataLoader(training_data, batch_size:=64)
|
||||||
|
|
||||||
for! train_dataloader, ((x, y),) =>
|
for! train_dataloader, ((x, y),) =>
|
||||||
print! x.shape, y.shape
|
print! x.shape, y.shape
|
||||||
|
|
||||||
|
Net = Inherit nn.Module, Additional := {
|
||||||
|
.conv1 = nn.Conv2d;
|
||||||
|
.conv2 = nn.Conv2d;
|
||||||
|
.pool = nn.MaxPool2d;
|
||||||
|
.fc1 = nn.Linear;
|
||||||
|
.fc2 = nn.Linear;
|
||||||
|
}
|
||||||
|
|
||||||
|
Net.
|
||||||
|
@Override
|
||||||
|
new() = Net::__new__ {
|
||||||
|
conv1 = nn.Conv2d(1, 16, kernel_size:=3, stride:=1, padding:=1);
|
||||||
|
conv2 = nn.Conv2d(16, 32, kernel_size:=3, stride:=1, padding:=1);
|
||||||
|
pool = nn.MaxPool2d(kernel_size:=2, stride:=2);
|
||||||
|
fc1 = nn.Linear(32 * 7 * 7, 128);
|
||||||
|
fc2 = nn.Linear(128, 10)
|
||||||
|
}
|
||||||
|
forward! self, x =
|
||||||
|
x1 = self.pool torch.relu self.conv1 x
|
||||||
|
x2 = self.pool torch.relu self.conv2 x1
|
||||||
|
x3 = x2.view([-1, 32 * 7 * 7])
|
||||||
|
x4 = torch.relu self.fc1 x3
|
||||||
|
x5 = self.fc2 x4
|
||||||
|
x5
|
||||||
|
|
||||||
|
net = Net.new().to device
|
||||||
|
_ = nn.CrossEntropyLoss()
|
||||||
|
_ = torch.optim.Adam! net.parameters(), lr:=0.001
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue