Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Preserving dtype for float32 / float64 in transformers #11000

Open
@glemaitre

Description

@glemaitre

This is the issue which we want to tackle during the Man AHL Hackathon.

We would like that the transformer does not convert float32 to float64 whenever possible. The transformers which are currently failing are:

We could think to extend it to integer whenever possible and applicable.

Also the following transformers are not included in the common tests. We should write a specific test:

# some strange ones                                                                                                                                                         
DONT_TEST = ['SparseCoder', 'DictVectorizer',                                                                                                                               
             'TfidfTransformer',                                                                                                                     
             'TfidfVectorizer' (check 10443), 'IsotonicRegression',                                                                                                                       
             'CategoricalEncoder',                                                                                                 
             'FeatureHasher',                                                                                             
             'TruncatedSVD', 'PolynomialFeatures',                                                                                                                          
             'GaussianRandomProjectionHash', 'HashingVectorizer',                                                                                                           
             'CountVectorizer']

We could also check classifiers, regressors or clusterers (see #8769 for more context),

Below the code executed to find the failure.

# Let's check the 32 - 64 bits type conservation.                                                                                                                       
if isinstance(X, np.ndarray):                                                                                                                                           
    for dtype in [np.float32, np.float64]:                                                                                                                              
        X_cast = X.astype(dtype)                                                                                                                                        
                                                                                                                                                                            
        transformer = clone(transformer_orig)                                                                                                                           
        set_random_state(transformer)                                                                                                                                   
                                                                                                                                                                            
        if hasattr(transformer, 'fit_transform'):                                                                                                                       
            X_trans = transformer.fit_transform(X_cast, y_)                                                                                                             
        elif hasattr(transformer, 'fit_transform'):                                                                                                                     
            transformer.fit(X_cast, y_)                                                                                                                                 
            X_trans = transformer.transform(X_cast)                                                                                                                     
                                                                                                                                                                            
        # FIXME: should we check that the dtype of some attributes are the                                                                                              
        # same than dtype.                                                                                                                                              
        assert X_trans.dtype == X_cast.dtype, 'transform dtype: {} - original dtype: {}'.format(X_trans.dtype, X_cast.dtype)

Tips to run the test for a specific transformer:

  • Choose a transformer, for instance FastICA
  • If this class does not already have a method named _more_tags: add the following code snippet at the bottom of the class definition:
    def _more_tags(self):
        return {"preserves_dtype": [np.float64, np.float32]}
  • Run the common tests for this specific class:
pytest sklearn/tests/test_common.py -k "FastICA and check_transformer_preserve_dtypes" -v
  • It should fail: read the error message and try to understand why the fit_transform method (if it exists) or the transform method returns a float64 data array when it is passed a float32 input array.

It might be helpful to use a debugger, for instance by adding the line:

import pdb; pdb.set_trace()

at the beginning of the fit_transform method and then re-rerunning pytest with:

pytest sklearn/tests/test_common.py -k "FastICA and check_transformer_preserve_dtypes" --pdb

Then using the l (list), n (next), s (step into a function call), p some_array_variable.dtype (p stands for print) and c (continue) commands to interactively debug the execution of this fit_transform call.

ping @rth feel free to edit this thread.

Metadata

Metadata

Assignees

No one assigned

    Labels

    HardHard level of difficultyMeta-issueGeneral issue associated to an identified list of tasksSprintfloat32Issues related to support for 32bit data

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions