-
-
Notifications
You must be signed in to change notification settings - Fork 26.6k
ENH speedup coordinate descent by avoiding calls to axpy in innermost loop #31956
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
BenchmarkingBenchmarking code from #17021. DetailsTime Ratios Code Detailsfile mtl_bench.py """
Benchmark of MultiTaskLasso
"""
import gc
from itertools import product
from time import time
import numpy as np
import pandas as pd
from sklearn.datasets import make_regression
from sklearn.linear_model import MultiTaskLasso
def compute_bench(alpha, n_samples, n_features, n_tasks):
results = []
n_bench = len(n_samples) * len(n_features) * len(n_tasks)
for it, (ns, nf, nt) in enumerate(product(n_samples, n_features, n_tasks)):
print('==================')
print('Iteration %s of %s' % (it, n_bench))
print('==================')
n_informative = nf // 10
X, Y, coef_ = make_regression(n_samples=ns, n_features=nf,
n_informative=n_informative,
n_targets=nt,
noise=0.1, coef=True)
X /= np.sqrt(np.sum(X ** 2, axis=0)) # Normalize data
gc.collect()
clf = MultiTaskLasso(alpha=alpha, fit_intercept=False)
tstart = time()
clf.fit(X, Y)
results.append(
dict(n_samples=ns, n_features=nf, n_tasks=nt, time=time() - tstart)
)
return pd.DataFrame(results)
def compare_results():
results_new = pd.read_csv('mlt_new.csv').set_index(['n_samples', 'n_features', 'n_tasks'])
results_old = pd.read_csv('mlt_old.csv').set_index(['n_samples', 'n_features', 'n_tasks'])
results_ratio = (results_old / results_new)
results_ratio.columns = ['time (old) / time (new)']
print(results_new)
print(results_old)
print(results_ratio)
if __name__ == '__main__':
import matplotlib.pyplot as plt
alpha = 0.01 # regularization parameter
list_n_features = [300, 1000, 4000]
list_n_samples = [100, 500]
list_n_tasks = [2, 10, 20, 50]
results = compute_bench(alpha, list_n_samples,
list_n_features, list_n_tasks)
# results.to_csv('mlt_old.csv', index=False)
results.to_csv('mlt_new.csv', index=False)
compare_results() |
OmarManzoor
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the PR @lorentzenchr
Just a few comments otherwise looks nice
| Same for functions :func:`linear_model.enet_path` and | ||
| :func:`linear_model.lasso_path`. | ||
| By :user:`Christian Lorentzen <lorentzenchr>`. | ||
| By :user:`Christian Lorentzen <lorentzenchr>` :pr:`31956` and |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we are still waiting for confirmation of this format.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You can have a look at the rendered docs in the CD, it looks fine.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I did check in the other PR where I agree it looked fine
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The alternative would be to create 31956.efficiency.rst with duplicate content (and without referring to any PR number in it). Towncrier should take care of merging entries with identical content and link to both PRs for the single resulting entry.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
But I agree the rendering looks good, so this solution is fine for me as well.
OmarManzoor
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM. Thanks @lorentzenchr
ogrisel
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM. Thanks @lorentzenchr and @OmarManzoor.
Reference Issues/PRs
Similar to #31880.
Continues and fixes #15931.
What does this implement/fix? Explain your changes.
This PR avoids calls to
_axpyin the innermost loop of all coordinate descent solvers (Lasso and Enet), exceptenet_coordinate_descent_gramwhich was done in #31880.Any other comments?
Ironically, this improvement also reduces code size 😄
For reviewers: better merge #31957 first.