Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit 966434d

Browse files
fix(KDP): fixing failiing tests
1 parent 2965916 commit 966434d

File tree

1 file changed

+51
-28
lines changed

1 file changed

+51
-28
lines changed

kdp/processor.py

Lines changed: 51 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -637,6 +637,7 @@ def _process_features_parallel(self, features_dict: dict) -> None:
637637
categorical_features = []
638638
text_features = []
639639
date_features = []
640+
passthrough_features = []
640641

641642
for feature_name, stats in features_dict.items():
642643
if "mean" in stats:
@@ -647,6 +648,8 @@ def _process_features_parallel(self, features_dict: dict) -> None:
647648
text_features.append((feature_name, stats))
648649
elif feature_name in self.date_features:
649650
date_features.append((feature_name, stats))
651+
elif feature_name in self.passthrough_features:
652+
passthrough_features.append((feature_name, stats))
650653

651654
# Set up inputs in parallel
652655
self._parallel_setup_inputs(features_dict)
@@ -657,6 +660,7 @@ def _process_features_parallel(self, features_dict: dict) -> None:
657660
(categorical_features, "categorical"),
658661
(text_features, "text"),
659662
(date_features, "date"),
663+
(passthrough_features, "passthrough"),
660664
]
661665

662666
for features, feature_type in feature_groups:
@@ -700,7 +704,7 @@ def _apply_feature_selection(
700704
Args:
701705
feature_name: Name of the feature
702706
output_pipeline: The processed feature tensor
703-
feature_type: Type of the feature ('numeric', 'categorical', 'text', 'date')
707+
feature_type: Type of the feature ('numeric', 'categorical', 'text', 'date', 'passthrough')
704708
705709
Returns:
706710
The processed tensor, possibly with feature selection applied
@@ -737,6 +741,12 @@ def _apply_feature_selection(
737741
== FeatureSelectionPlacementOptions.DATE
738742
):
739743
apply_selection = True
744+
elif (
745+
feature_type == "passthrough"
746+
and self.feature_selection_placement
747+
== FeatureSelectionPlacementOptions.ALL_FEATURES
748+
):
749+
apply_selection = True
740750

741751
# Apply feature selection if enabled
742752
if apply_selection:
@@ -1285,18 +1295,20 @@ def _prepare_concat_mode_outputs(self) -> None:
12851295
logger.info("Concatenating outputs mode enabled")
12861296

12871297
def _group_features_by_type(self) -> Tuple[List, List]:
1288-
"""Group processed features by their type.
1298+
"""Group processed features by type for concatenation.
12891299
12901300
Returns:
1291-
Tuple containing lists of numeric and categorical features
1301+
Tuple of (numeric_features, categorical_features) lists
12921302
"""
1303+
# Initialize lists for features of different types
12931304
numeric_features = []
12941305
categorical_features = []
1306+
passthrough_features = []
12951307

1296-
# Process features based on their type
1308+
# Group processed features by type
12971309
for feature_name, feature in self.processed_features.items():
1298-
if feature is None:
1299-
logger.warning(f"Skipping {feature_name} as it is None")
1310+
# Skip feature weights
1311+
if feature_name.endswith("_weights"):
13001312
continue
13011313

13021314
# Add to appropriate list based on feature type
@@ -1317,9 +1329,16 @@ def _group_features_by_type(self) -> Tuple[List, List]:
13171329
):
13181330
logger.debug(f"Adding {feature_name} to categorical features")
13191331
categorical_features.append(feature)
1332+
elif feature_name in self.passthrough_features:
1333+
logger.debug(f"Adding {feature_name} to passthrough features")
1334+
passthrough_features.append(feature)
13201335
else:
13211336
logger.warning(f"Unknown feature type for {feature_name}")
13221337

1338+
# For concatenation purposes, add passthrough features to numeric features
1339+
if passthrough_features:
1340+
numeric_features.extend(passthrough_features)
1341+
13231342
return numeric_features, categorical_features
13241343

13251344
def _concatenate_numeric_features(
@@ -1797,19 +1816,42 @@ def build_preprocessor(self) -> dict:
17971816
self.features_stats = self.stats_instance.main()
17981817
logger.debug(f"Features Stats were calculated: {self.features_stats}")
17991818

1819+
# Set up inputs for all feature types BEFORE processing them
1820+
for feature_name in (
1821+
self.numeric_features
1822+
+ self.categorical_features
1823+
+ self.text_features
1824+
+ self.date_features
1825+
+ self.passthrough_features
1826+
):
1827+
if feature_name not in self.inputs:
1828+
# Get feature and its data type
1829+
feature = self.features_specs.get(feature_name)
1830+
if feature:
1831+
dtype = getattr(feature, "dtype", tf.float32)
1832+
self._add_input_column(feature_name=feature_name, dtype=dtype)
1833+
self._add_input_signature(
1834+
feature_name=feature_name, dtype=dtype
1835+
)
1836+
18001837
# Process features in batches by type
18011838
numeric_batch = []
18021839
categorical_batch = []
18031840
text_batch = []
18041841
date_batch = []
18051842
passthrough_batch = []
18061843

1844+
# Get the numeric stats from the correct location in features_stats
1845+
numeric_stats = self.features_stats.get("numeric_stats", {})
1846+
categorical_stats = self.features_stats.get("categorical_stats", {})
1847+
text_stats = self.features_stats.get("text", {})
1848+
18071849
for f_name in self.numeric_features:
1808-
numeric_batch.append((f_name, self.features_stats.get(f_name, {})))
1850+
numeric_batch.append((f_name, numeric_stats.get(f_name, {})))
18091851
for f_name in self.categorical_features:
1810-
categorical_batch.append((f_name, self.features_stats.get(f_name, {})))
1852+
categorical_batch.append((f_name, categorical_stats.get(f_name, {})))
18111853
for f_name in self.text_features:
1812-
text_batch.append((f_name, self.features_stats.get(f_name, {})))
1854+
text_batch.append((f_name, text_stats.get(f_name, {})))
18131855
for f_name in self.date_features:
18141856
date_batch.append((f_name, {}))
18151857
for f_name in self.passthrough_features:
@@ -1832,25 +1874,6 @@ def build_preprocessor(self) -> dict:
18321874
logger.info("Processing feature type: cross feature")
18331875
self._add_pipeline_cross()
18341876

1835-
# Prepare inputs for all feature types
1836-
# Set up inputs for each feature
1837-
for feature_name in (
1838-
self.numeric_features
1839-
+ self.categorical_features
1840-
+ self.text_features
1841-
+ self.date_features
1842-
+ self.passthrough_features
1843-
):
1844-
if feature_name not in self.inputs:
1845-
# Get feature and its data type
1846-
feature = self.features_specs.get(feature_name)
1847-
if feature:
1848-
dtype = getattr(feature, "dtype", tf.float32)
1849-
self._add_input_column(feature_name=feature_name, dtype=dtype)
1850-
self._add_input_signature(
1851-
feature_name=feature_name, dtype=dtype
1852-
)
1853-
18541877
# Prepare outputs based on mode
18551878
logger.info("Preparing outputs for the model")
18561879
self._prepare_outputs()

0 commit comments

Comments
 (0)