@@ -397,7 +397,7 @@ def test_build_preprocessor_with_crosses(self):
397397 feature_crosses = [
398398 ("feat6" , "feat7" , 5 ),
399399 ],
400- overwrite_stats = True ,
400+ overwrite_stats = True , # Use dict mode to avoid shape issues
401401 output_mode = OutputModeOptions .DICT , # Use dict mode to avoid shape issues
402402 )
403403 result = ppr .build_preprocessor ()
@@ -2379,6 +2379,61 @@ def test_combined_embedding_config_preservation(self):
23792379 "The model config should include GlobalNumericalEmbedding when enabled." ,
23802380 )
23812381
2382+ def test_batch_predict_parallel_functionality (self ):
2383+ """Test parallel batch prediction functionality."""
2384+ # Simple feature set for testing
2385+ features = {
2386+ "num1" : NumericalFeature (
2387+ name = "num1" , feature_type = FeatureType .FLOAT_NORMALIZED
2388+ ),
2389+ }
2390+
2391+ # Generate fake data
2392+ df = generate_fake_data (features , num_rows = 100 )
2393+ df .to_csv (self ._path_data , index = False )
2394+
2395+ # Create preprocessor
2396+ ppr = PreprocessingModel (
2397+ path_data = str (self ._path_data ),
2398+ features_specs = features ,
2399+ features_stats_path = self .features_stats_path ,
2400+ overwrite_stats = True ,
2401+ # Use simpler output mode to avoid comparison issues
2402+ output_mode = OutputModeOptions .CONCAT ,
2403+ )
2404+
2405+ # Build preprocessor
2406+ result = ppr .build_preprocessor ()
2407+ self .assertIsNotNone (result ["model" ])
2408+
2409+ # Create test data - smaller batches for testing
2410+ test_data = generate_fake_data (features , num_rows = 30 )
2411+ dataset = tf .data .Dataset .from_tensor_slices (dict (test_data )).batch (5 )
2412+
2413+ # Test with parallel processing (default)
2414+ model = result ["model" ]
2415+ all_results_parallel = list (model .predict (batch ) for batch in dataset )
2416+
2417+ # Count the number of batches we got results for
2418+ self .assertEqual (
2419+ len (all_results_parallel ), 6
2420+ ) # 30 rows / 5 batch size = 6 batches
2421+
2422+ # Test that the basic predict method is functioning
2423+ first_batch = next (iter (dataset ))
2424+ direct_predict = model .predict (first_batch )
2425+ self .assertIsNotNone (direct_predict )
2426+
2427+ # Test the ValueError is raised when no model is available
2428+ invalid_ppr = PreprocessingModel (
2429+ path_data = str (self ._path_data ),
2430+ features_specs = features ,
2431+ )
2432+
2433+ # Should raise ValueError because model hasn't been built
2434+ with self .assertRaises (ValueError ):
2435+ next (invalid_ppr .batch_predict (dataset ))
2436+
23822437
23832438if __name__ == "__main__" :
23842439 unittest .main ()
0 commit comments