Preprocessing
Contents
Preprocessing¶
dask_ml.preprocessing
contains some scikit-learn style transformers that
can be used in Pipelines
to perform various data transformations as part
of the model fitting process. These transformers will work well on dask
collections (dask.array
, dask.dataframe
), NumPy arrays, or pandas
dataframes. They’ll fit and transform in parallel.
Scikit-Learn Clones¶
Some of the transformers are (mostly) drop-in replacements for their scikit-learn counterparts.
|
Transform features by scaling each feature to a given range. |
|
Transforms features using quantile information. |
|
Scale features using statistics that are robust to outliers. |
|
Standardize features by removing the mean and scaling to unit variance. |
|
Encode labels with value between 0 and n_classes-1. |
|
Encode categorical integer features as a one-hot numeric array. |
|
Generate polynomial and interaction features. |
These can be used just like the scikit-learn versions, except that:
They operate on dask collections in parallel
.transform
will return adask.array
ordask.dataframe
when the input is a dask collection
See sklearn.preprocessing
for more information about any particular
transformer. Scikit-learn does have some transforms that are alternatives to
the large-memory tasks that Dask serves. These include FeatureHasher (a
good alternative to DictVectorizer and CountVectorizer) and HashingVectorizer
(best suited for use in text over CountVectorizer). They are not
stateful, which allows easy use with Dask with map_partitions
:
In [1]: import dask.bag as db
In [2]: from sklearn.feature_extraction import FeatureHasher
In [3]: D = [{'dog': 1, 'cat':2, 'elephant':4}, {'dog': 2, 'run': 5}]
In [4]: b = db.from_sequence(D)
In [5]: h = FeatureHasher()
In [6]: b.map_partitions(h.transform).compute()
Out[6]:
[<Compressed Sparse Row sparse matrix of dtype 'float64'
with 3 stored elements and shape (1, 1048576)>,
<Compressed Sparse Row sparse matrix of dtype 'float64'
with 2 stored elements and shape (1, 1048576)>]
Note
dask_ml.preprocessing.LabelEncoder
and
dask_ml.preprocessing.OneHotEncoder
will use the categorical dtype information for a dask or pandas Series with
a pandas.api.types.CategoricalDtype
.
This improves performance, but may lead to different encodings depending on the
categories. See the class docstrings for more.
Encoding Categorical Features¶
dask_ml.preprocessing.OneHotEncoder
can be useful for “one-hot” (or
“dummy”) encoding features.
See the scikit-learn documentation for a full discussion. This section focuses only on the differences from scikit-learn.
Dask-ML Supports pandas’ Categorical dtype¶
Dask-ML supports and uses the type information from pandas Categorical dtype. See https://pandas.pydata.org/pandas-docs/stable/categorical.html for an introduction. For large datasets, using categorical dtypes is crucial for achieving performance.
This will have a couple effects on the learned attributes and transformed values.
The learned
categories_
may differ. Scikit-Learn requires the categories to be sorted. With aCategoricalDtype
the categories do not need to be sorted.The output of
OneHotEncoder.transform()
will be the same type as the input. Passing a pandas DataFrame returns a pandas Dataframe, instead of a NumPy array. Likewise, a Dask DataFrame returns a Dask DataFrame.
Dask-ML’s Sparse Support¶
The default behavior of OneHotEncoder is to return a sparse array. Scikit-Learn
returns a SciPy sparse matrix for ndarrays passed to transform
.
When passed a Dask Array, OneHotEncoder.transform()
returns a Dask Array
where each block is a scipy sparse matrix. SciPy sparse matrices don’t
support the same API as the NumPy ndarray, so most methods won’t work on the
result. Even basic things like compute
will fail. To work around this,
we currently recommend converting the sparse matrices to dense.
from dask_ml.preprocessing import OneHotEncoder
import dask.array as da
import numpy as np
enc = OneHotEncoder(sparse=True)
X = da.from_array(np.array([['A'], ['B'], ['A'], ['C']]), chunks=2)
enc = enc.fit(X)
result = enc.transform(X)
result
Each block of result
is a scipy sparse matrix
result.blocks[0].compute()
# This would fail!
# result.compute()
# Convert to, say, pydata/sparse COO matrices instead
from sparse import COO
result.map_blocks(COO.from_scipy_sparse, dtype=result.dtype).compute()
Dask-ML’s sparse support for sparse data is currently in flux. Reach out if you have any issues.
Additional Tranformers¶
Other transformers are specific to dask-ml.
|
Transform columns of a DataFrame to categorical dtype. |
|
Dummy (one-hot) encode categorical columns. |
|
Ordinal (integer) encode categorical columns. |
Both dask_ml.preprocessing.Categorizer
and
dask_ml.preprocessing.DummyEncoder
deal with converting non-numeric
data to numeric data. They are useful as a preprocessing step in a pipeline
where you start with heterogenous data (a mix of numeric and non-numeric), but
the estimator requires all numeric data.
In this toy example, we use a dataset with two columns. 'A'
is numeric and
'B'
contains text data. We make a small pipeline to
Categorize the text data
Dummy encode the categorical data
Fit a linear regression
In [7]: from dask_ml.preprocessing import Categorizer, DummyEncoder
In [8]: from sklearn.linear_model import LogisticRegression
In [9]: from sklearn.pipeline import make_pipeline
In [10]: import pandas as pd
In [11]: import dask.dataframe as dd
In [12]: df = pd.DataFrame({"A": [1, 2, 1, 2], "B": ["a", "b", "c", "c"]})
In [13]: X = dd.from_pandas(df, npartitions=2)
In [14]: y = dd.from_pandas(pd.Series([0, 1, 1, 0]), npartitions=2)
In [15]: pipe = make_pipeline(
....: Categorizer(),
....: DummyEncoder(),
....: LogisticRegression(solver='lbfgs')
....: )
....:
In [16]: pipe.fit(X, y)
Out[16]:
Pipeline(steps=[('categorizer', Categorizer()),
('dummyencoder', DummyEncoder()),
('logisticregression', LogisticRegression())])
Categorizer
will convert a subset of the columns in X
to categorical
dtype (see here
for more about how pandas handles categorical data). By default, it converts all
the object
dtype columns.
DummyEncoder
will dummy (or one-hot) encode the dataset. This replaces a
categorical column with multiple columns, where the values are either 0 or 1,
depending on whether the value in the original.
In [17]: df['B']
Out[17]:
0 a
1 b
2 c
3 c
Name: B, dtype: object
In [18]: pd.get_dummies(df['B'])
Out[18]:
a b c
0 True False False
1 False True False
2 False False True
3 False False True
Wherever the original was 'a'
, the transformed now has a 1
in the a
column and a 0
everywhere else.
Why was the Categorizer
step necessary? Why couldn’t we operate directly
on the object
(string) dtype column? Doing this would be fragile,
especially when using dask.dataframe
, since the shape of the output would
depend on the values present. For example, suppose that we just saw the first
two rows in the training, and the last two rows in the tests datasets. Then,
when training, our transformed columns would be:
In [19]: pd.get_dummies(df.loc[[0, 1], 'B'])
Out[19]:
a b
0 True False
1 False True
while on the test dataset, they would be:
In [20]: pd.get_dummies(df.loc[[2, 3], 'B'])
Out[20]:
c
2 True
3 True
Which is incorrect! The columns don’t match.
When we categorize the data, we can be confident that all the possible values
have been specified, so the output shape no longer depends on the values in the
whatever subset of the data we currently see. Instead, it depends on the
categories
, which are identical in all the subsets.