Thanks to visit codestin.com
Credit goes to github.com

Skip to content

[WIP] Adding tests for estimators implementing partial_fit and a few other related fixes / enhancements #3907

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions sklearn/linear_model/stochastic_gradient.py
Original file line number Diff line number Diff line change
Expand Up @@ -361,7 +361,7 @@ def _partial_fit(self, X, y, alpha, C,
n_samples, n_features = X.shape

self._validate_params()
_check_partial_fit_first_call(self, classes)
_check_partial_fit_first_call(self, y, classes)

n_classes = self.classes_.shape[0]

Expand All @@ -374,8 +374,8 @@ def _partial_fit(self, X, y, alpha, C,
self._allocate_parameter_mem(n_classes, n_features,
coef_init, intercept_init)
elif n_features != self.coef_.shape[-1]:
raise ValueError("Number of features %d does not match previous data %d."
% (n_features, self.coef_.shape[-1]))
raise ValueError("Number of features %d does not match previous "
"data %d." % (n_features, self.coef_.shape[-1]))

self.loss_function = self._get_loss_function(loss)
if self.t_ is None:
Expand Down Expand Up @@ -884,8 +884,8 @@ def _partial_fit(self, X, y, alpha, C, loss, learning_rate,
self._allocate_parameter_mem(1, n_features,
coef_init, intercept_init)
elif n_features != self.coef_.shape[-1]:
raise ValueError("Number of features %d does not match previous data %d."
% (n_features, self.coef_.shape[-1]))
raise ValueError(("Number of features %d does not match previous "
"data %d.") % (n_features, self.coef_.shape[-1]))
if self.average > 0 and self.average_coef_ is None:
self.average_coef_ = np.zeros(n_features,
dtype=np.float64,
Expand Down
15 changes: 3 additions & 12 deletions sklearn/naive_bayes.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@
from .utils import check_X_y, check_array
from .utils.extmath import safe_sparse_dot, logsumexp
from .utils.multiclass import _check_partial_fit_first_call
from .utils.fixes import in1d
from .utils.validation import check_is_fitted
from .externals import six

Expand Down Expand Up @@ -325,7 +324,7 @@ def _partial_fit(self, X, y, classes=None, _refit=False,
if _refit:
self.classes_ = None

if _check_partial_fit_first_call(self, classes):
if _check_partial_fit_first_call(self, y, classes):
# This is the first call to partial_fit:
# initialize various cumulative counters
n_features = X.shape[1]
Expand All @@ -341,18 +340,10 @@ def _partial_fit(self, X, y, classes=None, _refit=False,
# Put epsilon back in each time
self.sigma_[:, :] -= epsilon

classes = self.classes_

unique_y = np.unique(y)
unique_y_in_classes = in1d(unique_y, classes)

if not np.all(unique_y_in_classes):
raise ValueError("The target label(s) %s in y do not exist in the "
"initial classes %s" %
(y[~unique_y_in_classes], classes))

for y_i in unique_y:
i = classes.searchsorted(y_i)
i = self.classes_.searchsorted(y_i)
X_i = X[y == y_i, :]

if sample_weight is not None:
Expand Down Expand Up @@ -453,7 +444,7 @@ def partial_fit(self, X, y, classes=None, sample_weight=None):
X = check_array(X, accept_sparse='csr', dtype=np.float64)
_, n_features = X.shape

if _check_partial_fit_first_call(self, classes):
if _check_partial_fit_first_call(self, y, classes):
# This is the first call to partial_fit:
# initialize various cumulative counters
n_effective_classes = len(classes) if len(classes) > 1 else 2
Expand Down
1 change: 1 addition & 0 deletions sklearn/neural_network/rbm.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,6 +252,7 @@ def partial_fit(self, X, y=None):
self.h_samples_ = np.zeros((self.batch_size, self.n_components))

self._fit(X, self.random_state_)
return self

def _fit(self, v_pos, rng):
"""Inner fit for one mini-batch.
Expand Down
5 changes: 0 additions & 5 deletions sklearn/tests/test_naive_bayes.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,11 +44,6 @@ def test_gnb():
y_pred_log_proba = clf.predict_log_proba(X)
assert_array_almost_equal(np.log(y_pred_proba), y_pred_log_proba, 8)

# Test whether label mismatch between target y and classes raises
# an Error
# FIXME Remove this test once the more general partial_fit tests are merged
assert_raises(ValueError, GaussianNB().partial_fit, X, y, classes=[0, 1])


def test_gnb_prior():
# Test whether class priors are properly set.
Expand Down
Loading