dask_ml.model_selection.train_test_split(*arrays, test_size=None, train_size=None, random_state=None, shuffle=None, blockwise=True, convert_mixed_types=False, **options)

Split arrays into random train and test matrices.

*arraysSequence of Dask Arrays, DataFrames, or Series

Non-dask objects will be passed through to sklearn.model_selection.train_test_split().

test_sizefloat or int, default 0.1
train_sizefloat or int, optional
random_stateint, RandomState instance or None, optional (default=None)

If int, random_state is the seed used by the random number generator; If RandomState instance, random_state is the random number generator; If None, the random number generator is the RandomState instance used by np.random.

shufflebool, default None

Whether to shuffle the data before splitting.

blockwisebool, default True.

Whether to shuffle data only within blocks (True), or allow data to be shuffled between blocks (False). Shuffling between blocks can be much more expensive, especially in distributed environments.

The default is True, data are only shuffled within blocks. For Dask Arrays, set blockwise=False to shuffle data between blocks as well. For Dask DataFrames, blockwise=False is not currently supported and a ValueError will be raised.

convert_mixed_typesbool, default False

Whether to convert dask DataFrames and Series to dask Arrays when arrays contains a mixture of types. This results in some computation to determine the length of each block.

splittinglist, length=2 * len(arrays)

List containing train-test split of inputs


>>> import dask.array as da
>>> from dask_ml.datasets import make_regression
>>> X, y = make_regression(n_samples=125, n_features=4, chunks=50,
...                        random_state=0)
>>> X_train, X_test, y_train, y_test = train_test_split(X, y,
...                                                     random_state=0)
>>> X_train
dask.array<concatenate, shape=(113, 4), dtype=float64, chunksize=(45, 4)>
>>> X_train.compute()[:2]
array([[ 0.12372191,  0.58222459,  0.92950511, -2.09460307],
       [ 0.99439439, -0.70972797, -0.27567053,  1.73887268]])