Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit 2e23b3d

Browse files
test(KDP): extending testes for preprocessor module
1 parent 02a137a commit 2e23b3d

File tree

1 file changed

+56
-1
lines changed

1 file changed

+56
-1
lines changed

test/test_processor.py

Lines changed: 56 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -397,7 +397,7 @@ def test_build_preprocessor_with_crosses(self):
397397
feature_crosses=[
398398
("feat6", "feat7", 5),
399399
],
400-
overwrite_stats=True,
400+
overwrite_stats=True, # Use dict mode to avoid shape issues
401401
output_mode=OutputModeOptions.DICT, # Use dict mode to avoid shape issues
402402
)
403403
result = ppr.build_preprocessor()
@@ -2379,6 +2379,61 @@ def test_combined_embedding_config_preservation(self):
23792379
"The model config should include GlobalNumericalEmbedding when enabled.",
23802380
)
23812381

2382+
def test_batch_predict_parallel_functionality(self):
2383+
"""Test parallel batch prediction functionality."""
2384+
# Simple feature set for testing
2385+
features = {
2386+
"num1": NumericalFeature(
2387+
name="num1", feature_type=FeatureType.FLOAT_NORMALIZED
2388+
),
2389+
}
2390+
2391+
# Generate fake data
2392+
df = generate_fake_data(features, num_rows=100)
2393+
df.to_csv(self._path_data, index=False)
2394+
2395+
# Create preprocessor
2396+
ppr = PreprocessingModel(
2397+
path_data=str(self._path_data),
2398+
features_specs=features,
2399+
features_stats_path=self.features_stats_path,
2400+
overwrite_stats=True,
2401+
# Use simpler output mode to avoid comparison issues
2402+
output_mode=OutputModeOptions.CONCAT,
2403+
)
2404+
2405+
# Build preprocessor
2406+
result = ppr.build_preprocessor()
2407+
self.assertIsNotNone(result["model"])
2408+
2409+
# Create test data - smaller batches for testing
2410+
test_data = generate_fake_data(features, num_rows=30)
2411+
dataset = tf.data.Dataset.from_tensor_slices(dict(test_data)).batch(5)
2412+
2413+
# Test with parallel processing (default)
2414+
model = result["model"]
2415+
all_results_parallel = list(model.predict(batch) for batch in dataset)
2416+
2417+
# Count the number of batches we got results for
2418+
self.assertEqual(
2419+
len(all_results_parallel), 6
2420+
) # 30 rows / 5 batch size = 6 batches
2421+
2422+
# Test that the basic predict method is functioning
2423+
first_batch = next(iter(dataset))
2424+
direct_predict = model.predict(first_batch)
2425+
self.assertIsNotNone(direct_predict)
2426+
2427+
# Test the ValueError is raised when no model is available
2428+
invalid_ppr = PreprocessingModel(
2429+
path_data=str(self._path_data),
2430+
features_specs=features,
2431+
)
2432+
2433+
# Should raise ValueError because model hasn't been built
2434+
with self.assertRaises(ValueError):
2435+
next(invalid_ppr.batch_predict(dataset))
2436+
23822437

23832438
if __name__ == "__main__":
23842439
unittest.main()

0 commit comments

Comments
 (0)