@@ -1743,18 +1743,20 @@ def check_transformer_preserve_dtypes(name, transformer_orig):
1743
1743
X_cast = X .astype (dtype )
1744
1744
transformer = clone (transformer_orig )
1745
1745
set_random_state (transformer )
1746
- X_trans = transformer .fit_transform (X_cast , y )
1747
-
1748
- if isinstance (X_trans , tuple ):
1749
- # cross-decompostion returns a tuple of (x_scores, y_scores)
1750
- # when given y with fit_transform; only check the first element
1751
- X_trans = X_trans [0 ]
1752
-
1753
- # check that the output dtype is preserved
1754
- assert X_trans .dtype == dtype , (
1755
- f"Estimator transform dtype: { X_trans .dtype } - "
1756
- f"original/expected dtype: { dtype .__name__ } "
1757
- )
1746
+ X_trans1 = transformer .fit_transform (X_cast , y )
1747
+ X_trans2 = transformer .fit (X_cast , y ).transform (X_cast )
1748
+
1749
+ for Xt , method in zip ([X_trans1 , X_trans2 ], ["fit_transform" , "transform" ]):
1750
+ if isinstance (Xt , tuple ):
1751
+ # cross-decompostion returns a tuple of (x_scores, y_scores)
1752
+ # when given y with fit_transform; only check the first element
1753
+ Xt = Xt [0 ]
1754
+
1755
+ # check that the output dtype is preserved
1756
+ assert Xt .dtype == dtype , (
1757
+ f"{ name } (method={ method } ) does not preserve dtype. "
1758
+ f"Original/Expected dtype={ dtype .__name__ } , got dtype={ Xt .dtype } ."
1759
+ )
1758
1760
1759
1761
1760
1762
@ignore_warnings (category = FutureWarning )
0 commit comments