# Training Pipeline Domain Model The Training & ML Pipeline is the subsystem of WiFi-DensePose that turns raw public CSI datasets into a trained pose estimation model and its downstream derivatives: contrastive embeddings, domain-generalized weights, and deterministic proof bundles. It is the bridge between research data and deployable inference. This document defines the system using [Domain-Driven Design](https://martinfowler.com/bliki/DomainDrivenDesign.html) (DDD): bounded contexts that own their data and rules, aggregate roots that enforce invariants, value objects that carry meaning, and domain events that connect everything. The goal is to make the pipeline's structure match the physics and mathematics it implements -- so that anyone reading the code (or an AI agent modifying it) understands *why* each piece exists, not just *what* it does. **Bounded Contexts:** | # | Context | Responsibility | Key ADRs | Code | |---|---------|----------------|----------|------| | 1 | [Dataset Management](#1-dataset-management-context) | Load, validate, normalize, and preprocess training data from MM-Fi and Wi-Pose | [ADR-015](../adr/ADR-015-public-dataset-training-strategy.md) | `train/src/dataset.rs`, `train/src/subcarrier.rs` | | 2 | [Model Architecture](#2-model-architecture-context) | Define the neural network, forward pass, attention mechanisms, and spatial decoding | [ADR-016](../adr/ADR-016-ruvector-integration.md), [ADR-020](../adr/ADR-020-rust-ruvector-ai-model-migration.md) | `train/src/model.rs`, `train/src/graph_transformer.rs` | | 3 | [Training Orchestration](#3-training-orchestration-context) | Run the training loop, compute composite loss, checkpoint, and verify deterministic proofs | [ADR-015](../adr/ADR-015-public-dataset-training-strategy.md), [ADR-016](../adr/ADR-016-ruvector-integration.md) | `train/src/trainer.rs`, `train/src/losses.rs`, `train/src/metrics.rs`, `train/src/proof.rs` | | 4 | [Embedding & Transfer](#4-embedding--transfer-context) | Produce AETHER contrastive embeddings, MERIDIAN domain-generalized features, and LoRA adapters | [ADR-024](../adr/ADR-024-contrastive-csi-embedding-model.md), [ADR-027](../adr/ADR-027-cross-environment-domain-generalization.md) | `train/src/embedding.rs`, `train/src/domain.rs`, `train/src/sona.rs` | All code paths shown are relative to `rust-port/wifi-densepose-rs/crates/wifi-densepose-` unless otherwise noted. --- ## Domain-Driven Design Specification ### Ubiquitous Language | Term | Definition | |------|------------| | **Training Run** | A complete training session: configuration, epoch loop, checkpoint history, and final model weights | | **Epoch** | One full pass through the training dataset; produces train loss and validation metrics | | **Checkpoint** | A snapshot of model weights at a given epoch, identified by SHA-256 hash and validation PCK | | **CSI Sample** | A single observation: amplitude + phase tensors, ground-truth keypoints, and visibility flags | | **Subcarrier Interpolation** | Resampling CSI from source subcarrier count to the canonical 56 (114->56 for MM-Fi, 30->56 for Wi-Pose) | | **Teacher-Student** | Training regime where a camera-based RGB model generates pseudo-labels; at inference the camera is removed | | **Pseudo-Label** | DensePose UV surface coordinates generated by Detectron2 from paired RGB frames | | **PCK@0.2** | Percentage of Correct Keypoints within 20% of torso diameter; primary accuracy metric | | **OKS** | Object Keypoint Similarity; per-keypoint Gaussian-weighted distance used in COCO evaluation | | **MPJPE** | Mean Per Joint Position Error in millimeters; 3D accuracy metric | | **Hungarian Assignment** | Bipartite matching of predicted persons to ground-truth using min-cost assignment | | **Dynamic Min-Cut** | Subpolynomial O(n^1.5 log n) person-to-GT assignment maintained across frames | | **Compressed CSI Buffer** | Tiered-quantization temporal window: hot frames at 8-bit, warm at 5/7-bit, cold at 3-bit | | **Proof Verification** | Deterministic check: fixed seed -> N training steps -> loss decreases AND SHA-256 hash matches | | **AETHER Embedding** | 128-dim L2-normalized contrastive vector from the CsiToPoseTransformer backbone | | **InfoNCE Loss** | Contrastive loss that pushes same-identity embeddings together and different-identity apart | | **HNSW Index** | Hierarchical Navigable Small World graph for approximate nearest-neighbor embedding search | | **Domain Factorizer** | Splits latent features into pose-invariant (h_pose) and environment-specific (h_env) components | | **Gradient Reversal Layer** | Identity in forward pass; multiplies gradient by -lambda in backward pass to force domain invariance | | **GRL Lambda** | Adversarial weight annealed from 0.0 to 1.0 over the first 20 epochs | | **FiLM Conditioning** | Feature-wise Linear Modulation: gamma * features + beta, conditioned on geometry encoding | | **Hardware Normalizer** | Resamples CSI from any chipset to canonical 56 subcarriers with z-score amplitude normalization | | **LoRA Adapter** | Low-Rank Adaptation weights (rank r, alpha) for few-shot environment-specific fine-tuning | | **Rapid Adaptation** | 10-second unlabeled calibration producing a per-room LoRA adapter via contrastive test-time training | --- ## Bounded Contexts ### 1. Dataset Management Context **Responsibility:** Load raw CSI data from public datasets (MM-Fi, Wi-Pose), validate structural invariants, resample subcarriers to the canonical 56, apply phase sanitization, and present typed samples to the training loop. Memory efficiency via tiered temporal compression. ``` +----------------------------------------------------------+ | Dataset Management Context | +----------------------------------------------------------+ | | | +---------------+ +---------------+ | | | MM-Fi Loader | | Wi-Pose | | | | (.npy files, | | Loader | | | | 114 sub, | | (.mat files, | | | | 40 subjects)| | 30 sub, | | | +-------+-------+ | 12 subjects)| | | | +-------+-------+ | | | | | | +--------+-----------+ | | v | | +----------------+ | | | Subcarrier | | | | Interpolator | | | | (114->56 or | | | | 30->56) | | | +--------+-------+ | | v | | +----------------+ | | | Phase | | | | Sanitizer | | | | (SOTA algs | | | | from signal) | | | +--------+-------+ | | v | | +----------------+ | | | Compressed CSI |--> CsiSample | | | Buffer | | | | (tiered quant) | | | +----------------+ | | | +----------------------------------------------------------+ ``` **Aggregates:** - `MmFiDataset` (Aggregate Root) -- Manages the MM-Fi data lifecycle - `WiPoseDataset` (Aggregate Root) -- Manages the Wi-Pose data lifecycle **Value Objects:** - `CsiSample` -- Single observation with amplitude, phase, keypoints, visibility - `SubcarrierConfig` -- Source count, target count, interpolation method - `DatasetSplit` -- Train / Validation / Test subject partitioning - `CompressedCsiBuffer` -- Tiered temporal window backed by `TemporalTensorCompressor` **Domain Services:** - `SubcarrierInterpolationService` -- Resamples subcarriers via sparse least-squares or linear fallback - `PhaseSanitizationService` -- Applies SpotFi / MUSIC phase correction from `wifi-densepose-signal` - `TeacherLabelService` -- Runs Detectron2 on paired RGB frames to produce DensePose UV pseudo-labels - `HardwareNormalizerService` -- Z-score normalization + chipset-invariant phase sanitization **RuVector Integration:** - `ruvector-solver` -> `NeumannSolver` for sparse O(sqrt(n)) subcarrier interpolation (114->56) - `ruvector-temporal-tensor` -> `TemporalTensorCompressor` for 50-75% memory reduction in CSI windows --- ### 2. Model Architecture Context **Responsibility:** Define the WiFiDensePoseModel: CSI embedding, cross-attention between keypoint queries and CSI features, GNN message passing, attention-gated modality fusion, and spatial decoding heads for keypoints and DensePose UV. ``` +----------------------------------------------------------+ | Model Architecture Context | +----------------------------------------------------------+ | | | +---------------+ +---------------+ | | | CSI Embed | | Keypoint | | | | (Linear | | Queries | | | | 56 -> d) | | (17 learned | | | +-------+-------+ | embeddings) | | | | +-------+-------+ | | | | | | +--------+-----------+ | | v | | +----------------+ | | | Cross-Attention| | | | (Q=queries, | | | | K,V=csi) | | | +--------+-------+ | | v | | +----------------+ | | | GNN Stack | | | | (2-layer GCN | | | | skeleton | | | | adjacency) | | | +--------+-------+ | | v | | body_part_features [17 x d_model] | | | | | +-------+--------+--------+ | | v v v v | | +----------+ +------+ +-----+ +-------+ | | | Modality | | xyz | | UV | |Spatial| | | | Transl. | | Head | | Head| |Attn | | | | (attn | | | | | |Decoder| | | | mincut) | | | | | | | | | +----------+ +------+ +-----+ +-------+ | | | +----------------------------------------------------------+ ``` **Aggregates:** - `WiFiDensePoseModel` (Aggregate Root) -- The complete model graph **Entities:** - `ModalityTranslator` -- Attention-gated CSI fusion using min-cut - `CsiToPoseTransformer` -- Cross-attention + GNN backbone - `KeypointHead` -- Regresses 17 x (x, y, z, confidence) from body_part_features - `DensePoseHead` -- Predicts body part labels and UV surface coordinates **Value Objects:** - `ModelConfig` -- Architecture hyperparameters (d_model, n_heads, n_gnn_layers) - `AttentionOutput` -- Attended values + gating result from min-cut attention - `BodyPartFeatures` -- [17 x d_model] intermediate representation **Domain Services:** - `AttentionGatingService` -- Applies `attn_mincut` to prune irrelevant antenna paths - `SpatialDecodingService` -- Graph-based spatial attention among feature map locations **RuVector Integration:** - `ruvector-attn-mincut` -> `attn_mincut` for antenna-path gating in ModalityTranslator - `ruvector-attention` -> `ScaledDotProductAttention` for spatial decoder long-range dependencies --- ### 3. Training Orchestration Context **Responsibility:** Run the training loop across epochs, compute the composite loss (keypoint MSE + DensePose part CE + UV Smooth L1 + transfer MSE), evaluate validation metrics (PCK@0.2, OKS, MPJPE), manage checkpoints, and verify deterministic proof correctness. ``` +----------------------------------------------------------+ | Training Orchestration Context | +----------------------------------------------------------+ | | | +---------------+ +---------------+ | | | Training Loop | | Loss Computer | | | | (epoch iter, | | (composite: | | | | batch fwd/ | | kp_mse + | | | | bwd, optim) | | part_ce + | | | +-------+-------+ | uv_l1 + | | | | | transfer) | | | | +-------+-------+ | | +--------+-----------+ | | v | | +----------------+ | | | Metric | | | | Evaluator | | | | (PCK, OKS, | | | | MPJPE, | | | | Hungarian) | | | +--------+-------+ | | v | | +-------------+-------------+ | | v v | | +----------------+ +----------------+ | | | Checkpoint | | Proof Verifier | | | | Manager | | (fixed seed, | | | | (best-by-PCK, | | 50 steps, | | | | SHA-256 hash) | | loss + hash) | | | +----------------+ +----------------+ | | | +----------------------------------------------------------+ ``` **Aggregates:** - `TrainingRun` (Aggregate Root) -- The complete training session **Entities:** - `CheckpointManager` -- Persists and selects model snapshots - `ProofVerifier` -- Deterministic verification against stored hashes **Value Objects:** - `TrainingConfig` -- Epochs, batch_size, learning_rate, loss_weights, optimizer params - `Checkpoint` -- Epoch number, model weights SHA-256, validation PCK at that epoch - `LossWeights` -- Relative weights for each loss component - `CompositeTrainingLoss` -- Combined scalar loss with per-component breakdown - `OksScore` -- Per-keypoint Object Keypoint Similarity with sigma values - `PckScore` -- Percentage of Correct Keypoints at threshold 0.2 - `MpjpeScore` -- Mean Per Joint Position Error in millimeters - `ProofResult` -- Seed, steps, loss_decreased flag, hash_matches flag **Domain Services:** - `LossComputationService` -- Computes composite loss from model outputs and ground truth - `MetricEvaluationService` -- Computes PCK, OKS, MPJPE over validation set - `HungarianAssignmentService` -- Bipartite matching for multi-person evaluation - `DynamicPersonMatcherService` -- Frame-persistent assignment via `ruvector-mincut` - `ProofVerificationService` -- Fixed-seed training + SHA-256 verification **RuVector Integration:** - `ruvector-mincut` -> `DynamicMinCut` for O(n^1.5 log n) multi-person assignment in metrics - Original `hungarian_assignment` kept for single-frame static matching in proof verification --- ### 4. Embedding & Transfer Context **Responsibility:** Produce AETHER contrastive embeddings from the model backbone, train domain-adversarial features via MERIDIAN, manage the HNSW embedding index for re-ID and fingerprinting, and generate LoRA adapters for few-shot environment adaptation. ``` +----------------------------------------------------------+ | Embedding & Transfer Context | +----------------------------------------------------------+ | | | body_part_features [17 x d_model] | | | | | +--------+-----------+ | | v v | | +---------------+ +---------------+ | | | AETHER | | MERIDIAN | | | | Projection | | Domain | | | | Head | | Factorizer | | | | (MeanPool -> | | (PoseEncoder | | | | fc -> 128d) | | + EnvEncoder)| | | +-------+-------+ +-------+-------+ | | | | | | v v | | +---------------+ +---------------+ | | | InfoNCE Loss | | Gradient | | | | + Hard Neg | | Reversal | | | | Mining (HNSW) | | Layer (GRL) | | | +-------+-------+ +-------+-------+ | | | | | | v v | | +---------------+ +---------------+ | | | Embedding | | Geometry | | | | Index (HNSW) | | Encoder + | | | | (fingerprint | | FiLM Cond. | | | | store) | | (zero-shot) | | | +---------------+ +-------+-------+ | | | | | v | | +---------------+ | | | Rapid Adapt. | | | | (LoRA + TTT, | | | | 10-sec cal.) | | | +---------------+ | | | +----------------------------------------------------------+ ``` **Aggregates:** - `EmbeddingIndex` (Aggregate Root) -- HNSW-indexed store of AETHER fingerprints - `DomainAdaptationState` (Aggregate Root) -- Tracks GRL lambda, domain classifier accuracy, factorization quality **Entities:** - `ProjectionHead` -- MLP mapping body_part_features to 128-dim embedding space - `DomainFactorizer` -- Splits features into h_pose and h_env - `DomainClassifier` -- Classifies domain from h_pose (trained adversarially via GRL) - `GeometryEncoder` -- Fourier positional encoding + DeepSets for AP positions - `LoraAdapter` -- Low-rank adaptation weights for environment-specific fine-tuning **Value Objects:** - `AetherEmbedding` -- 128-dim L2-normalized contrastive vector - `FingerprintType` -- ReIdentification / RoomFingerprint / PersonFingerprint - `DomainLabel` -- Environment identifier for adversarial training - `GrlSchedule` -- Lambda annealing parameters (max_lambda, warmup_epochs) - `GeometryInput` -- AP positions in meters relative to room origin - `FilmParameters` -- Gamma (scale) and beta (shift) vectors from geometry conditioning - `LoraConfig` -- Rank, alpha, target layers - `AdaptationLoss` -- ContrastiveTTT / EntropyMin / Combined **Domain Services:** - `ContrastiveLossService` -- Computes InfoNCE loss with temperature scaling - `HardNegativeMiningService` -- HNSW k-NN search for difficult negative pairs - `DomainAdversarialService` -- Manages GRL annealing and domain classification - `GeometryConditioningService` -- Encodes AP layout and produces FiLM parameters - `VirtualDomainAugmentationService` -- Generates synthetic environment shifts for training diversity - `RapidAdaptationService` -- Produces LoRA adapter from 10-second unlabeled calibration --- ## Core Domain Entities ### TrainingRun (Aggregate Root) ```rust pub struct TrainingRun { /// Unique run identifier pub id: TrainingRunId, /// Full training configuration pub config: TrainingConfig, /// Datasets loaded for this run pub datasets: Vec, /// Ordered history of per-epoch metrics pub epoch_history: Vec, /// Best checkpoint by validation PCK pub best_checkpoint: Option, /// Current epoch (0-indexed) pub current_epoch: usize, /// Run status pub status: RunStatus, /// Proof verification result (if run) pub proof_result: Option, } pub enum RunStatus { Initializing, Training, Completed, Failed { reason: String }, ProofVerified, } ``` **Invariants:** - Must have at least 1 dataset loaded before transitioning to `Training` - `best_checkpoint` is updated only when a new epoch's validation PCK exceeds all prior epochs - `proof_result` can only be set once and is immutable after verification ### MmFiDataset (Aggregate Root) ```rust pub struct MmFiDataset { /// Root directory containing .npy files pub data_root: PathBuf, /// Subject IDs in this split pub subject_ids: Vec, /// Number of action classes pub n_actions: usize, // 27 /// Source subcarrier count pub source_subcarriers: usize, // 114 /// Target subcarrier count after interpolation pub target_subcarriers: usize, // 56 /// Antenna configuration: 1 TX x 3 RX pub antenna_pairs: usize, // 3 /// Sampling rate in Hz pub sample_rate_hz: f32, // 100.0 /// Temporal window size (frames per sample) pub window_frames: usize, // 10 /// Compressed buffer for memory-efficient storage pub buffer: CompressedCsiBuffer, /// Total loaded samples pub n_samples: usize, } ``` ### WiPoseDataset (Aggregate Root) ```rust pub struct WiPoseDataset { /// Root directory containing .mat files pub data_root: PathBuf, /// Subject IDs in this split pub subject_ids: Vec, /// Source subcarrier count pub source_subcarriers: usize, // 30 /// Target subcarrier count after zero-padding pub target_subcarriers: usize, // 56 /// Antenna configuration: 3 TX x 3 RX pub antenna_pairs: usize, // 9 /// Keypoint count (18 AlphaPose, mapped to 17 COCO) pub source_keypoints: usize, // 18 /// Compressed buffer pub buffer: CompressedCsiBuffer, /// Total loaded samples pub n_samples: usize, } ``` ### WiFiDensePoseModel (Aggregate Root) ```rust pub struct WiFiDensePoseModel { /// CSI embedding layer: Linear(56, d_model) pub csi_embed: Linear, /// Learned keypoint query embeddings [17 x d_model] pub keypoint_queries: Tensor, /// Cross-attention: Q=queries, K,V=csi_embed pub cross_attention: MultiHeadAttention, /// GNN message passing on skeleton graph pub gnn_stack: GnnStack, /// Modality translator with attention-gated fusion pub modality_translator: ModalityTranslator, /// Keypoint regression head pub keypoint_head: KeypointHead, /// DensePose UV prediction head pub densepose_head: DensePoseHead, /// Spatial attention decoder pub spatial_decoder: SpatialAttentionDecoder, /// Model dimensionality pub d_model: usize, // 64 } ``` ### EmbeddingIndex (Aggregate Root) ```rust pub struct EmbeddingIndex { /// HNSW graph for approximate nearest-neighbor search pub hnsw: HnswIndex, /// Stored embeddings with metadata pub entries: Vec, /// Embedding dimensionality pub dim: usize, // 128 /// Number of indexed embeddings pub count: usize, /// HNSW construction parameters pub ef_construction: usize, // 200 pub m_connections: usize, // 16 } pub struct EmbeddingEntry { pub id: EmbeddingId, pub embedding: Vec, // [128], L2-normalized pub fingerprint_type: FingerprintType, pub source_domain: Option, pub created_at: u64, } pub enum FingerprintType { ReIdentification, RoomFingerprint, PersonFingerprint, } ``` --- ## Value Objects ### CsiSample ```rust pub struct CsiSample { /// Amplitude tensor [n_antenna_pairs x n_subcarriers x n_time_frames] pub amplitude: Vec, /// Phase tensor [n_antenna_pairs x n_subcarriers x n_time_frames] pub phase: Vec, /// Ground-truth 3D keypoints [17 x 3] (x, y, z in meters) pub keypoints: [[f32; 3]; 17], /// Per-keypoint visibility flags pub visibility: [f32; 17], /// DensePose UV pseudo-labels (optional, from teacher model) pub densepose_uv: Option, /// Domain label for adversarial training pub domain_label: Option, /// Hardware source type pub hardware_type: HardwareType, } ``` ### TrainingConfig ```rust pub struct TrainingConfig { /// Number of training epochs pub epochs: usize, /// Mini-batch size pub batch_size: usize, /// Initial learning rate pub learning_rate: f64, // 1e-3 /// Learning rate schedule: step decay at these epochs pub lr_decay_epochs: Vec, // [40, 80] /// Learning rate decay factor pub lr_decay_factor: f64, // 0.1 /// Loss component weights pub loss_weights: LossWeights, /// Optimizer (Adam) pub optimizer: OptimizerConfig, /// Validation subject IDs (MM-Fi: 33-40) pub val_subjects: Vec, /// Random seed for reproducibility pub seed: u64, /// Enable MERIDIAN domain-adversarial training pub meridian_enabled: bool, /// Enable AETHER contrastive learning pub aether_enabled: bool, } pub struct LossWeights { /// Keypoint heatmap MSE weight pub keypoint_mse: f32, // 1.0 /// DensePose body part cross-entropy weight pub densepose_part_ce: f32, // 0.5 /// DensePose UV Smooth L1 weight pub uv_smooth_l1: f32, // 0.5 /// Teacher-student transfer MSE weight pub transfer_mse: f32, // 0.2 /// AETHER contrastive loss weight (ADR-024) pub contrastive: f32, // 0.1 /// MERIDIAN domain adversarial weight (ADR-027) pub domain_adversarial: f32, // annealed 0.0 -> 1.0 } ``` ### Checkpoint ```rust pub struct Checkpoint { /// Epoch at which this checkpoint was saved pub epoch: usize, /// SHA-256 hash of serialized model weights pub weights_hash: String, /// Validation PCK@0.2 at this epoch pub validation_pck: f64, /// Validation OKS at this epoch pub validation_oks: f64, /// File path to saved weights pub path: PathBuf, /// Timestamp pub created_at: u64, } ``` ### ProofResult ```rust pub struct ProofResult { /// Seed used for model initialization pub model_seed: u64, // MODEL_SEED = 0 /// Seed used for proof data generation pub proof_seed: u64, // PROOF_SEED = 42 /// Number of training steps in proof pub steps: usize, // 50 /// Whether loss decreased monotonically pub loss_decreased: bool, /// Whether final weights hash matches stored expected hash pub hash_matches: bool, /// The computed SHA-256 hash pub computed_hash: String, /// The expected SHA-256 hash (from file) pub expected_hash: String, } ``` ### LoraAdapter ```rust pub struct LoraAdapter { /// Low-rank decomposition rank pub rank: usize, // 4 /// LoRA alpha scaling factor pub alpha: f32, // 1.0 /// Per-layer weight matrices (A and B for each adapted layer) pub weights: Vec, /// Source domain this adapter was calibrated for pub source_domain: DomainLabel, /// Calibration duration in seconds pub calibration_duration_secs: f32, /// Number of calibration frames used pub calibration_frames: usize, } pub struct LoraLayerWeights { /// Layer name in the model pub layer_name: String, /// Down-projection: [d_model x rank] pub a: Vec, /// Up-projection: [rank x d_model] pub b: Vec, } ``` --- ## Domain Events ### Dataset Events ```rust pub enum DatasetEvent { /// Dataset loaded and validated DatasetLoaded { dataset_type: DatasetType, n_samples: usize, n_subjects: u32, source_subcarriers: usize, timestamp: u64, }, /// Subcarrier interpolation completed for a dataset SubcarrierInterpolationComplete { dataset_type: DatasetType, source_count: usize, target_count: usize, method: InterpolationMethod, timestamp: u64, }, /// Teacher pseudo-labels generated for a batch PseudoLabelsGenerated { n_samples: usize, n_with_uv: usize, timestamp: u64, }, } pub enum DatasetType { MmFi, WiPose, Synthetic, } pub enum InterpolationMethod { /// ruvector-solver NeumannSolver sparse least-squares SparseNeumannSolver, /// Fallback linear interpolation LinearInterpolation, /// Wi-Pose zero-padding ZeroPad, } ``` ### Training Events ```rust pub enum TrainingEvent { /// One epoch of training completed EpochCompleted { epoch: usize, train_loss: f64, val_pck: f64, val_oks: f64, val_mpjpe_mm: f64, learning_rate: f64, grl_lambda: f32, timestamp: u64, }, /// New best checkpoint saved CheckpointSaved { epoch: usize, weights_hash: String, validation_pck: f64, path: String, timestamp: u64, }, /// Deterministic proof verification completed ProofVerified { model_seed: u64, proof_seed: u64, steps: usize, loss_decreased: bool, hash_matches: bool, timestamp: u64, }, /// Training run completed or failed TrainingRunFinished { run_id: String, status: RunStatus, total_epochs: usize, best_pck: f64, best_oks: f64, timestamp: u64, }, } ``` ### Embedding Events ```rust pub enum EmbeddingEvent { /// New AETHER embedding indexed EmbeddingIndexed { embedding_id: String, fingerprint_type: FingerprintType, nearest_neighbor_distance: f32, index_size: usize, timestamp: u64, }, /// Hard negative pair discovered during mining HardNegativeFound { anchor_id: String, negative_id: String, similarity: f32, timestamp: u64, }, /// Domain adaptation completed for a target environment DomainAdaptationComplete { source_domain: String, target_domain: String, pck_before: f64, pck_after: f64, adaptation_method: String, timestamp: u64, }, /// LoRA adapter generated via rapid calibration LoraAdapterGenerated { domain: String, rank: usize, calibration_frames: usize, calibration_seconds: f32, timestamp: u64, }, } ``` --- ## Invariants ### Dataset Management - MM-Fi samples must be interpolated from 114 to 56 subcarriers before use in training - Wi-Pose samples must be zero-padded from 30 to 56 subcarriers before use in training - Wi-Pose keypoints must be mapped from 18 (AlphaPose) to 17 (COCO) by dropping neck index 1 - All CSI amplitudes must be finite and non-negative after loading - Phase values must be in [-pi, pi] after sanitization - Validation subjects (MM-Fi: 33-40) must never appear in the training split - `CompressedCsiBuffer` must preserve signal fidelity within quantization error bounds (hot: <1% error) ### Model Architecture - `csi_embed` input dimension must equal the canonical 56 subcarriers - `keypoint_queries` must have exactly 17 entries (one per COCO keypoint) - `attn_mincut` seq_len must equal n_antenna_pairs * n_time_frames - GNN adjacency matrix must encode the human skeleton topology (17 nodes, 16 edges) - Spatial attention decoder must preserve spatial resolution (no information loss in reshape) ### Training Orchestration - TrainingRun must have at least 1 dataset loaded before `start()` is called - Proof verification requires fixed seeds: MODEL_SEED=0, PROOF_SEED=42 - Proof verification uses exactly 50 training steps on deterministic SyntheticDataset - Loss must decrease over proof steps (otherwise proof fails) - SHA-256 hash of final weights must match stored expected hash (otherwise proof fails) - `best_checkpoint` is updated if and only if current val_pck > all previous val_pck values - Learning rate decays by factor 0.1 at epochs 40 and 80 (step schedule) - Hungarian assignment for static single-frame matching must use the deterministic implementation (not DynamicMinCut) during proof verification ### Embedding & Transfer - AETHER embeddings must be L2-normalized (unit norm) before indexing in HNSW - InfoNCE temperature must be > 0 (typically 0.07) - HNSW index ef_search must be >= k for k-NN queries - MERIDIAN GRL lambda must anneal from 0.0 to 1.0 over the first 20 epochs using the schedule: lambda(p) = 2 / (1 + exp(-10 * p)) - 1, where p = epoch / 20 - GRL lambda must not exceed 1.0 at any epoch - `DomainFactorizer` output dimensions: h_pose = [17 x 64], h_env = [32] - `GeometryEncoder` must be permutation-invariant with respect to AP ordering (DeepSets guarantee) - LoRA adapter rank must be <= d_model / 4 (default rank=4 for d_model=64) - Rapid adaptation requires at least 200 CSI frames (10 seconds at 20 Hz) --- ## Domain Services ### SubcarrierInterpolationService Resamples CSI subcarriers from source to target count using physically-motivated sparse interpolation. ```rust pub trait SubcarrierInterpolationService { /// Sparse interpolation via NeumannSolver (O(sqrt(n)), preferred) fn interpolate_sparse( &self, source: &[f32], source_count: usize, target_count: usize, tolerance: f64, ) -> Result, InterpolationError>; /// Linear interpolation fallback (O(n)) fn interpolate_linear( &self, source: &[f32], source_count: usize, target_count: usize, ) -> Vec; /// Zero-pad for Wi-Pose (30 -> 56) fn zero_pad( &self, source: &[f32], target_count: usize, ) -> Vec; } ``` ### LossComputationService Computes the composite training loss from model outputs and ground truth. ```rust pub trait LossComputationService { /// Compute composite loss with per-component breakdown fn compute( &self, predictions: &ModelOutput, targets: &GroundTruth, weights: &LossWeights, ) -> CompositeTrainingLoss; } pub struct CompositeTrainingLoss { /// Total weighted loss (scalar for backprop) pub total: f64, /// Keypoint heatmap MSE component pub keypoint_mse: f64, /// DensePose body part cross-entropy component pub densepose_part_ce: f64, /// DensePose UV Smooth L1 component pub uv_smooth_l1: f64, /// Teacher-student transfer MSE component pub transfer_mse: f64, /// AETHER contrastive loss (if enabled) pub contrastive: Option, /// MERIDIAN domain adversarial loss (if enabled) pub domain_adversarial: Option, } ``` ### MetricEvaluationService Evaluates model accuracy on the validation set using standard pose estimation metrics. ```rust pub trait MetricEvaluationService { /// PCK@0.2: fraction of keypoints within 20% of torso diameter fn compute_pck(&self, predictions: &[PosePrediction], targets: &[PoseTarget], threshold: f64) -> PckScore; /// OKS: Object Keypoint Similarity with per-keypoint sigmas fn compute_oks(&self, predictions: &[PosePrediction], targets: &[PoseTarget]) -> OksScore; /// MPJPE: Mean Per Joint Position Error in millimeters fn compute_mpjpe(&self, predictions: &[PosePrediction], targets: &[PoseTarget]) -> MpjpeScore; /// Multi-person assignment via Hungarian (static, deterministic) fn assign_hungarian(&self, pred: &[PosePrediction], gt: &[PoseTarget]) -> Vec<(usize, usize)>; /// Multi-person assignment via DynamicMinCut (persistent, O(n^1.5 log n)) fn assign_dynamic(&mut self, pred: &[PosePrediction], gt: &[PoseTarget]) -> Vec<(usize, usize)>; } ``` ### DomainAdversarialService Manages the MERIDIAN gradient reversal training regime. ```rust pub trait DomainAdversarialService { /// Compute GRL lambda for the current epoch fn grl_lambda(&self, epoch: usize, max_warmup_epochs: usize) -> f32; /// Forward pass through domain classifier with gradient reversal fn classify_domain( &self, h_pose: &Tensor, lambda: f32, ) -> Tensor; /// Compute domain adversarial loss (cross-entropy on domain logits) fn domain_loss( &self, domain_logits: &Tensor, domain_labels: &Tensor, ) -> f64; } ``` --- ## Context Map ``` +------------------------------------------------------------------+ | Training Pipeline System | +------------------------------------------------------------------+ | | | +------------------+ CsiSample +------------------+ | | | Dataset |-------------->| Training | | | | Management | | Orchestration | | | | Context | | Context | | | +--------+---------+ +--------+-----------+ | | | | | | | Publishes | Publishes | | | DatasetEvent | TrainingEvent | | v v | | +------------------------------------------------------+ | | | Event Bus (Domain Events) | | | +------------------------------------------------------+ | | | | | | v v | | +------------------+ +------------------+ | | | Model |<-------------| Embedding & | | | | Architecture | body_part_ | Transfer | | | | Context | features | Context | | | +------------------+ +------------------+ | | | +------------------------------------------------------------------+ | UPSTREAM (Conformist) | | +--------------+ +--------------+ +--------------+ | | |wifi-densepose| |wifi-densepose| |wifi-densepose| | | | -signal | | -nn | | -core | | | | (phase algs,| | (ONNX, | | (CsiFrame, | | | | SpotFi) | | Candle) | | error) | | | +--------------+ +--------------+ +--------------+ | | | +------------------------------------------------------------------+ | SIBLING (Partnership) | | +--------------+ +--------------+ +--------------+ | | | RuvSense | | MAT | | Sensing | | | | (pose | | (triage, | | Server | | | | tracker, | | survivor) | | (inference | | | | field | | | | deployment) | | | | model) | | | | | | | +--------------+ +--------------+ +--------------+ | | | +------------------------------------------------------------------+ | EXTERNAL (Published Language) | | +--------------+ +--------------+ +--------------+ | | | MM-Fi | | Wi-Pose | | Detectron2 | | | | (NeurIPS | | (NjtechCV | | (teacher | | | | dataset) | | dataset) | | labels) | | | +--------------+ +--------------+ +--------------+ | +------------------------------------------------------------------+ ``` **Relationship Types:** - Dataset Management -> Training Orchestration: **Customer/Supplier** (Dataset produces CsiSamples; Orchestration consumes) - Model Architecture -> Training Orchestration: **Partnership** (tight bidirectional coupling: Orchestration drives forward/backward; Architecture defines the computation graph) - Model Architecture -> Embedding & Transfer: **Customer/Supplier** (Architecture produces body_part_features; Embedding consumes for contrastive/adversarial heads) - Embedding & Transfer -> Training Orchestration: **Partnership** (contrastive and adversarial losses feed into composite loss) - Training Pipeline -> Upstream crates: **Conformist** (adapts to wifi-densepose-signal, -nn, -core types) - Training Pipeline -> RuvSense/MAT/Server: **Partnership** (trained model weights flow downstream) - Training Pipeline -> External datasets: **Anti-Corruption Layer** (dataset loaders translate external formats to domain types) --- ## Anti-Corruption Layer ### MM-Fi Adapter (Dataset Management -> External MM-Fi format) ```rust /// Translates raw MM-Fi numpy files into domain CsiSample values. /// Handles the 114->56 subcarrier interpolation and 1TX/3RX antenna layout. pub struct MmFiAdapter { /// Subcarrier interpolation service interpolator: Box, /// Phase sanitizer from wifi-densepose-signal phase_sanitizer: PhaseSanitizer, /// Hardware normalizer for z-score normalization normalizer: HardwareNormalizer, } impl MmFiAdapter { /// Load a single MM-Fi sample from .npy tensors and produce a CsiSample. /// Steps: /// 1. Read amplitude [3, 114, 10] and phase [3, 114, 10] /// 2. Interpolate 114 -> 56 subcarriers per antenna pair /// 3. Sanitize phase (remove linear offset, unwrap) /// 4. Z-score normalize amplitude per frame /// 5. Read 17-keypoint COCO annotations pub fn adapt(&self, raw: &MmFiRawSample) -> Result; } ``` ### Wi-Pose Adapter (Dataset Management -> External Wi-Pose format) ```rust /// Translates Wi-Pose .mat files into domain CsiSample values. /// Handles 30->56 zero-padding and 18->17 keypoint mapping. pub struct WiPoseAdapter { /// Zero-padding service interpolator: Box, /// Phase sanitizer phase_sanitizer: PhaseSanitizer, } impl WiPoseAdapter { /// Load a Wi-Pose sample from .mat format and produce a CsiSample. /// Steps: /// 1. Read CSI [9, 30] (3x3 antenna pairs, 30 subcarriers) /// 2. Zero-pad 30 -> 56 subcarriers (high-frequency padding) /// 3. Sanitize phase /// 4. Map 18 AlphaPose keypoints -> 17 COCO (drop neck, index 1) pub fn adapt(&self, raw: &WiPoseRawSample) -> Result; } ``` ### Teacher Model Adapter (Dataset Management -> Detectron2) ```rust /// Adapts Detectron2 DensePose outputs into domain DensePoseLabels. /// Used during teacher-student pseudo-label generation. pub struct TeacherModelAdapter; impl TeacherModelAdapter { /// Run Detectron2 DensePose on an RGB frame and produce pseudo-labels. /// Output: (part_labels [H x W], u_coords [H x W], v_coords [H x W]) pub fn generate_pseudo_labels( &self, rgb_frame: &RgbFrame, ) -> Result; } ``` ### RuVector Adapter (Model Architecture -> ruvector crates) ```rust /// Adapts ruvector-attn-mincut API to the model's tensor format. /// Handles the Tensor <-> Vec conversion overhead per batch element. pub struct AttnMinCutAdapter; impl AttnMinCutAdapter { /// Apply min-cut gated attention to antenna-path features. /// Converts [B, n_ant, n_sc] tensor to flat Vec per batch element, /// calls attn_mincut, and reshapes output back to tensor. pub fn apply( &self, features: &Tensor, n_antenna_paths: usize, n_subcarriers: usize, lambda: f32, ) -> Result; } ``` --- ## Repository Interfaces ```rust /// Persists and retrieves training run state pub trait TrainingRunRepository { fn save(&self, run: &TrainingRun) -> Result<(), RepositoryError>; fn find_by_id(&self, id: &TrainingRunId) -> Result, RepositoryError>; fn find_latest(&self) -> Result, RepositoryError>; fn list_completed(&self) -> Result, RepositoryError>; } /// Persists model checkpoints pub trait CheckpointRepository { fn save(&self, checkpoint: &Checkpoint) -> Result<(), RepositoryError>; fn find_best(&self, run_id: &TrainingRunId) -> Result, RepositoryError>; fn find_by_epoch(&self, run_id: &TrainingRunId, epoch: usize) -> Result, RepositoryError>; fn list_all(&self, run_id: &TrainingRunId) -> Result, RepositoryError>; } /// Persists AETHER embedding index pub trait EmbeddingRepository { fn save_index(&self, index: &EmbeddingIndex) -> Result<(), RepositoryError>; fn load_index(&self) -> Result, RepositoryError>; fn add_entry(&self, entry: &EmbeddingEntry) -> Result<(), RepositoryError>; fn search_knn(&self, query: &[f32], k: usize) -> Result, RepositoryError>; } /// Persists LoRA adapters for environment-specific fine-tuning pub trait LoraRepository { fn save(&self, adapter: &LoraAdapter) -> Result<(), RepositoryError>; fn find_by_domain(&self, domain: &DomainLabel) -> Result, RepositoryError>; fn list_all(&self) -> Result, RepositoryError>; } ``` --- ## References - ADR-015: Public Dataset Strategy (MM-Fi, Wi-Pose, teacher-student training) - ADR-016: RuVector Integration (5 crate integration points in training pipeline) - ADR-020: Rust Migration (training pipeline in wifi-densepose-train crate) - ADR-024: AETHER Contrastive CSI Embeddings (128-dim fingerprints, InfoNCE, HNSW) - ADR-027: MERIDIAN Cross-Environment Domain Generalization (GRL, FiLM, LoRA) - Yang et al., "MM-Fi: Multi-Modal Non-Intrusive 4D Human Dataset" (NeurIPS 2023) - NjtechCVLab, "Wi-Pose Dataset" (CSI-Former, MDPI Entropy 2023) - Geng et al., "DensePose From WiFi" (CMU, arXiv:2301.00250, 2023) - Ganin et al., "Domain-Adversarial Training of Neural Networks" (JMLR 2016) - Perez et al., "FiLM: Visual Reasoning with a General Conditioning Layer" (AAAI 2018)