Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit 08afb54

Browse files
committed
fixing some test cases
1 parent 5769733 commit 08afb54

File tree

4 files changed

+285
-28
lines changed

4 files changed

+285
-28
lines changed

src/eegprep/eeg_eegrej.py

Lines changed: 201 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,204 @@
11
import numpy as np
22
from copy import deepcopy
3-
from eegprep.eegrej import eegrej # expects: eegrej(data, regions, xdur, events) -> (data_out, xmax_rel, event2, boundevents)
3+
import numpy as np
4+
from typing import List, Dict, Optional, Tuple
5+
6+
def _is_boundary_event(event: Dict) -> bool:
7+
t = event.get("type")
8+
if isinstance(t, str):
9+
return t.lower() == "boundary"
10+
if isinstance(t, (int, float)):
11+
try:
12+
return int(t) == -99
13+
except Exception:
14+
return False
15+
return False
16+
17+
def _eegrej(indata, regions, timelength, events: Optional[List[Dict]] = None) -> Tuple[np.ndarray, float, List[Dict], np.ndarray]:
18+
"""
19+
Remove [beg end] sample ranges (1-based, inclusive) from continuous data
20+
and update events (list of dictionaries) in the MATLAB EEGLAB style.
21+
22+
Inputs
23+
- indata: 2D array shaped (channels, frames)
24+
- regions: array-like with shape (n_regions, 2), 1-based [beg end] per row
25+
- timelength: total duration of the original data in seconds
26+
- events: list of dicts with at least key 'latency'; optional keys include
27+
'type' and 'duration'. If None or empty, boundary events will
28+
still be inserted based on regions.
29+
30+
Returns
31+
- outdata: data with columns removed
32+
- newt: new total time in seconds
33+
- events_out: updated events list of dictionaries (with inserted boundaries)
34+
- boundevents: boundary latencies (float, 1-based, with +0.5 convention)
35+
"""
36+
x = np.asarray(indata)
37+
if x.ndim != 2:
38+
raise ValueError("indata must be 2D (channels, frames)")
39+
n = x.shape[1]
40+
41+
r = np.asarray(regions, dtype=float)
42+
if r.size == 0:
43+
# nothing to remove; still ensure events sorted and valid
44+
events_out = [] if events is None else [dict(ev) for ev in events]
45+
# Sort events by latency if present
46+
if events_out and all("latency" in ev for ev in events_out):
47+
events_out.sort(key=lambda ev: ev.get("latency", float("inf")))
48+
boundevents = np.array([], dtype=float)
49+
return x, float(timelength), events_out, boundevents
50+
51+
if r.ndim != 2 or r.shape[1] != 2:
52+
raise ValueError("regions must be of shape (n_regions, 2)")
53+
54+
# Round, clamp to [1, n], sort each row then sort rows (EEGLAB parity)
55+
r = np.rint(r).astype(int)
56+
r[:, 0] = np.clip(r[:, 0], 1, n)
57+
r[:, 1] = np.clip(r[:, 1], 1, n)
58+
r.sort(axis=1)
59+
r = r[np.lexsort((r[:, 1], r[:, 0]))]
60+
61+
# Enforce non-overlap by shifting starts forward (like MATLAB)
62+
for i in range(1, r.shape[0]):
63+
if r[i - 1, 1] >= r[i, 0]:
64+
r[i, 0] = r[i - 1, 1] + 1
65+
# Drop empty or inverted regions after adjustment
66+
r = r[r[:, 0] <= r[:, 1]]
67+
if r.size == 0:
68+
events_out = [] if events is None else [dict(ev) for ev in events]
69+
if events_out and all("latency" in ev for ev in events_out):
70+
events_out.sort(key=lambda ev: ev.get("latency", float("inf")))
71+
boundevents = np.array([], dtype=float)
72+
return x, float(timelength), events_out, boundevents
73+
74+
# Build reject mask (convert 1-based to 0-based slices)
75+
# MATLAB: reject(beg:end) = 1 (includes both beg and end, 1-based)
76+
# Python: reject[beg-1:end] = True (includes beg-1 to end-1, since end is exclusive in Python slicing)
77+
# To match MATLAB's inclusive end, we need reject[beg-1:end] where end is inclusive
78+
reject = np.zeros(n, dtype=bool)
79+
for beg, end in r:
80+
reject[beg - 1:end] = True # This matches MATLAB reject(beg:end) when end is already the inclusive end
81+
82+
# Prepare events
83+
ori_events: List[Dict] = [] if events is None else [dict(ev) for ev in events]
84+
events_out: List[Dict] = [dict(ev) for ev in ori_events]
85+
86+
# Recompute event latencies (if events have 'latency') and remove events strictly inside regions
87+
if events_out and all("latency" in ev for ev in events_out):
88+
ori_lat = np.array([float(ev.get("latency", float("nan"))) for ev in events_out], dtype=float)
89+
lat = ori_lat.copy()
90+
rejected_per_region: List[List[int]] = []
91+
for beg, end in r:
92+
# indices strictly inside (beg, end)
93+
rej_idx = np.where((ori_lat > beg) & (ori_lat < end))[0].tolist()
94+
rejected_per_region.append(rej_idx)
95+
# subtract span from latencies whose original latency is strictly after region start
96+
span = int(end - beg + 1)
97+
lat[ori_lat > beg] -= span
98+
99+
# Apply updated latencies
100+
for i, ev in enumerate(events_out):
101+
ev["latency"] = float(lat[i])
102+
103+
# Remove events inside rejected regions
104+
rm_idx = sorted(set(idx for group in rejected_per_region for idx in group))
105+
if rm_idx:
106+
keep_mask = np.ones(len(events_out), dtype=bool)
107+
keep_mask[rm_idx] = False
108+
events_out = [ev for j, ev in enumerate(events_out) if keep_mask[j]]
109+
110+
# Boundary latencies: start-1, then subtract cumulative prior durations, then +0.5
111+
base_durations = (r[:, 1] - r[:, 0] + 1).astype(int)
112+
113+
# If we have original events and they include type/duration, add nested boundary durations
114+
durations = base_durations.astype(float).copy()
115+
if ori_events and all("latency" in ev for ev in ori_events):
116+
ori_lat = np.array([float(ev.get("latency", float("nan"))) for ev in ori_events], dtype=float)
117+
for i_region, (beg, end) in enumerate(r):
118+
inside_mask = (ori_lat > beg) & (ori_lat < end)
119+
selected_events = [ori_events[i] for i, m in enumerate(inside_mask) if m]
120+
extra = 0.0
121+
for ev in selected_events:
122+
if _is_boundary_event(ev):
123+
extra += float(ev.get("duration", 0.0) or 0.0)
124+
durations[i_region] += extra
125+
126+
# Compute boundevents considering prior removals
127+
boundevents = r[:, 0].astype(float) - 1.0
128+
if len(durations) > 1:
129+
cums = np.concatenate([[0.0], np.cumsum(durations[:-1])])
130+
boundevents = boundevents - cums
131+
boundevents = boundevents + 0.5
132+
boundevents = boundevents[boundevents >= 0]
133+
134+
# Excise samples
135+
newx = x[:, ~reject]
136+
newn = int(newx.shape[1])
137+
138+
# Update total time proportionally
139+
newt = float(timelength) * (newn / float(n))
140+
141+
# Remove boundary events that would fall exactly after the last sample + 0.5
142+
boundevents = boundevents[boundevents < (newn + 1)]
143+
144+
# Merge duplicate boundary latencies and sum durations for duplicates
145+
if boundevents.size:
146+
rounded = np.round(boundevents, 12)
147+
merged_be: List[float] = []
148+
merged_du: List[float] = []
149+
for i, be in enumerate(rounded):
150+
if not merged_be:
151+
merged_be.append(be)
152+
merged_du.append(float(durations[i]))
153+
else:
154+
if np.isclose(be, merged_be[-1]):
155+
merged_du[-1] += float(durations[i])
156+
else:
157+
merged_be.append(be)
158+
merged_du.append(float(durations[i]))
159+
boundevents = np.asarray(merged_be, dtype=float)
160+
durations = np.asarray(merged_du, dtype=float)
161+
else:
162+
durations = np.asarray([], dtype=float)
163+
164+
# Insert boundary events into events list only if input events were provided
165+
if ori_events:
166+
bound_type = "boundary"
167+
for i in range(len(boundevents)):
168+
be = float(boundevents[i])
169+
if be > 0 and be < (newn + 1):
170+
events_out.append({
171+
"type": bound_type,
172+
"latency": be,
173+
"duration": float(durations[i] if i < len(durations) else (base_durations[i] if i < len(base_durations) else 0.0)),
174+
})
175+
176+
# Remove events with latency out of bound (> newn+1)
177+
filtered: List[Dict] = []
178+
for ev in events_out:
179+
latv = float(ev.get("latency", float("inf")))
180+
if latv <= (newn + 1):
181+
filtered.append(ev)
182+
events_out = filtered
183+
184+
# Sort by latency
185+
events_out.sort(key=lambda ev: ev.get("latency", float("inf")))
186+
187+
# Handle contiguous boundary events with same latency: merge durations
188+
if events_out:
189+
merged_events: List[Dict] = []
190+
for ev in events_out:
191+
if merged_events and _is_boundary_event(ev) and _is_boundary_event(merged_events[-1]) \
192+
and np.isclose(float(ev.get("latency", 0.0)), float(merged_events[-1].get("latency", 0.0))):
193+
prev_dur = float(merged_events[-1].get("duration", 0.0) or 0.0)
194+
cur_dur = float(ev.get("duration", 0.0) or 0.0)
195+
merged_events[-1]["duration"] = prev_dur + cur_dur
196+
else:
197+
merged_events.append(ev)
198+
events_out = merged_events
199+
200+
return newx, newt, events_out, boundevents
201+
4202

