..
pytorch
Data
- Torch has two classes for handling data. A
Dataset
class to store the data and aDataLoader
class to iterate over it DataLoader
supports shuffling, batching and iteration
Models
- Every model has to inherit from the
nn.Module
class - It needs two methods
__init__()
forward()
forward()
represents a forward pass in the network- Before running the model, the model has to be sent to a device
- This device can be from
["cuda", "mps", "cpu"]
Training
- For training we need a
optimizer
and aloss_fn
def train(dataloader, model, loss_fn, optimizer):
size = len(dataloader.dataset)
model.train()
for batch, (X, y) in enumerate(dataloader):
X, y = X.to(device), y.to(device)
# Compute prediction error
pred = model(X)
loss = loss_fn(pred, y)
# Backpropagation
loss.backward()
optimizer.step()
optimizer.zero_grad()
if batch % 100 == 0:
loss, current = loss.item(), (batch + 1) * len(X)
print(f"loss: {loss:>7f} [{current:>5d}/{size:>5d}]")
- This is the typical workflow