Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit d0303cb

Browse files
committed
test(validation): added unit tests and fixed a little type mismatch
1 parent fa88c24 commit d0303cb

File tree

2 files changed

+317
-17
lines changed

2 files changed

+317
-17
lines changed

kdp/custom_layers.py

Lines changed: 46 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -104,12 +104,16 @@ def call(self, inputs: tf.Tensor) -> tf.Tensor:
104104
inputs (tf.Tensor): Tensor with input data.
105105
106106
Returns:
107-
tf.Tensor: processed date tensor with all components [year, month, day_of_month, day_of_week].
107+
tf.Tensor: processed date tensor with all components
108+
[year, month, day_of_month, day_of_week].
108109
"""
109110

110111
def parse_date(date_str: str) -> tf.Tensor:
111112
# Handle missing/invalid dates
112-
is_valid = tf.strings.regex_full_match(date_str, r"^\d{1,4}[-/]\d{1,2}[-/]\d{1,2}$")
113+
is_valid = tf.strings.regex_full_match(
114+
date_str,
115+
r"^\d{1,4}[-/]\d{1,2}[-/]\d{1,2}$",
116+
)
113117
tf.debugging.assert_equal(
114118
is_valid,
115119
True,
@@ -126,24 +130,48 @@ def parse_date(date_str: str) -> tf.Tensor:
126130

127131
# Validate date components
128132
# Validate year is in reasonable range
129-
tf.debugging.assert_greater_equal(year, 1000, message="Year must be >= 1000")
130-
tf.debugging.assert_less_equal(year, 2200, message="Year must be <= 2200")
133+
tf.debugging.assert_greater_equal(
134+
year,
135+
1000,
136+
message="Year must be >= 1000",
137+
)
138+
tf.debugging.assert_less_equal(
139+
year,
140+
2200,
141+
message="Year must be <= 2200",
142+
)
131143

132144
# Validate month is between 1-12
133-
tf.debugging.assert_greater_equal(month, 1, message="Month must be >= 1")
134-
tf.debugging.assert_less_equal(month, 12, message="Month must be <= 12")
145+
tf.debugging.assert_greater_equal(
146+
month,
147+
1,
148+
message="Month must be >= 1",
149+
)
150+
tf.debugging.assert_less_equal(
151+
month,
152+
12,
153+
message="Month must be <= 12",
154+
)
135155

136156
# Validate day is between 1-31
137-
tf.debugging.assert_greater_equal(day_of_month, 1, message="Day must be >= 1")
138-
tf.debugging.assert_less_equal(day_of_month, 31, message="Day must be <= 31")
157+
tf.debugging.assert_greater_equal(
158+
day_of_month,
159+
1,
160+
message="Day must be >= 1",
161+
)
162+
tf.debugging.assert_less_equal(
163+
day_of_month,
164+
31,
165+
message="Day must be <= 31",
166+
)
139167

140168
# Calculate day of week using Zeller's congruence
141169
y = tf.where(month < 3, year - 1, year)
142170
m = tf.where(month < 3, month + 12, month)
143171
k = y % 100
144172
j = y // 100
145173
h = (day_of_month + ((13 * (m + 1)) // 5) + k + (k // 4) + (j // 4) - (2 * j)) % 7
146-
day_of_week = tf.where(h == 0, 6, h - 1) # Adjust to 0-6 range where 0 is Monday
174+
day_of_week = tf.where(h == 0, 6, h - 1) # Adjust to 0-6 range where 0 is Sunday
147175

148176
return tf.stack([year, month, day_of_month, day_of_week])
149177

@@ -268,10 +296,11 @@ class SeasonLayer(tf.keras.layers.Layer):
268296
Spring, Summer, and Fall. The one-hot encoding is appended to the input tensor.
269297
270298
Required Input Format:
271-
- A tensor of shape [batch_size, 3], where each row contains:
299+
- A tensor of shape [batch_size, 4], where each row contains:
272300
- year (int): Year as a numerical value.
273301
- month (int): Month as an integer from 1 to 12.
274-
- day_of_week (int): Day of the week as an integer from 0 to 6 (where 0=Monday).
302+
- day_of_month (int): Day of the month as an integer from 1 to 31.
303+
- day_of_week (int): Day of the week as an integer from 0 to 6 (where 0=Sunday).
275304
"""
276305

277306
def __init__(self, **kwargs):
@@ -283,14 +312,15 @@ def call(self, inputs: tf.Tensor) -> tf.Tensor:
283312
"""Adds seasonal one-hot encoding to the input tensor.
284313
285314
Args:
286-
inputs (tf.Tensor): A tensor of shape [batch_size, 3] where each row contains [year, month, day_of_week].
315+
inputs (tf.Tensor): A tensor of shape [batch_size, 4] where each row contains
316+
[year, month, day_of_month, day_of_week].
287317
288318
Returns:
289-
tf.Tensor: A tensor of shape [batch_size, 7] with the original features
319+
tf.Tensor: A tensor of shape [batch_size, 8] with the original features
290320
plus the one-hot encoded season information.
291321
292322
Raises:
293-
ValueError: If the input tensor does not have shape [batch_size, 3] or contains invalid month values.
323+
ValueError: If the input tensor does not have shape [batch_size, 4] or contains invalid month values.
294324
"""
295325
# Ensure inputs is 2D
296326
if len(tf.shape(inputs)) == 1:
@@ -312,8 +342,8 @@ def call(self, inputs: tf.Tensor) -> tf.Tensor:
312342
+ tf.cast(is_fall, tf.int32) * 3
313343
)
314344

315-
# Convert season to one-hot encoding
316-
season_one_hot = tf.one_hot(season, depth=4)
345+
# Convert season to one-hot encoding and cast to int32 to match input type
346+
season_one_hot = tf.cast(tf.one_hot(season, depth=4), tf.int32)
317347

318348
return tf.concat([inputs, season_one_hot], axis=-1)
319349

0 commit comments

Comments
 (0)