Thanks to visit codestin.com
Credit goes to github.com

Skip to content

[MRG+3] ENH Caching Pipeline by memoizing transformer #7990

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Feb 13, 2017

Conversation

glemaitre
Copy link
Member

Reference Issue

Address the discussions in #3951
Other related issues and PR: #2086 #5082 #5080

What does this implement/fix? Explain your changes.

It implements a version of Pipeline which allows for caching transformer.

@glemaitre
Copy link
Member Author

@jnothman @GaelVaroquaux @amueller

From #3951, this is what I could come with. I will add promptly the testing from the Pipeline and adapt it. From what I could check, this is working for those tests.

However, I get into trouble while using the CachedPipeline in a GridSearchCV:

from sklearn.ensemble import RandomForestClassifier
from sklearn.datasets import samples_generator
from sklearn.decomposition import PCA
from sklearn.pipeline import CachedPipeline
from sklearn.model_selection import GridSearchCV

# generate some data to play with                                                               
X, y = samples_generator.make_classification(
    n_samples=100,
    n_informative=5, n_redundant=0, random_state=42)

pca = PCA()
clf = RandomForestClassifier()
pipeline = CachedPipeline([('pca', pca), ('rf', clf)],
                          memory='./')
parameters = {'pca__n_components': (.25, .5, .75),
              'rf__n_estimators': (10, 20, 30)}
grid_search = GridSearchCV(pipeline, parameters, n_jobs=1, verbose=1)

grid_search.fit(X, y)

After that joblib loads the cached PCA, the transformer is seen as non fitted:

NotFittedError: This PCA instance is not fitted yet.
Call 'fit' with appropriate arguments before using this method.

I'm going to check why but if you have already an obvious answer, I would be happy to hear it.

@GaelVaroquaux
Copy link
Member

GaelVaroquaux commented Dec 6, 2016 via email

@glemaitre
Copy link
Member Author

Do you get the same problem if you take a fitted PCA, pickle it, unpickle it and try to transform on the data?

Nop, this is working fine.

pca.fit(X, y)
joblib.dump(pca, 'pca.p')
pickled_pca = joblib.load('pca.p')
pickled_pca.transform(X)

memory = self.memory
if isinstance(memory, six.string_types) or memory is None:
memory = Memory(cachedir=memory, verbose=10)
self._fit_transform_one = memory.cache(_fit_transform_one)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't do that: don't decorate methods. Use the pattern described in https://github.com/joblib/joblib/blob/master/doc/memory.rst, under the bullet point "caching methods".

@glemaitre
Copy link
Member Author

For the first configuration of the grid search, only the only dumped and fitted PCA is the first one. On the 2 others, PCA is dumped but not fitted.

Xt, transform = memory.cache(_fit_transform_one)(
transform, name,
None, Xt, y,
**fit_params_steps[name])
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would refactor Pipeline with a private class to avoid the code dupe.

@GaelVaroquaux
Copy link
Member

GaelVaroquaux commented Dec 6, 2016 via email

@glemaitre
Copy link
Member Author

glemaitre commented Dec 6, 2016

Do you get the same problem if you take a fitted PCA, pickle it, unpickle it and try to transform on the data?

Stupid mistakes fixed. I forgot to assign the loaded pipeline.

@glemaitre
Copy link
Member Author

@agramfort @GaelVaroquaux While speaking of private class, is the last commit implements what you had in mind?

@GaelVaroquaux

I would use a subclass of Pipeline, rather than Pipeline itself. The reason being that it is probably a good idea to clone the objects before memoizing the fit (in order to maximize the cache hits).

I should miss something regarding the cloning of the objects. Could you elaborate on what is required in the implementation?

@glemaitre
Copy link
Member Author

I would use a subclass of Pipeline, rather than Pipeline itself. The reason being that it is probably a good idea to clone the objects before memoizing the fit (in order to maximize the cache hits).

I should miss something regarding the cloning of the objects. Could you elaborate on what is required in the implementation?

I think I got it. The cloning is to clean the fitting info to not cache twice the same estimator, once without fitting and the second with the fitting.

@glemaitre
Copy link
Member Author

I added the test which are similar to the one propose in #3951
However, this test does not check that the cache has been loaded.
It only check the resulting array which kinda permissive.

Is there anything in joblib which inform if the cache has been read?

@glemaitre
Copy link
Member Author

For now, I am checking that the timestamp of a DummyTransformer, assigned at the first fit within the pipeline, is recovered through the different fitting.

@glemaitre
Copy link
Member Author

@agramfort Do you see additional changes to do?

@lesteve
Copy link
Member

lesteve commented Dec 20, 2016

You should probably change your title to [MRG].

Can you trigger a rebuild (e.g. by doing git commit --amend and force push) to trigger a documentation build. It looks like at one point CircleCI stopped running on PRs last week.

