@@ -82,14 +82,16 @@ impl WiFiDensePoseModel {
8282 let root = vs. root ( ) ;
8383
8484 // Compute the flattened CSI input size used by the modality translator.
85- let flat_csi = ( config. window_frames
85+ let n_ant = ( config. window_frames
8686 * config. num_antennas_tx
87- * config. num_antennas_rx
88- * config. num_subcarriers ) as i64 ;
87+ * config. num_antennas_rx ) as i64 ;
88+ let n_sc = config. num_subcarriers as i64 ;
89+ let flat_csi = n_ant * n_sc;
8990
9091 let num_parts = config. num_body_parts as i64 ;
9192
92- let translator = ModalityTranslator :: new ( & root / "translator" , flat_csi) ;
93+ let translator =
94+ ModalityTranslator :: new ( & root / "translator" , flat_csi, n_ant, n_sc) ;
9395 let backbone = Backbone :: new ( & root / "backbone" , config. backbone_channels as i64 ) ;
9496 let kp_head = KeypointHead :: new (
9597 & root / "kp_head" ,
@@ -255,17 +257,154 @@ fn phase_sanitize(phase: &Tensor) -> Tensor {
255257 Tensor :: cat ( & [ zeros, diff] , 2 )
256258}
257259
260+ // ---------------------------------------------------------------------------
261+ // ruvector attention helpers
262+ // ---------------------------------------------------------------------------
263+
264+ /// Apply min-cut gated attention over the antenna-path dimension.
265+ ///
266+ /// Treats each antenna path as a "token" and subcarriers as the feature
267+ /// dimension. Uses `attn_mincut` to gate irrelevant antenna-pair correlations,
268+ /// which is equivalent to automatic antenna selection.
269+ ///
270+ /// # Arguments
271+ ///
272+ /// - `x`: CSI tensor `[B, n_ant, n_sc]` — amplitude or phase
273+ /// - `lambda`: min-cut threshold (0.3 = moderate pruning)
274+ ///
275+ /// # Returns
276+ ///
277+ /// Attended tensor `[B, n_ant, n_sc]` with irrelevant antenna paths suppressed.
278+ fn apply_antenna_attention ( x : & Tensor , lambda : f32 ) -> Tensor {
279+ let sizes = x. size ( ) ;
280+ let n_ant = sizes[ 1 ] ;
281+ let n_sc = sizes[ 2 ] ;
282+
283+ // Skip trivial cases where attention is a no-op.
284+ if n_ant <= 1 || n_sc <= 1 {
285+ return x. shallow_clone ( ) ;
286+ }
287+
288+ let b = sizes[ 0 ] as usize ;
289+ let n_ant_usize = n_ant as usize ;
290+ let n_sc_usize = n_sc as usize ;
291+
292+ let device = x. device ( ) ;
293+ let kind = x. kind ( ) ;
294+
295+ // Process each batch element independently (attn_mincut operates on 2D inputs).
296+ let mut results: Vec < Tensor > = Vec :: with_capacity ( b) ;
297+
298+ for bi in 0 ..b {
299+ // Extract [n_ant, n_sc] slice for this batch element.
300+ let xi = x. select ( 0 , bi as i64 ) ; // [n_ant, n_sc]
301+
302+ // Move to CPU and convert to f32 for the pure-Rust attention kernel.
303+ let flat: Vec < f32 > =
304+ Vec :: from ( xi. to_kind ( Kind :: Float ) . to_device ( Device :: Cpu ) . contiguous ( ) ) ;
305+
306+ // Q = K = V = the antenna features (self-attention over antenna paths).
307+ let out = attn_mincut (
308+ & flat, // q: [n_ant * n_sc]
309+ & flat, // k: [n_ant * n_sc]
310+ & flat, // v: [n_ant * n_sc]
311+ n_sc_usize, // d: feature dim = n_sc subcarriers
312+ n_ant_usize, // seq_len: number of antenna paths
313+ lambda, // lambda: min-cut threshold
314+ 1 , // tau: no temporal hysteresis (single-frame)
315+ 1e-6 , // eps: numerical epsilon
316+ ) ;
317+
318+ let attended = Tensor :: from_slice ( & out. output )
319+ . reshape ( [ n_ant, n_sc] )
320+ . to_device ( device)
321+ . to_kind ( kind) ;
322+
323+ results. push ( attended) ;
324+ }
325+
326+ Tensor :: stack ( & results, 0 ) // [B, n_ant, n_sc]
327+ }
328+
329+ /// Apply scaled dot-product attention over spatial locations.
330+ ///
331+ /// Input: `[B, C, H, W]` feature map — each spatial location (H×W) becomes a
332+ /// token; C is the feature dimension. Captures long-range spatial dependencies
333+ /// between antenna-footprint regions.
334+ ///
335+ /// Returns `[B, C, H, W]` with spatial attention applied.
336+ ///
337+ /// This function can be applied after backbone features when long-range spatial
338+ /// context is needed. It is defined here for completeness and may be called
339+ /// from head implementations or future backbone variants.
340+ #[ allow( dead_code) ]
341+ fn apply_spatial_attention ( x : & Tensor ) -> Tensor {
342+ let sizes = x. size ( ) ;
343+ let ( b, c, h, w) = ( sizes[ 0 ] , sizes[ 1 ] , sizes[ 2 ] , sizes[ 3 ] ) ;
344+ let n_spatial = ( h * w) as usize ;
345+ let d = c as usize ;
346+
347+ let device = x. device ( ) ;
348+ let kind = x. kind ( ) ;
349+
350+ let attn = ScaledDotProductAttention :: new ( d) ;
351+
352+ let mut results: Vec < Tensor > = Vec :: with_capacity ( b as usize ) ;
353+
354+ for bi in 0 ..b {
355+ // Extract [C, H*W] and transpose to [H*W, C].
356+ let xi = x. select ( 0 , bi) . reshape ( [ c, h * w] ) . transpose ( 0 , 1 ) ; // [H*W, C]
357+ let flat: Vec < f32 > =
358+ Vec :: from ( xi. to_kind ( Kind :: Float ) . to_device ( Device :: Cpu ) . contiguous ( ) ) ;
359+
360+ // Build token slices — one per spatial position.
361+ let tokens: Vec < & [ f32 ] > = ( 0 ..n_spatial)
362+ . map ( |i| & flat[ i * d..( i + 1 ) * d] )
363+ . collect ( ) ;
364+
365+ // For each spatial token as query, compute attended output.
366+ let mut out_flat = vec ! [ 0.0f32 ; n_spatial * d] ;
367+ for i in 0 ..n_spatial {
368+ let query = & flat[ i * d..( i + 1 ) * d] ;
369+ match attn. compute ( query, & tokens, & tokens) {
370+ Ok ( attended) => {
371+ out_flat[ i * d..( i + 1 ) * d] . copy_from_slice ( & attended) ;
372+ }
373+ Err ( _) => {
374+ // Fallback: identity — keep original features unchanged.
375+ out_flat[ i * d..( i + 1 ) * d] . copy_from_slice ( query) ;
376+ }
377+ }
378+ }
379+
380+ let out_tensor = Tensor :: from_slice ( & out_flat)
381+ . reshape ( [ h * w, c] )
382+ . transpose ( 0 , 1 ) // [C, H*W]
383+ . reshape ( [ c, h, w] ) // [C, H, W]
384+ . to_device ( device)
385+ . to_kind ( kind) ;
386+
387+ results. push ( out_tensor) ;
388+ }
389+
390+ Tensor :: stack ( & results, 0 ) // [B, C, H, W]
391+ }
392+
258393// ---------------------------------------------------------------------------
259394// Modality Translator
260395// ---------------------------------------------------------------------------
261396
262397/// Translates flattened (amplitude, phase) CSI vectors into a pseudo-image.
263398///
264399/// ```text
265- /// amplitude [B, flat_csi] ─► amp_fc1 ► relu ► amp_fc2 ► relu ─┐
266- /// ├─► fuse_fc ► reshape ► spatial_conv ► [B, 3, 48, 48]
267- /// phase [B, flat_csi] ─► ph_fc1 ► relu ► ph_fc2 ► relu ─┘
400+ /// amplitude [B, flat_csi] ─► attn_mincut ─► amp_fc1 ► relu ► amp_fc2 ► relu ─┐
401+ /// ├─► fuse_fc ► reshape ► spatial_conv ► [B, 3, 48, 48]
402+ /// phase [B, flat_csi] ─► attn_mincut ─► ph_fc1 ► relu ► ph_fc2 ► relu ─┘
268403/// ```
404+ ///
405+ /// The `attn_mincut` step performs self-attention over the antenna-path dimension
406+ /// (`n_ant` tokens, each with `n_sc` subcarrier features) to gate out irrelevant
407+ /// antenna-pair correlations before the FC fusion layers.
269408struct ModalityTranslator {
270409 amp_fc1 : nn:: Linear ,
271410 amp_fc2 : nn:: Linear ,
@@ -276,10 +415,14 @@ struct ModalityTranslator {
276415 sp_conv1 : nn:: Conv2D ,
277416 sp_bn1 : nn:: BatchNorm ,
278417 sp_conv2 : nn:: Conv2D ,
418+ /// Number of antenna paths: T * n_tx * n_rx (used for attention reshape).
419+ n_ant : i64 ,
420+ /// Number of subcarriers per antenna path (used for attention reshape).
421+ n_sc : i64 ,
279422}
280423
281424impl ModalityTranslator {
282- fn new ( vs : nn:: Path , flat_csi : i64 ) -> Self {
425+ fn new ( vs : nn:: Path , flat_csi : i64 , n_ant : i64 , n_sc : i64 ) -> Self {
283426 let amp_fc1 = nn:: linear ( & vs / "amp_fc1" , flat_csi, 512 , Default :: default ( ) ) ;
284427 let amp_fc2 = nn:: linear ( & vs / "amp_fc2" , 512 , 256 , Default :: default ( ) ) ;
285428 let ph_fc1 = nn:: linear ( & vs / "ph_fc1" , flat_csi, 512 , Default :: default ( ) ) ;
@@ -320,22 +463,38 @@ impl ModalityTranslator {
320463 sp_conv1,
321464 sp_bn1,
322465 sp_conv2,
466+ n_ant,
467+ n_sc,
323468 }
324469 }
325470
326471 fn forward_t ( & self , amp : & Tensor , ph : & Tensor , train : bool ) -> Tensor {
327472 let b = amp. size ( ) [ 0 ] ;
328473
329- // Amplitude branch
330- let a = amp
474+ // === ruvector-attn-mincut: gate irrelevant antenna paths ===
475+ //
476+ // Reshape from [B, flat_csi] to [B, n_ant, n_sc], apply min-cut
477+ // self-attention over the antenna-path dimension (antenna paths are
478+ // "tokens", subcarrier responses are "features"), then flatten back.
479+ let amp_3d = amp. reshape ( [ b, self . n_ant , self . n_sc ] ) ;
480+ let ph_3d = ph. reshape ( [ b, self . n_ant , self . n_sc ] ) ;
481+
482+ let amp_attended = apply_antenna_attention ( & amp_3d, 0.3 ) ;
483+ let ph_attended = apply_antenna_attention ( & ph_3d, 0.3 ) ;
484+
485+ let amp_flat = amp_attended. reshape ( [ b, -1 ] ) ; // [B, flat_csi]
486+ let ph_flat = ph_attended. reshape ( [ b, -1 ] ) ; // [B, flat_csi]
487+
488+ // Amplitude branch (uses attended input)
489+ let a = amp_flat
331490 . apply ( & self . amp_fc1 )
332491 . relu ( )
333492 . dropout ( 0.2 , train)
334493 . apply ( & self . amp_fc2 )
335494 . relu ( ) ;
336495
337- // Phase branch
338- let p = ph
496+ // Phase branch (uses attended input)
497+ let p = ph_flat
339498 . apply ( & self . ph_fc1 )
340499 . relu ( )
341500 . dropout ( 0.2 , train)
0 commit comments