Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit 053850e

Browse files
committed
fixing eegrej.py and test
1 parent b916872 commit 053850e

File tree

5 files changed

+374
-76
lines changed

5 files changed

+374
-76
lines changed

src/eegprep/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,4 +33,5 @@
3333
from .bids_list_eeg_files import bids_list_eeg_files
3434
from .bids_preproc import bids_preproc
3535
from .eeg_decodechan import eeg_decodechan
36-
from .eegrej import eegrej
36+
from .eegrej import eegrej
37+
from .eeg_eegrej import eeg_eegrej

src/eegprep/eeg_eegrej.py

Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,112 @@
1+
import numpy as np
2+
from copy import deepcopy
3+
from eegprep.eegrej import eegrej # expects: eegrej(data, regions, xdur, events) -> (data_out, xmax_rel, event2, boundevents)
4+
5+
def eeg_eegrej(EEG, regions):
6+
EEG = deepcopy(EEG)
7+
if regions is None or len(regions) == 0:
8+
return EEG
9+
10+
regions = np.asarray(regions, dtype=np.int64)
11+
# sort rows like MATLAB
12+
if regions.shape[1] > 2:
13+
regions = regions[np.argsort(regions[:, 2])]
14+
else:
15+
regions = regions[np.argsort(regions[:, 0])]
16+
17+
# handle eegplot-style regions [.. .. beg end]
18+
if regions.shape[1] > 2:
19+
regions = regions[:, 2:4]
20+
21+
regions = _combine_regions(regions)
22+
23+
# remove events that fall within any region, except boundary events
24+
events = list(EEG.get("event", []))
25+
if events:
26+
ev_lats = np.array([float(e["latency"]) for e in events])
27+
kill = np.zeros(len(events), dtype=bool)
28+
for beg, end in regions:
29+
kill |= (ev_lats >= beg) & (ev_lats <= end)
30+
bidx = _find_boundary_event_indices(events)
31+
kill[bidx] = False
32+
events = [ev for i, ev in enumerate(events) if not kill[i]]
33+
34+
# call eegrej backend
35+
xdur = float(EEG["xmax"] - EEG["xmin"])
36+
data_out, xmax_rel, event2, boundevents = eegrej(EEG["data"], regions, xdur, events)
37+
38+
# finalize core fields
39+
old_pnts = int(EEG["pnts"])
40+
EEG["data"] = data_out
41+
EEG["pnts"] = int(data_out.shape[1])
42+
EEG["xmax"] = float(EEG["xmax"] + EEG["xmin"])
43+
44+
# insert boundary events into our pruned events, then consistency trims
45+
EEG["event"] = _insert_boundaries(events, old_pnts, regions)
46+
EEG["event"].sort(key=lambda e: e.get("latency", float("inf")))
47+
48+
if len(EEG["event"]) > 1 and EEG["event"][-1].get("latency", 0) - 0.5 > EEG["pnts"] and EEG.get("trials", 1) == 1:
49+
EEG["event"].pop()
50+
51+
# light duplicate cleanup mirroring MATLAB edge cases
52+
if len(EEG["event"]) > 1 and EEG["event"][0].get("latency") == 0:
53+
EEG["event"] = EEG["event"][1:]
54+
if len(EEG["event"]) > 1 and EEG["event"][-1].get("latency") == EEG["pnts"]:
55+
EEG["event"] = EEG["event"][:-1]
56+
if len(EEG["event"]) > 2:
57+
if EEG["event"][-1].get("latency") == EEG["event"][-2].get("latency"):
58+
if EEG["event"][-1].get("type") == EEG["event"][-2].get("type"):
59+
EEG["event"].pop()
60+
61+
return EEG
62+
63+
def _combine_regions(regs):
64+
if len(regs) == 0:
65+
return regs
66+
regs = np.array(sorted(regs.tolist(), key=lambda r: (r[0], r[1])), dtype=np.int64)
67+
merged = [regs[0].tolist()]
68+
for beg, end in regs[1:]:
69+
mbeg, mend = merged[-1]
70+
if beg <= mend + 1:
71+
merged[-1][1] = max(mend, end)
72+
else:
73+
merged.append([beg, end])
74+
newregs = np.asarray(merged, dtype=np.int64)
75+
if newregs.shape[0] != regs.shape[0]:
76+
print("Warning: overlapping regions detected and fixed in eeg_eegrej")
77+
return newregs
78+
79+
def _find_boundary_event_indices(events):
80+
idx = []
81+
for i, ev in enumerate(events):
82+
t = ev.get("type")
83+
if isinstance(t, str) and t.lower() == "boundary":
84+
idx.append(i)
85+
elif isinstance(t, (int, float)) and int(t) == -99:
86+
idx.append(i)
87+
return np.array(idx, dtype=int)
88+
89+
def _insert_boundaries(events, old_pnts, regions):
90+
# Build kept segments in 1-based indices
91+
kept = []
92+
cursor = 1
93+
for beg, end in regions:
94+
if cursor <= beg - 1:
95+
kept.append([cursor, beg - 1])
96+
cursor = end + 1
97+
if cursor <= old_pnts:
98+
kept.append([cursor, old_pnts])
99+
100+
out = [dict(ev) for ev in events]
101+
run_len = 0
102+
for i in range(len(kept) - 1):
103+
seg_len = kept[i][1] - kept[i][0] + 1
104+
run_len += seg_len
105+
rem_beg, rem_end = regions[i]
106+
rem_len = int(rem_end - rem_beg + 1)
107+
out.append({
108+
"type": "boundary",
109+
"latency": float(run_len + 1),
110+
"duration": float(rem_len),
111+
})
112+
return out

