diff --git a/feathr_project/feathr/definition/feature.py b/feathr_project/feathr/definition/feature.py index 5ba577498..0720aced7 100644 --- a/feathr_project/feathr/definition/feature.py +++ b/feathr_project/feathr/definition/feature.py @@ -30,6 +30,11 @@ def __init__(self, registry_tags: Optional[Dict[str, str]] = None, ): FeatureBase.validate_feature_name(name) + + # Validate the feature type + if not isinstance(feature_type, FeatureType): + raise KeyError(f'Feature type must be a FeatureType class, like INT32, but got {feature_type}') + self.name = name self.feature_type = feature_type self.registry_tags=registry_tags diff --git a/feathr_project/feathr/definition/typed_key.py b/feathr_project/feathr/definition/typed_key.py index 16274698d..c2732a476 100644 --- a/feathr_project/feathr/definition/typed_key.py +++ b/feathr_project/feathr/definition/typed_key.py @@ -20,6 +20,10 @@ def __init__(self, full_name: Optional[str] = None, description: Optional[str] = None, key_column_alias: Optional[str] = None) -> None: + # Validate the key_column type + if not isinstance(key_column_type, ValueType): + raise KeyError(f'key_column_type must be a ValueType, like Value.INT32, but got {key_column_type}') + self.key_column = key_column self.key_column_type = key_column_type self.full_name = full_name diff --git a/feathr_project/test/unit/test_dtype.py b/feathr_project/test/unit/test_dtype.py new file mode 100644 index 000000000..eb6aaf2ce --- /dev/null +++ b/feathr_project/test/unit/test_dtype.py @@ -0,0 +1,24 @@ +import pytest +from feathr import Feature, TypedKey, ValueType, INT32 + + +def test_key_type(): + key = TypedKey(key_column="key", key_column_type=ValueType.INT32) + assert key.key_column_type == ValueType.INT32 + + with pytest.raises(KeyError): + key = TypedKey(key_column="key", key_column_type=INT32) + +def test_feature_type(): + key = TypedKey(key_column="key", key_column_type=ValueType.INT32) + + feature = Feature(name="name", + key=key, + feature_type=INT32) + + assert feature.feature_type == INT32 + + with pytest.raises(KeyError): + feature = Feature(name="name", + key=key, + feature_type=ValueType.INT32) \ No newline at end of file