with its name to another estimator, or a transformer removed by setting
to None.

Read more in the :ref:`User Guide <pipeline>`.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should you not link to the cached_pipeline section?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1 to add it to See Also

Usage
-----

Similarly to :class:`Piepeline`, the pipeline is built on the same manner. However,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually you have a typo here in Pipeline. No need to do git commit --amend changing this should trigger a CircleCI doc build

@@ -0,0 +1,77 @@
#!/usr/bin/python
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like the wrong "#!" line to me: it will always use the system Python, which is often not the right thing. The right way is "#!/usr/bin/env python"

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The shebang line is only used if the example is executable right? No example has the executable flag AFAICT. I would be in favour of removing the shebang line.

@GaelVaroquaux
Copy link
Member

Any way that the new example could be integrated with an existing one? The danger is that we have too many examples.

The way that I would suggest doing it is by extending an existing one (if there is one that is relevant), and use "notebook-style examples" of sphinx-gallery to add extra cells at the bottom (with an extra title) without making the initial example more complicated. That aspect is important: the initial example should be left as simple, without additional lines of code. And sphinx-gallery notebook-style formatting can be used to add discussion and code at the end.

@glemaitre
Copy link
Member Author

Any way that the new example could be integrated with an existing one? The danger is that we have too many examples.

This example is a modified version of examples/plot_compare_reduction.py .
I'll try to merge both as suggested.

@lesteve
Copy link
Member

lesteve commented Dec 20, 2016

@glemaitre the Travis failure in the doctest is because the default of decision_function_shape is None and not 'ovr' in master.

I can reproduce the failure locally. The reason we could not reproduce locally earlier was because we were trying on your branch rather than on the result of the merge of your branch into master, d'oh ...

@glemaitre
Copy link
Member Author

@lesteve d'oh indeed ...

@glemaitre
Copy link
Member Author

@GaelVaroquaux Is the example look like what you had in mind?

memory=memory)

# This time, a cached pipeline will be used within the grid search
grid = GridSearchCV(cached_pipe, cv=3, n_jobs=2, param_grid=param_grid)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

don't use n_jobs != 1 in examples

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This must still be changed.

with its name to another estimator, or a transformer removed by setting
to None.

Read more in the :ref:`User Guide <pipeline>`.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1 to add it to See Also

# Don't make the cachedir, Memory should be able to do that on the fly
print(80 * '_')
print('test_memory setup (%s)' % env['dir'])
print(80 * '_')
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why all this print statements?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In fact it was the way which was used in joblib.
However, I am having a second thought about that. Probably a simple try - finally statement as here could be enough for the purpose of the test.

@@ -124,6 +125,40 @@ i.e. if the last estimator is a classifier, the :class:`Pipeline` can be used
as a classifier. If the last estimator is a transformer, again, so is the
pipeline.

.. _cached_pipeline:

CachedPipeline: memoizing transformers
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that I would put a title that is more focused towards the problem that it solves rather than the technique used. Something like "CachedPipeline: avoiding to repeat computation"


.. currentmodule:: sklearn.pipeline

:class:`CachedPipeline` can be used instead of :class:`Pipeline` to avoid to fit
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would say "to avoid to compute the fit"

Selecting dimensionality reduction with Pipeline and GridSearchCV
=================================================================
=======================================================================
Selecting dimensionality reduction with Pipeline, CachedPipeline, and \
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think that I would change the title here. It makes it too long, IMHO.

import numpy as np
import matplotlib.pyplot as plt
from sklearn.datasets import load_digits
from sklearn.model_selection import GridSearchCV
from sklearn.pipeline import Pipeline
from sklearn.pipeline import Pipeline, CachedPipeline
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would import this only later: trying to separate a bit the two parts of the example.

We have to keep in mind that every added piece of information to an example makes it harder to understand.

from sklearn.svm import LinearSVC
from sklearn.decomposition import PCA, NMF
from sklearn.feature_selection import SelectKBest, chi2
from sklearn.externals.joblib import Memory
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same comment.

@@ -73,3 +90,29 @@
plt.ylim((0, 1))
plt.legend(loc='upper left')
plt.show()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that that "show" must be moved at the end.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

With the current scheme, it will generate the figure right after the code snipped and before the second section. It looks fine to me.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It will block the execution of the script before the end. The convention is really to have it only at the end.

@codecov
Copy link

codecov bot commented Feb 9, 2017

Codecov Report

❗ No coverage uploaded for pull request base (master@dfcf632). Click here to learn what that means.

@@            Coverage Diff            @@
##             master    #7990   +/-   ##
=========================================
  Coverage          ?   94.74%           
=========================================
  Files             ?      342           
  Lines             ?    60739           
  Branches          ?        0           
=========================================
  Hits              ?    57546           
  Misses            ?     3193           
  Partials          ?        0