5203
def eeg_eegrej(EEG, regions):
6204
EEG = deepcopy(EEG)
@@ -26,9 +224,9 @@ def eeg_eegrej(EEG, regions):
26224
# Use original events; backend will handle pruning, shifting, and boundary insertion
27225
events = list(EEG.get("event", []))
28226

29-
# call eegrej backend
227+
# call _eegrej backend
30228
xdur = float(EEG["xmax"] - EEG["xmin"])
31-
data_out, xmax_rel, event2, boundevents = eegrej(EEG["data"], regions, xdur, events)
229+
data_out, xmax_rel, event2, boundevents = _eegrej(EEG["data"], regions, xdur, events)
32230

33231
# finalize core fields
34232
old_pnts = int(EEG["pnts"])

src/eegprep/pop_saveset.py

Lines changed: 73 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -157,29 +157,74 @@ def pop_saveset(EEG, file_name):
157157
'icawinv' : EEG['icawinv'],
158158
'icasphere' : EEG['icasphere'],
159159
'icaweights' : EEG['icaweights'],
160-
'icachansind' : _as_array_or_empty(EEG['icachansind']),
160+
'icachansind' : EEG['icachansind'] if EEG['icachansind'] is not None else {},
161161
'chanlocs' : EEG['chanlocs'],
162162
'urchanlocs' : EEG['urchanlocs'],
163163
'chaninfo' : EEG['chaninfo'],
164164
'ref' : EEG['ref'],
165-
'event' : _as_array_or_empty(EEG['event']),
166-
'urevent' : _as_array_or_empty(EEG['urevent']),
167-
'eventdescription': _as_array_or_empty(EEG['eventdescription']),
168-
'epoch' : _as_array_or_empty(EEG['epoch']),
169-
'epochdescription': _as_array_or_empty(EEG['epochdescription']),
170-
'reject' : _as_array_or_empty(EEG['reject']),
171-
'stats' : _as_array_or_empty(EEG['stats']),
172-
'specdata' : _as_array_or_empty(EEG['specdata']),
173-
'specicaact' : _as_array_or_empty(EEG['specicaact']),
174-
'splinefile' : _as_array_or_empty(EEG['splinefile']),
175-
'icasplinefile' : _as_array_or_empty(EEG['icasplinefile']),
176-
'dipfit' : _as_array_or_empty(EEG['dipfit']),
165+
'event' : EEG['event'] if EEG['event'] is not None else {},
166+
'urevent' : EEG['urevent'] if EEG['urevent'] is not None else {},
167+
'eventdescription': EEG['eventdescription'] if EEG['eventdescription'] is not None else {},
168+
'epoch' : EEG['epoch'] if EEG['epoch'] is not None else {},
169+
'epochdescription': EEG['epochdescription'] if EEG['epochdescription'] is not None else {},
170+
'reject' : EEG['reject'] if EEG['reject'] is not None else {},
171+
'stats' : EEG['stats'] if EEG['stats'] is not None else {},
172+
'specdata' : EEG['specdata'] if EEG['specdata'] is not None else {},
173+
'specicaact' : EEG['specicaact'] if EEG['specicaact'] is not None else {},
174+
'splinefile' : EEG['splinefile'] if EEG['splinefile'] is not None else {},
175+
'icasplinefile' : EEG['icasplinefile'] if EEG['icasplinefile'] is not None else {},
176+
'dipfit' : EEG['dipfit'] if EEG['dipfit'] is not None else {},
177177
'history' : EEG['history'],
178178
'saved' : EEG['saved'],
179179
'etc' : EEG['etc'],
180-
'run' : _as_array_or_empty(EEG['run']),
181-
'roi' : _as_array_or_empty(EEG['roi']),
182-
}
180+
'run' : EEG['run'] if EEG['run'] is not None else {},
181+
'roi' : EEG['roi'] if EEG['roi'] is not None else {}
182+
}
183+
184+
# eeglab_dict = {
185+
# 'setname' : '',
186+
# 'filename' : '',
187+
# 'filepath' : '',
188+
# 'subject' : '',
189+
# 'group' : '',
190+
# 'condition' : '',
191+
# 'session' : np.array([]),
192+
# 'comments' : '',
193+
# 'nbchan' : float(EEG['nbchan']),
194+
# 'trials' : float(EEG['trials']),
195+
# 'pnts' : float(EEG['pnts']),
196+
# 'srate' : float(EEG['srate']),
197+
# 'xmin' : float(EEG['xmin']),
198+
# 'xmax' : float(EEG['xmax']),
199+
# 'times' : EEG['times'],
200+
# 'data' : EEG['data'],
201+
# 'icaact' : EEG['icaact'],
202+
# 'icawinv' : EEG['icawinv'],
203+
# 'icasphere' : EEG['icasphere'],
204+
# 'icaweights' : EEG['icaweights'],
205+
# 'icachansind' : _as_array_or_empty(EEG['icachansind']),
206+
# 'chanlocs' : EEG['chanlocs'],
207+
# 'urchanlocs' : EEG['urchanlocs'],
208+
# 'chaninfo' : EEG['chaninfo'],
209+
# 'ref' : EEG['ref'],
210+
# 'event' : _as_array_or_empty(EEG['event']),
211+
# 'urevent' : _as_array_or_empty(EEG['urevent']),
212+
# 'eventdescription': _as_array_or_empty(EEG['eventdescription']),
213+
# 'epoch' : _as_array_or_empty(EEG['epoch']),
214+
# 'epochdescription': _as_array_or_empty(EEG['epochdescription']),
215+
# 'reject' : _as_array_or_empty(EEG['reject']),
216+
# 'stats' : _as_array_or_empty(EEG['stats']),
217+
# 'specdata' : _as_array_or_empty(EEG['specdata']),
218+
# 'specicaact' : _as_array_or_empty(EEG['specicaact']),
219+
# 'splinefile' : _as_array_or_empty(EEG['splinefile']),
220+
# 'icasplinefile' : _as_array_or_empty(EEG['icasplinefile']),
221+
# 'dipfit' : _as_array_or_empty(EEG['dipfit']),
222+
# 'history' : EEG['history'],
223+
# 'saved' : EEG['saved'],
224+
# 'etc' : EEG['etc'],
225+
# 'run' : _as_array_or_empty(EEG['run']),
226+
# 'roi' : _as_array_or_empty(EEG['roi']),
227+
# }
183228

