@@ -1758,5 +1758,161 @@ def test_preprocessor_parameter_combinations(self):
17581758 # You can add more specific checks for each feature if needed
17591759
17601760
1761+ class TestPreprocessingModel_AdvancedNumericalEmbedding (unittest .TestCase ):
1762+ @classmethod
1763+ def setUpClass (cls ):
1764+ cls .temp_dir = tempfile .TemporaryDirectory ()
1765+ cls .temp_file = Path (cls .temp_dir .name )
1766+ cls ._path_data = cls .temp_file / "data.csv"
1767+ cls .features_stats_path = cls .temp_file / "features_stats.json"
1768+
1769+ @classmethod
1770+ def tearDownClass (cls ):
1771+ cls .temp_dir .cleanup ()
1772+
1773+ def setUp (self ):
1774+ if self .features_stats_path .exists ():
1775+ self .features_stats_path .unlink ()
1776+
1777+ def test_preprocessor_with_advanced_numerical_embedding (self ):
1778+ """
1779+ Test that when advanced numerical embedding is enabled, the preprocessor model is
1780+ built successfully and produces an output with the expected 3D shape:
1781+ (batch_size, num_features, embedding_dim)
1782+ """
1783+ # Define a numerical feature. (No special flag is needed on the feature, as the model-level
1784+ # configuration controls the use of advanced numerical embedding.)
1785+ features = {
1786+ "num1" : NumericalFeature (
1787+ name = "num1" ,
1788+ feature_type = FeatureType .FLOAT_NORMALIZED ,
1789+ )
1790+ }
1791+ # Generate fake data for training statistics.
1792+ df = generate_fake_data (features , num_rows = 50 )
1793+ df .to_csv (self ._path_data , index = False )
1794+
1795+ # Build the PreprocessingModel with advanced numerical embedding turned on.
1796+ ppr = PreprocessingModel (
1797+ path_data = str (self ._path_data ),
1798+ features_specs = features ,
1799+ features_stats_path = self .features_stats_path ,
1800+ overwrite_stats = True ,
1801+ use_advanced_numerical_embedding = True ,
1802+ embedding_dim = 8 ,
1803+ mlp_hidden_units = 16 ,
1804+ num_bins = 10 ,
1805+ init_min = - 3.0 ,
1806+ init_max = 3.0 ,
1807+ dropout_rate = 0.1 ,
1808+ use_batch_norm = True ,
1809+ output_mode = OutputModeOptions .CONCAT ,
1810+ )
1811+ result = ppr .build_preprocessor ()
1812+ self .assertIsNotNone (result ["model" ], "Preprocessor model should be built" )
1813+
1814+ # Create a small batch of test data.
1815+ test_data = generate_fake_data (features , num_rows = 5 )
1816+ dataset = tf .data .Dataset .from_tensor_slices (dict (test_data )).batch (5 )
1817+ preprocessed = result ["model" ].predict (dataset )
1818+ self .assertIsNotNone (preprocessed , "Preprocessed output should not be None" )
1819+
1820+ # Check that advanced numerical embedding produces a 3D output
1821+ # (batch_size, num_features, embedding_dim)
1822+ self .assertEqual (
1823+ len (preprocessed .shape ),
1824+ 3 ,
1825+ "Expected output shape to be 3D with advanced numerical embedding" ,
1826+ )
1827+ self .assertEqual (
1828+ preprocessed .shape [- 1 ],
1829+ 8 ,
1830+ "The output's last dimension (embedding_dim) should match the provided value (8)" ,
1831+ )
1832+
1833+ def test_advanced_embedding_if_false (self ):
1834+ """
1835+ Test that the advanced numerical embedding is not used if the flag is set to False.
1836+ """
1837+ features = {
1838+ "num1" : NumericalFeature (
1839+ name = "num1" ,
1840+ feature_type = FeatureType .FLOAT_NORMALIZED ,
1841+ )
1842+ }
1843+ df = generate_fake_data (features , num_rows = 20 )
1844+ df .to_csv (self ._path_data , index = False )
1845+
1846+ # Build the model with advanced embedding.
1847+ ppr = PreprocessingModel (
1848+ path_data = str (self ._path_data ),
1849+ features_specs = features ,
1850+ features_stats_path = self .features_stats_path ,
1851+ use_advanced_numerical_embedding = False ,
1852+ output_mode = OutputModeOptions .CONCAT ,
1853+ )
1854+ result = ppr .build_preprocessor ()
1855+ self .assertIsNotNone (result ["model" ])
1856+
1857+ # Get the configuration from the built model.
1858+ config = result ["model" ].get_config ()
1859+ # Iterate the layer configurations.
1860+ layers_config = config .get ("layers" , [])
1861+ found = any (
1862+ layer .get ("class_name" , "" ) == "AdvancedNumericalEmbedding"
1863+ for layer in layers_config
1864+ )
1865+ self .assertFalse (
1866+ found ,
1867+ "The model config should not include an AdvancedNumericalEmbedding layer when disabled." ,
1868+ )
1869+
1870+ def test_advanced_embedding_config_preservation (self ):
1871+ """
1872+ Ensure that the advanced numerical embedding's configuration is properly saved and can be
1873+ reloaded with get_config/from_config.
1874+ """
1875+ features = {
1876+ "num1" : NumericalFeature (
1877+ name = "num1" ,
1878+ feature_type = FeatureType .FLOAT_NORMALIZED ,
1879+ )
1880+ }
1881+ df = generate_fake_data (features , num_rows = 20 )
1882+ df .to_csv (self ._path_data , index = False )
1883+
1884+ # Build the model with advanced embedding.
1885+ ppr = PreprocessingModel (
1886+ path_data = str (self ._path_data ),
1887+ features_specs = features ,
1888+ features_stats_path = self .features_stats_path ,
1889+ overwrite_stats = True ,
1890+ use_advanced_numerical_embedding = True ,
1891+ embedding_dim = 8 ,
1892+ mlp_hidden_units = 16 ,
1893+ num_bins = 10 ,
1894+ init_min = - 3.0 ,
1895+ init_max = 3.0 ,
1896+ dropout_rate = 0.1 ,
1897+ use_batch_norm = True ,
1898+ output_mode = OutputModeOptions .CONCAT ,
1899+ )
1900+ result = ppr .build_preprocessor ()
1901+ self .assertIsNotNone (result ["model" ])
1902+
1903+ # Get the configuration from the built model.
1904+ config = result ["model" ].get_config ()
1905+ # Iterate the layer configurations.
1906+ layers_config = config .get ("layers" , [])
1907+ found = any (
1908+ layer .get ("class_name" , "" ) == "AdvancedNumericalEmbedding"
1909+ for layer in layers_config
1910+ )
1911+ self .assertTrue (
1912+ found ,
1913+ "The model config should include an AdvancedNumericalEmbedding layer when enabled." ,
1914+ )
1915+
1916+
17611917if __name__ == "__main__" :
17621918 unittest .main ()
0 commit comments