Thanks to visit codestin.com
Credit goes to github.com

Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -68,3 +68,8 @@ benchmarks/bench_covertype_data/
.cache
.pytest_cache/
_configtest.o.d

# files generated from a template
sklearn/utils/seq_dataset.pyx
sklearn/utils/seq_dataset.pxd
sklearn/linear_model/sag_fast.pyx
153 changes: 107 additions & 46 deletions benchmarks/bench_saga.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
"""Author: Arthur Mensch
"""Author: Arthur Mensch, Nelle Varoquaux

Benchmarks of sklearn SAGA vs lightning SAGA vs Liblinear. Shows the gain
in using multinomial logistic regression in term of learning time.
"""
import json
import time
from os.path import expanduser
import os

import matplotlib.pyplot as plt
import numpy as np
Expand All @@ -21,7 +21,7 @@


def fit_single(solver, X, y, penalty='l2', single_target=True, C=1,
max_iter=10, skip_slow=False):
max_iter=10, skip_slow=False, dtype=np.float64):
if skip_slow and solver == 'lightning' and penalty == 'l1':
print('skip_slowping l1 logistic regression with solver lightning.')
return
Expand All @@ -37,7 +37,8 @@ def fit_single(solver, X, y, penalty='l2', single_target=True, C=1,
multi_class = 'ovr'
else:
multi_class = 'multinomial'

X = X.astype(dtype)
y = y.astype(dtype)
X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=42,
stratify=y)
n_samples = X_train.shape[0]
Expand Down Expand Up @@ -69,11 +70,15 @@ def fit_single(solver, X, y, penalty='l2', single_target=True, C=1,
multi_class=multi_class,
C=C,
penalty=penalty,
fit_intercept=False, tol=1e-24,
fit_intercept=False, tol=0,
max_iter=this_max_iter,
random_state=42,
)

# Makes cpu cache even for all fit calls
X_train.max()
t0 = time.clock()

lr.fit(X_train, y_train)
train_time = time.clock() - t0

Expand Down Expand Up @@ -106,9 +111,13 @@ def _predict_proba(lr, X):
return softmax(pred)


def exp(solvers, penalties, single_target, n_samples=30000, max_iter=20,
def exp(solvers, penalty, single_target,
n_samples=30000, max_iter=20,
dataset='rcv1', n_jobs=1, skip_slow=False):
mem = Memory(cachedir=expanduser('~/cache'), verbose=0)
dtypes_mapping = {
"float64": np.float64,
"float32": np.float32,
}

if dataset == 'rcv1':
rcv1 = fetch_rcv1()
Expand Down Expand Up @@ -151,21 +160,24 @@ def exp(solvers, penalties, single_target, n_samples=30000, max_iter=20,
X = X[:n_samples]
y = y[:n_samples]

cached_fit = mem.cache(fit_single)
out = Parallel(n_jobs=n_jobs, mmap_mode=None)(
delayed(cached_fit)(solver, X, y,
delayed(fit_single)(solver, X, y,
penalty=penalty, single_target=single_target,
dtype=dtype,
C=1, max_iter=max_iter, skip_slow=skip_slow)
for solver in solvers
for penalty in penalties)
for dtype in dtypes_mapping.values())

res = []
idx = 0
for solver in solvers:
for penalty in penalties:
if not (skip_slow and solver == 'lightning' and penalty == 'l1'):
for dtype_name in dtypes_mapping.keys():
for solver in solvers:
if not (skip_slow and
solver == 'lightning' and
penalty == 'l1'):
lr, times, train_scores, test_scores, accuracies = out[idx]
this_res = dict(solver=solver, penalty=penalty,
dtype=dtype_name,
single_target=single_target,
times=times, train_scores=train_scores,
test_scores=test_scores,
Expand All @@ -177,68 +189,117 @@ def exp(solvers, penalties, single_target, n_samples=30000, max_iter=20,
json.dump(res, f)


def plot():
def plot(outname=None):
import pandas as pd
with open('bench_saga.json', 'r') as f:
f = json.load(f)
res = pd.DataFrame(f)
res.set_index(['single_target', 'penalty'], inplace=True)
res.set_index(['single_target'], inplace=True)

grouped = res.groupby(level=['single_target', 'penalty'])
grouped = res.groupby(level=['single_target'])

colors = {'saga': 'blue', 'liblinear': 'orange', 'lightning': 'green'}
colors = {'saga': 'C0', 'liblinear': 'C1', 'lightning': 'C2'}
linestyles = {"float32": "--", "float64": "-"}
alpha = {"float64": 0.5, "float32": 1}

for idx, group in grouped:
single_target, penalty = idx
fig = plt.figure(figsize=(12, 4))
ax = fig.add_subplot(131)

train_scores = group['train_scores'].values
ref = np.min(np.concatenate(train_scores)) * 0.999

for scores, times, solver in zip(group['train_scores'], group['times'],
group['solver']):
scores = scores / ref - 1
ax.plot(times, scores, label=solver, color=colors[solver])
single_target = idx
fig, axes = plt.subplots(figsize=(12, 4), ncols=4)
ax = axes[0]

for scores, times, solver, dtype in zip(group['train_scores'],
group['times'],
group['solver'],
group["dtype"]):
ax.plot(times, scores, label="%s - %s" % (solver, dtype),
color=colors[solver],
alpha=alpha[dtype],
marker=".",
linestyle=linestyles[dtype])
ax.axvline(times[-1], color=colors[solver],
alpha=alpha[dtype],
linestyle=linestyles[dtype])
ax.set_xlabel('Time (s)')
ax.set_ylabel('Training objective (relative to min)')
ax.set_yscale('log')

ax = fig.add_subplot(132)
ax = axes[1]

test_scores = group['test_scores'].values
ref = np.min(np.concatenate(test_scores)) * 0.999
for scores, times, solver, dtype in zip(group['test_scores'],
group['times'],
group['solver'],
group["dtype"]):
ax.plot(times, scores, label=solver, color=colors[solver],
linestyle=linestyles[dtype],
marker=".",
alpha=alpha[dtype])
ax.axvline(times[-1], color=colors[solver],
alpha=alpha[dtype],
linestyle=linestyles[dtype])

for scores, times, solver in zip(group['test_scores'], group['times'],
group['solver']):
scores = scores / ref - 1
ax.plot(times, scores, label=solver, color=colors[solver])
ax.set_xlabel('Time (s)')
ax.set_ylabel('Test objective (relative to min)')
ax.set_yscale('log')

ax = fig.add_subplot(133)
ax = axes[2]
for accuracy, times, solver, dtype in zip(group['accuracies'],
group['times'],
group['solver'],
group["dtype"]):
ax.plot(times, accuracy, label="%s - %s" % (solver, dtype),
alpha=alpha[dtype],
marker=".",
color=colors[solver], linestyle=linestyles[dtype])
ax.axvline(times[-1], color=colors[solver],
alpha=alpha[dtype],
linestyle=linestyles[dtype])

for accuracy, times, solver in zip(group['accuracies'], group['times'],
group['solver']):
ax.plot(times, accuracy, label=solver, color=colors[solver])
ax.set_xlabel('Time (s)')
ax.set_ylabel('Test accuracy')
ax.legend()
name = 'single_target' if single_target else 'multi_target'
name += '_%s' % penalty
plt.suptitle(name)
name += '.png'
if outname is None:
outname = name + '.png'
fig.tight_layout()
fig.subplots_adjust(top=0.9)
plt.savefig(name)
plt.close(fig)

ax = axes[3]
for scores, times, solver, dtype in zip(group['train_scores'],
group['times'],
group['solver'],
group["dtype"]):
ax.plot(np.arange(len(scores)),
scores, label="%s - %s" % (solver, dtype),
marker=".",
alpha=alpha[dtype],
color=colors[solver], linestyle=linestyles[dtype])

ax.set_yscale("log")
ax.set_xlabel('# iterations')
ax.set_ylabel('Objective function')
ax.legend()

plt.savefig(outname)


if __name__ == '__main__':
solvers = ['saga', 'liblinear', 'lightning']
solvers = ['saga', ]
penalties = ['l1', 'l2']
n_samples = [100000, 300000, 500000, 800000, None]
single_target = True
exp(solvers, penalties, single_target, n_samples=None, n_jobs=1,
dataset='20newspaper', max_iter=20)
plot()
for penalty in penalties:
for n_sample in n_samples:
exp(solvers, penalty, single_target,
n_samples=n_sample, n_jobs=1,
dataset='rcv1', max_iter=10)
if n_sample is not None:
outname = "figures/saga_%s_%d.png" % (penalty, n_sample)
else:
outname = "figures/saga_%s_all.png" % (penalty,)
try:
os.makedirs("figures")
except OSError:
pass
plot(outname)
19 changes: 14 additions & 5 deletions sklearn/linear_model/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,9 @@
from ..utils.extmath import safe_sparse_dot
from ..utils.sparsefuncs import mean_variance_axis, inplace_column_scale
from ..utils.fixes import sparse_lsqr
from ..utils.seq_dataset import ArrayDataset, CSRDataset
from ..utils.seq_dataset import ArrayDataset32, CSRDataset32
from ..utils.seq_dataset import ArrayDataset64 as ArrayDataset
from ..utils.seq_dataset import CSRDataset64 as CSRDataset
from ..utils.validation import check_is_fitted
from ..exceptions import NotFittedError
from ..preprocessing.data import normalize as f_normalize
Expand Down Expand Up @@ -76,15 +78,22 @@ def make_dataset(X, y, sample_weight, random_state=None):
"""

