44import pytest
55import tensorflow as tf
66
7- from kdp .custom_layers import TabularAttention , MultiResolutionTabularAttention
7+ from kdp .custom_layers import MultiResolutionTabularAttention , TabularAttention
88from kdp .layers_factory import PreprocessorLayerFactory
99
1010
@@ -18,16 +18,11 @@ def test_tabular_attention_layer_init():
1818
1919def test_tabular_attention_layer_config ():
2020 """Test get_config and from_config methods."""
21- original_layer = TabularAttention (
22- num_heads = 4 ,
23- d_model = 64 ,
24- dropout_rate = 0.2 ,
25- name = "test_attention"
26- )
27-
21+ original_layer = TabularAttention (num_heads = 4 , d_model = 64 , dropout_rate = 0.2 , name = "test_attention" )
22+
2823 config = original_layer .get_config ()
2924 restored_layer = TabularAttention .from_config (config )
30-
25+
3126 assert restored_layer .num_heads == original_layer .num_heads
3227 assert restored_layer .d_model == original_layer .d_model
3328 assert restored_layer .dropout_rate == original_layer .dropout_rate
@@ -40,29 +35,26 @@ def test_tabular_attention_computation():
4035 num_samples = 10
4136 num_features = 8
4237 d_model = 16
43-
38+
4439 # Create a layer instance
4540 layer = TabularAttention (num_heads = 2 , d_model = d_model )
46-
41+
4742 # Create input data
4843 inputs = tf .random .normal ((batch_size , num_samples , num_features ))
49-
44+
5045 # Call the layer
5146 outputs = layer (inputs , training = True )
52-
47+
5348 # Check output shape - output will have d_model dimension
5449 assert outputs .shape == (batch_size , num_samples , d_model )
5550
5651
5752def test_tabular_attention_factory ():
5853 """Test creation of TabularAttention layer through PreprocessorLayerFactory."""
5954 layer = PreprocessorLayerFactory .tabular_attention_layer (
60- num_heads = 4 ,
61- d_model = 64 ,
62- name = "test_attention" ,
63- dropout_rate = 0.2
55+ num_heads = 4 , d_model = 64 , name = "test_attention" , dropout_rate = 0.2
6456 )
65-
57+
6658 assert isinstance (layer , TabularAttention )
6759 assert layer .num_heads == 4
6860 assert layer .d_model == 64
@@ -75,30 +67,30 @@ def test_tabular_attention_training():
7567 batch_size = 16
7668 num_samples = 5
7769 num_features = 4
78-
70+
7971 layer = TabularAttention (num_heads = 2 , d_model = 8 , dropout_rate = 0.5 )
8072 inputs = tf .random .normal ((batch_size , num_samples , num_features ))
81-
73+
8274 # Test in training mode
8375 outputs_training = layer (inputs , training = True )
84-
76+
8577 # Test in inference mode
8678 outputs_inference = layer (inputs , training = False )
87-
79+
8880 # The outputs should be different due to dropout
8981 assert not np .allclose (outputs_training .numpy (), outputs_inference .numpy ())
9082
9183
9284def test_tabular_attention_invalid_inputs ():
9385 """Test TabularAttention layer with invalid inputs."""
9486 layer = TabularAttention (num_heads = 2 , d_model = 8 )
95-
87+
9688 # Test with wrong input shape
9789 with pytest .raises (ValueError , match = "Input tensor must be 3-dimensional" ):
9890 # Missing batch dimension
9991 inputs = tf .random .normal ((5 , 4 ))
10092 layer (inputs )
101-
93+
10294 with pytest .raises (ValueError ):
10395 # Wrong rank
10496 inputs = tf .random .normal ((16 , 5 , 4 , 2 ))
@@ -110,35 +102,31 @@ def test_tabular_attention_end_to_end():
110102 batch_size = 16
111103 num_samples = 5
112104 num_features = 4
113-
105+
114106 # Create a simple model with TabularAttention
115107 inputs = tf .keras .Input (shape = (num_samples , num_features ))
116108 x = TabularAttention (num_heads = 2 , d_model = 8 )(inputs )
117109 outputs = tf .keras .layers .Dense (1 )(x )
118110 model = tf .keras .Model (inputs = inputs , outputs = outputs )
119-
111+
120112 # Compile the model
121- model .compile (optimizer = ' adam' , loss = ' mse' )
122-
113+ model .compile (optimizer = " adam" , loss = " mse" )
114+
123115 # Create some dummy data
124116 X = tf .random .normal ((batch_size , num_samples , num_features ))
125117 y = tf .random .normal ((batch_size , num_samples , 1 ))
126-
118+
127119 # Train for one epoch
128120 history = model .fit (X , y , epochs = 1 , verbose = 0 )
129-
121+
130122 # Check if loss was computed
131- assert ' loss' in history .history
132- assert len (history .history [' loss' ]) == 1
123+ assert " loss" in history .history
124+ assert len (history .history [" loss" ]) == 1
133125
134126
135127def test_multi_resolution_attention_layer_init ():
136128 """Test initialization of MultiResolutionTabularAttention layer."""
137- layer = MultiResolutionTabularAttention (
138- num_heads = 4 ,
139- d_model = 64 ,
140- embedding_dim = 32
141- )
129+ layer = MultiResolutionTabularAttention (num_heads = 4 , d_model = 64 , embedding_dim = 32 )
142130 assert layer .num_heads == 4
143131 assert layer .d_model == 64
144132 assert layer .embedding_dim == 32
@@ -148,16 +136,12 @@ def test_multi_resolution_attention_layer_init():
148136def test_multi_resolution_attention_layer_config ():
149137 """Test get_config and from_config methods for MultiResolutionTabularAttention."""
150138 original_layer = MultiResolutionTabularAttention (
151- num_heads = 4 ,
152- d_model = 64 ,
153- embedding_dim = 32 ,
154- dropout_rate = 0.2 ,
155- name = "test_multi_attention"
139+ num_heads = 4 , d_model = 64 , embedding_dim = 32 , dropout_rate = 0.2 , name = "test_multi_attention"
156140 )
157-
141+
158142 config = original_layer .get_config ()
159143 restored_layer = MultiResolutionTabularAttention .from_config (config )
160-
144+
161145 assert restored_layer .num_heads == original_layer .num_heads
162146 assert restored_layer .d_model == original_layer .d_model
163147 assert restored_layer .embedding_dim == original_layer .embedding_dim
@@ -172,29 +156,21 @@ def test_multi_resolution_attention_computation():
172156 num_categorical = 5
173157 numerical_dim = 16
174158 categorical_dim = 8
175-
159+
176160 # Create a layer instance
177- layer = MultiResolutionTabularAttention (
178- num_heads = 2 ,
179- d_model = numerical_dim ,
180- embedding_dim = categorical_dim
181- )
182-
161+ layer = MultiResolutionTabularAttention (num_heads = 2 , d_model = numerical_dim , embedding_dim = categorical_dim )
162+
183163 # Create input data
184164 numerical_features = tf .random .normal ((batch_size , num_numerical , numerical_dim ))
185165 categorical_features = tf .random .normal ((batch_size , num_categorical , categorical_dim ))
186-
166+
187167 # Call the layer
188- numerical_output , categorical_output = layer (
189- numerical_features ,
190- categorical_features ,
191- training = True
192- )
193-
168+ numerical_output , categorical_output = layer (numerical_features , categorical_features , training = True )
169+
194170 # Check output shapes
195171 assert numerical_output .shape == (batch_size , num_numerical , numerical_dim )
196172 assert categorical_output .shape == (batch_size , num_categorical , numerical_dim )
197-
173+
198174 # Test with different batch sizes
199175 numerical_features_2 = tf .random .normal ((64 , num_numerical , numerical_dim ))
200176 categorical_features_2 = tf .random .normal ((64 , num_categorical , categorical_dim ))
@@ -210,31 +186,20 @@ def test_multi_resolution_attention_training():
210186 num_categorical = 3
211187 numerical_dim = 8
212188 categorical_dim = 4
213-
189+
214190 layer = MultiResolutionTabularAttention (
215- num_heads = 2 ,
216- d_model = numerical_dim ,
217- embedding_dim = categorical_dim ,
218- dropout_rate = 0.5
191+ num_heads = 2 , d_model = numerical_dim , embedding_dim = categorical_dim , dropout_rate = 0.5
219192 )
220-
193+
221194 numerical_features = tf .random .normal ((batch_size , num_numerical , numerical_dim ))
222195 categorical_features = tf .random .normal ((batch_size , num_categorical , categorical_dim ))
223-
196+
224197 # Test in training mode
225- num_train , cat_train = layer (
226- numerical_features ,
227- categorical_features ,
228- training = True
229- )
230-
198+ num_train , cat_train = layer (numerical_features , categorical_features , training = True )
199+
231200 # Test in inference mode
232- num_infer , cat_infer = layer (
233- numerical_features ,
234- categorical_features ,
235- training = False
236- )
237-
201+ num_infer , cat_infer = layer (numerical_features , categorical_features , training = False )
202+
238203 # The outputs should be different due to dropout
239204 assert not np .allclose (num_train .numpy (), num_infer .numpy ())
240205 assert not np .allclose (cat_train .numpy (), cat_infer .numpy ())
@@ -243,13 +208,9 @@ def test_multi_resolution_attention_training():
243208def test_multi_resolution_attention_factory ():
244209 """Test creation of MultiResolutionTabularAttention layer through PreprocessorLayerFactory."""
245210 layer = PreprocessorLayerFactory .multi_resolution_attention_layer (
246- num_heads = 4 ,
247- d_model = 64 ,
248- embedding_dim = 32 ,
249- name = "test_multi_attention" ,
250- dropout_rate = 0.2
211+ num_heads = 4 , d_model = 64 , embedding_dim = 32 , name = "test_multi_attention" , dropout_rate = 0.2
251212 )
252-
213+
253214 assert isinstance (layer , MultiResolutionTabularAttention )
254215 assert layer .num_heads == 4
255216 assert layer .d_model == 64
@@ -266,44 +227,34 @@ def test_multi_resolution_attention_end_to_end():
266227 numerical_dim = 8
267228 categorical_dim = 4
268229 output_dim = 1
269-
230+
270231 # Create inputs
271232 numerical_inputs = tf .keras .Input (shape = (num_numerical , numerical_dim ))
272233 categorical_inputs = tf .keras .Input (shape = (num_categorical , categorical_dim ))
273-
234+
274235 # Apply multi-resolution attention
275236 num_attended , cat_attended = MultiResolutionTabularAttention (
276- num_heads = 2 ,
277- d_model = numerical_dim ,
278- embedding_dim = categorical_dim
237+ num_heads = 2 , d_model = numerical_dim , embedding_dim = categorical_dim
279238 )(numerical_inputs , categorical_inputs )
280-
239+
281240 # Combine outputs
282241 combined = tf .keras .layers .Concatenate (axis = 1 )([num_attended , cat_attended ])
283242 outputs = tf .keras .layers .Dense (output_dim )(combined )
284-
243+
285244 # Create model
286- model = tf .keras .Model (
287- inputs = [numerical_inputs , categorical_inputs ],
288- outputs = outputs
289- )
290-
245+ model = tf .keras .Model (inputs = [numerical_inputs , categorical_inputs ], outputs = outputs )
246+
291247 # Compile the model
292- model .compile (optimizer = ' adam' , loss = ' mse' )
293-
248+ model .compile (optimizer = " adam" , loss = " mse" )
249+
294250 # Create dummy data
295251 X_num = tf .random .normal ((batch_size , num_numerical , numerical_dim ))
296252 X_cat = tf .random .normal ((batch_size , num_categorical , categorical_dim ))
297253 y = tf .random .normal ((batch_size , num_numerical + num_categorical , output_dim ))
298-
254+
299255 # Train for one epoch
300- history = model .fit (
301- [X_num , X_cat ],
302- y ,
303- epochs = 1 ,
304- verbose = 0
305- )
306-
256+ history = model .fit ([X_num , X_cat ], y , epochs = 1 , verbose = 0 )
257+
307258 # Check if loss was computed
308- assert ' loss' in history .history
309- assert len (history .history [' loss' ]) == 1
259+ assert " loss" in history .history
260+ assert len (history .history [" loss" ]) == 1
0 commit comments