Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit 2c5ca30

Browse files
committed
feat(rust): Add workspace deps, tests, and refine training modules
- Cargo.toml: Add wifi-densepose-train to workspace members; add petgraph, ndarray-npy, walkdir, sha2, csv, indicatif, clap to workspace dependencies - error.rs: Slim down to focused error types (TrainError, DatasetError) - lib.rs: Wire up all module re-exports correctly - losses.rs: Add generate_gaussian_heatmaps implementation - tests/test_config.rs: Deterministic config roundtrip and validation tests https://claude.ai/code/session_01BSBAQJ34SLkiJy4A8SoiL4
1 parent ec98e40 commit 2c5ca30

5 files changed

Lines changed: 643 additions & 290 deletions

File tree

rust-port/wifi-densepose-rs/Cargo.toml

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ members = [
1111
"crates/wifi-densepose-wasm",
1212
"crates/wifi-densepose-cli",
1313
"crates/wifi-densepose-mat",
14+
"crates/wifi-densepose-train",
1415
]
1516

1617
[workspace.package]
@@ -73,6 +74,25 @@ getrandom = { version = "0.2", features = ["js"] }
7374
serialport = "4.3"
7475
pcap = "1.1"
7576

77+
# Graph algorithms (for min-cut assignment in metrics)
78+
petgraph = "0.6"
79+
80+
# Data loading
81+
ndarray-npy = "0.8"
82+
walkdir = "2.4"
83+
84+
# Hashing (for proof)
85+
sha2 = "0.10"
86+
87+
# CSV logging
88+
csv = "1.3"
89+
90+
# Progress bars
91+
indicatif = "0.17"
92+
93+
# CLI
94+
clap = { version = "4.4", features = ["derive"] }
95+
7696
# Testing
7797
criterion = { version = "0.5", features = ["html_reports"] }
7898
proptest = "1.4"
Lines changed: 18 additions & 287 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,24 @@
11
//! Error types for the WiFi-DensePose training pipeline.
22
//!
3-
//! This module defines a hierarchy of errors covering every failure mode in
4-
//! the training pipeline: configuration validation, dataset I/O, subcarrier
5-
//! interpolation, and top-level training orchestration.
3+
//! This module provides:
4+
//!
5+
//! - [`TrainError`]: top-level error aggregating all training failure modes.
6+
//! - [`TrainResult`]: convenient `Result` alias using `TrainError`.
7+
//!
8+
//! Module-local error types live in their respective modules:
9+
//!
10+
//! - [`crate::config::ConfigError`]: configuration validation errors.
11+
//! - [`crate::dataset::DatasetError`]: dataset loading/access errors.
12+
//!
13+
//! All are re-exported at the crate root for ergonomic use.
614
715
use thiserror::Error;
816
use std::path::PathBuf;
917

18+
// Import module-local error types so TrainError can wrap them via #[from].
19+
use crate::config::ConfigError;
20+
use crate::dataset::DatasetError;
21+
1022
// ---------------------------------------------------------------------------
1123
// Top-level training error
1224
// ---------------------------------------------------------------------------
@@ -16,8 +28,9 @@ pub type TrainResult<T> = Result<T, TrainError>;
1628

