Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit e2e7d75

Browse files
authored
MAINT Extend dtype preserved common test to check transform (#24982)
1 parent b728b2e commit e2e7d75

File tree

2 files changed

+21
-13
lines changed

2 files changed

+21
-13
lines changed

sklearn/manifold/_isomap.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -409,7 +409,13 @@ def transform(self, X):
409409

410410
n_samples_fit = self.nbrs_.n_samples_fit_
411411
n_queries = distances.shape[0]
412-
G_X = np.zeros((n_queries, n_samples_fit))
412+
413+
if hasattr(X, "dtype") and X.dtype == np.float32:
414+
dtype = np.float32
415+
else:
416+
dtype = np.float64
417+
418+
G_X = np.zeros((n_queries, n_samples_fit), dtype)
413419
for i in range(n_queries):
414420
G_X[i] = np.min(self.dist_matrix_[indices[i]] + distances[i][:, None], 0)
415421

sklearn/utils/estimator_checks.py

Lines changed: 14 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1743,18 +1743,20 @@ def check_transformer_preserve_dtypes(name, transformer_orig):
17431743
X_cast = X.astype(dtype)
17441744
transformer = clone(transformer_orig)
17451745
set_random_state(transformer)
1746-
X_trans = transformer.fit_transform(X_cast, y)
1747-
1748-
if isinstance(X_trans, tuple):
1749-
# cross-decompostion returns a tuple of (x_scores, y_scores)
1750-
# when given y with fit_transform; only check the first element
1751-
X_trans = X_trans[0]
1752-
1753-
# check that the output dtype is preserved
1754-
assert X_trans.dtype == dtype, (
1755-
f"Estimator transform dtype: {X_trans.dtype} - "
1756-
f"original/expected dtype: {dtype.__name__}"
1757-
)
1746+
X_trans1 = transformer.fit_transform(X_cast, y)
1747+
X_trans2 = transformer.fit(X_cast, y).transform(X_cast)
1748+
1749+
for Xt, method in zip([X_trans1, X_trans2], ["fit_transform", "transform"]):
1750+
if isinstance(Xt, tuple):
1751+
# cross-decompostion returns a tuple of (x_scores, y_scores)
1752+
# when given y with fit_transform; only check the first element
1753+
Xt = Xt[0]
1754+
1755+
# check that the output dtype is preserved
1756+
assert Xt.dtype == dtype, (
1757+
f"{name} (method={method}) does not preserve dtype. "
1758+
f"Original/Expected dtype={dtype.__name__}, got dtype={Xt.dtype}."
1759+
)
17581760

17591761

17601762
@ignore_warnings(category=FutureWarning)

0 commit comments

Comments
 (0)