src/eegprep/eegrej.py

Lines changed: 141 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,37 @@
11
import numpy as np
2+
from typing import List, Dict, Optional, Tuple
23

3-
def eegrej(indata, regions, timelength, eventlatencies=None):
4+
5+
def _is_boundary_event(event: Dict) -> bool:
6+
t = event.get("type")
7+
if isinstance(t, str):
8+
return t.lower() == "boundary"
9+
if isinstance(t, (int, float)):
10+
try:
11+
return int(t) == -99
12+
except Exception:
13+
return False
14+
return False
15+
16+
17+
def eegrej(indata, regions, timelength, events: Optional[List[Dict]] = None) -> Tuple[np.ndarray, float, List[Dict], np.ndarray]:
418
"""
5-
Remove [beg end] sample ranges (1-based, inclusive) from continuous data.
19+
Remove [beg end] sample ranges (1-based, inclusive) from continuous data
20+
and update events (list of dictionaries) in the MATLAB EEGLAB style.
621
722
Inputs
8-
indata: 2D array shaped (channels, frames)
9-
regions: array-like with shape (n_regions, 2), 1-based [beg end] per row
10-
timelength: total duration of the original data in seconds
11-
eventlatencies: iterable of event latencies in samples (1-based). Optional.
23+
- indata: 2D array shaped (channels, frames)
24+
- regions: array-like with shape (n_regions, 2), 1-based [beg end] per row
25+
- timelength: total duration of the original data in seconds
26+
- events: list of dicts with at least key 'latency'; optional keys include
27+
'type' and 'duration'. If None or empty, boundary events will
28+
still be inserted based on regions.
1229
1330
Returns
14-
outdata: data with columns removed
15-
newt: new total time in seconds
16-
newevents: adjusted event latencies (NaN for events inside removed regions)
17-
boundevents: boundary latencies (float, 1-based, with +0.5 convention)
31+
- outdata: data with columns removed
32+
- newt: new total time in seconds
33+
- events_out: updated events list of dictionaries (with inserted boundaries)
34+
- boundevents: boundary latencies (float, 1-based, with +0.5 convention)
1835
"""
1936
x = np.asarray(indata)
2037
if x.ndim != 2:
@@ -23,91 +40,158 @@ def eegrej(indata, regions, timelength, eventlatencies=None):
2340

2441
r = np.asarray(regions, dtype=float)
2542
if r.size == 0:
26-
# nothing to remove
27-
newx = x
28-
newt = float(timelength)
29-
if eventlatencies is None:
30-
newevents = None
31-
else:
32-
newevents = np.asarray(eventlatencies, dtype=float)
43+
# nothing to remove; still ensure events sorted and valid
44+
events_out = [] if events is None else [dict(ev) for ev in events]
45+
# Sort events by latency if present
46+
if events_out and all("latency" in ev for ev in events_out):
47+
events_out.sort(key=lambda ev: ev.get("latency", float("inf")))
3348
boundevents = np.array([], dtype=float)
34-
return newx, newt, newevents, boundevents
49+
return x, float(timelength), events_out, boundevents
3550

3651
if r.ndim != 2 or r.shape[1] != 2:
3752
raise ValueError("regions must be of shape (n_regions, 2)")
3853

