Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit 23b36ce

Browse files
committed
fix(KDP): edited some of the tests to reflect the changes in processor.py
1 parent a170db8 commit 23b36ce

File tree

1 file changed

+31
-33
lines changed

1 file changed

+31
-33
lines changed

test/test_feature_selection_preprocessor.py

Lines changed: 31 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ def _verify_feature_weights(self, feature_weights: dict, features: dict, placeme
6464
"""Helper method to verify feature weight properties.
6565
6666
Args:
67-
feature_weights: Dictionary of feature weights
67+
feature_weights: Dictionary of feature importances
6868
features: Dictionary of feature specifications
6969
placement: Where feature selection is applied ("all_features", "numeric", or "categorical")
7070
"""
@@ -82,23 +82,21 @@ def _verify_feature_weights(self, feature_weights: dict, features: dict, placeme
8282
or (placement == "categorical" and is_categorical)
8383
)
8484

85-
weight_key = f"{feature_name}_weights"
86-
if weight_key in feature_weights:
87-
weights = feature_weights[weight_key]
85+
if feature_name in feature_weights:
86+
weight = feature_weights[feature_name]
8887

89-
# Check that weights are finite
90-
self.assertTrue(tf.reduce_all(tf.math.is_finite(weights)))
88+
# Check that weight is finite
89+
self.assertTrue(tf.math.is_finite(weight))
9190

92-
# Check that weights have reasonable magnitude
93-
self.assertAllInRange(weights, -10.0, 10.0)
91+
# Check that weight has reasonable magnitude
92+
self.assertAllInRange([weight], -10.0, 10.0)
9493

95-
# Check variance based on whether feature should have weights
96-
weights_std = tf.math.reduce_std(weights)
97-
if should_have_weights and len(weights) > 1:
98-
# Should have non-zero variance
99-
self.assertGreater(weights_std, 0)
94+
# Check if feature should have weights
95+
if should_have_weights:
96+
# Should have non-zero weight
97+
self.assertNotEqual(weight, 0)
10098
else:
101-
# Might not have weights at all, or might have constant weights
99+
# Might not have weights at all
102100
pass
103101

104102
def test_feature_selection_weights(self):
@@ -135,16 +133,16 @@ def test_feature_selection_weights(self):
135133
# Build the preprocessor
136134
result = ppr.build_preprocessor()
137135

138-
# Get feature weights
139-
feature_weights = ppr._extract_feature_weights()
136+
# Get feature importances
137+
feature_importances = ppr.get_feature_importances()
140138

141139
# Verify weights exist for all features
142-
self.assertNotEmpty(feature_weights)
140+
self.assertNotEmpty(feature_importances)
143141
for feature_name in features:
144-
self.assertIn(f"{feature_name}_weights", feature_weights)
142+
self.assertIn(feature_name, feature_importances)
145143

146144
# Use helper method to verify weights
147-
self._verify_feature_weights(feature_weights, features)
145+
self._verify_feature_weights(feature_importances, features)
148146

149147
def test_feature_selection_with_tabular_attention(self):
150148
"""Test feature selection combined with tabular attention."""
@@ -178,9 +176,9 @@ def test_feature_selection_with_tabular_attention(self):
178176
)
179177

180178
result = ppr.build_preprocessor()
181-
feature_weights = ppr._extract_feature_weights()
179+
feature_importances = ppr.get_feature_importances()
182180

183-
self._verify_feature_weights(feature_weights, features)
181+
self._verify_feature_weights(feature_importances, features)
184182

185183
def test_feature_selection_with_transformer(self):
186184
"""Test feature selection combined with transformer blocks."""
@@ -215,9 +213,9 @@ def test_feature_selection_with_transformer(self):
215213
)
216214

217215
result = ppr.build_preprocessor()
218-
feature_weights = ppr._extract_feature_weights()
216+
feature_importances = ppr.get_feature_importances()
219217

220-
self._verify_feature_weights(feature_weights, features)
218+
self._verify_feature_weights(feature_importances, features)
221219

222220
def test_feature_selection_with_both(self):
223221
"""Test feature selection with both tabular attention and transformer blocks."""
@@ -258,9 +256,9 @@ def test_feature_selection_with_both(self):
258256
)
259257

260258
result = ppr.build_preprocessor()
261-
feature_weights = ppr._extract_feature_weights()
259+
feature_importances = ppr.get_feature_importances()
262260

263-
self._verify_feature_weights(feature_weights, features)
261+
self._verify_feature_weights(feature_importances, features)
264262

265263
def test_feature_selection_with_both_mixed_placement(self):
266264
"""Test feature selection with both tabular attention and transformer blocks."""
@@ -302,9 +300,9 @@ def test_feature_selection_with_both_mixed_placement(self):
302300
)
303301

304302
result = ppr.build_preprocessor()
305-
feature_weights = ppr._extract_feature_weights()
303+
feature_importances = ppr.get_feature_importances()
306304

307-
self._verify_feature_weights(feature_weights, features, placement="numeric")
305+
self._verify_feature_weights(feature_importances, features, placement="numeric")
308306

309307
def test_feature_selection_with_both_mixed_placement_v2(self):
310308
"""Test feature selection with both tabular attention and transformer blocks."""
@@ -346,9 +344,9 @@ def test_feature_selection_with_both_mixed_placement_v2(self):
346344
)
347345

348346
result = ppr.build_preprocessor()
349-
feature_weights = ppr._extract_feature_weights()
347+
feature_importances = ppr.get_feature_importances()
350348

351-
self._verify_feature_weights(feature_weights, features, placement="categorical")
349+
self._verify_feature_weights(feature_importances, features, placement="categorical")
352350

353351
def test_feature_selection_with_both_mixed_placement_v3(self):
354352
"""Test feature selection with both tabular attention and transformer blocks."""
@@ -390,9 +388,9 @@ def test_feature_selection_with_both_mixed_placement_v3(self):
390388
)
391389

392390
result = ppr.build_preprocessor()
393-
feature_weights = ppr._extract_feature_weights()
391+
feature_importances = ppr.get_feature_importances()
394392

395-
self._verify_feature_weights(feature_weights, features, placement="numeric")
393+
self._verify_feature_weights(feature_importances, features, placement="numeric")
396394

397395
def test_feature_selection_with_both_mixed_placement_v4(self):
398396
"""Test feature selection with mixed placement configuration.
@@ -441,6 +439,6 @@ def test_feature_selection_with_both_mixed_placement_v4(self):
441439
)
442440

443441
result = ppr.build_preprocessor()
444-
feature_weights = ppr._extract_feature_weights()
442+
feature_importances = ppr.get_feature_importances()
445443

446-
self._verify_feature_weights(feature_weights, features, placement="categorical")
444+
self._verify_feature_weights(feature_importances, features, placement="categorical")

0 commit comments

Comments
 (0)