Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit 1cf5d09

Browse files
committed
test(KDP): add end to end and unit test for the "TabularAttention" and the "MultiResolutionTabularAttention" functionality (where it looks for pattern in/inbetween the data) + add dense layer for projecting the "numerical_features" to "d_model" size
1 parent a1b3018 commit 1cf5d09

File tree

3 files changed

+452
-7
lines changed

3 files changed

+452
-7
lines changed

kdp/custom_layers.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -615,6 +615,10 @@ def __init__(self, num_heads: int, d_model: int, embedding_dim: int, dropout_rat
615615
self.embedding_dim = embedding_dim
616616
self.dropout_rate = dropout_rate
617617

618+
# Create projection layers during initialization
619+
self.numerical_projection = tf.keras.layers.Dense(d_model)
620+
self.categorical_projection = tf.keras.layers.Dense(embedding_dim)
621+
618622
# Numerical attention
619623
self.numerical_attention = tf.keras.layers.MultiHeadAttention(
620624
num_heads=num_heads,
@@ -633,7 +637,6 @@ def __init__(self, num_heads: int, d_model: int, embedding_dim: int, dropout_rat
633637
self.numerical_dropout2 = tf.keras.layers.Dropout(dropout_rate)
634638

635639
# Categorical attention
636-
self.categorical_projection = tf.keras.layers.Dense(embedding_dim)
637640
self.categorical_attention = tf.keras.layers.MultiHeadAttention(
638641
num_heads=num_heads,
639642
key_dim=embedding_dim // num_heads,
@@ -680,15 +683,17 @@ def call(
680683
- numerical_output: Tensor of shape (batch_size, num_numerical, d_model)
681684
- categorical_output: Tensor of shape (batch_size, num_categorical, d_model)
682685
"""
683-
# Process numerical features
686+
# Use the pre-initialized projection layer
687+
numerical_projected = self.numerical_projection(numerical_features)
688+
# Now process with attention
684689
numerical_attn = self.numerical_attention(
685-
numerical_features,
686-
numerical_features,
687-
numerical_features,
690+
numerical_projected,
691+
numerical_projected,
692+
numerical_projected,
688693
training=training,
689694
)
690695
numerical_1 = self.numerical_layernorm1(
691-
numerical_features + self.numerical_dropout1(numerical_attn, training=training),
696+
numerical_projected + self.numerical_dropout1(numerical_attn, training=training),
692697
)
693698
numerical_ffn = self.numerical_ffn(numerical_1)
694699
numerical_2 = self.numerical_layernorm2(

test/test_custom_layers.py

Lines changed: 321 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
"""Tests for custom layers in the KDP package."""
22

3+
import logging
34
import math
45
from datetime import datetime
56

@@ -528,3 +529,323 @@ def test_full_year_cycle(self):
528529

529530
# First winter and next winter should have same encoding
530531
assert tf.reduce_all(result[0, -4:] == result[4, -4:])
532+
533+
534+
def test_tabular_attention_shapes():
535+
"""Test that TabularAttention produces correct output shapes."""
536+
# Setup
537+
batch_size = 32
538+
num_samples = 10
539+
num_features = 8
540+
d_model = 16
541+
num_heads = 4
542+
543+
layer = TabularAttention(num_heads=num_heads, d_model=d_model)
544+
545+
# Create sample inputs
546+
inputs = tf.random.normal((batch_size, num_samples, num_features))
547+
548+
# Process features
549+
outputs = layer(inputs, training=True)
550+
551+
# Check shapes
552+
assert outputs.shape == (batch_size, num_samples, d_model)
553+
554+
# Test with different input shapes
555+
inputs_2d = tf.random.normal((batch_size, num_features))
556+
with pytest.raises(ValueError):
557+
layer(inputs_2d) # Should raise error for 2D input
558+
559+
560+
def test_tabular_attention_training_modes():
561+
"""Test TabularAttention behavior in training vs inference modes."""
562+
batch_size = 16
563+
num_samples = 8
564+
num_features = 12
565+
d_model = 24
566+
num_heads = 3
567+
dropout_rate = 0.5 # High dropout for visible effect
568+
569+
layer = TabularAttention(num_heads=num_heads, d_model=d_model, dropout_rate=dropout_rate)
570+
571+
# Create inputs
572+
inputs = tf.random.normal((batch_size, num_samples, num_features))
573+
574+
# Get outputs in training mode
575+
train_output = layer(inputs, training=True)
576+
577+
# Get outputs in inference mode
578+
infer_output = layer(inputs, training=False)
579+
580+
# Check that outputs are different due to dropout
581+
assert not np.allclose(train_output.numpy(), infer_output.numpy())
582+
583+
584+
def test_tabular_attention_feature_interactions():
585+
"""Test that TabularAttention captures feature interactions."""
586+
batch_size = 8
587+
num_samples = 4
588+
num_features = 6
589+
d_model = 12
590+
num_heads = 2
591+
592+
layer = TabularAttention(num_heads=num_heads, d_model=d_model)
593+
594+
# Create correlated features
595+
base_feature = tf.random.normal((batch_size, num_samples, 1))
596+
correlated_features = tf.concat(
597+
[
598+
base_feature,
599+
base_feature * 2 + tf.random.normal((batch_size, num_samples, 1), stddev=0.1),
600+
tf.random.normal((batch_size, num_samples, num_features - 2)),
601+
],
602+
axis=-1,
603+
)
604+
605+
# Process features
606+
output_correlated = layer(correlated_features)
607+
608+
# Create uncorrelated features
609+
uncorrelated_features = tf.random.normal((batch_size, num_samples, num_features))
610+
output_uncorrelated = layer(uncorrelated_features)
611+
612+
# The attention patterns should be different
613+
assert not np.allclose(output_correlated.numpy(), output_uncorrelated.numpy(), rtol=1e-3)
614+
615+
616+
def test_tabular_attention_config():
617+
"""Test configuration saving and loading."""
618+
original_layer = TabularAttention(num_heads=4, d_model=32, dropout_rate=0.2)
619+
620+
config = original_layer.get_config()
621+
restored_layer = TabularAttention.from_config(config)
622+
623+
assert restored_layer.num_heads == original_layer.num_heads
624+
assert restored_layer.d_model == original_layer.d_model
625+
assert restored_layer.dropout_rate == original_layer.dropout_rate
626+
627+
628+
def test_tabular_attention_end_to_end():
629+
"""Test TabularAttention in a simple end-to-end model."""
630+
batch_size = 16
631+
num_samples = 6
632+
num_features = 8
633+
d_model = 16
634+
num_heads = 2
635+
636+
# Create a simple model
637+
inputs = tf.keras.Input(shape=(num_samples, num_features))
638+
attention_layer = TabularAttention(num_heads=num_heads, d_model=d_model)
639+
640+
x = attention_layer(inputs)
641+
x = tf.keras.layers.GlobalAveragePooling1D()(x)
642+
outputs = tf.keras.layers.Dense(1)(x)
643+
644+
model = tf.keras.Model(inputs=inputs, outputs=outputs)
645+
646+
# Compile model
647+
model.compile(optimizer="adam", loss="mse")
648+
649+
# Create dummy data
650+
X = tf.random.normal((batch_size, num_samples, num_features))
651+
y = tf.random.normal((batch_size, 1))
652+
653+
# Train for one epoch
654+
history = model.fit(X, y, epochs=1, verbose=0)
655+
656+
# Check if loss was computed
657+
assert "loss" in history.history
658+
assert len(history.history["loss"]) == 1
659+
660+
661+
def test_tabular_attention_masking():
662+
"""Test TabularAttention with masked inputs."""
663+
batch_size = 8
664+
num_samples = 5
665+
num_features = 4
666+
d_model = 8
667+
num_heads = 2
668+
669+
layer = TabularAttention(num_heads=num_heads, d_model=d_model)
670+
671+
# Create inputs with masked values
672+
inputs = tf.random.normal((batch_size, num_samples, num_features))
673+
mask = tf.random.uniform((batch_size, num_samples)) > 0.3
674+
masked_inputs = tf.where(tf.expand_dims(mask, -1), inputs, tf.zeros_like(inputs))
675+
676+
# Process both masked and unmasked inputs
677+
output_masked = layer(masked_inputs)
678+
output_unmasked = layer(inputs)
679+
680+
# Outputs should be different
681+
assert not np.allclose(output_masked.numpy(), output_unmasked.numpy())
682+
683+
684+
def test_multi_resolution_attention_shapes():
685+
"""Test that MultiResolutionTabularAttention produces correct output shapes."""
686+
# Setup
687+
batch_size = 32
688+
num_numerical = 10
689+
num_categorical = 5
690+
numerical_features_num = 8
691+
categorical_features_num = 7
692+
d_model = 16
693+
embedding_dim = 16
694+
num_heads = 4
695+
696+
layer = MultiResolutionTabularAttention(num_heads=num_heads, d_model=d_model, embedding_dim=embedding_dim)
697+
698+
# Create sample inputs
699+
numerical_features = tf.random.normal((batch_size, num_numerical, numerical_features_num))
700+
categorical_features = tf.random.normal((batch_size, num_categorical, categorical_features_num))
701+
702+
# Process features
703+
numerical_output, categorical_output = layer(numerical_features, categorical_features, training=True)
704+
705+
# Check shapes
706+
assert numerical_output.shape == (batch_size, num_numerical, d_model)
707+
assert categorical_output.shape == (batch_size, num_categorical, d_model)
708+
709+
710+
def test_multi_resolution_attention_training():
711+
"""Test MultiResolutionTabularAttention behavior in training vs inference modes."""
712+
# Setup
713+
batch_size = 16
714+
num_numerical = 8
715+
num_categorical = 4
716+
numerical_dim = 24
717+
categorical_dim = 6
718+
d_model = 24
719+
embedding_dim = 12
720+
num_heads = 3
721+
dropout_rate = 0.5 # High dropout for visible effect
722+
723+
layer = MultiResolutionTabularAttention(
724+
num_heads=num_heads, d_model=d_model, embedding_dim=embedding_dim, dropout_rate=dropout_rate
725+
)
726+
727+
# Create inputs
728+
numerical_features = tf.random.normal((batch_size, num_numerical, numerical_dim))
729+
categorical_features = tf.random.normal((batch_size, num_categorical, categorical_dim))
730+
731+
# Get outputs in training mode
732+
num_train, cat_train = layer(numerical_features, categorical_features, training=True)
733+
734+
# Get outputs in inference mode
735+
num_infer, cat_infer = layer(numerical_features, categorical_features, training=False)
736+
737+
# Check that outputs are different due to dropout
738+
assert not np.allclose(num_train.numpy(), num_infer.numpy())
739+
assert not np.allclose(cat_train.numpy(), cat_infer.numpy())
740+
741+
742+
def test_multi_resolution_attention_cross_attention():
743+
"""Test that cross-attention is working between numerical and categorical features."""
744+
745+
# Setup
746+
batch_size = 8
747+
num_numerical = 4
748+
num_categorical = 2
749+
numerical_dim = 8
750+
categorical_dim = 4
751+
d_model = 8
752+
embedding_dim = 8
753+
num_heads = 2
754+
755+
layer = MultiResolutionTabularAttention(
756+
num_heads=num_heads,
757+
d_model=d_model,
758+
embedding_dim=embedding_dim,
759+
dropout_rate=0.0, # Disable dropout for deterministic testing
760+
)
761+
762+
# Create numerical features
763+
numerical_features = tf.random.normal((batch_size, num_numerical, numerical_dim))
764+
765+
# Create contrasting categorical patterns using string colors
766+
colors1 = tf.constant([["blue", "green"] for _ in range(batch_size)]) # Warm colors
767+
colors2 = tf.constant([["red", "yellow"] for _ in range(batch_size)]) # Cool colors
768+
769+
# Convert strings to one-hot encodings
770+
all_colors = ["red", "blue", "green", "yellow"]
771+
table = tf.lookup.StaticHashTable(
772+
tf.lookup.KeyValueTensorInitializer(all_colors, tf.range(len(all_colors), dtype=tf.int64)), default_value=-1
773+
)
774+
775+
categorical_pattern1 = tf.one_hot(table.lookup(colors1), categorical_dim)
776+
categorical_pattern2 = tf.one_hot(table.lookup(colors2), categorical_dim)
777+
778+
# Process with contrasting categorical patterns
779+
num_output1, cat_output1 = layer(numerical_features, categorical_pattern1, training=False)
780+
num_output2, cat_output2 = layer(numerical_features, categorical_pattern2, training=False)
781+
782+
# Check numerical outputs are different due to contrasting categorical patterns
783+
num_mean_diff = tf.reduce_mean(tf.abs(num_output1 - num_output2))
784+
assert num_mean_diff > 1e-3, "Numerical outputs are too similar - cross attention may not be working"
785+
786+
# Check categorical outputs are different
787+
cat_mean_diff = tf.reduce_mean(tf.abs(cat_output1 - cat_output2))
788+
assert cat_mean_diff > 1e-3, "Categorical outputs are too similar - cross attention may not be working"
789+
790+
# Check shapes are correct
791+
assert cat_output1.shape == cat_output2.shape
792+
assert cat_output1.shape[0] == batch_size
793+
assert cat_output1.shape[1] == num_categorical
794+
assert cat_output1.shape[2] == d_model
795+
796+
# Check outputs are in reasonable range
797+
assert tf.reduce_all(tf.abs(cat_output1) < 10), "Categorical outputs 1 have unexpectedly large values"
798+
assert tf.reduce_all(tf.abs(cat_output2) < 10), "Categorical outputs 2 have unexpectedly large values"
799+
800+
801+
def test_multi_resolution_attention_config():
802+
"""Test configuration saving and loading."""
803+
original_layer = MultiResolutionTabularAttention(num_heads=4, d_model=32, embedding_dim=16, dropout_rate=0.2)
804+
805+
config = original_layer.get_config()
806+
restored_layer = MultiResolutionTabularAttention.from_config(config)
807+
808+
assert restored_layer.num_heads == original_layer.num_heads
809+
assert restored_layer.d_model == original_layer.d_model
810+
assert restored_layer.embedding_dim == original_layer.embedding_dim
811+
assert restored_layer.dropout_rate == original_layer.dropout_rate
812+
813+
814+
def test_multi_resolution_attention_end_to_end():
815+
"""Test MultiResolutionTabularAttention in a simple end-to-end model."""
816+
# Setup
817+
batch_size = 16
818+
num_numerical = 100
819+
num_categorical = 10
820+
numerical_dim = 16
821+
categorical_dim = 4
822+
d_model = 8
823+
embedding_dim = 8
824+
num_heads = 2
825+
826+
# Create a simple model
827+
numerical_inputs = tf.keras.Input(shape=(num_numerical, numerical_dim))
828+
categorical_inputs = tf.keras.Input(shape=(num_categorical, categorical_dim))
829+
830+
attention_layer = MultiResolutionTabularAttention(num_heads=num_heads, d_model=d_model, embedding_dim=embedding_dim)
831+
832+
num_output, cat_output = attention_layer(numerical_inputs, categorical_inputs)
833+
combined = tf.keras.layers.Concatenate(axis=1)([num_output, cat_output])
834+
outputs = tf.keras.layers.Dense(1)(combined)
835+
836+
model = tf.keras.Model(inputs=[numerical_inputs, categorical_inputs], outputs=outputs)
837+
838+
# Compile model
839+
model.compile(optimizer="adam", loss="mse")
840+
841+
# Create dummy data
842+
X_num = tf.random.normal((batch_size, num_numerical, numerical_dim))
843+
X_cat = tf.random.normal((batch_size, num_categorical, categorical_dim))
844+
y = tf.random.normal((batch_size, num_numerical + num_categorical, 1))
845+
846+
# Train for one epoch
847+
history = model.fit([X_num, X_cat], y, epochs=1, verbose=0)
848+
849+
# Check if loss was computed
850+
assert "loss" in history.history
851+
assert len(history.history["loss"]) == 1

0 commit comments

Comments
 (0)