184229
# add 1 to EEG['icachansind'] to make it 1-based
185230
if ('icachansind' in eeglab_dict and
@@ -242,8 +287,19 @@ def pop_saveset(EEG, file_name):
242287
eeglab_dict['event'] = np.array(eeglab_dict['event'])
243288

244289
for key in eeglab_dict:
245-
if isinstance(eeglab_dict[key], np.ndarray) and not(is_effectively_empty(eeglab_dict[key])) and len(eeglab_dict[key]) > 0 and isinstance(eeglab_dict[key][0], dict):
290+
if isinstance(eeglab_dict[key], np.ndarray) and len(eeglab_dict[key]) > 0 and isinstance(eeglab_dict[key][0], dict):
246291
eeglab_dict[key] = flatten_dict(eeglab_dict[key])
292+
# for key in eeglab_dict:
293+
# arr = eeglab_dict[key]
294+
# if isinstance(arr, np.ndarray) and not is_effectively_empty(arr):
295+
# if not arr.ndim == 0:
296+
# if arr.shape != () and arr.shape[0] > 0 and isinstance(arr[0], dict):
297+
# eeglab_dict[key] = flatten_dict(arr)
298+
# else:
299+
# elem = arr.item()
300+
# if isinstance(elem, dict):
301+
# eeglab_dict[key] = flatten_dict([elem]) # wrap single dict
302+
247303
# # Step 4: Save the EEGLAB dataset as a .mat file
248304
scipy.io.savemat(file_name, eeglab_dict, appendmat=False)
249305

0 commit comments

Comments
 (0)