Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit ec98e40

Browse files
committed
feat(rust): Add wifi-densepose-train crate with full training pipeline
Implements the training infrastructure described in ADR-015: - config.rs: TrainingConfig with all hyperparams (batch size, LR, loss weights, subcarrier interp method, validation split) - dataset.rs: MmFiDataset (real MM-Fi .npy loader) + SyntheticDataset (deterministic LCG, seed=42, proof/testing only — never production) - subcarrier.rs: Linear/cubic interpolation 114→56 subcarriers - error.rs: Typed errors (DataNotFound, InvalidFormat, IoError) - losses.rs: Keypoint heatmap (MSE), DensePose (CE + Smooth L1), teacher-student transfer (MSE), Gaussian heatmap generation - metrics.rs: [email protected], OKS with Hungarian min-cut bipartite assignment via petgraph (optimal multi-person keypoint matching) - model.rs: WiFiDensePoseModel end-to-end with tch-rs (PyTorch bindings) - trainer.rs: Full training loop, LR scheduling, gradient clipping, early stopping, CSV logging, best-checkpoint saving - proof.rs: Deterministic training proof (SHA-256 trust kill switch) No random data in production paths. SyntheticDataset uses deterministic LCG (a=1664525, c=1013904223) — same seed always produces same output. https://claude.ai/code/session_01BSBAQJ34SLkiJy4A8SoiL4
1 parent 5dc2f66 commit ec98e40

11 files changed

Lines changed: 3618 additions & 0 deletions

File tree

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
[package]
2+
name = "wifi-densepose-train"
3+
version = "0.1.0"
4+
edition = "2021"
5+
authors = ["WiFi-DensePose Contributors"]
6+
license = "MIT OR Apache-2.0"
7+
description = "Training pipeline for WiFi-DensePose pose estimation"
8+
keywords = ["wifi", "training", "pose-estimation", "deep-learning"]
9+
10+
[[bin]]
11+
name = "train"
12+
path = "src/bin/train.rs"
13+
14+
[[bin]]
15+
name = "verify-training"
16+
path = "src/bin/verify_training.rs"
17+
18+
[features]
19+
default = ["tch-backend"]
20+
tch-backend = ["tch"]
21+
cuda = ["tch-backend"]
22+
23+
[dependencies]
24+
# Internal crates
25+
wifi-densepose-signal = { path = "../wifi-densepose-signal" }
26+
wifi-densepose-nn = { path = "../wifi-densepose-nn", default-features = false }
27+
28+
# Core
29+
thiserror = "1.0"
30+
anyhow = "1.0"
31+
serde = { version = "1.0", features = ["derive"] }
32+
serde_json = "1.0"
33+
34+
# Tensor / math
35+
ndarray = { version = "0.15", features = ["serde"] }
36+
ndarray-linalg = { version = "0.16", features = ["openblas-static"] }
37+
num-complex = "0.4"
38+
num-traits = "0.2"
39+
40+
# PyTorch bindings (training)
41+
tch = { version = "0.14", optional = true }
42+
43+
# Graph algorithms (min-cut for optimal keypoint assignment)
44+
petgraph = "0.6"
45+
46+
# Data loading
47+
ndarray-npy = "0.8"
48+
memmap2 = "0.9"
49+
walkdir = "2.4"
50+
51+
# Serialization
52+
csv = "1.3"
53+
toml = "0.8"
54+
55+
# Logging / progress
56+
tracing = "0.1"
57+
tracing-subscriber = { version = "0.3", features = ["env-filter"] }
58+
indicatif = "0.17"
59+
60+
# Async
61+
tokio = { version = "1.35", features = ["rt", "rt-multi-thread", "macros", "fs"] }
62+
63+
# Crypto (for proof hash)
64+
sha2 = "0.10"
65+
66+
# CLI
67+
clap = { version = "4.4", features = ["derive"] }
68+
69+
# Time
70+
chrono = { version = "0.4", features = ["serde"] }
71+
72+
[dev-dependencies]
73+
criterion = { version = "0.5", features = ["html_reports"] }
74+
proptest = "1.4"
75+
tempfile = "3.10"
76+
approx = "0.5"
77+
78+
[[bench]]
79+
name = "training_bench"
80+
harness = false

0 commit comments

Comments
 (0)