Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit 82b6d7e

Browse files
fix(KDP): Preserve original dtype for PASSTHROUGH features (#30)
2 parents b9d237e + 6326dbf commit 82b6d7e

File tree

9 files changed

+1658
-867
lines changed

9 files changed

+1658
-867
lines changed

.github/workflows/UTESTS.yml

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ jobs:
7171
strategy:
7272
fail-fast: false
7373
matrix:
74-
python-version: ["3.9", "3.10", "3.11"]
74+
python-version: ["3.10", "3.11", "3.12"]
7575
test-group: ["unit", "integration", "layers"]
7676
include:
7777
- python-version: "3.11"
@@ -87,7 +87,7 @@ jobs:
8787
python-version: ${{ matrix.python-version }}
8888

8989
- name: Cache Poetry dependencies
90-
uses: actions/cache@v3
90+
uses: actions/cache@v4
9191
with:
9292
path: |
9393
~/.cache/pypoetry
@@ -124,7 +124,7 @@ jobs:
124124
timeout-minutes: 8
125125

126126
- name: Upload test results
127-
uses: actions/upload-artifact@v3
127+
uses: actions/upload-artifact@v4
128128
if: always()
129129
with:
130130
name: test-results-${{ matrix.python-version }}-${{ matrix.test-group }}
@@ -157,15 +157,15 @@ jobs:
157157
timeout-minutes: 10
158158

159159
- name: Upload coverage to Codecov
160-
uses: codecov/codecov-action@v3
160+
uses: codecov/codecov-action@v4
161161
with:
162162
file: ./coverage.xml
163163
flags: unittests
164164
name: codecov-umbrella
165165
fail_ci_if_error: false
166166

167167
- name: Upload coverage report
168-
uses: actions/upload-artifact@v3
168+
uses: actions/upload-artifact@v4
169169
with:
170170
name: coverage-report
171171
path: htmlcov/
@@ -194,7 +194,7 @@ jobs:
194194
timeout-minutes: 10
195195

196196
- name: Upload benchmark results
197-
uses: actions/upload-artifact@v3
197+
uses: actions/upload-artifact@v4
198198
if: always()
199199
with:
200200
name: benchmark-results

kdp/layers/preserve_dtype.py

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
import tensorflow as tf
2+
from tensorflow import keras
3+
4+
5+
@tf.keras.utils.register_keras_serializable(package="kdp.layers")
6+
class PreserveDtypeLayer(keras.layers.Layer):
7+
"""Custom Keras layer that preserves the original dtype of input tensors.
8+
9+
This is useful for passthrough features where we want to maintain the original
10+
data type without any casting.
11+
"""
12+
13+
def __init__(self, target_dtype=None, **kwargs):
14+
"""Initialize the layer.
15+
16+
Args:
17+
target_dtype: Optional target dtype to cast to. If None, preserves original dtype.
18+
**kwargs: Additional keyword arguments
19+
"""
20+
super().__init__(**kwargs)
21+
self.target_dtype = target_dtype
22+
23+
def call(self, inputs, **kwargs):
24+
"""Process the input tensor, optionally casting to target_dtype.
25+
26+
Args:
27+
inputs: Input tensor of any dtype
28+
**kwargs: Additional keyword arguments
29+
30+
Returns:
31+
Tensor with preserved or target dtype
32+
"""
33+
if self.target_dtype is not None:
34+
return tf.cast(inputs, self.target_dtype)
35+
return inputs
36+
37+
def get_config(self):
38+
"""Return the config dictionary for serialization.
39+
40+
Returns:
41+
A dictionary with the layer configuration
42+
"""
43+
config = super().get_config()
44+
config.update({"target_dtype": self.target_dtype})
45+
return config
46+
47+
@classmethod
48+
def from_config(cls, config):
49+
"""Create a new instance from the serialized configuration.
50+
51+
Args:
52+
config: Layer configuration dictionary
53+
54+
Returns:
55+
A new instance of the layer
56+
"""
57+
return cls(**config)

kdp/layers_factory.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99

1010
from kdp.layers.text_preprocessing_layer import TextPreprocessingLayer
1111
from kdp.layers.cast_to_float import CastToFloat32Layer
12+
from kdp.layers.preserve_dtype import PreserveDtypeLayer
1213
from kdp.layers.date_parsing_layer import DateParsingLayer
1314
from kdp.layers.date_encoding_layer import DateEncodingLayer
1415
from kdp.layers.season_layer import SeasonLayer
@@ -183,6 +184,27 @@ def cast_to_float32_layer(
183184
**kwargs,
184185
)
185186

187+
@staticmethod
188+
def preserve_dtype_layer(
189+
name: str = "preserve_dtype", target_dtype=None, **kwargs: dict
190+
) -> tf.keras.layers.Layer:
191+
"""Create a PreserveDtypeLayer layer.
192+
193+
Args:
194+
name: The name of the layer.
195+
target_dtype: Optional target dtype to cast to. If None, preserves original dtype.
196+
**kwargs: Additional keyword arguments to pass to the layer constructor.
197+
198+
Returns:
199+
An instance of the PreserveDtypeLayer layer.
200+
"""
201+
return PreprocessorLayerFactory.create_layer(
202+
layer_class=PreserveDtypeLayer,
203+
name=name,
204+
target_dtype=target_dtype,
205+
**kwargs,
206+
)
207+
186208
@staticmethod
187209
def date_parsing_layer(
188210
name: str = "date_parsing_layer", **kwargs: dict

kdp/processor.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1126,6 +1126,15 @@ def _add_categorical_lookup(
11261126
if feature.category_encoding == CategoryEncodingOptions.HASHING:
11271127
return
11281128

1129+
# Handle empty vocabulary by providing a fallback
1130+
if not vocab:
1131+
logger.warning(
1132+
f"Empty vocabulary for categorical feature '{feature_name}'. "
1133+
"Using fallback vocabulary with placeholder values."
1134+
)
1135+
# Provide a minimal vocabulary with unknown/placeholder values
1136+
vocab = ["<UNK>"]
1137+
11291138
# Default behavior if no specific preprocessing is defined
11301139
if feature.feature_type == FeatureType.STRING_CATEGORICAL:
11311140
preprocessor.add_processing_step(
@@ -1414,10 +1423,12 @@ def _add_pipeline_passthrough(self, feature_name: str, input_layer) -> None:
14141423
feature_name=feature_name,
14151424
)
14161425
else:
1417-
# For passthrough features, we only ensure type consistency by casting to float32
1426+
# For passthrough features, preserve the original dtype or cast to specified dtype
1427+
target_dtype = getattr(_feature, "dtype", None)
14181428
preprocessor.add_processing_step(
1419-
layer_creator=PreprocessorLayerFactory.cast_to_float32_layer,
1420-
name=f"cast_to_float_{feature_name}",
1429+
layer_creator=PreprocessorLayerFactory.preserve_dtype_layer,
1430+
name=f"preserve_dtype_{feature_name}",
1431+
target_dtype=target_dtype,
14211432
)
14221433

14231434
# Optionally reshape if needed

0 commit comments

Comments
 (0)