1
1
import numpy as np
2
+ from typing import List , Dict , Optional , Tuple
2
3
3
- def eegrej (indata , regions , timelength , eventlatencies = None ):
4
+
5
+ def _is_boundary_event (event : Dict ) -> bool :
6
+ t = event .get ("type" )
7
+ if isinstance (t , str ):
8
+ return t .lower () == "boundary"
9
+ if isinstance (t , (int , float )):
10
+ try :
11
+ return int (t ) == - 99
12
+ except Exception :
13
+ return False
14
+ return False
15
+
16
+
17
+ def eegrej (indata , regions , timelength , events : Optional [List [Dict ]] = None ) -> Tuple [np .ndarray , float , List [Dict ], np .ndarray ]:
4
18
"""
5
- Remove [beg end] sample ranges (1-based, inclusive) from continuous data.
19
+ Remove [beg end] sample ranges (1-based, inclusive) from continuous data
20
+ and update events (list of dictionaries) in the MATLAB EEGLAB style.
6
21
7
22
Inputs
8
- indata: 2D array shaped (channels, frames)
9
- regions: array-like with shape (n_regions, 2), 1-based [beg end] per row
10
- timelength: total duration of the original data in seconds
11
- eventlatencies: iterable of event latencies in samples (1-based). Optional.
23
+ - indata: 2D array shaped (channels, frames)
24
+ - regions: array-like with shape (n_regions, 2), 1-based [beg end] per row
25
+ - timelength: total duration of the original data in seconds
26
+ - events: list of dicts with at least key 'latency'; optional keys include
27
+ 'type' and 'duration'. If None or empty, boundary events will
28
+ still be inserted based on regions.
12
29
13
30
Returns
14
- outdata: data with columns removed
15
- newt: new total time in seconds
16
- newevents: adjusted event latencies (NaN for events inside removed regions )
17
- boundevents: boundary latencies (float, 1-based, with +0.5 convention)
31
+ - outdata: data with columns removed
32
+ - newt: new total time in seconds
33
+ - events_out: updated events list of dictionaries (with inserted boundaries )
34
+ - boundevents: boundary latencies (float, 1-based, with +0.5 convention)
18
35
"""
19
36
x = np .asarray (indata )
20
37
if x .ndim != 2 :
@@ -23,91 +40,158 @@ def eegrej(indata, regions, timelength, eventlatencies=None):
23
40
24
41
r = np .asarray (regions , dtype = float )
25
42
if r .size == 0 :
26
- # nothing to remove
27
- newx = x
28
- newt = float (timelength )
29
- if eventlatencies is None :
30
- newevents = None
31
- else :
32
- newevents = np .asarray (eventlatencies , dtype = float )
43
+ # nothing to remove; still ensure events sorted and valid
44
+ events_out = [] if events is None else [dict (ev ) for ev in events ]
45
+ # Sort events by latency if present
46
+ if events_out and all ("latency" in ev for ev in events_out ):
47
+ events_out .sort (key = lambda ev : ev .get ("latency" , float ("inf" )))
33
48
boundevents = np .array ([], dtype = float )
34
- return newx , newt , newevents , boundevents
49
+ return x , float ( timelength ), events_out , boundevents
35
50
36
51
if r .ndim != 2 or r .shape [1 ] != 2 :
37
52
raise ValueError ("regions must be of shape (n_regions, 2)" )
38
53
39
- # Round, clamp to [1, n], sort each row then sort rows
54
+ # Round, clamp to [1, n], sort each row then sort rows (EEGLAB parity)
40
55
r = np .rint (r ).astype (int )
41
56
r [:, 0 ] = np .clip (r [:, 0 ], 1 , n )
42
57
r [:, 1 ] = np .clip (r [:, 1 ], 1 , n )
43
58
r .sort (axis = 1 )
44
59
r = r [np .lexsort ((r [:, 1 ], r [:, 0 ]))]
45
60
46
- # Enforce non-overlap by shifting starts forward
61
+ # Enforce non-overlap by shifting starts forward (like MATLAB)
47
62
for i in range (1 , r .shape [0 ]):
48
63
if r [i - 1 , 1 ] >= r [i , 0 ]:
49
64
r [i , 0 ] = r [i - 1 , 1 ] + 1
50
65
# Drop empty or inverted regions after adjustment
51
66
r = r [r [:, 0 ] <= r [:, 1 ]]
52
67
if r .size == 0 :
53
- newx = x
54
- newt = float (timelength )
55
- if eventlatencies is None :
56
- newevents = None
57
- else :
58
- newevents = np .asarray (eventlatencies , dtype = float )
68
+ events_out = [] if events is None else [dict (ev ) for ev in events ]
69
+ if events_out and all ("latency" in ev for ev in events_out ):
70
+ events_out .sort (key = lambda ev : ev .get ("latency" , float ("inf" )))
59
71
boundevents = np .array ([], dtype = float )
60
- return newx , newt , newevents , boundevents
72
+ return x , float ( timelength ), events_out , boundevents
61
73
62
74
# Build reject mask (convert 1-based to 0-based slices)
63
75
reject = np .zeros (n , dtype = bool )
64
76
for beg , end in r :
65
77
reject [beg - 1 :end ] = True
66
78
67
- # Recompute event latencies
68
- if eventlatencies is None :
69
- newevents = None
70
- rejected_events_masks = None
71
- else :
72
- ev = np .asarray (eventlatencies , dtype = float ).copy ()
73
- newevents = ev .copy ()
74
- durations = (r [:, 1 ] - r [:, 0 ] + 1 )
75
- # Mark events inside any region as NaN
76
- inside = np .zeros (ev .shape , dtype = bool )
77
- for beg , end in r :
78
- inside |= (ev >= beg ) & (ev <= end )
79
- newevents [inside ] = np .nan
80
- # Shift remaining events left by total removed samples preceding them
81
- # Use original ev for comparisons, per EEGLAB behavior
79
+ # Prepare events
80
+ ori_events : List [Dict ] = [] if events is None else [dict (ev ) for ev in events ]
81
+ events_out : List [Dict ] = [dict (ev ) for ev in ori_events ]
82
+
83
+ # Recompute event latencies (if events have 'latency') and remove events strictly inside regions
84
+ if events_out and all ("latency" in ev for ev in events_out ):
85
+ ori_lat = np .array ([float (ev .get ("latency" , float ("nan" ))) for ev in events_out ], dtype = float )
86
+ lat = ori_lat .copy ()
87
+ rejected_per_region : List [List [int ]] = []
82
88
for beg , end in r :
83
- span = end - beg + 1
84
- affected = ev > end # original latency strictly after the region
85
- newevents [affected ] -= span
89
+ # indices strictly inside (beg, end)
90
+ rej_idx = np .where ((ori_lat > beg ) & (ori_lat < end ))[0 ].tolist ()
91
+ rejected_per_region .append (rej_idx )
92
+ # subtract span from latencies whose original latency is strictly after region start
93
+ span = int (end - beg + 1 )
94
+ lat [ori_lat > beg ] -= span
95
+
96
+ # Apply updated latencies
97
+ for i , ev in enumerate (events_out ):
98
+ ev ["latency" ] = float (lat [i ])
99
+
100
+ # Remove events inside rejected regions
101
+ rm_idx = sorted (set (idx for group in rejected_per_region for idx in group ))
102
+ if rm_idx :
103
+ keep_mask = np .ones (len (events_out ), dtype = bool )
104
+ keep_mask [rm_idx ] = False
105
+ events_out = [ev for j , ev in enumerate (events_out ) if keep_mask [j ]]
86
106
87
- # Boundary latencies: start-1, then account for prior removals, then +0.5
88
- durations = (r [:, 1 ] - r [:, 0 ] + 1 ).astype (int )
107
+ # Boundary latencies: start-1, then subtract cumulative prior durations, then +0.5
108
+ base_durations = (r [:, 1 ] - r [:, 0 ] + 1 ).astype (int )
109
+
110
+ # If we have original events and they include type/duration, add nested boundary durations
111
+ durations = base_durations .astype (float ).copy ()
112
+ if ori_events and all ("latency" in ev for ev in ori_events ):
113
+ ori_lat = np .array ([float (ev .get ("latency" , float ("nan" ))) for ev in ori_events ], dtype = float )
114
+ for i_region , (beg , end ) in enumerate (r ):
115
+ inside_mask = (ori_lat > beg ) & (ori_lat < end )
116
+ selected_events = [ori_events [i ] for i , m in enumerate (inside_mask ) if m ]
117
+ extra = 0.0
118
+ for ev in selected_events :
119
+ if _is_boundary_event (ev ):
120
+ extra += float (ev .get ("duration" , 0.0 ) or 0.0 )
121
+ durations [i_region ] += extra
122
+
123
+ # Compute boundevents considering prior removals
89
124
boundevents = r [:, 0 ].astype (float ) - 1.0
90
- # subtract cumulative durations of earlier regions
91
- cums = np .concatenate ([[0 ], np .cumsum (durations [:- 1 ])]). astype ( float )
92
- boundevents = boundevents - cums
125
+ if len ( durations ) > 1 :
126
+ cums = np .concatenate ([[0.0 ], np .cumsum (durations [:- 1 ])])
127
+ boundevents = boundevents - cums
93
128
boundevents = boundevents + 0.5
94
129
boundevents = boundevents [boundevents >= 0 ]
95
130
96
131
# Excise samples
97
132
newx = x [:, ~ reject ]
98
- newn = newx .shape [1 ]
133
+ newn = int ( newx .shape [1 ])
99
134
100
135
# Update total time proportionally
101
136
newt = float (timelength ) * (newn / float (n ))
102
137
103
138
# Remove boundary events that would fall exactly after the last sample + 0.5
104
139
boundevents = boundevents [boundevents < (newn + 1 )]
105
140
106
- # Merge duplicate boundary latencies (rare after de-overlap, but keep parity with EEGLAB)
141
+ # Merge duplicate boundary latencies and sum durations for duplicates
107
142
if boundevents .size :
108
- be = boundevents
109
- # Since we do not track duration objects here, just unique latencies preserving order
110
- _ , idx = np .unique (np .round (be , 12 ), return_index = True )
111
- boundevents = be [np .sort (idx )]
143
+ rounded = np .round (boundevents , 12 )
144
+ merged_be : List [float ] = []
145
+ merged_du : List [float ] = []
146
+ for i , be in enumerate (rounded ):
147
+ if not merged_be :
148
+ merged_be .append (be )
149
+ merged_du .append (float (durations [i ]))
150
+ else :
151
+ if np .isclose (be , merged_be [- 1 ]):
152
+ merged_du [- 1 ] += float (durations [i ])
153
+ else :
154
+ merged_be .append (be )
155
+ merged_du .append (float (durations [i ]))
156
+ boundevents = np .asarray (merged_be , dtype = float )
157
+ durations = np .asarray (merged_du , dtype = float )
158
+ else :
159
+ durations = np .asarray ([], dtype = float )
160
+
161
+ # Insert boundary events into events list only if input events were provided
162
+ if ori_events :
163
+ bound_type = "boundary"
164
+ for i in range (len (boundevents )):
165
+ be = float (boundevents [i ])
166
+ if be > 0 and be < (newn + 1 ):
167
+ events_out .append ({
168
+ "type" : bound_type ,
169
+ "latency" : be ,
170
+ "duration" : float (durations [i ] if i < len (durations ) else (base_durations [i ] if i < len (base_durations ) else 0.0 )),
171
+ })
172
+
173
+ # Remove events with latency out of bound (> newn+1)
174
+ filtered : List [Dict ] = []
175
+ for ev in events_out :
176
+ latv = float (ev .get ("latency" , float ("inf" )))
177
+ if latv <= (newn + 1 ):
178
+ filtered .append (ev )
179
+ events_out = filtered
180
+
181
+ # Sort by latency
182
+ events_out .sort (key = lambda ev : ev .get ("latency" , float ("inf" )))
183
+
184
+ # Handle contiguous boundary events with same latency: merge durations
185
+ if events_out :
186
+ merged_events : List [Dict ] = []
187
+ for ev in events_out :
188
+ if merged_events and _is_boundary_event (ev ) and _is_boundary_event (merged_events [- 1 ]) \
189
+ and np .isclose (float (ev .get ("latency" , 0.0 )), float (merged_events [- 1 ].get ("latency" , 0.0 ))):
190
+ prev_dur = float (merged_events [- 1 ].get ("duration" , 0.0 ) or 0.0 )
191
+ cur_dur = float (ev .get ("duration" , 0.0 ) or 0.0 )
192
+ merged_events [- 1 ]["duration" ] = prev_dur + cur_dur
193
+ else :
194
+ merged_events .append (ev )
195
+ events_out = merged_events
112
196
113
- return newx , newt , newevents , boundevents
197
+ return newx , newt , events_out , boundevents
0 commit comments