-
-
Notifications
You must be signed in to change notification settings - Fork 26.3k
Partial centroid #13033
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Partial centroid #13033
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -17,15 +17,17 @@ | |
from ..preprocessing import LabelEncoder | ||
from ..utils.validation import check_array, check_X_y, check_is_fitted | ||
from ..utils.sparsefuncs import csc_median_axis_0 | ||
from ..utils.multiclass import check_classification_targets | ||
from ..utils.multiclass import check_classification_targets, _check_partial_fit_first_call | ||
from copy import deepcopy | ||
|
||
|
||
class NearestCentroid(BaseEstimator, ClassifierMixin): | ||
"""Nearest centroid classifier. | ||
|
||
Each class is represented by its centroid, with test samples classified to | ||
the class with the nearest centroid. | ||
|
||
Read more in the :ref:`User Guide <nearest_centroid_classifier>`. | ||
Read more in the :ref:` >`. | ||
|
||
Parameters | ||
---------- | ||
|
@@ -92,8 +94,65 @@ def fit(self, X, y): | |
Training vector, where n_samples is the number of samples and | ||
n_features is the number of features. | ||
Note that centroid shrinking cannot be used with sparse matrices. | ||
|
||
y : array, shape = [n_samples] | ||
Target values (integers) | ||
""" | ||
return self._partial_fit(X, y, np.unique(y), _refit=True) | ||
|
||
def partial_fit(self, X, y, classes=None): | ||
"""Incremental fit on a batch of samples. | ||
|
||
This method is expected to be called several times consecutively | ||
on different chunks of a dataset so as to implement out-of-core | ||
or online learning. | ||
|
||
This is especially useful when the whole dataset is too big to fit in | ||
memory at once. | ||
|
||
Parameters | ||
---------- | ||
X : {array-like, sparse matrix}, shape = [n_samples, n_features] | ||
Training vector, where n_samples is the number of samples and | ||
n_features is the number of features. | ||
Note that centroid shrinking cannot be used with sparse matrices. | ||
|
||
y : array, shape = [n_samples] | ||
Target values (integers) | ||
|
||
classes : array-like, shape (n_classes,), optional (default=None) | ||
List of all the classes that can possibly appear in the y vector. | ||
|
||
Must be provided at the first call to partial_fit, can be omitted | ||
in subsequent calls. | ||
""" | ||
if self.metric == 'manhattan': | ||
raise ValueError("Partial fitting with manhattan is not supported.") | ||
return self._partial_fit(X, y, classes, _refit=False) | ||
|
||
def _partial_fit(self, X, y, classes=None, _refit=False): | ||
""" | ||
Actual implementation of the Nearest Centroid fitting. | ||
|
||
Parameters | ||
---------- | ||
X : {array-like, sparse matrix}, shape = [n_samples, n_features] | ||
Training vector, where n_samples is the number of samples and | ||
n_features is the number of features. | ||
Note that centroid shrinking cannot be used with sparse matrices. | ||
|
||
y : array, shape = [n_samples] | ||
Target values (integers) | ||
|
||
classes : array-like, shape (n_classes,), optional (default=None) | ||
List of all the classes that can possibly appear in the y vector. | ||
|
||
Must be provided at the first call to partial_fit, can be omitted | ||
in subsequent calls. | ||
|
||
_refit : bool, optional (default=False) | ||
If true, act as though this were the first time we called | ||
_partial_fit (ie, throw away any past fitting and start over). | ||
""" | ||
if self.metric == 'precomputed': | ||
raise ValueError("Precomputed is not supported.") | ||
|
@@ -110,53 +169,90 @@ def fit(self, X, y): | |
check_classification_targets(y) | ||
|
||
n_samples, n_features = X.shape | ||
|
||
if _refit or _check_partial_fit_first_call(self, classes): | ||
self.true_classes_ = classes = np.asarray(classes) | ||
# Mask mapping each class to its members. | ||
self.true_centroids_ = np.zeros((classes.size, n_features), dtype=np.float64) | ||
# Number of clusters in each class. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. As a user, I find the warning on lines 214-217 to be quite annoying. This code was present in the original implementation, but it might make sense to move it to the "init" instead so the user only gets the warning once. Or we could move it to the "if _refit or _check_partial_fit_first_call(self, classes)" conditional on line 173. Previously, we only received the warning when calling "fit", which isn't too often. But "partial_fit" is meant to be called often on small batches of data, and this warning gets printed every time, which is redundant and tiresome. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think having it only appear on the first partial_fit is reasonable. |
||
self.nk_ = np.zeros(classes.size) | ||
|
||
if self.shrink_threshold: | ||
self.ssd_ = np.zeros((classes.size, n_features), dtype=np.float64) | ||
self.dataset_centroid_ = np.mean(X, axis=0) | ||
|
||
le = LabelEncoder() | ||
y_ind = le.fit_transform(y) | ||
self.classes_ = classes = le.classes_ | ||
n_classes = classes.size | ||
le.fit(self.true_classes_) | ||
y_ind = le.transform(y) | ||
n_classes = self.true_classes_.size | ||
if n_classes < 2: | ||
raise ValueError('The number of classes has to be greater than' | ||
' one; got %d class' % (n_classes)) | ||
|
||
# Mask mapping each class to its members. | ||
self.centroids_ = np.empty((n_classes, n_features), dtype=np.float64) | ||
# Number of clusters in each class. | ||
nk = np.zeros(n_classes) | ||
old_nk = deepcopy(self.nk_) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. please use |
||
old_centroids = deepcopy(self.true_centroids_) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. same |
||
|
||
for cur_class in range(n_classes): | ||
center_mask = y_ind == cur_class | ||
nk[cur_class] = np.sum(center_mask) | ||
|
||
# Ignore if no data for this class | ||
if X[center_mask].size == 0: | ||
continue | ||
if is_X_sparse: | ||
center_mask = np.where(center_mask)[0] | ||
|
||
# XXX: Update other averaging methods according to the metrics. | ||
if self.metric == "manhattan": | ||
self.nk_[cur_class] += np.sum(center_mask) | ||
# NumPy does not calculate median of sparse matrices. | ||
if not is_X_sparse: | ||
self.centroids_[cur_class] = np.median(X[center_mask], axis=0) | ||
self.true_centroids_[cur_class] = np.median(X[center_mask], axis=0) | ||
else: | ||
self.centroids_[cur_class] = csc_median_axis_0(X[center_mask]) | ||
self.true_centroids_[cur_class] = csc_median_axis_0(X[center_mask]) | ||
else: | ||
if self.metric != 'euclidean': | ||
warnings.warn("Averaging for metrics other than " | ||
"euclidean and manhattan not supported. " | ||
"The average is set to be the mean." | ||
) | ||
self.centroids_[cur_class] = X[center_mask].mean(axis=0) | ||
# Update each centroid weighted by the number of samples | ||
self.true_centroids_[cur_class] = X[center_mask].mean(axis=0) * np.sum(center_mask) +\ | ||
self.true_centroids_[cur_class] * self.nk_[cur_class] | ||
self.nk_[cur_class] += np.sum(center_mask) | ||
self.true_centroids_[cur_class] /= self.nk_[cur_class] | ||
|
||
# Filtering out centroids without any data | ||
self.classes_ = self.true_classes_[self.nk_ != 0] | ||
self.centroids_ = self.true_centroids_[self.nk_ != 0] | ||
|
||
if self.shrink_threshold: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I recommend moving the two "deepcopy" calls on lines 192, 193 into the "if self.shrink_threshold" conditional on line 228. This way you only compute them when needed. |
||
dataset_centroid_ = np.mean(X, axis=0) | ||
n_total = np.sum(self.nk_) | ||
self.dataset_centroid_ = (self.dataset_centroid_ * old_nk.sum(axis=0) + np.sum(X, axis=0)) / n_total | ||
|
||
# Update sum of square distances of each class | ||
for cur_class in range(n_classes): | ||
n_old = old_nk[cur_class] | ||
n_new = self.nk_[cur_class] - n_old | ||
if n_new == 0: | ||
continue | ||
center_mask = y_ind == cur_class | ||
old_ssd = self.ssd_[cur_class] | ||
new_ssd = ((X[center_mask] - X[center_mask].mean(axis=0))**2).sum(axis=0) | ||
self.ssd_[cur_class] = (old_ssd + new_ssd + | ||
(n_old / float(n_new * (n_new + n_old))) * | ||
(n_new * old_centroids[cur_class] - n_new * X[center_mask].mean(axis=0)) ** 2) | ||
|
||
# m parameter for determining deviation | ||
m = np.sqrt((1. / nk) - (1. / n_samples)) | ||
m = np.sqrt((1. / self.nk_) - (1. / np.sum(self.nk_))) | ||
|
||
# Calculate deviation using the standard deviation of centroids. | ||
variance = (X - self.centroids_[y_ind]) ** 2 | ||
variance = variance.sum(axis=0) | ||
s = np.sqrt(variance / (n_samples - n_classes)) | ||
ssd = self.ssd_.sum(axis=0) | ||
s = np.sqrt(ssd / (n_total - n_classes)) | ||
s += np.median(s) # To deter outliers from affecting the results. | ||
mm = m.reshape(len(m), 1) # Reshape to allow broadcasting. | ||
ms = mm * s | ||
deviation = ((self.centroids_ - dataset_centroid_) / ms) | ||
deviation = ((self.true_centroids_ - self.dataset_centroid_) / ms) | ||
|
||
# Soft thresholding: if the deviation crosses 0 during shrinking, | ||
# it becomes zero. | ||
signs = np.sign(deviation) | ||
|
@@ -165,7 +261,7 @@ def fit(self, X, y): | |
deviation *= signs | ||
# Now adjust the centroids using the deviation | ||
msd = ms * deviation | ||
self.centroids_ = dataset_centroid_[np.newaxis, :] + msd | ||
self.centroids_ = self.dataset_centroid_[np.newaxis, :] + msd | ||
return self | ||
|
||
def predict(self, X): | ||
|
@@ -191,4 +287,4 @@ def predict(self, X): | |
|
||
X = check_array(X, accept_sparse='csr') | ||
return self.classes_[pairwise_distances( | ||
X, self.centroids_, metric=self.metric).argmin(axis=1)] | ||
X, self.centroids_, metric=self.metric).argmin(axis=1)] |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -27,7 +27,6 @@ | |
iris.data = iris.data[perm] | ||
iris.target = iris.target[perm] | ||
|
||
|
||
def test_classification_toy(): | ||
# Check classification on a toy dataset, including sparse versions. | ||
clf = NearestCentroid() | ||
|
@@ -55,6 +54,9 @@ def test_classification_toy(): | |
assert_array_equal(clf.predict(T_csr.tolil()), true_result) | ||
|
||
|
||
|
||
|
||
|
||
def test_precomputed(): | ||
clf = NearestCentroid(metric='precomputed') | ||
with assert_raises(ValueError): | ||
|
@@ -102,7 +104,6 @@ def test_shrinkage_correct(): | |
# The expected result is calculated by R (pamr), | ||
# which is implemented by the author of the original paper. | ||
# (One need to modify the code to output the new centroid in pamr.predict) | ||
|
||
X = np.array([[0, 1], [1, 0], [1, 1], [2, 0], [6, 8]]) | ||
y = np.array([1, 1, 2, 2, 2]) | ||
clf = NearestCentroid(shrink_threshold=0.1) | ||
|
@@ -147,3 +148,37 @@ def test_manhattan_metric(): | |
clf.fit(X_csr, y) | ||
assert_array_equal(clf.centroids_, dense_centroid) | ||
assert_array_equal(dense_centroid, [[-1, -1], [1, 1]]) | ||
|
||
|
||
def test_partial_fit(): | ||
# Test the partial fitting | ||
|
||
clf = NearestCentroid() | ||
clf.partial_fit(X[:3], y[:3], classes=[-1, 1]) | ||
clf.partial_fit(X[3:], y[3:]) | ||
assert_array_equal(clf.predict(T), true_result) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can we also check that the centroids match using |
||
|
||
X2 = [[-2, -1], [-1, -2], [1, 1], [-1, -1], [1, 2], [2, 1]] | ||
y2 = [-1, -1, 1, -1, 1, 1] | ||
clf = NearestCentroid() | ||
clf.partial_fit(X2[:3], y2[:3], classes=[-1, 1]) | ||
assert_array_equal(clf.predict(T), [-1, 1, 1]) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. But this is identical to true_result, so what have we shown? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can we at least check that the centroids differ from the complete fit? |
||
clf.partial_fit(X2[3:], y2[3:]) | ||
assert_array_equal(clf.predict(T), true_result) | ||
|
||
|
||
def test_partial_shrinkage_correct(): | ||
# Ensure that the shrinking is correct. | ||
# The expected result is calculated by R (pamr), | ||
# which is implemented by the author of the original paper. | ||
# (One need to modify the code to output the new centroid in pamr.predict) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Modify which code? I don't understand this. You could just show the R code used to derive these numbers. |
||
X = np.array([[0, 1], [1, 0], [1, 1], [2, 0], [6, 8]]) | ||
y = np.array([1, 1, 2, 2, 2]) | ||
clf = NearestCentroid(shrink_threshold=0.1) | ||
clf.partial_fit(X[:3], y[:3], classes=[1, 2]) | ||
expected_result = np.array([[0.55773503, 0.55773503], [0.88452995, 0.88452995]]) | ||
np.testing.assert_array_almost_equal(clf.centroids_, expected_result) | ||
|
||
clf.partial_fit(X[3:], y[3:]) | ||
expected_result = np.array([[0.7787310, 0.8545292], [2.814179, 2.763647]]) | ||
np.testing.assert_array_almost_equal(clf.centroids_, expected_result) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
please document the new attributes