@@ -104,12 +104,16 @@ def call(self, inputs: tf.Tensor) -> tf.Tensor:
104104 inputs (tf.Tensor): Tensor with input data.
105105
106106 Returns:
107- tf.Tensor: processed date tensor with all components [year, month, day_of_month, day_of_week].
107+ tf.Tensor: processed date tensor with all components
108+ [year, month, day_of_month, day_of_week].
108109 """
109110
110111 def parse_date (date_str : str ) -> tf .Tensor :
111112 # Handle missing/invalid dates
112- is_valid = tf .strings .regex_full_match (date_str , r"^\d{1,4}[-/]\d{1,2}[-/]\d{1,2}$" )
113+ is_valid = tf .strings .regex_full_match (
114+ date_str ,
115+ r"^\d{1,4}[-/]\d{1,2}[-/]\d{1,2}$" ,
116+ )
113117 tf .debugging .assert_equal (
114118 is_valid ,
115119 True ,
@@ -126,24 +130,48 @@ def parse_date(date_str: str) -> tf.Tensor:
126130
127131 # Validate date components
128132 # Validate year is in reasonable range
129- tf .debugging .assert_greater_equal (year , 1000 , message = "Year must be >= 1000" )
130- tf .debugging .assert_less_equal (year , 2200 , message = "Year must be <= 2200" )
133+ tf .debugging .assert_greater_equal (
134+ year ,
135+ 1000 ,
136+ message = "Year must be >= 1000" ,
137+ )
138+ tf .debugging .assert_less_equal (
139+ year ,
140+ 2200 ,
141+ message = "Year must be <= 2200" ,
142+ )
131143
132144 # Validate month is between 1-12
133- tf .debugging .assert_greater_equal (month , 1 , message = "Month must be >= 1" )
134- tf .debugging .assert_less_equal (month , 12 , message = "Month must be <= 12" )
145+ tf .debugging .assert_greater_equal (
146+ month ,
147+ 1 ,
148+ message = "Month must be >= 1" ,
149+ )
150+ tf .debugging .assert_less_equal (
151+ month ,
152+ 12 ,
153+ message = "Month must be <= 12" ,
154+ )
135155
136156 # Validate day is between 1-31
137- tf .debugging .assert_greater_equal (day_of_month , 1 , message = "Day must be >= 1" )
138- tf .debugging .assert_less_equal (day_of_month , 31 , message = "Day must be <= 31" )
157+ tf .debugging .assert_greater_equal (
158+ day_of_month ,
159+ 1 ,
160+ message = "Day must be >= 1" ,
161+ )
162+ tf .debugging .assert_less_equal (
163+ day_of_month ,
164+ 31 ,
165+ message = "Day must be <= 31" ,
166+ )
139167
140168 # Calculate day of week using Zeller's congruence
141169 y = tf .where (month < 3 , year - 1 , year )
142170 m = tf .where (month < 3 , month + 12 , month )
143171 k = y % 100
144172 j = y // 100
145173 h = (day_of_month + ((13 * (m + 1 )) // 5 ) + k + (k // 4 ) + (j // 4 ) - (2 * j )) % 7
146- day_of_week = tf .where (h == 0 , 6 , h - 1 ) # Adjust to 0-6 range where 0 is Monday
174+ day_of_week = tf .where (h == 0 , 6 , h - 1 ) # Adjust to 0-6 range where 0 is Sunday
147175
148176 return tf .stack ([year , month , day_of_month , day_of_week ])
149177
@@ -268,10 +296,11 @@ class SeasonLayer(tf.keras.layers.Layer):
268296 Spring, Summer, and Fall. The one-hot encoding is appended to the input tensor.
269297
270298 Required Input Format:
271- - A tensor of shape [batch_size, 3 ], where each row contains:
299+ - A tensor of shape [batch_size, 4 ], where each row contains:
272300 - year (int): Year as a numerical value.
273301 - month (int): Month as an integer from 1 to 12.
274- - day_of_week (int): Day of the week as an integer from 0 to 6 (where 0=Monday).
302+ - day_of_month (int): Day of the month as an integer from 1 to 31.
303+ - day_of_week (int): Day of the week as an integer from 0 to 6 (where 0=Sunday).
275304 """
276305
277306 def __init__ (self , ** kwargs ):
@@ -283,14 +312,15 @@ def call(self, inputs: tf.Tensor) -> tf.Tensor:
283312 """Adds seasonal one-hot encoding to the input tensor.
284313
285314 Args:
286- inputs (tf.Tensor): A tensor of shape [batch_size, 3] where each row contains [year, month, day_of_week].
315+ inputs (tf.Tensor): A tensor of shape [batch_size, 4] where each row contains
316+ [year, month, day_of_month, day_of_week].
287317
288318 Returns:
289- tf.Tensor: A tensor of shape [batch_size, 7 ] with the original features
319+ tf.Tensor: A tensor of shape [batch_size, 8 ] with the original features
290320 plus the one-hot encoded season information.
291321
292322 Raises:
293- ValueError: If the input tensor does not have shape [batch_size, 3 ] or contains invalid month values.
323+ ValueError: If the input tensor does not have shape [batch_size, 4 ] or contains invalid month values.
294324 """
295325 # Ensure inputs is 2D
296326 if len (tf .shape (inputs )) == 1 :
@@ -312,8 +342,8 @@ def call(self, inputs: tf.Tensor) -> tf.Tensor:
312342 + tf .cast (is_fall , tf .int32 ) * 3
313343 )
314344
315- # Convert season to one-hot encoding
316- season_one_hot = tf .one_hot (season , depth = 4 )
345+ # Convert season to one-hot encoding and cast to int32 to match input type
346+ season_one_hot = tf .cast ( tf . one_hot (season , depth = 4 ), tf . int32 )
317347
318348 return tf .concat ([inputs , season_one_hot ], axis = - 1 )
319349
0 commit comments