@@ -324,6 +324,72 @@ def test_csr_preprocess_data():
324
324
assert_equal (csr_ .getformat (), 'csr' )
325
325
326
326
327
+ def test_dtype_preprocess_data ():
328
+ n_samples = 200
329
+ n_features = 2
330
+ X = rng .rand (n_samples , n_features )
331
+ y = rng .rand (n_samples )
332
+
333
+ X_32 = np .asarray (X , dtype = np .float32 )
334
+ y_32 = np .asarray (y , dtype = np .float32 )
335
+ X_64 = np .asarray (X , dtype = np .float64 )
336
+ y_64 = np .asarray (y , dtype = np .float64 )
337
+
338
+ for fit_intercept in [True , False ]:
339
+ for normalize in [True , False ]:
340
+
341
+ Xt_32 , yt_32 , X_mean_32 , y_mean_32 , X_norm_32 = _preprocess_data (
342
+ X_32 , y_32 , fit_intercept = fit_intercept , normalize = normalize ,
343
+ return_mean = True )
344
+
345
+ Xt_64 , yt_64 , X_mean_64 , y_mean_64 , X_norm_64 = _preprocess_data (
346
+ X_64 , y_64 , fit_intercept = fit_intercept , normalize = normalize ,
347
+ return_mean = True )
348
+
349
+ Xt_3264 , yt_3264 , X_mean_3264 , y_mean_3264 , X_norm_3264 = (
350
+ _preprocess_data (X_32 , y_64 , fit_intercept = fit_intercept ,
351
+ normalize = normalize , return_mean = True ))
352
+
353
+ Xt_6432 , yt_6432 , X_mean_6432 , y_mean_6432 , X_norm_6432 = (
354
+ _preprocess_data (X_64 , y_32 , fit_intercept = fit_intercept ,
355
+ normalize = normalize , return_mean = True ))
356
+
357
+ assert_equal (Xt_32 .dtype , np .float32 )
358
+ assert_equal (yt_32 .dtype , np .float32 )
359
+ assert_equal (X_mean_32 .dtype , np .float32 )
360
+ assert_equal (y_mean_32 .dtype , np .float32 )
361
+ assert_equal (X_norm_32 .dtype , np .float32 )
362
+
363
+ assert_equal (Xt_64 .dtype , np .float64 )
364
+ assert_equal (yt_64 .dtype , np .float64 )
365
+ assert_equal (X_mean_64 .dtype , np .float64 )
366
+ assert_equal (y_mean_64 .dtype , np .float64 )
367
+ assert_equal (X_norm_64 .dtype , np .float64 )
368
+
369
+ assert_equal (Xt_3264 .dtype , np .float32 )
370
+ assert_equal (yt_3264 .dtype , np .float32 )
371
+ assert_equal (X_mean_3264 .dtype , np .float32 )
372
+ assert_equal (y_mean_3264 .dtype , np .float32 )
373
+ assert_equal (X_norm_3264 .dtype , np .float32 )
374
+
375
+ assert_equal (Xt_6432 .dtype , np .float64 )
376
+ assert_equal (yt_6432 .dtype , np .float64 )
377
+ assert_equal (X_mean_6432 .dtype , np .float64 )
378
+ assert_equal (y_mean_6432 .dtype , np .float64 )
379
+ assert_equal (X_norm_6432 .dtype , np .float64 )
380
+
381
+ assert_equal (X_32 .dtype , np .float32 )
382
+ assert_equal (y_32 .dtype , np .float32 )
383
+ assert_equal (X_64 .dtype , np .float64 )
384
+ assert_equal (y_64 .dtype , np .float64 )
385
+
386
+ assert_array_almost_equal (Xt_32 , Xt_64 )
387
+ assert_array_almost_equal (yt_32 , yt_64 )
388
+ assert_array_almost_equal (X_mean_32 , X_mean_64 )
389
+ assert_array_almost_equal (y_mean_32 , y_mean_64 )
390
+ assert_array_almost_equal (X_norm_32 , X_norm_64 )
391
+
392
+
327
393
def test_rescale_data ():
328
394
n_samples = 200
329
395
n_features = 2
0 commit comments