|
1 | 1 | """Phase sanitizer for WiFi-DensePose CSI phase data processing.""" |
2 | 2 |
|
3 | 3 | import numpy as np |
| 4 | +import torch |
4 | 5 | from typing import Optional |
5 | 6 | from scipy import signal |
6 | 7 |
|
@@ -60,6 +61,35 @@ def remove_outliers(self, phase_data: np.ndarray) -> np.ndarray: |
60 | 61 |
|
61 | 62 | return result |
62 | 63 |
|
| 64 | + def sanitize_phase_batch(self, processed_csi: torch.Tensor) -> torch.Tensor: |
| 65 | + """Sanitize phase information in a batch of processed CSI data. |
| 66 | + |
| 67 | + Args: |
| 68 | + processed_csi: Processed CSI tensor from CSI processor |
| 69 | + |
| 70 | + Returns: |
| 71 | + CSI tensor with sanitized phase information |
| 72 | + """ |
| 73 | + if not isinstance(processed_csi, torch.Tensor): |
| 74 | + raise ValueError("Input must be a torch.Tensor") |
| 75 | + |
| 76 | + # Convert to numpy for processing |
| 77 | + csi_numpy = processed_csi.detach().cpu().numpy() |
| 78 | + |
| 79 | + # The processed CSI has shape (batch, channels, subcarriers, time) |
| 80 | + # where channels = 2 * antennas (amplitude and phase interleaved) |
| 81 | + batch_size, channels, subcarriers, time_samples = csi_numpy.shape |
| 82 | + |
| 83 | + # Process phase channels (odd indices contain phase information) |
| 84 | + for batch_idx in range(batch_size): |
| 85 | + for ch_idx in range(1, channels, 2): # Phase channels are at odd indices |
| 86 | + phase_data = csi_numpy[batch_idx, ch_idx, :, :] |
| 87 | + sanitized_phase = self.sanitize(phase_data) |
| 88 | + csi_numpy[batch_idx, ch_idx, :, :] = sanitized_phase |
| 89 | + |
| 90 | + # Convert back to tensor |
| 91 | + return torch.from_numpy(csi_numpy).float() |
| 92 | + |
63 | 93 | def smooth_phase(self, phase_data: np.ndarray) -> np.ndarray: |
64 | 94 | """Apply smoothing filter to reduce noise in phase data. |
65 | 95 | |
|
0 commit comments