@@ -49,48 +49,89 @@ pub struct AdaptationResult {
4949 pub adaptation_epochs : usize ,
5050}
5151
52+ /// Error type for rapid adaptation.
53+ #[ derive( Debug , Clone ) ]
54+ pub enum AdaptError {
55+ /// Not enough calibration frames.
56+ InsufficientFrames {
57+ /// Frames currently buffered.
58+ have : usize ,
59+ /// Minimum required.
60+ need : usize ,
61+ } ,
62+ /// LoRA rank must be at least 1.
63+ InvalidRank ,
64+ }
65+
66+ impl std:: fmt:: Display for AdaptError {
67+ fn fmt ( & self , f : & mut std:: fmt:: Formatter < ' _ > ) -> std:: fmt:: Result {
68+ match self {
69+ Self :: InsufficientFrames { have, need } =>
70+ write ! ( f, "insufficient calibration frames: have {have}, need at least {need}" ) ,
71+ Self :: InvalidRank => write ! ( f, "lora_rank must be >= 1" ) ,
72+ }
73+ }
74+ }
75+
76+ impl std:: error:: Error for AdaptError { }
77+
5278/// Few-shot rapid adaptation engine.
5379///
5480/// Accumulates unlabeled CSI calibration frames and runs test-time training
55- /// to produce LoRA weight deltas.
81+ /// to produce LoRA weight deltas. Buffer is capped at `max_buffer_frames`
82+ /// (default 10 000) to prevent unbounded memory growth.
5683///
5784/// ```rust
5885/// use wifi_densepose_train::rapid_adapt::{RapidAdaptation, AdaptationLoss};
5986/// let loss = AdaptationLoss::Combined { epochs: 5, lr: 0.001, lambda_ent: 0.5 };
6087/// let mut ra = RapidAdaptation::new(10, 4, loss);
6188/// for i in 0..10 { ra.push_frame(&vec![i as f32; 8]); }
6289/// assert!(ra.is_ready());
63- /// let r = ra.adapt();
90+ /// let r = ra.adapt().unwrap() ;
6491/// assert_eq!(r.frames_used, 10);
6592/// ```
6693pub struct RapidAdaptation {
6794 /// Minimum frames before adaptation (default 200 = 10 s @ 20 Hz).
6895 pub min_calibration_frames : usize ,
69- /// LoRA factorization rank (default 4 ).
96+ /// LoRA factorization rank (must be >= 1 ).
7097 pub lora_rank : usize ,
7198 /// Loss variant for test-time training.
7299 pub adaptation_loss : AdaptationLoss ,
100+ /// Maximum buffer size (ring-buffer eviction beyond this cap).
101+ pub max_buffer_frames : usize ,
73102 calibration_buffer : Vec < Vec < f32 > > ,
74103}
75104
105+ /// Default maximum calibration buffer size.
106+ const DEFAULT_MAX_BUFFER : usize = 10_000 ;
107+
76108impl RapidAdaptation {
77109 /// Create a new adaptation engine.
78110 pub fn new ( min_calibration_frames : usize , lora_rank : usize , adaptation_loss : AdaptationLoss ) -> Self {
79- Self { min_calibration_frames, lora_rank, adaptation_loss, calibration_buffer : Vec :: new ( ) }
111+ Self { min_calibration_frames, lora_rank, adaptation_loss, max_buffer_frames : DEFAULT_MAX_BUFFER , calibration_buffer : Vec :: new ( ) }
112+ }
113+ /// Push a single unlabeled CSI frame. Evicts oldest frame when buffer is full.
114+ pub fn push_frame ( & mut self , frame : & [ f32 ] ) {
115+ if self . calibration_buffer . len ( ) >= self . max_buffer_frames {
116+ self . calibration_buffer . remove ( 0 ) ;
117+ }
118+ self . calibration_buffer . push ( frame. to_vec ( ) ) ;
80119 }
81- /// Push a single unlabeled CSI frame.
82- pub fn push_frame ( & mut self , frame : & [ f32 ] ) { self . calibration_buffer . push ( frame. to_vec ( ) ) ; }
83120 /// True when buffer >= min_calibration_frames.
84121 pub fn is_ready ( & self ) -> bool { self . calibration_buffer . len ( ) >= self . min_calibration_frames }
85122 /// Number of buffered frames.
86123 pub fn buffer_len ( & self ) -> usize { self . calibration_buffer . len ( ) }
87124
88125 /// Run test-time adaptation producing LoRA weight deltas.
89126 ///
90- /// # Panics
91- /// Panics if the calibration buffer is empty.
92- pub fn adapt ( & self ) -> AdaptationResult {
93- assert ! ( !self . calibration_buffer. is_empty( ) , "empty calibration buffer" ) ;
127+ /// Returns an error if the calibration buffer is empty or lora_rank is 0.
128+ pub fn adapt ( & self ) -> Result < AdaptationResult , AdaptError > {
129+ if self . calibration_buffer . is_empty ( ) {
130+ return Err ( AdaptError :: InsufficientFrames { have : 0 , need : 1 } ) ;
131+ }
132+ if self . lora_rank == 0 {
133+ return Err ( AdaptError :: InvalidRank ) ;
134+ }
94135 let ( n, fdim) = ( self . calibration_buffer . len ( ) , self . calibration_buffer [ 0 ] . len ( ) ) ;
95136 let lora_sz = 2 * fdim * self . lora_rank ;
96137 let mut w = vec ! [ 0.01_f32 ; lora_sz] ;
@@ -112,7 +153,7 @@ impl RapidAdaptation {
112153 for ( wi, gi) in w. iter_mut ( ) . zip ( g. iter ( ) ) { * wi -= lr * gi; }
113154 final_loss = loss;
114155 }
115- AdaptationResult { lora_weights : w, final_loss, frames_used : n, adaptation_epochs : epochs }
156+ Ok ( AdaptationResult { lora_weights : w, final_loss, frames_used : n, adaptation_epochs : epochs } )
116157 }
117158
118159 fn contrastive_step ( & self , w : & [ f32 ] , fdim : usize , grad : & mut [ f32 ] ) -> f32 {
@@ -207,7 +248,7 @@ mod tests {
207248 let ( fdim, rank) = ( 16 , 4 ) ;
208249 let mut a = RapidAdaptation :: new ( 10 , rank, AdaptationLoss :: ContrastiveTTT { epochs : 3 , lr : 0.01 } ) ;
209250 for i in 0 ..10 { a. push_frame ( & vec ! [ i as f32 * 0.1 ; fdim] ) ; }
210- let r = a. adapt ( ) ;
251+ let r = a. adapt ( ) . unwrap ( ) ;
211252 assert_eq ! ( r. lora_weights. len( ) , 2 * fdim * rank) ;
212253 assert_eq ! ( r. frames_used, 10 ) ;
213254 assert_eq ! ( r. adaptation_epochs, 3 ) ;
@@ -219,7 +260,7 @@ mod tests {
219260 let mk = |ep| {
220261 let mut a = RapidAdaptation :: new ( 20 , rank, AdaptationLoss :: ContrastiveTTT { epochs : ep, lr : 0.01 } ) ;
221262 for i in 0 ..20 { let v = i as f32 * 0.1 ; a. push_frame ( & ( 0 ..fdim) . map ( |d| v + d as f32 * 0.01 ) . collect :: < Vec < _ > > ( ) ) ; }
222- a. adapt ( ) . final_loss
263+ a. adapt ( ) . unwrap ( ) . final_loss
223264 } ;
224265 assert ! ( mk( 10 ) <= mk( 1 ) + 1e-6 , "10 epochs should yield <= 1 epoch loss" ) ;
225266 }
@@ -229,14 +270,35 @@ mod tests {
229270 let ( fdim, rank) = ( 16 , 4 ) ;
230271 let mut a = RapidAdaptation :: new ( 10 , rank, AdaptationLoss :: Combined { epochs : 5 , lr : 0.001 , lambda_ent : 0.5 } ) ;
231272 for i in 0 ..10 { a. push_frame ( & ( 0 ..fdim) . map ( |d| ( ( i * fdim + d) as f32 ) . sin ( ) ) . collect :: < Vec < _ > > ( ) ) ; }
232- let r = a. adapt ( ) ;
273+ let r = a. adapt ( ) . unwrap ( ) ;
233274 assert_eq ! ( r. frames_used, 10 ) ;
234275 assert_eq ! ( r. adaptation_epochs, 5 ) ;
235276 assert ! ( r. final_loss. is_finite( ) ) ;
236277 assert_eq ! ( r. lora_weights. len( ) , 2 * fdim * rank) ;
237278 assert ! ( r. lora_weights. iter( ) . all( |w| w. is_finite( ) ) ) ;
238279 }
239280
281+ #[ test]
282+ fn adapt_empty_buffer_returns_error ( ) {
283+ let a = RapidAdaptation :: new ( 10 , 4 , AdaptationLoss :: ContrastiveTTT { epochs : 1 , lr : 0.01 } ) ;
284+ assert ! ( a. adapt( ) . is_err( ) ) ;
285+ }
286+
287+ #[ test]
288+ fn adapt_zero_rank_returns_error ( ) {
289+ let mut a = RapidAdaptation :: new ( 1 , 0 , AdaptationLoss :: ContrastiveTTT { epochs : 1 , lr : 0.01 } ) ;
290+ a. push_frame ( & [ 1.0 , 2.0 ] ) ;
291+ assert ! ( a. adapt( ) . is_err( ) ) ;
292+ }
293+
294+ #[ test]
295+ fn buffer_cap_evicts_oldest ( ) {
296+ let mut a = RapidAdaptation :: new ( 2 , 4 , AdaptationLoss :: ContrastiveTTT { epochs : 1 , lr : 0.01 } ) ;
297+ a. max_buffer_frames = 3 ;
298+ for i in 0 ..5 { a. push_frame ( & [ i as f32 ] ) ; }
299+ assert_eq ! ( a. buffer_len( ) , 3 ) ;
300+ }
301+
240302 #[ test]
241303 fn l2_distance_tests ( ) {
242304 assert ! ( l2_dist( & [ 1.0 , 2.0 , 3.0 ] , & [ 1.0 , 2.0 , 3.0 ] ) . abs( ) < 1e-10 ) ;
0 commit comments