Thanks to visit codestin.com
Credit goes to github.com

Skip to content

WIP allow Pipeline to memoize partial results #2086

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from

Conversation

jnothman
Copy link
Member

This PR adds a memory parameter to Pipeline which allows it to memoize the results of partial pipeline evaluations (fits, transforms). See [http://www.mail-archive.com/[email protected]/msg07402.html](a request for this feature).

Currently:

  • it is a prototype, untested and not sparkling from clean code
  • it requires storing the Pipeline's training data on the instance (perhaps a hash would suffice)
  • the implementation perhaps caches data more frequently than necessary
  • fit and transform are called separately even where fit_transform is implemented
  • it memoizes the last step (whether it be fit, transform, score or predict, etc.), perhaps unnecessarily
  • it does not support passing keyword arguments to steps' fit methods as Pipeline currently does (this is the failing test)
  • it does not take advantage of the fact that for some estimators only a subset of parameters produce distinct models (e.g. for SelectKBest, only score_func should be a key for fit, while (score_func, k) affect the result of transform)

@GaelVaroquaux, is this what you had in mind?
@amueller, this isn't as much a generalised CV solution (#1626) as #2000 was; to what extent does it satisfy your use-cases?

@jnothman
Copy link
Member Author

I should clarify the mechanism a bit: for the ith step, the following are in the cache key:

  • the step names, classes and parameters up to step i
  • the most recent arguments to fit or fit_transform
  • the current arguments to whatever method is being called

For fit, the cache stores all models from the beginning of the Pipeline to step i. For transform etc it stores the output at step i.

Thus the cache is checked recursively from the end to the beginning of the pipeline (and filled from the beginning to the end).

@jnothman
Copy link
Member Author

I guess one thing I would like to know is: which transform methods in scikit-learn are substantially more expensive than loading a cached result?

@GaelVaroquaux
Copy link
Member

I guess one thing I would like to know is: which transform methods in
scikit-learn are substantially more expensive than loading a cached result?

As this is estimator dependent, this could/should be sorted by adding a
'memory' keyword to the transformers, as in the case of the feature
agglomeration. The nice aspect is that it can then be chosen to be used
in the best place (think the SVD in the PCA for instance).

G

return self.steps[:len(step_states)]

@cache
def _transform(X, other_args, fit_args, step_states,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please don't define functions in closures. It makes code hard to debug.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I couldn't work out another way to sensibly use Memory.cache to perform this operation. In particular, I do not want the pipeline object itself to be part of the cache key. I considered an approach that uses Memory.cache's ignore argument, but I can't remember why I decided against it.

If you have a neat alternative, let me know. But this code is currently intended only as a prototype of the functionality.

@jnothman
Copy link
Member Author

As this is estimator dependent, this could/should be sorted by adding a 'memory' keyword to the transformers, as in the case of the feature agglomeration.

I thought this might be what you intended. As far as I'm concerned, configuring a memory parameter (and providing) for each transformer is an annoyance, as would be a caching MetaEstimator.

I think there is a great usability advantage in providing an easy solution for avoiding unnecessary Pipeline subsequence fits and perhaps transforms (the value of both is estimator-dependent). The question I was getting at is whether there's much value saved by caching transforms as well as fits.

And if fine-grained control is actually necessary, one could implement cache_fits and cache_transforms parameters that are boolean or an array of booleans corresponding to the Pipeline steps. But I think this solution is overkill for a convenience implementation.

@GaelVaroquaux
Copy link
Member

As far as I'm concerned, configuring a memory parameter (and providing)
for each transformer is an annoyance, as would be a caching
MetaEstimator.

You are right in theory, but in practice, you will always get better
performance with more specific code.

I think there is a great usability advantage in providing an easy
solution for avoiding unnecessary Pipeline subsequence fits and perhaps
transforms (the value of both is estimator-dependent). The question I
was getting at is whether there's much value saved by caching
transforms as well as fits.

I agree that generic memory in pipelines would be useful. I would think
that we want to cache transformer's fit, but probably not the transform
method.

@jnothman
Copy link
Member Author

I'm considering what you wrote here:

One problem that you face is that an estimator object can have estimated parameters. These may render the cache invalid, while they really shouldn't affect the fit.

My way of handling this was providing a closure whose arguments corresponded exactly to the cache key. Such closures are indeed not ideal, but I also think cloning the estimator is merely a workaround, not a clean solution. This is a limitation of Memory. Perhaps just as Memory.cache has an argument ignore, it should also have a way to get additional cache keys.

The other difficulty I here is that fit updates the state of the object, and we don't actually want to cache its return value so much as the final state (even though they should generally be the same). It is the reason fit_transform cannot be cached.

@jnothman
Copy link
Member Author

Your cache(fit)(clone(estimator), X, y) solution also will not handle the case of excluding parameters that don't affect fit and its learnt attributes. For example, using a LinearSVC as a feature selector, but wanting to play with the threshold in different grid search candidates, we don't want to refit when that threshold parameter changes.

As far as I'm concerned, that eliminates clone as an option. Would you rather a closure, or some extra argument to Memory.cache?

@GaelVaroquaux
Copy link
Member

My way of handling this was providing a closure whose arguments corresponded
exactly to the cache key. Such closures are indeed not ideal, but I also think
cloning the estimator is merely a workaround, not a clean solution. This is a
limitation of Memory. Perhaps just as Memory.cache has an argument ignore, it
should also have a way to get additional cache keys.

My experience of software development (which is somewhat substantial
having worked and followed many project) is that complexity is our worst
enemy, to a point which should not be underestimated. Indeed, as a
project grows more and more complex, the development slows down as adding
each feature becomes harder, and less and less people are qualified to
modify it.

Simpler code should be preferred to elegant code. Features that add a lot
of complexity should be included only if they are mission critical.

These rules, I believe, are excellent guidelines to making a project
successful in the long run. And indeed, I review code with them in mind:
if there is a simpler solution (less abstractions, less lines of code), I
will always push for it.

@GaelVaroquaux
Copy link
Member

Your cache(fit)(clone(estimator), X, y) solution also will not handle
the case of excluding parameters that don't affect fit and its learnt
attributes. For example, using a LinearSVC as a feature selector, but
wanting to play with the threshold in different grid search candidates,
we don't want to refit when that threshold parameter changes.

Yes, that's correct. I don't want to build a full pipeline with parameter
tracking. Experience shows that it is a very costly enterprise that
really slows down package development. I would suggest people with such
need to implement a 'mem' inside the estimator object: using a simpler
solution, that is not generic.

@jnothman
Copy link
Member Author

I understand where you're coming from, but removing redundant work in a grid search is a frequent, and sensible, request. If you can consider a way to do it without modifying each underlying estimator (should users really be expected to do so?) that does not add complexity to the code or the interface, do let me know.

@ogrisel
Copy link
Member

ogrisel commented Oct 31, 2013

I have been thinking about this use case and I think it should be possible to generate clean cache keys recursively to support nested estimators, see: https://gist.github.com/ogrisel/7091781

@briandastous
Copy link

My feature extraction step takes up the vast bulk of my pipeline execution time. This feature would really help me out.

@jnothman
Copy link
Member Author

One option you could use is here. Unfortunately it requires a small change to the scikit-learn codebase, after which you can just wrap those models you want to memoize the fitting of with remember_model (or remember_transform).

However, you might be better off just performing feature extraction as a preprocessing step.

@briandastous
Copy link

I ended up writing a wrapper estimator that pickles the fitted estimator to a file on the first run and on subsequent runs just unpickles it, but a more general solution does seem like it would be useful in many situations.

@agramfort
Copy link
Member

Well, if there are no other comments, I'll merge in a day or so. The CI
failures appear spurious.

please don't merge before getting 2 +1 by others.

@jnothman
Copy link
Member Author

Ah, @agramfort, I intended that for another thread! :)

@jnothman
Copy link
Member Author

I ended up writing a wrapper estimator that pickles the fitted estimator to a file on the first run and on subsequent runs just unpickles it, but a more general solution does seem like it would be useful in many situations.

remember_model does basically that, however in that case I found the cloning behaviour in cross validation broke it, and so I had to specially handle it.

@jnothman
Copy link
Member Author

Perhaps the cloning behaviour only broke when you wanted to memoize the fit, but there were parameters etc. that didn't depend on fit. There may be better ways to do remember_model in any case.

@amueller
Copy link
Member

Sorry I didn't follow the discussion closely, could you summarize why you closed this one?

@jnothman
Copy link
Member Author

I have PRs that are older than this one and I actually expect to receive
reviews and be merged. This one I don't expect so. The issue is still an
issue: too much work is being redone in cross-validated pipelines by
default, but I don't think this is the right solution. I think being able
to tell any estimator to cache its model with respect to all or a subset of
its parameters and training data (by way of metaestimator, mixin, inbuilt
parameter in BaseEstimator or whatever) is a good way to keep the feature
modular. Supporting a metaestimator form only requires a change for
sklearn.base.clone to support polymorphism (see
jnothman@de0f86d
).

On 11 November 2014 04:52, Andreas Mueller [email protected] wrote:

Sorry I didn't follow the discussion closely, could you summarize why
you closed this one?


Reply to this email directly or view it on GitHub
#2086 (comment)
.

@simonzack
Copy link

@jnothman I'm after this feature as well. So what was the problem with your other method jnothman@de0f86d for it not to be merged? It looks like something I can use already in my project. The only problem I see is the possibly slow cache comparisons when the inputs (X) are large, as in pipelines only the initial input really needs to be compared.

@amueller
Copy link
Member

I think we should actually investigate https://github.com/ContinuumIO/dask for doing this, though this would be a bit of a dependency.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

7 participants