Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit a1b3018

Browse files
committed
test(kdp): add end to end tests
1 parent 3042e2a commit a1b3018

File tree

2 files changed

+165
-3
lines changed

2 files changed

+165
-3
lines changed

kdp/custom_layers.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -342,9 +342,13 @@ def call(self, inputs: tf.Tensor) -> tf.Tensor:
342342
+ tf.cast(is_fall, tf.int32) * 3
343343
)
344344

345-
# Convert season to one-hot encoding and cast to int32 to match input type
346-
season_one_hot = tf.cast(tf.one_hot(season, depth=4), tf.int32)
345+
# Convert season to one-hot encoding and cast to float32 to match input type
346+
season_one_hot = tf.cast(tf.one_hot(season, depth=4), tf.float32)
347347

348+
# Just in case it comes as int32, cast inputs to float32
349+
inputs = tf.cast(inputs, tf.float32)
350+
351+
# Now both tensors are float32, concatenation will work
348352
return tf.concat([inputs, season_one_hot], axis=-1)
349353

350354
def compute_output_shape(self, input_shape: int) -> int:

test/test_processor.py

Lines changed: 159 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,10 +64,13 @@ class instances (NumericalFeature, CategoricalFeature, TextFeature, DateFeature)
6464
sentences = ["First sentence special x", "Second sentence special y"]
6565
data[feature] = np.random.choice(sentences, size=num_rows)
6666
elif feature_type == FeatureType.DATE:
67+
# Generate dates and convert them to string format
6768
start_date = pd.Timestamp("2020-01-01")
6869
end_date = pd.Timestamp("2023-01-01")
6970
date_range = pd.date_range(start=start_date, end=end_date, freq="D")
70-
data[feature] = np.random.choice(date_range, size=num_rows)
71+
dates = pd.Series(np.random.choice(date_range, size=num_rows))
72+
data[feature] = dates.dt.strftime("%Y-%m-%d")
73+
7174
return pd.DataFrame(data)
7275

7376

@@ -479,6 +482,161 @@ def test_caching_functionality(self):
479482
)
480483
self.assertIsNone(model_no_cache._preprocessed_cache)
481484

485+
def test_end_to_end_feature_combinations(self):
486+
"""Test different combinations of features with dates."""
487+
488+
test_cases = [
489+
{
490+
"name": "numeric_and_dates",
491+
"features": {
492+
"num1": FeatureType.FLOAT_NORMALIZED,
493+
"num2": FeatureType.FLOAT_RESCALED,
494+
"date1": DateFeature(name="date1", feature_type=FeatureType.DATE, add_season=True),
495+
},
496+
},
497+
{
498+
"name": "numeric_categorical_dates",
499+
"features": {
500+
"num1": FeatureType.FLOAT_NORMALIZED,
501+
"cat1": FeatureType.STRING_CATEGORICAL,
502+
"date1": DateFeature(name="date1", feature_type=FeatureType.DATE, add_season=True),
503+
},
504+
},
505+
{
506+
"name": "categorical_and_dates",
507+
"features": {
508+
"cat1": FeatureType.STRING_CATEGORICAL,
509+
"cat2": FeatureType.INTEGER_CATEGORICAL,
510+
"date1": DateFeature(name="date1", feature_type=FeatureType.DATE, add_season=True),
511+
},
512+
},
513+
{
514+
"name": "dates_and_text",
515+
"features": {
516+
"text1": TextFeature(
517+
name="text1",
518+
max_tokens=100,
519+
),
520+
"date1": DateFeature(name="date1", feature_type=FeatureType.DATE, add_season=True),
521+
},
522+
},
523+
{
524+
"name": "all_features_with_transformer",
525+
"features": {
526+
"num1": FeatureType.FLOAT_NORMALIZED,
527+
"cat1": FeatureType.STRING_CATEGORICAL,
528+
"text1": TextFeature(name="text1", max_tokens=100),
529+
"date1": DateFeature(name="date1", feature_type=FeatureType.DATE, add_season=True),
530+
},
531+
"use_transformer": True,
532+
},
533+
{
534+
"name": "multiple_dates",
535+
"features": {
536+
"date1": DateFeature(name="date1", feature_type=FeatureType.DATE, add_season=True),
537+
"date2": DateFeature(name="date2", feature_type=FeatureType.DATE, output_format="year"),
538+
"date3": DateFeature(name="date3", feature_type=FeatureType.DATE, output_format="month"),
539+
},
540+
},
541+
]
542+
543+
for test_case in test_cases:
544+
with self.subTest(msg=f"Testing {test_case['name']}"):
545+
# Generate fake data
546+
df = generate_fake_data(test_case["features"], num_rows=100)
547+
548+
df.to_csv(self._path_data, index=False)
549+
550+
# Create preprocessor
551+
ppr = PreprocessingModel(
552+
path_data=str(self._path_data),
553+
features_specs=test_case["features"],
554+
features_stats_path=self.features_stats_path,
555+
overwrite_stats=True,
556+
output_mode=OutputModeOptions.CONCAT,
557+
# Add transformer blocks if specified
558+
transfo_nr_blocks=2 if test_case.get("use_transformer") else None,
559+
transfo_nr_heads=4 if test_case.get("use_transformer") else None,
560+
transfo_ff_units=32 if test_case.get("use_transformer") else None,
561+
)
562+
563+
# Build and verify preprocessor
564+
result = ppr.build_preprocessor()
565+
self.assertIsNotNone(result["model"])
566+
567+
# Create a small batch of test data
568+
test_data = generate_fake_data(test_case["features"], num_rows=5)
569+
dataset = tf.data.Dataset.from_tensor_slices(dict(test_data))
570+
571+
# Test preprocessing
572+
preprocessed = ppr.batch_predict(dataset)
573+
self.assertIsNotNone(preprocessed)
574+
575+
# Additional checks based on feature combination
576+
if "date1" in test_case["features"]:
577+
date_feature = test_case["features"]["date1"]
578+
if getattr(date_feature, "add_season", False):
579+
# Check if output shape includes seasonal encoding
580+
self.assertGreaterEqual(preprocessed.shape[-1], 4) # At least 4 dims for season
581+
582+
if test_case.get("use_transformer"):
583+
# Verify transformer layers are present
584+
self.assertTrue(any("transformer" in layer.name.lower() for layer in result["model"].layers))
585+
586+
def test_date_feature_variations(self):
587+
"""Test different date feature configurations."""
588+
589+
date_configs = [
590+
{
591+
"name": "basic_date",
592+
"config": DateFeature(
593+
name="date",
594+
feature_type=FeatureType.DATE,
595+
),
596+
},
597+
{
598+
"name": "date_with_season",
599+
"config": DateFeature(name="date", feature_type=FeatureType.DATE, add_season=True),
600+
},
601+
{
602+
"name": "custom_format_date",
603+
"config": DateFeature(name="date", feature_type=FeatureType.DATE, date_format="%m/%d/%Y"),
604+
},
605+
{
606+
"name": "date_year_only",
607+
"config": DateFeature(name="date", feature_type=FeatureType.DATE, output_format="year"),
608+
},
609+
]
610+
611+
for config in date_configs:
612+
with self.subTest(msg=f"Testing {config['name']}"):
613+
features_specs = {"date": config["config"]}
614+
615+
# Generate and save test data
616+
df = generate_fake_data(features_specs, num_rows=50)
617+
df.to_csv(self._path_data, index=False)
618+
619+
# Create and build preprocessor
620+
ppr = PreprocessingModel(
621+
path_data=str(self._path_data),
622+
features_specs=features_specs,
623+
features_stats_path=self.features_stats_path,
624+
overwrite_stats=True,
625+
)
626+
627+
result = ppr.build_preprocessor()
628+
self.assertIsNotNone(result["model"])
629+
630+
# Test with specific date formats
631+
if config["name"] == "custom_format_date":
632+
test_data = pd.DataFrame({"date": ["01/15/2023", "12/31/2022"]})
633+
else:
634+
test_data = pd.DataFrame({"date": ["2023-01-15", "2022-12-31"]})
635+
636+
dataset = tf.data.Dataset.from_tensor_slices(dict(test_data))
637+
preprocessed = ppr.batch_predict(dataset)
638+
self.assertIsNotNone(preprocessed)
639+
482640

483641
if __name__ == "__main__":
484642
unittest.main()

0 commit comments

Comments
 (0)