..

pytorch

Data

  • Torch has two classes for handling data. A Dataset class to store the data and a DataLoader class to iterate over it
  • DataLoader supports shuffling, batching and iteration

Models

  • Every model has to inherit from the nn.Module class
  • It needs two methods
    • __init__()
    • forward()
  • forward() represents a forward pass in the network
  • Before running the model, the model has to be sent to a device
  • This device can be from ["cuda", "mps", "cpu"]

Training

  • For training we need a optimizer and a loss_fn
def train(dataloader, model, loss_fn, optimizer):
    size = len(dataloader.dataset)
    model.train()
    for batch, (X, y) in enumerate(dataloader):
        X, y = X.to(device), y.to(device)

        # Compute prediction error
        pred = model(X)
        loss = loss_fn(pred, y)

        # Backpropagation
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

        if batch % 100 == 0:
            loss, current = loss.item(), (batch + 1) * len(X)
            print(f"loss: {loss:>7f}  [{current:>5d}/{size:>5d}]")
  • This is the typical workflow