@@ -1981,15 +1981,27 @@ class AdvancedNumericalEmbedding(layers.Layer):
19811981
19821982 def __init__ (
19831983 self ,
1984- embedding_dim : int ,
1985- mlp_hidden_units : int ,
1986- num_bins : int ,
1987- init_min ,
1988- init_max ,
1989- dropout_rate : float = 0.0 ,
1990- use_batch_norm : bool = False ,
1984+ embedding_dim : int = 8 ,
1985+ mlp_hidden_units : int = 16 ,
1986+ num_bins : int = 10 ,
1987+ init_min : float | list [ float ] = - 3.0 ,
1988+ init_max : float | list [ float ] = 3.0 ,
1989+ dropout_rate : float = 0.1 ,
1990+ use_batch_norm : bool = True ,
19911991 ** kwargs ,
19921992 ):
1993+ """Initialize the AdvancedNumericalEmbedding layer.
1994+
1995+ Args:
1996+ embedding_dim: Dimension of the output embedding for each feature.
1997+ mlp_hidden_units: Number of hidden units in the MLP.
1998+ num_bins: Number of bins for discretization.
1999+ init_min: Minimum value(s) for initialization. Can be a single float or list of floats.
2000+ init_max: Maximum value(s) for initialization. Can be a single float or list of floats.
2001+ dropout_rate: Dropout rate for regularization.
2002+ use_batch_norm: Whether to use batch normalization.
2003+ **kwargs: Additional layer arguments.
2004+ """
19932005 super ().__init__ (** kwargs )
19942006 self .embedding_dim = embedding_dim
19952007 self .mlp_hidden_units = mlp_hidden_units
@@ -2152,16 +2164,29 @@ class GlobalAdvancedNumericalEmbedding(tf.keras.layers.Layer):
21522164
21532165 def __init__ (
21542166 self ,
2155- global_embedding_dim : int ,
2156- global_mlp_hidden_units : int ,
2157- global_num_bins : int ,
2158- global_init_min ,
2159- global_init_max ,
2160- global_dropout_rate : float ,
2161- global_use_batch_norm : bool ,
2167+ global_embedding_dim : int = 8 ,
2168+ global_mlp_hidden_units : int = 16 ,
2169+ global_num_bins : int = 10 ,
2170+ global_init_min : float | list [ float ] = - 3.0 ,
2171+ global_init_max : float | list [ float ] = 3.0 ,
2172+ global_dropout_rate : float = 0.1 ,
2173+ global_use_batch_norm : bool = True ,
21622174 global_pooling : str = "average" ,
21632175 ** kwargs ,
21642176 ):
2177+ """Initialize the GlobalAdvancedNumericalEmbedding layer.
2178+
2179+ Args:
2180+ global_embedding_dim: Dimension of the final global embedding.
2181+ global_mlp_hidden_units: Number of hidden units in the global MLP.
2182+ global_num_bins: Number of bins for discretization.
2183+ global_init_min: Minimum value(s) for initialization. Can be a single float or list of floats.
2184+ global_init_max: Maximum value(s) for initialization. Can be a single float or list of floats.
2185+ global_dropout_rate: Dropout rate for regularization.
2186+ global_use_batch_norm: Whether to use batch normalization.
2187+ global_pooling: Pooling method to use ("average" or "max").
2188+ **kwargs: Additional layer arguments.
2189+ """
21652190 super ().__init__ (** kwargs )
21662191 self .global_embedding_dim = global_embedding_dim
21672192 self .global_mlp_hidden_units = global_mlp_hidden_units
0 commit comments