dask_ml.model_selection.train_test_split(*arrays, test_size=None, train_size=None, random_state=None, shuffle=None, blockwise=None, convert_mixed_types=False, **options)

Split arrays into random train and test matricies.

*arrays : Sequence of Dask Arrays, DataFrames, or Series

Non-dask objects will be passed through to sklearn.model_selection.train_test_split().

test_size : float or int, default 0.1
train_size : float or int, optional
random_state : int, RandomState instance or None, optional (default=None)

If int, random_state is the seed used by the random number generator; If RandomState instance, random_state is the random number generator; If None, the random number generator is the RandomState instance used by np.random.

shuffle : bool, default None

Whether to shuffle the data before splitting.

blockwise : bool, optional.

Whether to shuffle data only within blocks (True), or allow data to be shuffled between blocks (False). Shuffling between blocks can be much more expensive, especially in distributed environments.

The default behavior depends on the types in arrays. For Dask Arrays, the default is True (data are not shuffled between blocks). For Dask DataFrames, the default and only allowed value is False (data are shuffled between blocks).

convert_mixed_types : bool, default False

Whether to convert dask DataFrames and Series to dask Arrays when arrays contains a mixiture of types. This results in some computation to determine the length of each block.

splitting : list, length=2 * len(arrays)

List containing train-test split of inputs


>>> import dask.array as da
>>> from dask_ml.datasets import make_regression
>>> X, y = make_regression(n_samples=125, n_features=4, chunks=50,
...                    random_state=0)
>>> X_train, X_test, y_train, y_test = train_test_split(X, y,
...                                                     random_state=0)
>>> X_train
dask.array<concatenate, shape=(113, 4), dtype=float64, chunksize=(45, 4)>
>>> X_train.compute()[:2]
array([[ 0.12372191,  0.58222459,  0.92950511, -2.09460307],
       [ 0.99439439, -0.70972797, -0.27567053,  1.73887268]])