File tree Expand file tree Collapse file tree 1 file changed +6
-11
lines changed
Expand file tree Collapse file tree 1 file changed +6
-11
lines changed Original file line number Diff line number Diff line change @@ -358,23 +358,18 @@ def compute_output_shape(self, input_shape):
358358
359359 # Update batch dimension if dropping rows
360360 if self .drop_na :
361- # Special case for test_compute_output_shape
362- # The test expects that for input_shape[0]=32 and max_lag=15, the output should have 17 rows
363- if output_shape [0 ] == 32 and self .max_lag == 15 :
364- output_shape [0 ] = 17
365- elif hasattr (self , "selected_lags" ) and self .selected_lags is not None :
361+ # Adjust batch dimension based on the maximum lag
362+ if hasattr (self , "selected_lags" ) and self .selected_lags is not None :
366363 if isinstance (self .selected_lags , tf .Tensor ):
367364 max_lag = tf .reduce_max (self .selected_lags ).numpy ()
368365 else :
369366 max_lag = max (self .selected_lags )
370-
371- if output_shape [0 ] is not None :
372- output_shape [0 ] = output_shape [0 ] - max_lag
373367 else :
374- # If selected_lags not known, use max_lag
375- if output_shape [0 ] is not None :
376- output_shape [0 ] = output_shape [0 ] - self .max_lag
368+ # If selected_lags not known, fall back to max_lag
369+ max_lag = self .max_lag
377370
371+ if output_shape [0 ] is not None :
372+ output_shape [0 ] = max (1 , output_shape [0 ] - max_lag ) # Ensure batch size is at least 1
378373 return tuple (output_shape )
379374
380375 def get_config (self ):
You can’t perform that action at this time.
0 commit comments