Impacted Files Coverage Δ
sklearn/tests/test_pipeline.py 99.61% <100%> (ø)
sklearn/pipeline.py 99.26% <95.23%> (ø)

Continue to review full report at Codecov.

Legend - Click here to learn more
Δ = absolute <relative> (impact), ø = not affected, ? = missing data
Powered by Codecov. Last update dfcf632...1afb47f. Read the comment docs.

@jnothman jnothman changed the title [MRG+2] ENH Caching Pipeline by memoizing transformer [MRG+3] ENH Caching Pipeline by memoizing transformer Feb 9, 2017
@jnothman
Copy link
Member

jnothman commented Feb 9, 2017

Are we good for merge? Should we wait for the joblib fixes?

@glemaitre
Copy link
Member Author

@ogrisel wanted to wait for the joblib fixes.

@jnothman
Copy link
Member

jnothman commented Feb 9, 2017 via email

Copy link
Member

@ogrisel ogrisel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually as the default setting will not trigger the joblib race condition I thing we can merge this as is.

@GaelVaroquaux
Copy link
Member

Hurray. Merging. Good job, @glemaitre !

@GaelVaroquaux GaelVaroquaux merged commit b3a639f into scikit-learn:master Feb 13, 2017
@ogrisel
Copy link
Member

ogrisel commented Feb 13, 2017

@jnothman w.r.t. your comment on GS loop ordering in #7990 (review) this would not impact the optimal design of this PR right?

@raghavrv
Copy link
Member

Hurray!! 🎉

@jnothman
Copy link
Member

@jnothman w.r.t. your comment on GS loop ordering in #7990 (review) this would not impact the optimal design of this PR right?

I don't know what you're asking, @ogrisel. The ordering issue doesn't stop this change being useful, but it makes this (and other memoisation) less useful in parallel because the cache will be missed unnecessarily. Maximal cache hits is (n_candidates - 1) * n_splits, but under the current ordering it is more like (n_candidates - n_jobs) * n_splits.

@ogrisel
Copy link
Member

ogrisel commented Feb 13, 2017

Agreed. Once the joblib race condition is fixed on windows we ca reinvestigate that issue in GS.

@amueller
Copy link
Member

OMG AMAZING!

sergeyf pushed a commit to sergeyf/scikit-learn that referenced this pull request Feb 28, 2017
)

* ENH Caching Pipeline by memoizing transformer

* Fix lesteve changes

* Fix comments

* Fix doc

* Fix jnothman comments
@Przemo10 Przemo10 mentioned this pull request Mar 17, 2017
@lsorber
Copy link

lsorber commented Jun 6, 2017

Thanks for the nice work on this. May I suggest to (optionally) also cache the pipeline's last step? The last step could itself be a transformer (e.g., in a pipeline of pipelines).

In fact, I'm not even sure what the downside of caching all steps by default is. If you want the current behaviour, you could simply create a cached pipeline of all the steps you want cached, and insert that into a non-caching pipeline for the steps you don't want cached. Conversely, it is much more tedious to create a fully cached pipeline with the current implementation.

@amueller
Copy link
Member

amueller commented Jun 6, 2017

@lsorber can you maybe open a new issue for that?

@lsorber
Copy link

lsorber commented Jun 6, 2017

All right, will do.

@jnothman
Copy link
Member

jnothman commented Jun 6, 2017 via email

@lsorber
Copy link

lsorber commented Jun 6, 2017

@jnothman I was in fact suggesting to cache the full pipeline, including the last step. Some motivation in #9007.

Sundrique pushed a commit to Sundrique/scikit-learn that referenced this pull request Jun 14, 2017
)

* ENH Caching Pipeline by memoizing transformer

* Fix lesteve changes

* Fix comments

* Fix doc

* Fix jnothman comments
NelleV pushed a commit to NelleV/scikit-learn that referenced this pull request Aug 11, 2017
)

* ENH Caching Pipeline by memoizing transformer

* Fix lesteve changes

* Fix comments

* Fix doc

* Fix jnothman comments
paulha pushed a commit to paulha/scikit-learn that referenced this pull request Aug 19, 2017
)

* ENH Caching Pipeline by memoizing transformer

* Fix lesteve changes

* Fix comments

* Fix doc

* Fix jnothman comments
maskani-moh pushed a commit to maskani-moh/scikit-learn that referenced this pull request Nov 15, 2017
)

* ENH Caching Pipeline by memoizing transformer

* Fix lesteve changes

* Fix comments

* Fix doc

* Fix jnothman comments
lemonlaug pushed a commit to lemonlaug/scikit-learn that referenced this pull request Jan 6, 2021
)

* ENH Caching Pipeline by memoizing transformer

* Fix lesteve changes

* Fix comments

* Fix doc

* Fix jnothman comments
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

9 participants