@@ -64,7 +64,7 @@ def _verify_feature_weights(self, feature_weights: dict, features: dict, placeme
6464 """Helper method to verify feature weight properties.
6565
6666 Args:
67- feature_weights: Dictionary of feature weights
67+ feature_weights: Dictionary of feature importances
6868 features: Dictionary of feature specifications
6969 placement: Where feature selection is applied ("all_features", "numeric", or "categorical")
7070 """
@@ -82,23 +82,21 @@ def _verify_feature_weights(self, feature_weights: dict, features: dict, placeme
8282 or (placement == "categorical" and is_categorical )
8383 )
8484
85- weight_key = f"{ feature_name } _weights"
86- if weight_key in feature_weights :
87- weights = feature_weights [weight_key ]
85+ if feature_name in feature_weights :
86+ weight = feature_weights [feature_name ]
8887
89- # Check that weights are finite
90- self .assertTrue (tf .reduce_all ( tf . math .is_finite (weights ) ))
88+ # Check that weight is finite
89+ self .assertTrue (tf .math .is_finite (weight ))
9190
92- # Check that weights have reasonable magnitude
93- self .assertAllInRange (weights , - 10.0 , 10.0 )
91+ # Check that weight has reasonable magnitude
92+ self .assertAllInRange ([ weight ] , - 10.0 , 10.0 )
9493
95- # Check variance based on whether feature should have weights
96- weights_std = tf .math .reduce_std (weights )
97- if should_have_weights and len (weights ) > 1 :
98- # Should have non-zero variance
99- self .assertGreater (weights_std , 0 )
94+ # Check if feature should have weights
95+ if should_have_weights :
96+ # Should have non-zero weight
97+ self .assertNotEqual (weight , 0 )
10098 else :
101- # Might not have weights at all, or might have constant weights
99+ # Might not have weights at all
102100 pass
103101
104102 def test_feature_selection_weights (self ):
@@ -135,16 +133,16 @@ def test_feature_selection_weights(self):
135133 # Build the preprocessor
136134 result = ppr .build_preprocessor ()
137135
138- # Get feature weights
139- feature_weights = ppr ._extract_feature_weights ()
136+ # Get feature importances
137+ feature_importances = ppr .get_feature_importances ()
140138
141139 # Verify weights exist for all features
142- self .assertNotEmpty (feature_weights )
140+ self .assertNotEmpty (feature_importances )
143141 for feature_name in features :
144- self .assertIn (f" { feature_name } _weights" , feature_weights )
142+ self .assertIn (feature_name , feature_importances )
145143
146144 # Use helper method to verify weights
147- self ._verify_feature_weights (feature_weights , features )
145+ self ._verify_feature_weights (feature_importances , features )
148146
149147 def test_feature_selection_with_tabular_attention (self ):
150148 """Test feature selection combined with tabular attention."""
@@ -178,9 +176,9 @@ def test_feature_selection_with_tabular_attention(self):
178176 )
179177
180178 result = ppr .build_preprocessor ()
181- feature_weights = ppr ._extract_feature_weights ()
179+ feature_importances = ppr .get_feature_importances ()
182180
183- self ._verify_feature_weights (feature_weights , features )
181+ self ._verify_feature_weights (feature_importances , features )
184182
185183 def test_feature_selection_with_transformer (self ):
186184 """Test feature selection combined with transformer blocks."""
@@ -215,9 +213,9 @@ def test_feature_selection_with_transformer(self):
215213 )
216214
217215 result = ppr .build_preprocessor ()
218- feature_weights = ppr ._extract_feature_weights ()
216+ feature_importances = ppr .get_feature_importances ()
219217
220- self ._verify_feature_weights (feature_weights , features )
218+ self ._verify_feature_weights (feature_importances , features )
221219
222220 def test_feature_selection_with_both (self ):
223221 """Test feature selection with both tabular attention and transformer blocks."""
@@ -258,9 +256,9 @@ def test_feature_selection_with_both(self):
258256 )
259257
260258 result = ppr .build_preprocessor ()
261- feature_weights = ppr ._extract_feature_weights ()
259+ feature_importances = ppr .get_feature_importances ()
262260
263- self ._verify_feature_weights (feature_weights , features )
261+ self ._verify_feature_weights (feature_importances , features )
264262
265263 def test_feature_selection_with_both_mixed_placement (self ):
266264 """Test feature selection with both tabular attention and transformer blocks."""
@@ -302,9 +300,9 @@ def test_feature_selection_with_both_mixed_placement(self):
302300 )
303301
304302 result = ppr .build_preprocessor ()
305- feature_weights = ppr ._extract_feature_weights ()
303+ feature_importances = ppr .get_feature_importances ()
306304
307- self ._verify_feature_weights (feature_weights , features , placement = "numeric" )
305+ self ._verify_feature_weights (feature_importances , features , placement = "numeric" )
308306
309307 def test_feature_selection_with_both_mixed_placement_v2 (self ):
310308 """Test feature selection with both tabular attention and transformer blocks."""
@@ -346,9 +344,9 @@ def test_feature_selection_with_both_mixed_placement_v2(self):
346344 )
347345
348346 result = ppr .build_preprocessor ()
349- feature_weights = ppr ._extract_feature_weights ()
347+ feature_importances = ppr .get_feature_importances ()
350348
351- self ._verify_feature_weights (feature_weights , features , placement = "categorical" )
349+ self ._verify_feature_weights (feature_importances , features , placement = "categorical" )
352350
353351 def test_feature_selection_with_both_mixed_placement_v3 (self ):
354352 """Test feature selection with both tabular attention and transformer blocks."""
@@ -390,9 +388,9 @@ def test_feature_selection_with_both_mixed_placement_v3(self):
390388 )
391389
392390 result = ppr .build_preprocessor ()
393- feature_weights = ppr ._extract_feature_weights ()
391+ feature_importances = ppr .get_feature_importances ()
394392
395- self ._verify_feature_weights (feature_weights , features , placement = "numeric" )
393+ self ._verify_feature_weights (feature_importances , features , placement = "numeric" )
396394
397395 def test_feature_selection_with_both_mixed_placement_v4 (self ):
398396 """Test feature selection with mixed placement configuration.
@@ -441,6 +439,6 @@ def test_feature_selection_with_both_mixed_placement_v4(self):
441439 )
442440
443441 result = ppr .build_preprocessor ()
444- feature_weights = ppr ._extract_feature_weights ()
442+ feature_importances = ppr .get_feature_importances ()
445443
446- self ._verify_feature_weights (feature_weights , features , placement = "categorical" )
444+ self ._verify_feature_weights (feature_importances , features , placement = "categorical" )
0 commit comments