mirror of
https://github.com/erg-lang/erg.git
synced 2025-12-23 05:36:48 +00:00
110 lines
3 KiB
Python
110 lines
3 KiB
Python
torch = pyimport "torch"
|
|
nn = pyimport "torch/nn"
|
|
data = pyimport "torch/utils/data"
|
|
|
|
datasets = pyimport "torchvision/datasets"
|
|
transforms = pyimport "torchvision/transforms"
|
|
|
|
plt = pyimport "matplotlib/pyplot"
|
|
|
|
_ = torch.manual_seed! 1
|
|
device = torch.device if torch.cuda.is_available!(), do "cuda", do "mps"
|
|
|
|
transform = transforms.Compose(
|
|
[transforms.ToTensor(), transforms.Normalize([0.5], [0.5])]
|
|
)
|
|
batch_size = 512
|
|
train_set = datasets.MNIST(
|
|
root:="target/data",
|
|
train:=True,
|
|
download:=True,
|
|
transform:=transform,
|
|
)
|
|
train_loader = data.DataLoader(
|
|
train_set,
|
|
batch_size:=batch_size,
|
|
shuffle:=True,
|
|
num_workers:=2,
|
|
)
|
|
test_set = datasets.MNIST(
|
|
root:="target/data",
|
|
train:=False,
|
|
download:=True,
|
|
transform:=transform,
|
|
)
|
|
test_loader = data.DataLoader(
|
|
test_set,
|
|
batch_size:=batch_size,
|
|
shuffle:=False,
|
|
num_workers:=2,
|
|
)
|
|
|
|
Net = Inherit nn.Module, Additional := {
|
|
.conv1 = nn.Conv2d;
|
|
.conv2 = nn.Conv2d;
|
|
.pool = nn.MaxPool2d;
|
|
.fc1 = nn.Linear;
|
|
.fc2 = nn.Linear;
|
|
}
|
|
|
|
Net.
|
|
@Override
|
|
__init__ ref! self =
|
|
nn.Module.__init__ self
|
|
@Override
|
|
new() = Net {
|
|
conv1 = nn.Conv2d(1, 16, kernel_size:=3, stride:=1, padding:=1);
|
|
conv2 = nn.Conv2d(16, 32, kernel_size:=3, stride:=1, padding:=1);
|
|
pool = nn.MaxPool2d(kernel_size:=2, stride:=2);
|
|
fc1 = nn.Linear(32 * 7 * 7, 128);
|
|
fc2 = nn.Linear(128, 10)
|
|
}
|
|
forward self, x =
|
|
x1 = self.pool torch.relu self.conv1 x
|
|
x2 = self.pool torch.relu self.conv2 x1
|
|
x3 = x2.view([-1, 32 * 7 * 7])
|
|
x4 = torch.relu self.fc1 x3
|
|
x5 = self.fc2 x4
|
|
x5
|
|
|
|
net = Net.new().to device
|
|
criterion = nn.CrossEntropyLoss()
|
|
optimizer = torch.optim.Adam! net.parameters(), lr:=0.001
|
|
|
|
for! 0..4, epoch =>
|
|
running_loss = !0.0
|
|
for! enumerate(train_loader), ((i, data),) =>
|
|
inputs, labels = data.0.to(device), data.1.to device
|
|
|
|
optimizer.zero_grad!()
|
|
outputs = net inputs
|
|
loss = criterion outputs, labels
|
|
loss.backward!()
|
|
optimizer.step!()
|
|
|
|
running_loss.inc! loss.item()
|
|
if! i % 100 == 99, do!:
|
|
print! "epoch:", epoch + 1, "loss:", running_loss / 100
|
|
running_loss.update! _ -> !0.0
|
|
|
|
correct = !0
|
|
total = !0
|
|
with! torch.no_grad(), _ =>
|
|
for! test_loader, data =>
|
|
images, labels = data.0.to(device), data.1.to(device)
|
|
outputs = net images
|
|
_, predicted = torch.max outputs, 1
|
|
total.inc! labels.size(0)
|
|
correct.inc!((predicted == labels).sum().item())
|
|
|
|
print! "Accuracy of the network on the 10000 test images:", correct / total
|
|
|
|
image = test_set[0].0.to device
|
|
output = net(image.unsqueeze(0))
|
|
predicted = torch.softmax output, dim:=1
|
|
discard plt.text! 30, 3, "Propabilities:", fontsize:=12
|
|
for! 0..<10, i =>
|
|
discard plt.text! 30, 5 + i * 2, "\{i}: {:.4f}".format(predicted[0][i]), fontsize:=12
|
|
img = image.cpu().squeeze().numpy()
|
|
discard plt.imshow! img, cmap:="gray"
|
|
plt.show!()
|