Description
Description
I get the error of ValueError: buffer source array is read-only
in the example below whenever I pass a dataframe with around 200K rows and at least one column of dtype Object into GridSearchCV
with n_jobs > 1
. The error seems to be caused by passing in a Dataframe that has Object columns into GridsearchCV.fit
. My custom class, DataFrame_Encoder
, properly encodes the Object rows (by dummy encoding them) when the pipeline executes but this error occurs before it executes. Things work fine if I use a smaller dataset, drop the Object column from the dataframe, or set n_jobs=1
.
My minimal example to reproduce the bug is a bit lengthy, so I've also included a notebook with the code and some theories as to what is happening: https://github.com/stoddardg/sklearn_bug_example/blob/master/Bug%20Exploration.ipynb
Steps/Code to Reproduce
Example:
import pandas as pd
from sklearn.model_selection import GridSearchCV
from sklearn.pipeline import Pipeline
from sklearn.datasets import make_classification
from sklearn.base import BaseEstimator, TransformerMixin
from sklearn.ensemble import RandomForestClassifier
from sklearn.feature_extraction import DictVectorizer
import numpy as np
class DataFrame_Encoder(BaseEstimator, TransformerMixin):
def __init__(self, categorical_cols_=None,numeric_cols_=None):
print("__init__ called")
self.categorical_cols_ = categorical_cols_
self.numeric_cols_ = numeric_cols_
def fit(self, df, y=None):
print("Fit called")
### df should be a dataframe that is a mix of categorical and numeric columns
self.vec_ = DictVectorizer(sparse=False)
temp_data = df[self.categorical_cols_].astype(str)
self.vec_.fit(temp_data.to_dict('records'))
self.feature_names_ = list(self.numeric_cols_) + list(self.vec_.feature_names_)
return self
def transform(self, df):
### df should be a dataframe that is a mix of categorical and numeric columns
print("Transform called")
temp_data = df[self.categorical_cols_].astype(str)
categorical_data = self.vec_.transform(temp_data.to_dict('records'))
categorical_df = pd.DataFrame(categorical_data, columns=self.vec_.feature_names_, index=df.index)
new_data = pd.concat([df[self.numeric_cols_], categorical_df],axis=1)
return new_data
x,y = make_classification(n_samples=200000,n_features=5)
numeric_features = ['x1','x2','x3','x4','x5']
string_features = ['category']
df = pd.DataFrame(data=x,columns=numeric_features)
df['category'] = 'a'
base_clf = RandomForestClassifier(n_jobs=4)
param_grid = {'clf__n_estimators':[10,100]}
pipeline = Pipeline([
('feature_encoder',DataFrame_Encoder()),
('clf',base_clf)
])
pipeline.set_params(feature_encoder__categorical_cols_=string_features, feature_encoder__numeric_cols_=numeric_features)
clf = GridSearchCV(pipeline, param_grid,cv=5,n_jobs=2,verbose=1)
clf.fit(df,y)
---------------------------------------------------------------------------
Sub-process traceback:
---------------------------------------------------------------------------
ValueError Thu Aug 3 16:29:11 2017
PID: 16736 Python 3.6.2: /volatile/le243287/miniconda3/bin/python
...........................................................................
/home/le243287/dev/alt-scikit-learn/sklearn/externals/joblib/parallel.py in __call__(self=<sklearn.externals.joblib.parallel.BatchedCalls object>)
126 def __init__(self, iterator_slice):
127 self.items = list(iterator_slice)
128 self._size = len(self.items)
129
130 def __call__(self):
--> 131 return [func(*args, **kwargs) for func, args, kwargs in self.items]
self.items = [(<function _fit_and_score>, (Pipeline(memory=None,
steps=[('feature_enco...None, verbose=0,
warm_start=False))]), x1 x2 x3 x4 ...124 0.815972 a
[200000 rows x 6 columns], memmap([0, 0, 1, ..., 1, 1, 1]), {'score': <function _passthrough_scorer>}, memmap([ 39843, 39844, 39846, ..., 199997, 199998, 199999]), array([ 0, 1, 2, ..., 40174, 40178, 40184]), 1, {'clf__n_estimators': 10}), {'error_score': 'raise', 'fit_params': {}, 'return_n_test_samples': True, 'return_parameters': False, 'return_times': True, 'return_train_score': True})]
132
133 def __len__(self):
134 return self._size
135
...........................................................................
/home/le243287/dev/alt-scikit-learn/sklearn/externals/joblib/parallel.py in <listcomp>(.0=<list_iterator object>)
126 def __init__(self, iterator_slice):
127 self.items = list(iterator_slice)
128 self._size = len(self.items)
129
130 def __call__(self):
--> 131 return [func(*args, **kwargs) for func, args, kwargs in self.items]
func = <function _fit_and_score>
args = (Pipeline(memory=None,
steps=[('feature_enco...None, verbose=0,
warm_start=False))]), x1 x2 x3 x4 ...124 0.815972 a
[200000 rows x 6 columns], memmap([0, 0, 1, ..., 1, 1, 1]), {'score': <function _passthrough_scorer>}, memmap([ 39843, 39844, 39846, ..., 199997, 199998, 199999]), array([ 0, 1, 2, ..., 40174, 40178, 40184]), 1, {'clf__n_estimators': 10})
kwargs = {'error_score': 'raise', 'fit_params': {}, 'return_n_test_samples': True, 'return_parameters': False, 'return_times': True, 'return_train_score': True}
132
133 def __len__(self):
134 return self._size
135
...........................................................................
/home/le243287/dev/alt-scikit-learn/sklearn/model_selection/_validation.py in _fit_and_score(estimator=Pipeline(memory=None,
steps=[('feature_enco...None, verbose=0,
warm_start=False))]), X= x1 x2 x3 x4 ...124 0.815972 a
[200000 rows x 6 columns], y=memmap([0, 0, 1, ..., 1, 1, 1]), scorer={'score': <function _passthrough_scorer>}, train=memmap([ 39843, 39844, 39846, ..., 199997, 199998, 199999]), test=array([ 0, 1, 2, ..., 40174, 40178, 40184]), verbose=1, parameters={'clf__n_estimators': 10}, fit_params={}, return_train_score=True, return_parameters=False, return_n_test_samples=True, return_times=True, error_score='raise')
422 if parameters is not None:
423 estimator.set_params(**parameters)
424
425 start_time = time.time()
426
--> 427 X_train, y_train = _safe_split(estimator, X, y, train)
X_train = undefined
y_train = undefined
estimator = Pipeline(memory=None,
steps=[('feature_enco...None, verbose=0,
warm_start=False))])
X = x1 x2 x3 x4 ...124 0.815972 a
[200000 rows x 6 columns]
y = memmap([0, 0, 1, ..., 1, 1, 1])
train = memmap([ 39843, 39844, 39846, ..., 199997, 199998, 199999])
428 X_test, y_test = _safe_split(estimator, X, y, test, train)
429
430 is_multimetric = not callable(scorer)
431 n_scorers = len(scorer.keys()) if is_multimetric else 1
...........................................................................
/home/le243287/dev/alt-scikit-learn/sklearn/utils/metaestimators.py in _safe_split(estimator=Pipeline(memory=None,
steps=[('feature_enco...None, verbose=0,
warm_start=False))]), X= x1 x2 x3 x4 ...124 0.815972 a
[200000 rows x 6 columns], y=memmap([0, 0, 1, ..., 1, 1, 1]), indices=memmap([ 39843, 39844, 39846, ..., 199997, 199998, 199999]), train_indices=None)
195 if train_indices is None:
196 X_subset = X[np.ix_(indices, indices)]
197 else:
198 X_subset = X[np.ix_(indices, train_indices)]
199 else:
--> 200 X_subset = safe_indexing(X, indices)
X_subset = undefined
X = x1 x2 x3 x4 ...124 0.815972 a
[200000 rows x 6 columns]
indices = memmap([ 39843, 39844, 39846, ..., 199997, 199998, 199999])
201
202 if y is not None:
203 y_subset = safe_indexing(y, indices)
204 else:
...........................................................................
/home/le243287/dev/alt-scikit-learn/sklearn/utils/__init__.py in safe_indexing(X= x1 x2 x3 x4 ...124 0.815972 a
[200000 rows x 6 columns], indices=memmap([ 39843, 39844, 39846, ..., 199997, 199998, 199999]))
148 except ValueError:
149 # Cython typed memoryviews internally used in pandas do not support
150 # readonly buffers.
151 warnings.warn("Copying input dataframe for slicing.",
152 DataConversionWarning)
--> 153 return X.copy().iloc[indices]
X.copy.iloc = undefined
indices = memmap([ 39843, 39844, 39846, ..., 199997, 199998, 199999])
154 elif hasattr(X, "shape"):
155 if hasattr(X, 'take') and (hasattr(indices, 'dtype') and
156 indices.dtype.kind == 'i'):
157 # This is often substantially faster than X[indices]
...........................................................................
/volatile/le243287/miniconda3/lib/python3.6/site-packages/pandas/core/indexing.py in __getitem__(self=<pandas.core.indexing._iLocIndexer object>, key=memmap([ 39843, 39844, 39846, ..., 199997, 199998, 199999]))
1323 except (KeyError, IndexError):
1324 pass
1325 return self._getitem_tuple(key)
1326 else:
1327 key = com._apply_if_callable(key, self.obj)
-> 1328 return self._getitem_axis(key, axis=0)
self._getitem_axis = <bound method _iLocIndexer._getitem_axis of <pandas.core.indexing._iLocIndexer object>>
key = memmap([ 39843, 39844, 39846, ..., 199997, 199998, 199999])
1329
1330 def _is_scalar_access(self, key):
1331 raise NotImplementedError()
1332
...........................................................................
/volatile/le243287/miniconda3/lib/python3.6/site-packages/pandas/core/indexing.py in _getitem_axis(self=<pandas.core.indexing._iLocIndexer object>, key=memmap([ 39843, 39844, 39846, ..., 199997, 199998, 199999]), axis=0)
1733 self._has_valid_type(key, axis)
1734 return self._getbool_axis(key, axis=axis)
1735
1736 # a list of integers
1737 elif is_list_like_indexer(key):
-> 1738 return self._get_list_axis(key, axis=axis)
self._get_list_axis = <bound method _iLocIndexer._get_list_axis of <pandas.core.indexing._iLocIndexer object>>
key = memmap([ 39843, 39844, 39846, ..., 199997, 199998, 199999])
axis = 0
1739
1740 # a single integer
1741 else:
1742 key = self._convert_scalar_indexer(key, axis)
...........................................................................
/volatile/le243287/miniconda3/lib/python3.6/site-packages/pandas/core/indexing.py in _get_list_axis(self=<pandas.core.indexing._iLocIndexer object>, key=memmap([ 39843, 39844, 39846, ..., 199997, 199998, 199999]), axis=0)
1710 Returns
1711 -------
1712 Series object
1713 """
1714 try:
-> 1715 return self.obj.take(key, axis=axis, convert=False)
self.obj.take = <bound method NDFrame.take of x1 ...24 0.815972 a
[200000 rows x 6 columns]>
key = memmap([ 39843, 39844, 39846, ..., 199997, 199998, 199999])
axis = 0
1716 except IndexError:
1717 # re-raise with different error message
1718 raise IndexError("positional indexers are out-of-bounds")
1719
...........................................................................
/volatile/le243287/miniconda3/lib/python3.6/site-packages/pandas/core/generic.py in take(self= x1 x2 x3 x4 ...124 0.815972 a
[200000 rows x 6 columns], indices=memmap([ 39843, 39844, 39846, ..., 199997, 199998, 199999]), axis=0, convert=False, is_copy=True, **kwargs={})
1923 """
1924 nv.validate_take(tuple(), kwargs)
1925 self._consolidate_inplace()
1926 new_data = self._data.take(indices,
1927 axis=self._get_block_manager_axis(axis),
-> 1928 convert=True, verify=True)
convert = False
1929 result = self._constructor(new_data).__finalize__(self)
1930
1931 # maybe set copy if we didn't actually change the index
1932 if is_copy:
...........................................................................
/volatile/le243287/miniconda3/lib/python3.6/site-packages/pandas/core/internals.py in take(self=BlockManager
Items: Index(['x1', 'x2', 'x3', 'x4...tBlock: slice(5, 6, 1), 1 x 200000, dtype: object, indexer=memmap([ 39843, 39844, 39846, ..., 199997, 199998, 199999]), axis=1, verify=True, convert=True)
4006 raise Exception('Indices must be nonzero and less than '
4007 'the axis length')
4008
4009 new_labels = self.axes[axis].take(indexer)
4010 return self.reindex_indexer(new_axis=new_labels, indexer=indexer,
-> 4011 axis=axis, allow_dups=True)
axis = 1
4012
4013 def merge(self, other, lsuffix='', rsuffix=''):
4014 if not self._is_indexed_like(other):
4015 raise AssertionError('Must have same axes to merge managers')
...........................................................................
/volatile/le243287/miniconda3/lib/python3.6/site-packages/pandas/core/internals.py in reindex_indexer(self=BlockManager
Items: Index(['x1', 'x2', 'x3', 'x4...tBlock: slice(5, 6, 1), 1 x 200000, dtype: object, new_axis=Int64Index([ 39843, 39844, 39846, 39848, 398...199999],
dtype='int64', length=160000), indexer=memmap([ 39843, 39844, 39846, ..., 199997, 199998, 199999]), axis=1, fill_value=None, allow_dups=True, copy=True)
3892 new_blocks = self._slice_take_blocks_ax0(indexer,
3893 fill_tuple=(fill_value,))
3894 else:
3895 new_blocks = [blk.take_nd(indexer, axis=axis, fill_tuple=(
3896 fill_value if fill_value is not None else blk.fill_value,))
-> 3897 for blk in self.blocks]
self.blocks = (FloatBlock: slice(0, 5, 1), 5 x 200000, dtype: float64, ObjectBlock: slice(5, 6, 1), 1 x 200000, dtype: object)
3898
3899 new_axes = list(self.axes)
3900 new_axes[axis] = new_axis
3901 return self.__class__(new_blocks, new_axes)
...........................................................................
/volatile/le243287/miniconda3/lib/python3.6/site-packages/pandas/core/internals.py in <listcomp>(.0=<tuple_iterator object>)
3892 new_blocks = self._slice_take_blocks_ax0(indexer,
3893 fill_tuple=(fill_value,))
3894 else:
3895 new_blocks = [blk.take_nd(indexer, axis=axis, fill_tuple=(
3896 fill_value if fill_value is not None else blk.fill_value,))
-> 3897 for blk in self.blocks]
blk = FloatBlock: slice(0, 5, 1), 5 x 200000, dtype: float64
3898
3899 new_axes = list(self.axes)
3900 new_axes[axis] = new_axis
3901 return self.__class__(new_blocks, new_axes)
...........................................................................
/volatile/le243287/miniconda3/lib/python3.6/site-packages/pandas/core/internals.py in take_nd(self=FloatBlock: slice(0, 5, 1), 5 x 200000, dtype: float64, indexer=memmap([ 39843, 39844, 39846, ..., 199997, 199998, 199999]), axis=1, new_mgr_locs=None, fill_tuple=(nan,))
1041 new_values = algos.take_nd(values, indexer, axis=axis,
1042 allow_fill=False)
1043 else:
1044 fill_value = fill_tuple[0]
1045 new_values = algos.take_nd(values, indexer, axis=axis,
-> 1046 allow_fill=True, fill_value=fill_value)
fill_value = nan
1047
1048 if new_mgr_locs is None:
1049 if axis == 0:
1050 slc = lib.indexer_as_slice(indexer)
...........................................................................
/volatile/le243287/miniconda3/lib/python3.6/site-packages/pandas/core/algorithms.py in take_nd(arr=memmap([[ 1.85430272, 0.02363887, -0.44955668, ... 0.22950348,
0.80573119, 0.81597234]]), indexer=memmap([ 39843, 39844, 39846, ..., 199997, 199998, 199999]), axis=1, out=array([[ 0., 0., 0., ..., 0., 0., 0.],
...0.],
[ 0., 0., 0., ..., 0., 0., 0.]]), fill_value=nan, mask_info=None, allow_fill=True)
1466 else:
1467 out = np.empty(out_shape, dtype=dtype)
1468
1469 func = _get_take_nd_function(arr.ndim, arr.dtype, out.dtype, axis=axis,
1470 mask_info=mask_info)
-> 1471 func(arr, indexer, out, fill_value)
func = <built-in function take_2d_axis1_float64_float64>
arr = memmap([[ 1.85430272, 0.02363887, -0.44955668, ... 0.22950348,
0.80573119, 0.81597234]])
indexer = memmap([ 39843, 39844, 39846, ..., 199997, 199998, 199999])
out = array([[ 0., 0., 0., ..., 0., 0., 0.],
...0.],
[ 0., 0., 0., ..., 0., 0., 0.]])
fill_value = nan
1472
1473 if flip_order:
1474 out = out.T
1475 return out
...........................................................................
/volatile/le243287/miniconda3/lib/python3.6/site-packages/pandas/_libs/algos.cpython-36m-x86_64-linux-gnu.so in pandas._libs.algos.take_2d_axis1_float64_float64 (pandas/_libs/algos.c:111160)()
...........................................................................
/volatile/le243287/miniconda3/lib/python3.6/site-packages/pandas/_libs/algos.cpython-36m-x86_64-linux-gnu.so in View.MemoryView.memoryview_cwrapper (pandas/_libs/algos.c:124730)()
...........................................................................
/volatile/le243287/miniconda3/lib/python3.6/site-packages/pandas/_libs/algos.cpython-36m-x86_64-linux-gnu.so in View.MemoryView.memoryview.__cinit__ (pandas/_libs/algos.c:120965)()
ValueError: buffer source array is read-only
___________________________________________________________________________
Expected Results
No error is thrown.
Actual Results
I get an incredibly long error message (viewable in the notebook) but the punchline is:
ValueError: buffer source array is read-only
Versions
Darwin-15.6.0-x86_64-i386-64bit
Python 3.6.1 |Continuum Analytics, Inc.| (default, May 11 2017, 13:04:09)
[GCC 4.2.1 Compatible Apple LLVM 6.0 (clang-600.0.57)]
NumPy 1.13.1
SciPy 0.19.1
Scikit-Learn 0.18.2