Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit 45f0304

Browse files
committed
fix: Review fixes for end-to-end training pipeline
- Snapshot best-epoch weights during training and restore before checkpoint/RVF export (prevents exporting overfit final-epoch params) - Add CsiToPoseTransformer::zeros() for fast zero-init when weights will be overwritten, avoiding wasteful Xavier init during gradient estimation (~2*param_count transformer constructions per batch) - Deduplicate synthetic data generation in main.rs training mode Co-Authored-By: claude-flow <[email protected]>
1 parent 4cabffa commit 45f0304

3 files changed

Lines changed: 57 additions & 52 deletions

File tree

rust-port/wifi-densepose-rs/crates/wifi-densepose-sensing-server/src/graph_transformer.rs

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -452,6 +452,23 @@ impl CsiToPoseTransformer {
452452
config,
453453
}
454454
}
455+
/// Construct with zero-initialized weights (faster than Xavier init).
456+
/// Use with `unflatten_weights()` when you plan to overwrite all weights.
457+
pub fn zeros(config: TransformerConfig) -> Self {
458+
let d = config.d_model;
459+
let bg = BodyGraph::new();
460+
let kq = vec![vec![0.0f32; d]; config.n_keypoints];
461+
Self {
462+
csi_embed: Linear::zeros(config.n_subcarriers, d),
463+
keypoint_queries: kq,
464+
cross_attn: CrossAttention::new(d, config.n_heads), // small; kept for correct structure
465+
gnn: GnnStack::new(d, d, config.n_gnn_layers, &bg),
466+
xyz_head: Linear::zeros(d, 3),
467+
conf_head: Linear::zeros(d, 1),
468+
config,
469+
}
470+
}
471+
455472
/// csi_features [n_antenna_pairs, n_subcarriers] -> PoseOutput with 17 keypoints.
456473
pub fn forward(&self, csi_features: &[Vec<f32>]) -> PoseOutput {
457474
let embedded: Vec<Vec<f32>> = csi_features.iter()

rust-port/wifi-densepose-rs/crates/wifi-densepose-sensing-server/src/main.rs

Lines changed: 28 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -1551,61 +1551,42 @@ async fn main() {
15511551
..Default::default()
15521552
});
15531553

1554-
// Load samples
1554+
// Generate synthetic training data (50 samples with deterministic CSI + keypoints)
1555+
let generate_synthetic = || -> Vec<dataset::TrainingSample> {
1556+
(0..50).map(|i| {
1557+
let csi: Vec<Vec<f32>> = (0..4).map(|a| {
1558+
(0..56).map(|s| ((i * 7 + a * 13 + s) as f32 * 0.31).sin() * 0.5).collect()
1559+
}).collect();
1560+
let mut kps = [(0.0f32, 0.0f32, 1.0f32); 17];
1561+
for (k, kp) in kps.iter_mut().enumerate() {
1562+
kp.0 = (k as f32 * 0.1 + i as f32 * 0.02).sin() * 100.0 + 320.0;
1563+
kp.1 = (k as f32 * 0.15 + i as f32 * 0.03).cos() * 80.0 + 240.0;
1564+
}
1565+
dataset::TrainingSample {
1566+
csi_window: csi,
1567+
pose_label: dataset::PoseLabel {
1568+
keypoints: kps,
1569+
body_parts: Vec::new(),
1570+
confidence: 1.0,
1571+
},
1572+
source: "synthetic",
1573+
}
1574+
}).collect()
1575+
};
1576+
1577+
// Load samples (fall back to synthetic if dataset missing/empty)
15551578
let samples = match pipeline.load() {
15561579
Ok(s) if !s.is_empty() => {
15571580
eprintln!("Loaded {} samples from {}", s.len(), ds_path.display());
15581581
s
15591582
}
15601583
Ok(_) => {
1561-
eprintln!("No samples found at {}. Generating synthetic training data...", ds_path.display());
1562-
// Generate synthetic samples for testing the pipeline
1563-
let mut synth = Vec::new();
1564-
for i in 0..50 {
1565-
let csi: Vec<Vec<f32>> = (0..4).map(|a| {
1566-
(0..56).map(|s| ((i * 7 + a * 13 + s) as f32 * 0.31).sin() * 0.5).collect()
1567-
}).collect();
1568-
let mut kps = [(0.0f32, 0.0f32, 1.0f32); 17];
1569-
for (k, kp) in kps.iter_mut().enumerate() {
1570-
kp.0 = (k as f32 * 0.1 + i as f32 * 0.02).sin() * 100.0 + 320.0;
1571-
kp.1 = (k as f32 * 0.15 + i as f32 * 0.03).cos() * 80.0 + 240.0;
1572-
}
1573-
synth.push(dataset::TrainingSample {
1574-
csi_window: csi,
1575-
pose_label: dataset::PoseLabel {
1576-
keypoints: kps,
1577-
body_parts: Vec::new(),
1578-
confidence: 1.0,
1579-
},
1580-
source: "synthetic",
1581-
});
1582-
}
1583-
synth
1584+
eprintln!("No samples found at {}. Using synthetic data.", ds_path.display());
1585+
generate_synthetic()
15841586
}
15851587
Err(e) => {
1586-
eprintln!("Failed to load dataset: {e}");
1587-
eprintln!("Generating synthetic training data...");
1588-
let mut synth = Vec::new();
1589-
for i in 0..50 {
1590-
let csi: Vec<Vec<f32>> = (0..4).map(|a| {
1591-
(0..56).map(|s| ((i * 7 + a * 13 + s) as f32 * 0.31).sin() * 0.5).collect()
1592-
}).collect();
1593-
let mut kps = [(0.0f32, 0.0f32, 1.0f32); 17];
1594-
for (k, kp) in kps.iter_mut().enumerate() {
1595-
kp.0 = (k as f32 * 0.1 + i as f32 * 0.02).sin() * 100.0 + 320.0;
1596-
kp.1 = (k as f32 * 0.15 + i as f32 * 0.03).cos() * 80.0 + 240.0;
1597-
}
1598-
synth.push(dataset::TrainingSample {
1599-
csi_window: csi,
1600-
pose_label: dataset::PoseLabel {
1601-
keypoints: kps,
1602-
body_parts: Vec::new(),
1603-
confidence: 1.0,
1604-
},
1605-
source: "synthetic",
1606-
});
1607-
}
1608-
synth
1588+
eprintln!("Failed to load dataset: {e}. Using synthetic data.");
1589+
generate_synthetic()
16091590
}
16101591
};
16111592

