1+ """Modality translation network for WiFi-DensePose system."""
2+
3+ import torch
4+ import torch .nn as nn
5+ import torch .nn .functional as F
6+ from typing import Dict , Any
7+
8+
9+ class ModalityTranslationNetwork (nn .Module ):
10+ """Neural network for translating CSI data to visual feature space."""
11+
12+ def __init__ (self , config : Dict [str , Any ]):
13+ """Initialize modality translation network.
14+
15+ Args:
16+ config: Configuration dictionary with network parameters
17+ """
18+ super ().__init__ ()
19+
20+ self .input_channels = config ['input_channels' ]
21+ self .hidden_dim = config ['hidden_dim' ]
22+ self .output_dim = config ['output_dim' ]
23+ self .num_layers = config ['num_layers' ]
24+ self .dropout_rate = config ['dropout_rate' ]
25+
26+ # Encoder: CSI -> Feature space
27+ self .encoder = self ._build_encoder ()
28+
29+ # Decoder: Feature space -> Visual-like features
30+ self .decoder = self ._build_decoder ()
31+
32+ # Initialize weights
33+ self ._initialize_weights ()
34+
35+ def _build_encoder (self ) -> nn .Module :
36+ """Build encoder network."""
37+ layers = []
38+
39+ # Initial convolution
40+ layers .append (nn .Conv2d (self .input_channels , 64 , kernel_size = 3 , padding = 1 ))
41+ layers .append (nn .BatchNorm2d (64 ))
42+ layers .append (nn .ReLU (inplace = True ))
43+ layers .append (nn .Dropout2d (self .dropout_rate ))
44+
45+ # Progressive downsampling
46+ in_channels = 64
47+ for i in range (self .num_layers - 1 ):
48+ out_channels = min (in_channels * 2 , self .hidden_dim )
49+ layers .extend ([
50+ nn .Conv2d (in_channels , out_channels , kernel_size = 3 , stride = 2 , padding = 1 ),
51+ nn .BatchNorm2d (out_channels ),
52+ nn .ReLU (inplace = True ),
53+ nn .Dropout2d (self .dropout_rate )
54+ ])
55+ in_channels = out_channels
56+
57+ return nn .Sequential (* layers )
58+
59+ def _build_decoder (self ) -> nn .Module :
60+ """Build decoder network."""
61+ layers = []
62+
63+ # Get the actual output channels from encoder (should be hidden_dim)
64+ encoder_out_channels = self .hidden_dim
65+
66+ # Progressive upsampling
67+ in_channels = encoder_out_channels
68+ for i in range (self .num_layers - 1 ):
69+ out_channels = max (in_channels // 2 , 64 )
70+ layers .extend ([
71+ nn .ConvTranspose2d (in_channels , out_channels , kernel_size = 3 , stride = 2 , padding = 1 , output_padding = 1 ),
72+ nn .BatchNorm2d (out_channels ),
73+ nn .ReLU (inplace = True ),
74+ nn .Dropout2d (self .dropout_rate )
75+ ])
76+ in_channels = out_channels
77+
78+ # Final output layer
79+ layers .append (nn .Conv2d (in_channels , self .output_dim , kernel_size = 3 , padding = 1 ))
80+ layers .append (nn .Tanh ()) # Normalize output
81+
82+ return nn .Sequential (* layers )
83+
84+ def _initialize_weights (self ):
85+ """Initialize network weights."""
86+ for m in self .modules ():
87+ if isinstance (m , (nn .Conv2d , nn .ConvTranspose2d )):
88+ nn .init .kaiming_normal_ (m .weight , mode = 'fan_out' , nonlinearity = 'relu' )
89+ if m .bias is not None :
90+ nn .init .constant_ (m .bias , 0 )
91+ elif isinstance (m , nn .BatchNorm2d ):
92+ nn .init .constant_ (m .weight , 1 )
93+ nn .init .constant_ (m .bias , 0 )
94+
95+ def forward (self , x : torch .Tensor ) -> torch .Tensor :
96+ """Forward pass through the network.
97+
98+ Args:
99+ x: Input CSI tensor of shape (batch_size, channels, height, width)
100+
101+ Returns:
102+ Translated features tensor
103+ """
104+ # Validate input shape
105+ if x .shape [1 ] != self .input_channels :
106+ raise RuntimeError (f"Expected { self .input_channels } input channels, got { x .shape [1 ]} " )
107+
108+ # Encode CSI data
109+ encoded = self .encoder (x )
110+
111+ # Decode to visual-like features
112+ decoded = self .decoder (encoded )
113+
114+ return decoded
0 commit comments