@@ -90,6 +90,7 @@ def __init__(self, date_format: str = "YYYY-MM-DD", **kwargs) -> None:
9090
9191 Args:
9292 date_format (str): format of the string encoded date to parse.
93+ Supported formats: YYYY-MM-DD, YYYY/MM/DD
9394 kwargs (dict): other params to pass to the class.
9495 """
9596 super ().__init__ (** kwargs )
@@ -103,31 +104,83 @@ def call(self, inputs: tf.Tensor) -> tf.Tensor:
103104 inputs (tf.Tensor): Tensor with input data.
104105
105106 Returns:
106- tf.Tensor: processed date tensor with all cyclic components.
107+ tf.Tensor: processed date tensor with all components
108+ [year, month, day_of_month, day_of_week].
107109 """
108110
109111 def parse_date (date_str : str ) -> tf .Tensor :
112+ # Handle missing/invalid dates
113+ is_valid = tf .strings .regex_full_match (
114+ date_str ,
115+ r"^\d{1,4}[-/]\d{1,2}[-/]\d{1,2}$" ,
116+ )
117+ tf .debugging .assert_equal (
118+ is_valid ,
119+ True ,
120+ message = "Invalid date format. Expected YYYY-MM-DD or YYYY/MM/DD" ,
121+ )
122+
123+ # First, standardize the separator to '-' in case of YYYY/MM/DD format
124+ date_str = tf .strings .regex_replace (date_str , "/" , "-" )
125+
110126 parts = tf .strings .split (date_str , "-" )
111127 year = tf .strings .to_number (parts [0 ], out_type = tf .int32 )
112128 month = tf .strings .to_number (parts [1 ], out_type = tf .int32 )
113- day = tf .strings .to_number (parts [2 ], out_type = tf .int32 )
129+ day_of_month = tf .strings .to_number (parts [2 ], out_type = tf .int32 )
130+
131+ # Validate date components
132+ # Validate year is in reasonable range
133+ tf .debugging .assert_greater_equal (
134+ year ,
135+ 1000 ,
136+ message = "Year must be >= 1000" ,
137+ )
138+ tf .debugging .assert_less_equal (
139+ year ,
140+ 2200 ,
141+ message = "Year must be <= 2200" ,
142+ )
143+
144+ # Validate month is between 1-12
145+ tf .debugging .assert_greater_equal (
146+ month ,
147+ 1 ,
148+ message = "Month must be >= 1" ,
149+ )
150+ tf .debugging .assert_less_equal (
151+ month ,
152+ 12 ,
153+ message = "Month must be <= 12" ,
154+ )
155+
156+ # Validate day is between 1-31
157+ tf .debugging .assert_greater_equal (
158+ day_of_month ,
159+ 1 ,
160+ message = "Day must be >= 1" ,
161+ )
162+ tf .debugging .assert_less_equal (
163+ day_of_month ,
164+ 31 ,
165+ message = "Day must be <= 31" ,
166+ )
114167
115168 # Calculate day of week using Zeller's congruence
116169 y = tf .where (month < 3 , year - 1 , year )
117170 m = tf .where (month < 3 , month + 12 , month )
118171 k = y % 100
119172 j = y // 100
120- h = (day + ((13 * (m + 1 )) // 5 ) + k + (k // 4 ) + (j // 4 ) - (2 * j )) % 7
121- day_of_week = tf .where (h == 0 , 6 , h - 1 ) # Adjust to 0-6 range where 0 is Monday
173+ h = (day_of_month + ((13 * (m + 1 )) // 5 ) + k + (k // 4 ) + (j // 4 ) - (2 * j )) % 7
174+ day_of_week = tf .where (h == 0 , 6 , h - 1 ) # Adjust to 0-6 range where 0 is Sunday
122175
123- return tf .stack ([year , month , day_of_week ])
176+ return tf .stack ([year , month , day_of_month , day_of_week ])
124177
125178 parsed_dates = tf .map_fn (parse_date , tf .squeeze (inputs ), fn_output_signature = tf .int32 )
126179 return parsed_dates
127180
128181 def compute_output_shape (self , input_shape : int ) -> int :
129182 """Getting output shape."""
130- return tf .TensorShape ([input_shape [0 ], 3 ])
183+ return tf .TensorShape ([input_shape [0 ], 4 ]) # Changed to 4 components
131184
132185 def get_config (self ) -> dict :
133186 """Saving configuration."""
@@ -171,14 +224,14 @@ def cyclic_encoding(self, value: tf.Tensor, period: float) -> tuple[tf.Tensor, t
171224
172225 @tf .function
173226 def call (self , inputs : tf .Tensor ) -> tf .Tensor :
174- """Splits the date into 3 components: year, month and day and
227+ """Splits the date into 4 components: year, month, day and day of the week and
175228 encodes it into sin and cos cyclical projections.
176229
177230 Args:
178- inputs (tf.Tensor): input data.
231+ inputs (tf.Tensor): input data [year, month, day_of_month, day_of_week] .
179232
180233 Returns:
181- ( tf.Tensor) : cyclically encoded data (sin and cos).
234+ tf.Tensor: cyclically encoded data (sin and cos components ).
182235 """
183236 # Reshape input if necessary
184237 input_shape = tf .shape (inputs )
@@ -188,19 +241,22 @@ def call(self, inputs: tf.Tensor) -> tf.Tensor:
188241 # Extract features
189242 year = inputs [:, 0 ]
190243 month = inputs [:, 1 ]
191- day_of_week = inputs [:, 2 ]
244+ day_of_month = inputs [:, 2 ] # New: day of month
245+ day_of_week = inputs [:, 3 ] # Now at index 3
192246
193- # Cyclical encoding
247+ # Convert to float
194248 year_float = tf .cast (year , tf .float32 )
195249 month_float = tf .cast (month , tf .float32 )
250+ day_of_month_float = tf .cast (day_of_month , tf .float32 )
196251 day_of_week_float = tf .cast (day_of_week , tf .float32 )
197252
198253 # Ensure inputs are in the correct range
199254 year_float = self .normalize_year (year_float )
200255
201- # Encode each feature
256+ # Encode each feature in cyclinc projections
202257 year_sin , year_cos = self .cyclic_encoding (year_float , period = 1.0 )
203258 month_sin , month_cos = self .cyclic_encoding (month_float , period = 12.0 )
259+ day_of_month_sin , day_of_month_cos = self .cyclic_encoding (day_of_month_float , period = 31.0 )
204260 day_of_week_sin , day_of_week_cos = self .cyclic_encoding (day_of_week_float , period = 7.0 )
205261
206262 encoded = tf .stack (
@@ -209,6 +265,8 @@ def call(self, inputs: tf.Tensor) -> tf.Tensor:
209265 year_cos ,
210266 month_sin ,
211267 month_cos ,
268+ day_of_month_sin , # New
269+ day_of_month_cos , # New
212270 day_of_week_sin ,
213271 day_of_week_cos ,
214272 ],
@@ -219,7 +277,7 @@ def call(self, inputs: tf.Tensor) -> tf.Tensor:
219277
220278 def compute_output_shape (self , input_shape : int ) -> int :
221279 """Getting output shape."""
222- return tf .TensorShape ([input_shape [0 ], 6 ])
280+ return tf .TensorShape ([input_shape [0 ], 8 ]) # Changed to 8 for 4 features * 2 components each
223281
224282 def get_config (self ) -> dict :
225283 """Returns the configuration of the layer as a dictionary."""
@@ -238,10 +296,11 @@ class SeasonLayer(tf.keras.layers.Layer):
238296 Spring, Summer, and Fall. The one-hot encoding is appended to the input tensor.
239297
240298 Required Input Format:
241- - A tensor of shape [batch_size, 3 ], where each row contains:
299+ - A tensor of shape [batch_size, 4 ], where each row contains:
242300 - year (int): Year as a numerical value.
243301 - month (int): Month as an integer from 1 to 12.
244- - day_of_week (int): Day of the week as an integer from 0 to 6 (where 0=Monday).
302+ - day_of_month (int): Day of the month as an integer from 1 to 31.
303+ - day_of_week (int): Day of the week as an integer from 0 to 6 (where 0=Sunday).
245304 """
246305
247306 def __init__ (self , ** kwargs ):
@@ -253,14 +312,15 @@ def call(self, inputs: tf.Tensor) -> tf.Tensor:
253312 """Adds seasonal one-hot encoding to the input tensor.
254313
255314 Args:
256- inputs (tf.Tensor): A tensor of shape [batch_size, 3] where each row contains [year, month, day_of_week].
315+ inputs (tf.Tensor): A tensor of shape [batch_size, 4] where each row contains
316+ [year, month, day_of_month, day_of_week].
257317
258318 Returns:
259- tf.Tensor: A tensor of shape [batch_size, 7 ] with the original features
319+ tf.Tensor: A tensor of shape [batch_size, 8 ] with the original features
260320 plus the one-hot encoded season information.
261321
262322 Raises:
263- ValueError: If the input tensor does not have shape [batch_size, 3 ] or contains invalid month values.
323+ ValueError: If the input tensor does not have shape [batch_size, 4 ] or contains invalid month values.
264324 """
265325 # Ensure inputs is 2D
266326 if len (tf .shape (inputs )) == 1 :
@@ -282,8 +342,8 @@ def call(self, inputs: tf.Tensor) -> tf.Tensor:
282342 + tf .cast (is_fall , tf .int32 ) * 3
283343 )
284344
285- # Convert season to one-hot encoding
286- season_one_hot = tf .one_hot (season , depth = 4 )
345+ # Convert season to one-hot encoding and cast to int32 to match input type
346+ season_one_hot = tf .cast ( tf . one_hot (season , depth = 4 ), tf . int32 )
287347
288348 return tf .concat ([inputs , season_one_hot ], axis = - 1 )
289349
0 commit comments