39-
# Round, clamp to [1, n], sort each row then sort rows
54+
# Round, clamp to [1, n], sort each row then sort rows (EEGLAB parity)
4055
r = np.rint(r).astype(int)
4156
r[:, 0] = np.clip(r[:, 0], 1, n)
4257
r[:, 1] = np.clip(r[:, 1], 1, n)
4358
r.sort(axis=1)
4459
r = r[np.lexsort((r[:, 1], r[:, 0]))]
4560

46-
# Enforce non-overlap by shifting starts forward
61+
# Enforce non-overlap by shifting starts forward (like MATLAB)
4762
for i in range(1, r.shape[0]):
4863
if r[i - 1, 1] >= r[i, 0]:
4964
r[i, 0] = r[i - 1, 1] + 1
5065
# Drop empty or inverted regions after adjustment
5166
r = r[r[:, 0] <= r[:, 1]]
5267
if r.size == 0:
53-
newx = x
54-
newt = float(timelength)
55-
if eventlatencies is None:
56-
newevents = None
57-
else:
58-
newevents = np.asarray(eventlatencies, dtype=float)
68+
events_out = [] if events is None else [dict(ev) for ev in events]
69+
if events_out and all("latency" in ev for ev in events_out):
70+
events_out.sort(key=lambda ev: ev.get("latency", float("inf")))
5971
boundevents = np.array([], dtype=float)
60-
return newx, newt, newevents, boundevents
72+
return x, float(timelength), events_out, boundevents
6173

6274
# Build reject mask (convert 1-based to 0-based slices)
6375
reject = np.zeros(n, dtype=bool)
6476
for beg, end in r:
6577
reject[beg - 1:end] = True
6678

67-
# Recompute event latencies
68-
if eventlatencies is None:
69-
newevents = None
70-
rejected_events_masks = None
71-
else:
72-
ev = np.asarray(eventlatencies, dtype=float).copy()
73-
newevents = ev.copy()
74-
durations = (r[:, 1] - r[:, 0] + 1)
75-
# Mark events inside any region as NaN
76-
inside = np.zeros(ev.shape, dtype=bool)
77-
for beg, end in r:
78-
inside |= (ev >= beg) & (ev <= end)
79-
newevents[inside] = np.nan
80-
# Shift remaining events left by total removed samples preceding them
81-
# Use original ev for comparisons, per EEGLAB behavior
79+
# Prepare events
80+
ori_events: List[Dict] = [] if events is None else [dict(ev) for ev in events]
81+
events_out: List[Dict] = [dict(ev) for ev in ori_events]
82+
83+
# Recompute event latencies (if events have 'latency') and remove events strictly inside regions
84+
if events_out and all("latency" in ev for ev in events_out):
85+
ori_lat = np.array([float(ev.get("latency", float("nan"))) for ev in events_out], dtype=float)
86+
lat = ori_lat.copy()
87+
rejected_per_region: List[List[int]] = []
8288
for beg, end in r:
83-
span = end - beg + 1
84-
affected = ev > end # original latency strictly after the region
85-
newevents[affected] -= span
89+
# indices strictly inside (beg, end)
90+
rej_idx = np.where((ori_lat > beg) & (ori_lat < end))[0].tolist()
91+
rejected_per_region.append(rej_idx)
92+
# subtract span from latencies whose original latency is strictly after region start
93+
span = int(end - beg + 1)
94+
lat[ori_lat > beg] -= span
95+
96+
# Apply updated latencies
97+
for i, ev in enumerate(events_out):
98+
ev["latency"] = float(lat[i])
99+
100+
# Remove events inside rejected regions
101+
rm_idx = sorted(set(idx for group in rejected_per_region for idx in group))
102+
if rm_idx:
103+
keep_mask = np.ones(len(events_out), dtype=bool)
104+
keep_mask[rm_idx] = False
105+
events_out = [ev for j, ev in enumerate(events_out) if keep_mask[j]]
86106

