erg/examples/pytorch.er
2024-02-09 14:38:57 +09:00

110 lines
3 KiB
Python

torch = pyimport "torch"
nn = pyimport "torch/nn"
data = pyimport "torch/utils/data"
datasets = pyimport "torchvision/datasets"
transforms = pyimport "torchvision/transforms"
plt = pyimport "matplotlib/pyplot"
_ = torch.manual_seed! 1
device = torch.device if torch.cuda.is_available!(), do "cuda", do "mps"
transform = transforms.Compose(
[transforms.ToTensor(), transforms.Normalize([0.5], [0.5])]
)
batch_size = 512
train_set = datasets.MNIST(
root:="target/data",
train:=True,
download:=True,
transform:=transform,
)
train_loader = data.DataLoader(
train_set,
batch_size:=batch_size,
shuffle:=True,
num_workers:=2,
)
test_set = datasets.MNIST(
root:="target/data",
train:=False,
download:=True,
transform:=transform,
)
test_loader = data.DataLoader(
test_set,
batch_size:=batch_size,
shuffle:=False,
num_workers:=2,
)
Net = Inherit nn.Module, Additional := {
.conv1 = nn.Conv2d;
.conv2 = nn.Conv2d;
.pool = nn.MaxPool2d;
.fc1 = nn.Linear;
.fc2 = nn.Linear;
}
Net.
@Override
__init__ ref! self =
nn.Module.__init__ self
@Override
new() = Net {
conv1 = nn.Conv2d(1, 16, kernel_size:=3, stride:=1, padding:=1);
conv2 = nn.Conv2d(16, 32, kernel_size:=3, stride:=1, padding:=1);
pool = nn.MaxPool2d(kernel_size:=2, stride:=2);
fc1 = nn.Linear(32 * 7 * 7, 128);
fc2 = nn.Linear(128, 10)
}
forward self, x =
x1 = self.pool torch.relu self.conv1 x
x2 = self.pool torch.relu self.conv2 x1
x3 = x2.view([-1, 32 * 7 * 7])
x4 = torch.relu self.fc1 x3
x5 = self.fc2 x4
x5
net = Net.new().to device
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam! net.parameters(), lr:=0.001
for! 0..4, epoch =>
running_loss = !0.0
for! enumerate(train_loader), ((i, data),) =>
inputs, labels = data.0.to(device), data.1.to device
optimizer.zero_grad!()
outputs = net inputs
loss = criterion outputs, labels
loss.backward!()
optimizer.step!()
running_loss.inc! loss.item()
if! i % 100 == 99, do!:
print! "epoch:", epoch + 1, "loss:", running_loss / 100
running_loss.update! _ -> !0.0
correct = !0
total = !0
with! torch.no_grad(), _ =>
for! test_loader, data =>
images, labels = data.0.to(device), data.1.to(device)
outputs = net images
_, predicted = torch.max outputs, 1
total.inc! labels.size(0)
correct.inc!((predicted == labels).sum().item())
print! "Accuracy of the network on the 10000 test images:", correct / total
image = test_set[0].0.to device
output = net(image.unsqueeze(0))
predicted = torch.softmax output, dim:=1
discard plt.text! 30, 3, "Propabilities:", fontsize:=12
for! 0..<10, i =>
discard plt.text! 30, 5 + i * 2, "\{i}: {:.4f}".format(predicted[0][i]), fontsize:=12
img = image.cpu().squeeze().numpy()
discard plt.imshow! img, cmap:="gray"
plt.show!()