Cross Validation¶
See the scikit-learn cross validation documentation for a fuller discussion of cross validation. This document only describes the extensions made to support Dask arrays.
The simplest way to split one or more Dask arrays is with dask_ml.model_selection.train_test_split()
.
In [1]: import dask.array as da
In [2]: from dask_ml.datasets import make_regression
In [3]: from dask_ml.model_selection import train_test_split
In [4]: X, y = make_regression(n_samples=125, n_features=4, random_state=0, chunks=50)
In [5]: X
Out[5]: dask.array<normal, shape=(125, 4), dtype=float64, chunksize=(50, 4)>
The interface for splitting Dask arrays is the same as scikit-learn’s version.
In [6]: X_train, X_test, y_train, y_test = train_test_split(X, y)
In [7]: X_train # A dask Array
Out[7]: dask.array<concatenate, shape=(112, 4), dtype=float64, chunksize=(45, 4)>
In [8]: X_train.compute()[:3]