rng = check_random_state(random_state)
# seed should never be 0 in SequentialDataset
# seed should never be 0 in SequentialDataset64
seed = rng.randint(1, np.iinfo(np.int32).max)

if X.dtype == np.float32:
CSRData = CSRDataset32
ArrayData = ArrayDataset32
else:
CSRData = CSRDataset
ArrayData = ArrayDataset

if sp.issparse(X):
dataset = CSRDataset(X.data, X.indptr, X.indices, y, sample_weight,
seed=seed)
dataset = CSRData(X.data, X.indptr, X.indices, y, sample_weight,
seed=seed)
intercept_decay = SPARSE_INTERCEPT_DECAY
else:
dataset = ArrayDataset(X, y, sample_weight, seed=seed)
dataset = ArrayData(X, y, sample_weight, seed=seed)
intercept_decay = 1.0

return dataset, intercept_decay
Expand Down
12 changes: 8 additions & 4 deletions sklearn/linear_model/logistic.py
Original file line number Diff line number Diff line change
Expand Up @@ -738,7 +738,7 @@ def logistic_regression_path(X, y, pos_class=None, Cs=10, fit_intercept=True,

elif solver in ['sag', 'saga']:
if multi_class == 'multinomial':
target = target.astype(np.float64)
target = target.astype(X.dtype, copy=False)
loss = 'multinomial'
else:
loss = 'log'
Expand Down Expand Up @@ -1206,6 +1206,10 @@ def fit(self, X, y, sample_weight=None):
Returns
-------
self : object

Notes
-----
The SAGA solver supports both float64 and float32 bit arrays.
"""
if not isinstance(self.C, numbers.Number) or self.C < 0:
raise ValueError("Penalty term must be positive; got (C=%r)"
Expand All @@ -1217,10 +1221,10 @@ def fit(self, X, y, sample_weight=None):
raise ValueError("Tolerance for stopping criteria must be "
"positive; got (tol=%r)" % self.tol)

if self.solver in ['newton-cg']:
_dtype = [np.float64, np.float32]
else:
if self.solver in ['lbfgs', 'liblinear']:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be good to document in the Notes section of the object what solvers can handle what dtype.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done!

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Although I could improve the wording…

_dtype = np.float64
else:
_dtype = [np.float64, np.float32]

X, y = check_X_y(X, y, accept_sparse='csr', dtype=_dtype, order="C",
accept_large_sparse=self.solver != 'liblinear')
Expand Down
Loading