API Reference

This page lists all of the estimators and top-level functions in dask_ml. Unless otherwise noted, the estimators implemented in dask-ml are appropriate for parallel and distributed training.

dask_ml.model_selection: Model Selection

Utilities for hyperparameter optimization.

These estimators will operate in parallel. Their scalability depends on the underlying estimators being used.

Dask-ML has a few cross validation utilities.

model_selection.train_test_split(*arrays[, ...])

Split arrays into random train and test matrices.

model_selection.train_test_split() is a simple helper that uses model_selection.ShuffleSplit internally.

model_selection.ShuffleSplit([n_splits, ...])

Random permutation cross-validator.

model_selection.KFold([n_splits, shuffle, ...])

K-Folds cross-validator

Dask-ML provides drop-in replacements for grid and randomized search. These are appropriate for datasets where the CV splits fit in memory.

model_selection.GridSearchCV(estimator, ...)

Exhaustive search over specified parameter values for an estimator.

model_selection.RandomizedSearchCV(...[, ...])

Randomized search on hyper parameters.

For hyperparameter optimization on larger-than-memory datasets, Dask-ML provides the following:

model_selection.IncrementalSearchCV(...[, ...])

Incrementally search for hyper-parameters on models that support partial_fit

model_selection.HyperbandSearchCV(estimator, ...)

Find the best parameters for a particular model with an adaptive cross-validation algorithm.


Perform the successive halving algorithm [R424ea1a907b1-1].

model_selection.InverseDecaySearchCV(...[, ...])

Incrementally search for hyper-parameters on models that support partial_fit

dask_ml.ensemble: Ensemble Methods


Blockwise training and ensemble voting classifier.


Blockwise training and ensemble voting regressor.

dask_ml.linear_model: Generalized Linear Models

The dask_ml.linear_model module implements linear models for classification and regression.

linear_model.LinearRegression([penalty, ...])

Esimator for linear regression.

linear_model.LogisticRegression([penalty, ...])

Esimator for logistic regression.

linear_model.PoissonRegression([penalty, ...])

Esimator for poisson regression.

dask_ml.naive_bayes: Naive Bayes

naive_bayes.GaussianNB([priors, classes])

Fit a naive bayes model with a Gaussian likelihood

dask_ml.wrappers: Meta-Estimators

dask-ml provides some meta-estimators that help use regular estimators that follow the scikit-learn API. These meta-estimators make the underlying estimator work well with Dask Arrays or DataFrames.

wrappers.ParallelPostFit([estimator, ...])

Meta-estimator for parallel predict and transform.

wrappers.Incremental([estimator, scoring, ...])

Metaestimator for feeding Dask Arrays to an estimator blockwise.

dask_ml.cluster: Clustering

Unsupervised Clustering Algorithms

cluster.KMeans([n_clusters, init, ...])

Scalable KMeans for clustering

cluster.SpectralClustering([n_clusters, ...])

Apply parallel Spectral Clustering

dask_ml.decomposition: Matrix Decomposition

decomposition.IncrementalPCA([n_components, ...])

Incremental principal components analysis (IPCA).

decomposition.PCA([n_components, copy, ...])

Principal component analysis (PCA)

decomposition.TruncatedSVD([n_components, ...])


dask_ml.preprocessing: Preprocessing Data

Utilties for Preprocessing data.

preprocessing.StandardScaler(*[, copy, ...])

Standardize features by removing the mean and scaling to unit variance.

preprocessing.RobustScaler(*[, ...])

Scale features using statistics that are robust to outliers.

preprocessing.MinMaxScaler([feature_range, ...])

Transform features by scaling each feature to a given range.

preprocessing.QuantileTransformer(*[, ...])

Transforms features using quantile information.

preprocessing.Categorizer([categories, columns])

Transform columns of a DataFrame to categorical dtype.

preprocessing.DummyEncoder([columns, drop_first])

Dummy (one-hot) encode categorical columns.


Ordinal (integer) encode categorical columns.


Encode labels with value between 0 and n_classes-1.

preprocessing.PolynomialFeatures([degree, ...])

Generate polynomial and interaction features.

preprocessing.BlockTransformer(func, *[, ...])

Construct a transformer from a an arbitrary callable

dask_ml.feature_extraction.text: Feature extraction

feature_extraction.text.CountVectorizer(*[, ...])

Convert a collection of text documents to a matrix of token counts


Convert a collection of text documents to a matrix of token occurrences.


Implements feature hashing, aka the hashing trick.

dask_ml.compose: Composite Estimators

Meta-estimators for building composite models with transformers.

Meta-estimators for composing models with multiple transformers.

These estimators are useful for working with heterogenous tabular data.

compose.ColumnTransformer(transformers[, ...])

Applies transformers to columns of an array or pandas DataFrame.


Construct a ColumnTransformer from the given transformers.

dask_ml.impute: Imputing Missing Data

impute.SimpleImputer(*[, missing_values, ...])


dask_ml.metrics: Metrics

Score functions, performance metrics, and pairwise distance computations.

Regression Metrics

metrics.mean_absolute_error(y_true, y_pred)

Mean absolute error regression loss.


Mean absolute percentage error regression loss.

metrics.mean_squared_error(y_true, y_pred[, ...])

Mean squared error regression loss.

metrics.mean_squared_log_error(y_true, y_pred)

Mean squared logarithmic error regression loss.

metrics.r2_score(y_true, y_pred[, ...])

\(R^2\) (coefficient of determination) regression score function.

Classification Metrics

metrics.accuracy_score(y_true, y_pred[, ...])

Accuracy classification score.

metrics.log_loss(y_true, y_pred[, eps, ...])

Log loss, aka logistic loss or cross-entropy loss.

dask_ml.xgboost: XGBoost

Train an XGBoost model on dask arrays or dataframes.

This may be used for training an XGBoost model on a cluster. XGBoost will be setup in distributed mode alongside your existing dask.distributed cluster.

XGBClassifier([max_depth, learning_rate, ...])


XGBRegressor([max_depth, learning_rate, ...])


train(client, params, data, labels[, ...])

Train an XGBoost model on a Dask Cluster

predict(client, model, data)

Distributed prediction with XGBoost

dask_ml.datasets: Datasets

dask-ml provides some utilities for generating toy datasets.

make_counts([n_samples, n_features, ...])

Generate a dummy dataset for modeling count data.

make_blobs([n_samples, n_features, centers, ...])

Generate isotropic Gaussian blobs for clustering.

make_regression([n_samples, n_features, ...])

Generate a random regression problem.

make_classification([n_samples, n_features, ...])

make_classification_df([n_samples, ...])

Uses the make_classification function to create a dask dataframe for testing.