87-
# Boundary latencies: start-1, then account for prior removals, then +0.5
88-
durations = (r[:, 1] - r[:, 0] + 1).astype(int)
107+
# Boundary latencies: start-1, then subtract cumulative prior durations, then +0.5
108+
base_durations = (r[:, 1] - r[:, 0] + 1).astype(int)
109+
110+
# If we have original events and they include type/duration, add nested boundary durations
111+
durations = base_durations.astype(float).copy()
112+
if ori_events and all("latency" in ev for ev in ori_events):
113+
ori_lat = np.array([float(ev.get("latency", float("nan"))) for ev in ori_events], dtype=float)
114+
for i_region, (beg, end) in enumerate(r):
115+
inside_mask = (ori_lat > beg) & (ori_lat < end)
116+
selected_events = [ori_events[i] for i, m in enumerate(inside_mask) if m]
117+
extra = 0.0
118+
for ev in selected_events:
119+
if _is_boundary_event(ev):
120+
extra += float(ev.get("duration", 0.0) or 0.0)
121+
durations[i_region] += extra
122+
123+
# Compute boundevents considering prior removals
89124
boundevents = r[:, 0].astype(float) - 1.0
90-
# subtract cumulative durations of earlier regions
91-
cums = np.concatenate([[0], np.cumsum(durations[:-1])]).astype(float)
92-
boundevents = boundevents - cums
125+
if len(durations) > 1:
126+
cums = np.concatenate([[0.0], np.cumsum(durations[:-1])])
127+
boundevents = boundevents - cums
93128
boundevents = boundevents + 0.5
94129
boundevents = boundevents[boundevents >= 0]
95130

96131
# Excise samples
97132
newx = x[:, ~reject]
98-
newn = newx.shape[1]
133+
newn = int(newx.shape[1])
99134

100135
# Update total time proportionally
101136
newt = float(timelength) * (newn / float(n))
102137

103138
# Remove boundary events that would fall exactly after the last sample + 0.5
104139
boundevents = boundevents[boundevents < (newn + 1)]
105140

106-
# Merge duplicate boundary latencies (rare after de-overlap, but keep parity with EEGLAB)
141+
# Merge duplicate boundary latencies and sum durations for duplicates
107142
if boundevents.size:
108-
be = boundevents
109-
# Since we do not track duration objects here, just unique latencies preserving order
110-
_, idx = np.unique(np.round(be, 12), return_index=True)
111-
boundevents = be[np.sort(idx)]
143+
rounded = np.round(boundevents, 12)
144+
merged_be: List[float] = []
145+
merged_du: List[float] = []
146+
for i, be in enumerate(rounded):
147+
if not merged_be:
148+
merged_be.append(be)
149+
merged_du.append(float(durations[i]))
150+
else:
151+
if np.isclose(be, merged_be[-1]):
152+
merged_du[-1] += float(durations[i])
153+
else:
154+
merged_be.append(be)
155+
merged_du.append(float(durations[i]))
156+
boundevents = np.asarray(merged_be, dtype=float)
157+
durations = np.asarray(merged_du, dtype=float)
158+
else:
159+
durations = np.asarray([], dtype=float)
160+
161+
# Insert boundary events into events list only if input events were provided
162+
if ori_events:
163+
bound_type = "boundary"
164+
for i in range(len(boundevents)):
165+
be = float(boundevents[i])
166+
if be > 0 and be < (newn + 1):
167+
events_out.append({
168+
"type": bound_type,
169+
"latency": be,
170+
"duration": float(durations[i] if i < len(durations) else (base_durations[i] if i < len(base_durations) else 0.0)),
171+
})
172+
173+
# Remove events with latency out of bound (> newn+1)
174+
filtered: List[Dict] = []
175+
for ev in events_out:
176+
latv = float(ev.get("latency", float("inf")))
177+
if latv <= (newn + 1):
178+
filtered.append(ev)
179+
events_out = filtered
180+
181+
# Sort by latency
182+
events_out.sort(key=lambda ev: ev.get("latency", float("inf")))
183+
184+
# Handle contiguous boundary events with same latency: merge durations
185+
if events_out:
186+
merged_events: List[Dict] = []
187+
for ev in events_out:
188+
if merged_events and _is_boundary_event(ev) and _is_boundary_event(merged_events[-1]) \
189+
and np.isclose(float(ev.get("latency", 0.0)), float(merged_events[-1].get("latency", 0.0))):
190+
prev_dur = float(merged_events[-1].get("duration", 0.0) or 0.0)
191+
cur_dur = float(ev.get("duration", 0.0) or 0.0)
192+
merged_events[-1]["duration"] = prev_dur + cur_dur
193+
else:
194+
merged_events.append(ev)
195+
events_out = merged_events
112196

113-
return newx, newt, newevents, boundevents
197+
return newx, newt, events_out, boundevents

0 commit comments

Comments
 (0)