dask_ml.model_selection.train_test_split
dask_ml.model_selection.train_test_split¶
- dask_ml.model_selection.train_test_split(*arrays, test_size=None, train_size=None, random_state=None, shuffle=None, blockwise=True, convert_mixed_types=False, **options)¶
Split arrays into random train and test matrices.
- Parameters
- *arraysSequence of Dask Arrays, DataFrames, or Series
Non-dask objects will be passed through to
sklearn.model_selection.train_test_split()
.- test_sizefloat or int, default 0.1
- train_sizefloat or int, optional
- random_stateint, RandomState instance or None, optional (default=None)
If int, random_state is the seed used by the random number generator; If RandomState instance, random_state is the random number generator; If None, the random number generator is the RandomState instance used by np.random.
- shufflebool, default None
Whether to shuffle the data before splitting.
- blockwisebool, default True.
Whether to shuffle data only within blocks (True), or allow data to be shuffled between blocks (False). Shuffling between blocks can be much more expensive, especially in distributed environments.
The default is
True
, data are only shuffled within blocks. For Dask Arrays, setblockwise=False
to shuffle data between blocks as well. For Dask DataFrames,blockwise=False
is not currently supported and aValueError
will be raised.- convert_mixed_typesbool, default False
Whether to convert dask DataFrames and Series to dask Arrays when arrays contains a mixture of types. This results in some computation to determine the length of each block.
- Returns
- splittinglist, length=2 * len(arrays)
List containing train-test split of inputs
Examples
>>> import dask.array as da >>> from dask_ml.datasets import make_regression
>>> X, y = make_regression(n_samples=125, n_features=4, chunks=50, ... random_state=0) >>> X_train, X_test, y_train, y_test = train_test_split(X, y, ... random_state=0) >>> X_train dask.array<concatenate, shape=(113, 4), dtype=float64, chunksize=(45, 4)> >>> X_train.compute()[:2] array([[ 0.12372191, 0.58222459, 0.92950511, -2.09460307], [ 0.99439439, -0.70972797, -0.27567053, 1.73887268]])