PyTorch

Skorch brings a Scikit-learn API to PyTorch. Skorch allows PyTorch models to be wrapped in Scikit-learn compatible estimators. So, that means that PyTorch models wrapped in Skorch can be used with the rest of the Dask-ML API. For example, using Dask-ML’s HyperbandSearchCV or Incremental with PyTorch is possible after wrapping with Skorch.

We encourage looking at the Skorch documentation for complete details.

Example usage

First, let’s create a normal PyTorch model:

import torch.nn as nn
import torch.nn.functional as F

class ShallowNet(nn.Module):
    def __init__(self, n_features=5):
        super().__init__()
        self.layer1 = nn.Linear(n_features, 1)

    def forward(self, x):
        return F.relu(self.layer1(x))

With this, it’s easy to use Skorch:

from skorch import NeuralNetRegressor
import torch.optim as optim

niceties = {
    "callbacks": False,
    "warm_start": False,
    "train_split": None,
    "max_epochs": 1,
}

model = NeuralNetRegressor(
    module=ShallowNet,
    module__n_features=5,
    criterion=nn.MSELoss,
    optimizer=optim.SGD,
    optimizer__lr=0.1,
    optimizer__momentum=0.9,
    batch_size=64,
    **niceties,
)

Each parameter that the PyTorch nn.Module takes is prefixed with module__, and same for the optimizer (optim.SGD takes a lr and momentum parameters). The niceties make sure Skorch uses all the data for training and doesn’t print excessive amounts of logs.

Now, this model can be used with Dask-ML. For example, it’s possible to do the following: