Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit 53234c5

Browse files
rprkhjjerphan
andauthored
ENH Add dtype preservation for Isomap (#24714)
Co-authored-by: Julien Jerphanion <[email protected]>
1 parent 1b36cbb commit 53234c5

File tree

3 files changed

+37
-0
lines changed

3 files changed

+37
-0
lines changed

doc/whats_new/v1.2.rst

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -413,6 +413,9 @@ Changelog
413413
`eigen_tol="auto"` in version 1.3.
414414
:pr:`23210` by :user:`Meekail Zain <micky774>`.
415415

416+
- |Enhancement| :class:`manifold.Isomap` now preserves
417+
dtype for `np.float32` inputs. :pr:`24714` by :user:`Rahil Parikh <rprkh>`.
418+
416419
:mod:`sklearn.metrics`
417420
......................
418421

sklearn/manifold/_isomap.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -294,6 +294,11 @@ def _fit_transform(self, X):
294294

295295
self.dist_matrix_ = shortest_path(nbg, method=self.path_method, directed=False)
296296

297+
if self.nbrs_._fit_X.dtype == np.float32:
298+
self.dist_matrix_ = self.dist_matrix_.astype(
299+
self.nbrs_._fit_X.dtype, copy=False
300+
)
301+
297302
G = self.dist_matrix_**2
298303
G *= -0.5
299304

@@ -412,3 +417,6 @@ def transform(self, X):
412417
G_X *= -0.5
413418

414419
return self.kernel_pca_.transform(G_X)
420+
421+
def _more_tags(self):
422+
return {"preserves_dtype": [np.float64, np.float32]}

sklearn/manifold/tests/test_isomap.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -250,6 +250,32 @@ def test_isomap_fit_precomputed_radius_graph():
250250
assert_allclose(precomputed_result, result)
251251

252252

253+
def test_isomap_fitted_attributes_dtype(global_dtype):
254+
"""Check that the fitted attributes are stored accordingly to the
255+
data type of X."""
256+
iso = manifold.Isomap(n_neighbors=2)
257+
258+
X = np.array([[1, 2], [3, 4], [5, 6]], dtype=global_dtype)
259+
260+
iso.fit(X)
261+
262+
assert iso.dist_matrix_.dtype == global_dtype
263+
assert iso.embedding_.dtype == global_dtype
264+
265+
266+
def test_isomap_dtype_equivalence():
267+
"""Check the equivalence of the results with 32 and 64 bits input."""
268+
iso_32 = manifold.Isomap(n_neighbors=2)
269+
X_32 = np.array([[1, 2], [3, 4], [5, 6]], dtype=np.float32)
270+
iso_32.fit(X_32)
271+
272+
iso_64 = manifold.Isomap(n_neighbors=2)
273+
X_64 = np.array([[1, 2], [3, 4], [5, 6]], dtype=np.float64)
274+
iso_64.fit(X_64)
275+
276+
assert_allclose(iso_32.dist_matrix_, iso_64.dist_matrix_)
277+
278+
253279
def test_isomap_raise_error_when_neighbor_and_radius_both_set():
254280
# Isomap.fit_transform must raise a ValueError if
255281
# radius and n_neighbors are provided.

0 commit comments

Comments
 (0)