Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit 0eb968d

Browse files
committed
fix(KDP): added docstrings
1 parent 0bf0cc3 commit 0eb968d

File tree

1 file changed

+39
-14
lines changed

1 file changed

+39
-14
lines changed

kdp/custom_layers.py

Lines changed: 39 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1981,15 +1981,27 @@ class AdvancedNumericalEmbedding(layers.Layer):
19811981

19821982
def __init__(
19831983
self,
1984-
embedding_dim: int,
1985-
mlp_hidden_units: int,
1986-
num_bins: int,
1987-
init_min,
1988-
init_max,
1989-
dropout_rate: float = 0.0,
1990-
use_batch_norm: bool = False,
1984+
embedding_dim: int = 8,
1985+
mlp_hidden_units: int = 16,
1986+
num_bins: int = 10,
1987+
init_min: float | list[float] = -3.0,
1988+
init_max: float | list[float] = 3.0,
1989+
dropout_rate: float = 0.1,
1990+
use_batch_norm: bool = True,
19911991
**kwargs,
19921992
):
1993+
"""Initialize the AdvancedNumericalEmbedding layer.
1994+
1995+
Args:
1996+
embedding_dim: Dimension of the output embedding for each feature.
1997+
mlp_hidden_units: Number of hidden units in the MLP.
1998+
num_bins: Number of bins for discretization.
1999+
init_min: Minimum value(s) for initialization. Can be a single float or list of floats.
2000+
init_max: Maximum value(s) for initialization. Can be a single float or list of floats.
2001+
dropout_rate: Dropout rate for regularization.
2002+
use_batch_norm: Whether to use batch normalization.
2003+
**kwargs: Additional layer arguments.
2004+
"""
19932005
super().__init__(**kwargs)
19942006
self.embedding_dim = embedding_dim
19952007
self.mlp_hidden_units = mlp_hidden_units
@@ -2152,16 +2164,29 @@ class GlobalAdvancedNumericalEmbedding(tf.keras.layers.Layer):
21522164

21532165
def __init__(
21542166
self,
2155-
global_embedding_dim: int,
2156-
global_mlp_hidden_units: int,
2157-
global_num_bins: int,
2158-
global_init_min,
2159-
global_init_max,
2160-
global_dropout_rate: float,
2161-
global_use_batch_norm: bool,
2167+
global_embedding_dim: int = 8,
2168+
global_mlp_hidden_units: int = 16,
2169+
global_num_bins: int = 10,
2170+
global_init_min: float | list[float] = -3.0,
2171+
global_init_max: float | list[float] = 3.0,
2172+
global_dropout_rate: float = 0.1,
2173+
global_use_batch_norm: bool = True,
21622174
global_pooling: str = "average",
21632175
**kwargs,
21642176
):
2177+
"""Initialize the GlobalAdvancedNumericalEmbedding layer.
2178+
2179+
Args:
2180+
global_embedding_dim: Dimension of the final global embedding.
2181+
global_mlp_hidden_units: Number of hidden units in the global MLP.
2182+
global_num_bins: Number of bins for discretization.
2183+
global_init_min: Minimum value(s) for initialization. Can be a single float or list of floats.
2184+
global_init_max: Maximum value(s) for initialization. Can be a single float or list of floats.
2185+
global_dropout_rate: Dropout rate for regularization.
2186+
global_use_batch_norm: Whether to use batch normalization.
2187+
global_pooling: Pooling method to use ("average" or "max").
2188+
**kwargs: Additional layer arguments.
2189+
"""
21652190
super().__init__(**kwargs)
21662191
self.global_embedding_dim = global_embedding_dim
21672192
self.global_mlp_hidden_units = global_mlp_hidden_units

0 commit comments

Comments
 (0)