1729
/// Top-level error type for the training pipeline.
1830
///
19-
/// Every public function in this crate that can fail returns
20-
/// `TrainResult<T>`, which is `Result<T, TrainError>`.
31+
/// Every orchestration-level function returns `TrainResult<T>`. Lower-level
32+
/// functions in [`crate::config`] and [`crate::dataset`] return their own
33+
/// module-specific error types which are automatically coerced via `#[from]`.
2134
#[derive(Debug, Error)]
2235
pub enum TrainError {
2336
/// Configuration is invalid or internally inconsistent.
@@ -28,10 +41,6 @@ pub enum TrainError {
2841
#[error("Dataset error: {0}")]
2942
Dataset(#[from] DatasetError),
3043

31-
/// Subcarrier interpolation / resampling failed.
32-
#[error("Subcarrier interpolation error: {0}")]
33-
Subcarrier(#[from] SubcarrierError),
34-
3544
/// An underlying I/O error not covered by a more specific variant.
3645
#[error("I/O error: {0}")]
3746
Io(#[from] std::io::Error),
@@ -40,14 +49,6 @@ pub enum TrainError {
4049
#[error("JSON error: {0}")]
4150
Json(#[from] serde_json::Error),
4251

43-
/// TOML (de)serialization error.
44-
#[error("TOML deserialization error: {0}")]
45-
TomlDe(#[from] toml::de::Error),
46-
47-
/// TOML serialization error.
48-
#[error("TOML serialization error: {0}")]
49-
TomlSer(#[from] toml::ser::Error),
50-
5152
/// An operation was attempted on an empty dataset.
5253
#[error("Dataset is empty")]
5354
EmptyDataset,
@@ -112,273 +113,3 @@ impl TrainError {
112113
TrainError::ShapeMismatch { expected, actual }
113114
}
114115
}
115-
116-
// ---------------------------------------------------------------------------
117-
// Configuration errors
118-
// ---------------------------------------------------------------------------
119-
120-
/// Errors produced when validating or loading a [`TrainingConfig`].
121-
///
122-
/// [`TrainingConfig`]: crate::config::TrainingConfig
123-
#[derive(Debug, Error)]
124-
pub enum ConfigError {
125-
/// A required field has a value that violates a constraint.
126-
#[error("Invalid value for field `{field}`: {reason}")]
127-
InvalidValue {
128-
/// Name of the configuration field.
129-
field: &'static str,
130-
/// Human-readable reason the value is invalid.
131-
reason: String,
132-
},
133-
134-
/// The configuration file could not be read.
135-
#[error("Cannot read configuration file `{path}`: {source}")]
136-
FileRead {
137-
/// Path that was being read.
138-
path: PathBuf,
139-
/// Underlying I/O error.
140-
#[source]
141-
source: std::io::Error,
142-
},
143-
144-
/// The configuration file contains invalid TOML.
145-
#[error("Cannot parse configuration file `{path}`: {source}")]
146-
ParseError {
147-
/// Path that was being parsed.
148-
path: PathBuf,
149-
/// Underlying TOML parse error.
150-
#[source]
151-
source: toml::de::Error,
152-
},
153-
154-
/// A path specified in the config does not exist.
155-
#[error("Path `{path}` specified in config does not exist")]
156-
PathNotFound {
157-
/// The missing path.
158-
path: PathBuf,
159-
},
160-
}
161-
162-
impl ConfigError {
163-
/// Construct an [`ConfigError::InvalidValue`] error.
164-
pub fn invalid_value<S: Into<String>>(field: &'static str, reason: S) -> Self {
165-
ConfigError::InvalidValue {
166-
field,
167-
reason: reason.into(),
168-
}
169-
}
170-
}
171-
172-
// ---------------------------------------------------------------------------
173-
// Dataset errors
174-
// ---------------------------------------------------------------------------
175-
176-
/// Errors produced while loading or accessing dataset samples.
177-
#[derive(Debug, Error)]
178-
pub enum DatasetError {
179-
/// The requested data file or directory was not found.
180-
///
181-
/// Production training data is mandatory; this error is never silently
182-
/// suppressed. Use [`SyntheticDataset`] only for proof/testing.
183-
///
184-
/// [`SyntheticDataset`]: crate::dataset::SyntheticDataset
185-
#[error("Data not found at `{path}`: {message}")]
186-
DataNotFound {
187-
/// Path that was expected to contain data.
188-
path: PathBuf,
189-
/// Additional context.
190-
message: String,
191-
},
192-
193-
/// A file was found but its format is incorrect or unexpected.
194-
///
195-
/// This covers malformed numpy arrays, unexpected shapes, bad JSON
196-
/// metadata, etc.
197-
#[error("Invalid data format in `{path}`: {message}")]
198-
InvalidFormat {
199-
/// Path of the malformed file.
200-
path: PathBuf,
201-
/// Description of the format problem.
202-
message: String,
203-
},
204-
205-
/// A low-level I/O error while reading a data file.
206-
#[error("I/O error reading `{path}`: {source}")]
207-
IoError {
208-
/// Path being read when the error occurred.
209-
path: PathBuf,
210-
/// Underlying I/O error.
211-
#[source]
212-
source: std::io::Error,
213-
},
214-
215-
/// The number of subcarriers in the data file does not match the
216-
/// configuration expectation (before or after interpolation).
217-
#[error(
218-
"Subcarrier count mismatch in `{path}`: \
219-
file has {found} subcarriers, expected {expected}"
220-
)]
221-
SubcarrierMismatch {
222-
/// Path of the offending file.
223-
path: PathBuf,
224-
/// Number of subcarriers found in the file.
225-
found: usize,
226-
/// Number of subcarriers expected by the configuration.
227-
expected: usize,
228-
},
229-
230-
/// A sample index was out of bounds.
231-
#[error("Index {index} is out of bounds for dataset of length {len}")]
232-
IndexOutOfBounds {
233-
/// The requested index.
234-
index: usize,
235-
/// Total number of samples.
236-
len: usize,
237-
},
238-
239-
/// A numpy array could not be read.
240-
#[error("NumPy array read error in `{path}`: {message}")]
241-
NpyReadError {
242-
/// Path of the `.npy` file.
243-
path: PathBuf,
244-
/// Error description.
245-
message: String,
246-
},
247-
248-
/// A metadata file (e.g., `meta.json`) is missing or malformed.
249-
#[error("Metadata error for subject {subject_id}: {message}")]
250-
MetadataError {
251-
/// Subject whose metadata could not be read.
252-
subject_id: u32,
253-
/// Description of the problem.
254-
message: String,
255-
},
256-
257-
/// No subjects matching the requested IDs were found in the data directory.
258-
#[error(
259-
"No subjects found in `{data_dir}` matching the requested IDs: {requested:?}"
260-
)]
261-
NoSubjectsFound {
262-
/// Root data directory that was scanned.
263-
data_dir: PathBuf,
264-
/// Subject IDs that were requested.
265-
requested: Vec<u32>,
266-
},
267-
268-
/// A subcarrier interpolation error occurred during sample loading.
269-
#[error("Subcarrier interpolation failed while loading sample {sample_idx}: {source}")]
270-
InterpolationError {
271-
/// The sample index being loaded.
272-
sample_idx: usize,
273-
/// Underlying interpolation error.
274-
#[source]
275-
source: SubcarrierError,
276-
},
277-
}
278-
279-
impl DatasetError {
280-
/// Construct a [`DatasetError::DataNotFound`] error.
281-
pub fn not_found<S: Into<String>>(path: impl Into<PathBuf>, msg: S) -> Self {
282-
DatasetError::DataNotFound {
283-
path: path.into(),
284-
message: msg.into(),
285-
}
286-
}
287-
288-
/// Construct a [`DatasetError::InvalidFormat`] error.
289-
pub fn invalid_format<S: Into<String>>(path: impl Into<PathBuf>, msg: S) -> Self {
290-
DatasetError::InvalidFormat {
291-
path: path.into(),
292-
message: msg.into(),
293-
}
294-
}
295-
296-
/// Construct a [`DatasetError::IoError`] error.
297-
pub fn io_error(path: impl Into<PathBuf>, source: std::io::Error) -> Self {
298-
DatasetError::IoError {
299-
path: path.into(),
300-
source,
301-
}
302-
}
303-
304-
/// Construct a [`DatasetError::SubcarrierMismatch`] error.
305-
pub fn subcarrier_mismatch(path: impl Into<PathBuf>, found: usize, expected: usize) -> Self {
306-
DatasetError::SubcarrierMismatch {
307-
path: path.into(),
308-
found,
309-
expected,
310-
}
311-
}
312-
313-
/// Construct a [`DatasetError::NpyReadError`] error.
314-
pub fn npy_read<S: Into<String>>(path: impl Into<PathBuf>, msg: S) -> Self {
315-
DatasetError::NpyReadError {
316-
path: path.into(),
317-
message: msg.into(),
318-
}
319-
}
320-
}
321-
322-
// ---------------------------------------------------------------------------
323-
// Subcarrier interpolation errors
324-
// ---------------------------------------------------------------------------
325-
326-
/// Errors produced by the subcarrier resampling functions.
327-
#[derive(Debug, Error)]
328-
pub enum SubcarrierError {
329-
/// The source or destination subcarrier count is zero.
330-
#[error("Subcarrier count must be at least 1, got {count}")]
331-
ZeroCount {
332-
/// The offending count.
333-
count: usize,
334-
},
335-
336-
/// The input array has an unexpected shape.
337-
#[error(
338-
"Input array shape mismatch: expected last dimension {expected_sc}, \
339-
got {actual_sc} (full shape: {shape:?})"
340-
)]
341-
InputShapeMismatch {
342-
/// Expected number of subcarriers (last dimension).
343-
expected_sc: usize,
344-
/// Actual number of subcarriers found.
345-
actual_sc: usize,
346-
/// Full shape of the input array.
347-
shape: Vec<usize>,
348-
},
349-
350-
/// The requested interpolation method is not implemented.
351-
#[error("Interpolation method `{method}` is not yet implemented")]
352-
MethodNotImplemented {
353-
/// Name of the unimplemented method.
354-
method: String,
355-
},
356-
357-
/// Source and destination subcarrier counts are already equal.
358-
///
359-
/// Callers should check [`TrainingConfig::needs_subcarrier_interp`] before
360-
/// calling the interpolation routine to avoid this error.
361-
///
362-
/// [`TrainingConfig::needs_subcarrier_interp`]:
363-
/// crate::config::TrainingConfig::needs_subcarrier_interp
364-
#[error(
365-
"Source and destination subcarrier counts are equal ({count}); \
366-
no interpolation is needed"
367-
)]
368-
NopInterpolation {
369-
/// The equal count.
370-
count: usize,
371-
},
372-
373-
/// A numerical error occurred during interpolation (e.g., division by zero
374-
/// due to coincident knot positions).
375-
#[error("Numerical error during interpolation: {0}")]
376-
NumericalError(String),
377-
}
378-
379-
impl SubcarrierError {
380-
/// Construct a [`SubcarrierError::NumericalError`].
381-
pub fn numerical<S: Into<String>>(msg: S) -> Self {
382-
SubcarrierError::NumericalError(msg.into())
383-
}
384-
}

rust-port/wifi-densepose-rs/crates/wifi-densepose-train/src/lib.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -52,9 +52,9 @@ pub mod subcarrier;
5252
pub mod trainer;
5353

5454
// Convenient re-exports at the crate root.
55-
pub use config::TrainingConfig;
56-
pub use dataset::{CsiDataset, CsiSample, DataLoader, MmFiDataset, SyntheticCsiDataset, SyntheticConfig};
57-
pub use error::{ConfigError, DatasetError, SubcarrierError, TrainError, TrainResult};
55+
pub use config::{ConfigError, TrainingConfig};
56+
pub use dataset::{CsiDataset, CsiSample, DataLoader, DatasetError, MmFiDataset, SyntheticCsiDataset, SyntheticConfig};
57+
pub use error::{TrainError, TrainResult};
5858
pub use subcarrier::{compute_interp_weights, interpolate_subcarriers, select_subcarriers_by_variance};
5959

6060
/// Crate version string.

0 commit comments

Comments
 (0)