Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit 43e92c5

Browse files
committed
Add batch processing methods for CSI data in CSIProcessor and PhaseSanitizer
1 parent cbebdd6 commit 43e92c5

2 files changed

Lines changed: 63 additions & 1 deletion

File tree

src/core/csi_processor.py

Lines changed: 33 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,4 +44,36 @@ def process_raw_csi(self, raw_data: np.ndarray) -> np.ndarray:
4444
if processed.std() > 0:
4545
processed = (processed - processed.mean()) / processed.std()
4646

47-
return processed
47+
return processed
48+
49+
def process_csi_batch(self, csi_data: np.ndarray) -> torch.Tensor:
50+
"""Process a batch of CSI data for neural network input.
51+
52+
Args:
53+
csi_data: Complex CSI data array of shape (batch, antennas, subcarriers, time)
54+
55+
Returns:
56+
Processed CSI tensor ready for neural network input
57+
"""
58+
if csi_data.ndim != 4:
59+
raise ValueError(f"Expected 4D input (batch, antennas, subcarriers, time), got {csi_data.ndim}D")
60+
61+
batch_size, num_antennas, num_subcarriers, time_samples = csi_data.shape
62+
63+
# Extract amplitude and phase
64+
amplitude = np.abs(csi_data)
65+
phase = np.angle(csi_data)
66+
67+
# Process each component
68+
processed_amplitude = self.process_raw_csi(amplitude)
69+
processed_phase = self.process_raw_csi(phase)
70+
71+
# Stack amplitude and phase as separate channels
72+
processed_data = np.stack([processed_amplitude, processed_phase], axis=1)
73+
74+
# Reshape to (batch, channels, antennas, subcarriers, time)
75+
# Then flatten spatial dimensions for CNN input
76+
processed_data = processed_data.reshape(batch_size, 2 * num_antennas, num_subcarriers, time_samples)
77+
78+
# Convert to tensor
79+
return torch.from_numpy(processed_data).float()

src/core/phase_sanitizer.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
"""Phase sanitizer for WiFi-DensePose CSI phase data processing."""
22

33
import numpy as np
4+
import torch
45
from typing import Optional
56
from scipy import signal
67

@@ -60,6 +61,35 @@ def remove_outliers(self, phase_data: np.ndarray) -> np.ndarray:
6061

6162
return result
6263

64+
def sanitize_phase_batch(self, processed_csi: torch.Tensor) -> torch.Tensor:
65+
"""Sanitize phase information in a batch of processed CSI data.
66+
67+
Args:
68+
processed_csi: Processed CSI tensor from CSI processor
69+
70+
Returns:
71+
CSI tensor with sanitized phase information
72+
"""
73+
if not isinstance(processed_csi, torch.Tensor):
74+
raise ValueError("Input must be a torch.Tensor")
75+
76+
# Convert to numpy for processing
77+
csi_numpy = processed_csi.detach().cpu().numpy()
78+
79+
# The processed CSI has shape (batch, channels, subcarriers, time)
80+
# where channels = 2 * antennas (amplitude and phase interleaved)
81+
batch_size, channels, subcarriers, time_samples = csi_numpy.shape
82+
83+
# Process phase channels (odd indices contain phase information)
84+
for batch_idx in range(batch_size):
85+
for ch_idx in range(1, channels, 2): # Phase channels are at odd indices
86+
phase_data = csi_numpy[batch_idx, ch_idx, :, :]
87+
sanitized_phase = self.sanitize(phase_data)
88+
csi_numpy[batch_idx, ch_idx, :, :] = sanitized_phase
89+
90+
# Convert back to tensor
91+
return torch.from_numpy(csi_numpy).float()
92+
6393
def smooth_phase(self, phase_data: np.ndarray) -> np.ndarray:
6494
"""Apply smoothing filter to reduce noise in phase data.
6595

0 commit comments

Comments
 (0)