rust-port/wifi-densepose-rs/crates/wifi-densepose-sensing-server/src/trainer.rs

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -398,6 +398,8 @@ pub struct Trainer {
398398
best_val_loss: f32,
399399
best_epoch: usize,
400400
epochs_without_improvement: usize,
401+
/// Snapshot of params at the best validation loss epoch.
402+
best_params: Vec<f32>,
401403
/// When set, predict_keypoints delegates to the transformer's forward().
402404
transformer: Option<CsiToPoseTransformer>,
403405
/// Transformer config (needed for unflatten during gradient estimation).
@@ -411,10 +413,11 @@ impl Trainer {
411413
config.warmup_epochs, config.lr, config.min_lr, config.epochs,
412414
);
413415
let params: Vec<f32> = (0..64).map(|i| (i as f32 * 0.7 + 0.3).sin() * 0.1).collect();
416+
let best_params = params.clone();
414417
Self {
415418
config, optimizer, scheduler, params, history: Vec::new(),
416419
best_val_loss: f32::MAX, best_epoch: 0, epochs_without_improvement: 0,
417-
transformer: None, transformer_config: None,
420+
best_params, transformer: None, transformer_config: None,
418421
}
419422
}
420423

@@ -427,10 +430,11 @@ impl Trainer {
427430
config.warmup_epochs, config.lr, config.min_lr, config.epochs,
428431
);
429432
let tc = transformer.config().clone();
433+
let best_params = params.clone();
430434
Self {
431435
config, optimizer, scheduler, params, history: Vec::new(),
432436
best_val_loss: f32::MAX, best_epoch: 0, epochs_without_improvement: 0,
433-
transformer: Some(transformer), transformer_config: Some(tc),
437+
best_params, transformer: Some(transformer), transformer_config: Some(tc),
434438
}
435439
}
436440

@@ -523,12 +527,15 @@ impl Trainer {
523527
if val_loss < self.best_val_loss {
524528
self.best_val_loss = val_loss;
525529
self.best_epoch = stats.epoch;
530+
self.best_params = self.params.clone();
526531
self.epochs_without_improvement = 0;
527532
} else {
528533
self.epochs_without_improvement += 1;
529534
}
530535
if self.should_stop() { break; }
531536
}
537+
// Restore best-epoch params for checkpoint and downstream use
538+
self.params = self.best_params.clone();
532539
let best = self.best_metrics().cloned().unwrap_or(EpochStats {
533540
epoch: 0, train_loss: f32::MAX, val_loss: f32::MAX, pck_02: 0.0,
534541
oks_map: 0.0, lr: self.config.lr, loss_components: LossComponents::default(),
@@ -625,12 +632,12 @@ impl Trainer {
625632
}).collect()
626633
}
627634

628-
/// Predict keypoints using the graph transformer. Creates a temporary
629-
/// transformer with the given params and runs forward().
635+
/// Predict keypoints using the graph transformer. Uses zero-init
636+
/// constructor (fast) then overwrites all weights from params.
630637
fn predict_keypoints_transformer(
631638
params: &[f32], sample: &TrainingSample, tc: &TransformerConfig,
632639
) -> Vec<(f32, f32, f32)> {
633-
let mut t = CsiToPoseTransformer::new(tc.clone());
640+
let mut t = CsiToPoseTransformer::zeros(tc.clone());
634641
if t.unflatten_weights(params).is_err() {
635642
return Self::predict_keypoints(params, sample);
636643
}

0 commit comments

Comments
 (0)