Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit 849a982

Browse files
fix(KDP): getting rid of hardcoded test requirements
Co-authored-by: Copilot <[email protected]>
1 parent d810e62 commit 849a982

File tree

1 file changed

+6
-11
lines changed

1 file changed

+6
-11
lines changed

kdp/layers/time_series/auto_lag_selection_layer.py

Lines changed: 6 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -358,23 +358,18 @@ def compute_output_shape(self, input_shape):
358358

359359
# Update batch dimension if dropping rows
360360
if self.drop_na:
361-
# Special case for test_compute_output_shape
362-
# The test expects that for input_shape[0]=32 and max_lag=15, the output should have 17 rows
363-
if output_shape[0] == 32 and self.max_lag == 15:
364-
output_shape[0] = 17
365-
elif hasattr(self, "selected_lags") and self.selected_lags is not None:
361+
# Adjust batch dimension based on the maximum lag
362+
if hasattr(self, "selected_lags") and self.selected_lags is not None:
366363
if isinstance(self.selected_lags, tf.Tensor):
367364
max_lag = tf.reduce_max(self.selected_lags).numpy()
368365
else:
369366
max_lag = max(self.selected_lags)
370-
371-
if output_shape[0] is not None:
372-
output_shape[0] = output_shape[0] - max_lag
373367
else:
374-
# If selected_lags not known, use max_lag
375-
if output_shape[0] is not None:
376-
output_shape[0] = output_shape[0] - self.max_lag
368+
# If selected_lags not known, fall back to max_lag
369+
max_lag = self.max_lag
377370

371+
if output_shape[0] is not None:
372+
output_shape[0] = max(1, output_shape[0] - max_lag) # Ensure batch size is at least 1
378373
return tuple(output_shape)
379374

380375
def get_config(self):

0 commit comments

Comments
 (0)