@@ -382,5 +382,271 @@ def test_invalid_input(self):
382382 self .encoder (tf .constant ([["1" , "2" ], ["3" , "4" ]]))
383383
384384
385+ class TestAdvancedOptionsDistributionAwareEncoder (tf .test .TestCase ):
386+ def setUp (self ):
387+ super ().setUp ()
388+ # Create an instance of the DistributionAwareEncoder with advanced features enabled.
389+ self .encoder = DistributionAwareEncoder (
390+ name = "distribution_aware_encoder" ,
391+ num_bins = 1000 ,
392+ epsilon = 1e-6 ,
393+ detect_periodicity = True ,
394+ handle_sparsity = True ,
395+ adaptive_binning = True ,
396+ mixture_components = 3 ,
397+ trainable = True ,
398+ )
399+
400+ def test_config_serialization (self ):
401+ """Test that the encoder's configuration is correctly saved and restored."""
402+ config = self .encoder .get_config ()
403+ new_encoder = DistributionAwareEncoder .from_config (config )
404+ self .assertEqual (new_encoder .num_bins , self .encoder .num_bins )
405+ self .assertEqual (new_encoder .epsilon , self .encoder .epsilon )
406+ self .assertEqual (
407+ new_encoder .detect_periodicity , self .encoder .detect_periodicity
408+ )
409+ self .assertEqual (new_encoder .handle_sparsity , self .encoder .handle_sparsity )
410+ self .assertEqual (new_encoder .adaptive_binning , self .encoder .adaptive_binning )
411+ self .assertEqual (
412+ new_encoder .mixture_components , self .encoder .mixture_components
413+ )
414+ self .assertTrue (new_encoder .trainable )
415+
416+ def test_periodic_processing (self ):
417+ """Test that periodic input data is encoded with the periodic branch."""
418+ # Create periodic data: sin wave with some noise.
419+ t = np .linspace (0 , 4 * np .pi , 100 ).astype (np .float32 )
420+ data = np .sin (t ) + 0.05 * np .random .normal (0 , 1 , 100 ).astype (np .float32 )
421+ inputs = tf .convert_to_tensor (data )
422+ outputs = self .encoder (inputs , training = False )
423+
424+ # With detect_periodicity=True, the output is expected to be concatenated
425+ # (e.g., sin/cos branches) doubling the dimensionality.
426+ self .assertEqual (
427+ outputs .shape [0 ],
428+ inputs .shape [0 ] * 2 ,
429+ "Periodicity encoding failed to double output dimensions." ,
430+ )
431+
432+ def test_sparsity_handling (self ):
433+ """Test that sparse inputs (mostly zeros) produce near-zero outputs in those positions."""
434+ data = np .zeros (100 , dtype = np .float32 )
435+ # Set a few indices to non-zero values.
436+ indices = np .random .choice (np .arange (100 ), size = 10 , replace = False )
437+ data [indices ] = np .random .normal (1 , 0.1 , size = 10 )
438+ inputs = tf .convert_to_tensor (data )
439+ outputs = self .encoder (inputs , training = False )
440+
441+ # In regions where input values are near zero the encoder should preserve sparsity.
442+ zero_mask = np .abs (data ) < self .encoder .epsilon
443+ outputs_val = outputs .numpy ()
444+ self .assertTrue (
445+ np .all (np .abs (outputs_val [zero_mask ]) < self .encoder .epsilon ),
446+ "Sparse inputs not preserved as near-zero in outputs." ,
447+ )
448+
449+
450+ class TestEncoderConfigurations (tf .test .TestCase ):
451+ def test_detect_periodicity_true (self ):
452+ """When detect_periodicity is True, periodic inputs should produce an output with doubled dimensions."""
453+ encoder = DistributionAwareEncoder (
454+ name = "encoder_periodic_true" ,
455+ num_bins = 1000 ,
456+ epsilon = 1e-6 ,
457+ detect_periodicity = True ,
458+ handle_sparsity = True ,
459+ adaptive_binning = True ,
460+ mixture_components = 3 ,
461+ trainable = True ,
462+ )
463+ # Create a sinusoidal input signal.
464+ t = np .linspace (0 , 4 * np .pi , 100 ).astype (np .float32 )
465+ data = np .sin (t )
466+ inputs = tf .convert_to_tensor (data )
467+ outputs = encoder (inputs , training = False )
468+ # With periodic detection enabled, the encoder output is expected to be (input_length * 2,)
469+ self .assertEqual (
470+ outputs .shape ,
471+ (inputs .shape [0 ] * 2 ,),
472+ "Expected output shape to be twice the input length when detecting periodicity." ,
473+ )
474+
475+ def test_detect_periodicity_false (self ):
476+ """When detect_periodicity is False, the output shape should match the input."""
477+ encoder = DistributionAwareEncoder (
478+ name = "encoder_periodic_false" ,
479+ num_bins = 1000 ,
480+ epsilon = 1e-6 ,
481+ detect_periodicity = False ,
482+ handle_sparsity = True ,
483+ adaptive_binning = True ,
484+ mixture_components = 3 ,
485+ trainable = True ,
486+ )
487+ # Use a sinusoidal input.
488+ t = np .linspace (0 , 4 * np .pi , 100 ).astype (np .float32 )
489+ data = np .sin (t )
490+ inputs = tf .convert_to_tensor (data )
491+ outputs = encoder (inputs , training = False )
492+ self .assertEqual (
493+ outputs .shape ,
494+ inputs .shape ,
495+ "Expected output shape to be the same as input when periodicity detection is disabled." ,
496+ )
497+
498+ def test_handle_sparsity_true (self ):
499+ """When handle_sparsity is True, input values near zero should be preserved as near-zero in the output."""
500+ encoder = DistributionAwareEncoder (
501+ name = "encoder_sparsity_true" ,
502+ num_bins = 1000 ,
503+ epsilon = 1e-6 ,
504+ detect_periodicity = False ,
505+ handle_sparsity = True ,
506+ adaptive_binning = True ,
507+ mixture_components = 3 ,
508+ trainable = True ,
509+ )
510+ # Generate sparse input data: mostly zeros with some non-zero values.
511+ data = np .zeros (200 , dtype = np .float32 )
512+ np .random .seed (42 )
513+ indices = np .random .choice (200 , size = 20 , replace = False )
514+ data [indices ] = np .random .normal (0 , 1 , size = 20 )
515+ inputs = tf .convert_to_tensor (data )
516+ outputs = encoder (inputs , training = False )
517+
518+ # For sparsity handling, zeros (or near-zero) in the input should give near-zero outputs.
519+ zero_mask = np .abs (data ) < encoder .epsilon
520+ outputs_np = outputs .numpy ()
521+ self .assertTrue (
522+ np .all (np .abs (outputs_np [zero_mask ]) < encoder .epsilon ),
523+ "When handle_sparsity is True, inputs near zero should produce near-zero outputs." ,
524+ )
525+
526+ def test_handle_sparsity_false (self ):
527+ """When handle_sparsity is False, there is no requirement to preserve zeros."""
528+ encoder = DistributionAwareEncoder (
529+ name = "encoder_sparsity_false" ,
530+ num_bins = 1000 ,
531+ epsilon = 1e-6 ,
532+ detect_periodicity = False ,
533+ handle_sparsity = False ,
534+ adaptive_binning = True ,
535+ mixture_components = 3 ,
536+ trainable = True ,
537+ )
538+ # Generate similar sparse input.
539+ data = np .zeros (200 , dtype = np .float32 )
540+ np .random .seed (42 )
541+ indices = np .random .choice (200 , size = 20 , replace = False )
542+ data [indices ] = np .random .normal (0 , 1 , size = 20 )
543+ inputs = tf .convert_to_tensor (data )
544+ outputs = encoder (inputs , training = False )
545+
546+ # When handle_sparsity is False, we do not insist on preserving zeros; instead, we can check that
547+ # at least some non-zero output is produced for non-zero input.
548+ non_zero_mask = np .abs (data ) > encoder .epsilon
549+ outputs_np = outputs .numpy ()
550+ self .assertTrue (
551+ np .any (np .abs (outputs_np [non_zero_mask ]) > encoder .epsilon ),
552+ "When handle_sparsity is False, non-zero inputs should result in non-zero outputs." ,
553+ )
554+
555+ def test_adaptive_binning_flag (self ):
556+ """Test that the adaptive_binning flag is stored correctly."""
557+ encoder_true = DistributionAwareEncoder (
558+ name = "encoder_adaptive_true" ,
559+ num_bins = 1000 ,
560+ epsilon = 1e-6 ,
561+ detect_periodicity = False ,
562+ handle_sparsity = True ,
563+ adaptive_binning = True ,
564+ mixture_components = 3 ,
565+ trainable = True ,
566+ )
567+ encoder_false = DistributionAwareEncoder (
568+ name = "encoder_adaptive_false" ,
569+ num_bins = 1000 ,
570+ epsilon = 1e-6 ,
571+ detect_periodicity = False ,
572+ handle_sparsity = True ,
573+ adaptive_binning = False ,
574+ mixture_components = 3 ,
575+ trainable = True ,
576+ )
577+ self .assertTrue (
578+ encoder_true .adaptive_binning , "Encoder should have adaptive_binning=True."
579+ )
580+ self .assertFalse (
581+ encoder_false .adaptive_binning ,
582+ "Encoder should have adaptive_binning=False." ,
583+ )
584+
585+ def test_mixture_components (self ):
586+ """Test that the mixture_components parameter is correctly stored."""
587+ encoder = DistributionAwareEncoder (
588+ name = "encoder_mixture" ,
589+ num_bins = 1000 ,
590+ epsilon = 1e-6 ,
591+ detect_periodicity = False ,
592+ handle_sparsity = True ,
593+ adaptive_binning = True ,
594+ mixture_components = 5 ,
595+ trainable = True ,
596+ )
597+ self .assertEqual (
598+ encoder .mixture_components ,
599+ 5 ,
600+ "The mixture_components parameter should be correctly set to 5." ,
601+ )
602+
603+ def test_trainable_flag (self ):
604+ """Test that setting the trainable flag correctly updates the layer's trainability."""
605+ encoder_trainable = DistributionAwareEncoder (
606+ name = "encoder_trainable_true" ,
607+ num_bins = 1000 ,
608+ epsilon = 1e-6 ,
609+ detect_periodicity = False ,
610+ handle_sparsity = True ,
611+ adaptive_binning = True ,
612+ mixture_components = 3 ,
613+ trainable = True ,
614+ )
615+ encoder_non_trainable = DistributionAwareEncoder (
616+ name = "encoder_trainable_false" ,
617+ num_bins = 1000 ,
618+ epsilon = 1e-6 ,
619+ detect_periodicity = False ,
620+ handle_sparsity = True ,
621+ adaptive_binning = True ,
622+ mixture_components = 3 ,
623+ trainable = False ,
624+ )
625+ self .assertTrue (
626+ encoder_trainable .trainable ,
627+ "Encoder should be trainable when trainable=True." ,
628+ )
629+ self .assertFalse (
630+ encoder_non_trainable .trainable ,
631+ "Encoder should not be trainable when trainable=False." ,
632+ )
633+
634+ def test_num_bins_parameter (self ):
635+ """Test that the num_bins parameter is correctly set and stored."""
636+ encoder = DistributionAwareEncoder (
637+ name = "encoder_num_bins" ,
638+ num_bins = 500 ,
639+ epsilon = 1e-6 ,
640+ detect_periodicity = False ,
641+ handle_sparsity = True ,
642+ adaptive_binning = True ,
643+ mixture_components = 3 ,
644+ trainable = True ,
645+ )
646+ self .assertEqual (
647+ encoder .num_bins , 500 , "The num_bins parameter should be set to 500."
648+ )
649+
650+
385651if __name__ == "__main__" :
386652 tf .test .main ()
0 commit comments