@@ -1346,7 +1346,7 @@ def test_one_hot_encoder_sparse_deprecated():
1346
1346
1347
1347
# deliberately omit 'OS' as an invalid combo
1348
1348
@pytest .mark .parametrize (
1349
- "input_dtype, category_dtype" , ["OO" , "OU" , "UO" , "UU" , "US" , " SO" , "SU" , "SS" ]
1349
+ "input_dtype, category_dtype" , ["OO" , "OU" , "UO" , "UU" , "SO" , "SU" , "SS" ]
1350
1350
)
1351
1351
@pytest .mark .parametrize ("array_type" , ["list" , "array" , "dataframe" ])
1352
1352
def test_encoders_string_categories (input_dtype , category_dtype , array_type ):
@@ -1376,6 +1376,27 @@ def test_encoders_string_categories(input_dtype, category_dtype, array_type):
1376
1376
assert_array_equal (X_trans , expected )
1377
1377
1378
1378
1379
+ def test_mixed_string_bytes_categoricals ():
1380
+ """Check that this mixture of predefined categories and X raises an error.
1381
+
1382
+ Categories defined as bytes can not easily be compared to data that is
1383
+ a string.
1384
+ """
1385
+ # data as unicode
1386
+ X = np .array ([["b" ], ["a" ]], dtype = "U" )
1387
+ # predefined categories as bytes
1388
+ categories = [np .array (["b" , "a" ], dtype = "S" )]
1389
+ ohe = OneHotEncoder (categories = categories , sparse_output = False )
1390
+
1391
+ msg = re .escape (
1392
+ "In column 0, the predefined categories have type 'bytes' which is incompatible"
1393
+ " with values of type 'str_'."
1394
+ )
1395
+
1396
+ with pytest .raises (ValueError , match = msg ):
1397
+ ohe .fit (X )
1398
+
1399
+
1379
1400
@pytest .mark .parametrize ("missing_value" , [np .nan , None ])
1380
1401
def test_ohe_missing_values_get_feature_names (missing_value ):
1381
1402
# encoder with missing values with object dtypes
@@ -1939,3 +1960,20 @@ def test_ordinal_set_output():
1939
1960
1940
1961
assert_allclose (X_pandas .to_numpy (), X_default )
1941
1962
assert_array_equal (ord_pandas .get_feature_names_out (), X_pandas .columns )
1963
+
1964
+
1965
+ def test_predefined_categories_dtype ():
1966
+ """Check that the categories_ dtype is `object` for string categories
1967
+
1968
+ Regression test for gh-25171.
1969
+ """
1970
+ categories = [["as" , "mmas" , "eas" , "ras" , "acs" ], ["1" , "2" ]]
1971
+
1972
+ enc = OneHotEncoder (categories = categories )
1973
+
1974
+ enc .fit ([["as" , "1" ]])
1975
+
1976
+ assert len (categories ) == len (enc .categories_ )
1977
+ for n , cat in enumerate (enc .categories_ ):
1978
+ assert cat .dtype == object
1979
+ assert_array_equal (categories [n ], cat )
0 commit comments