1
+ import sys
2
+ import math
3
+ from collections .abc import Sequence
4
+ import numpy as np
5
+
6
+ def eeg_compare (eeg1 , eeg2 ):
7
+
8
+ def isequaln (a , b ):
9
+ """Treat None and NaN as equal, otherwise compare by value."""
10
+ # both None
11
+ if a is None and b is None :
12
+ return True
13
+ # None vs NaN
14
+ if a is None and isinstance (b , float ) and math .isnan (b ):
15
+ return True
16
+ if b is None and isinstance (a , float ) and math .isnan (a ):
17
+ return True
18
+ # both NaN
19
+ if isinstance (a , float ) and isinstance (b , float ) and math .isnan (a ) and math .isnan (b ):
20
+ return True
21
+ # arrays with NaN
22
+ if isinstance (a , np .ndarray ) or isinstance (b , np .ndarray ):
23
+ try :
24
+ return bool (np .array_equal (np .array (a ), np .array (b ), equal_nan = True ))
25
+ except :
26
+ pass
27
+ # Handle numpy arrays in general comparison
28
+ if isinstance (a , np .ndarray ) and isinstance (b , np .ndarray ):
29
+ try :
30
+ return bool (np .array_equal (a , b , equal_nan = True ))
31
+ except :
32
+ pass
33
+ # Handle scalar vs array comparisons
34
+ if isinstance (a , np .ndarray ) and np .isscalar (b ):
35
+ try :
36
+ return bool (np .all (a == b ))
37
+ except :
38
+ pass
39
+ if isinstance (b , np .ndarray ) and np .isscalar (a ):
40
+ try :
41
+ return bool (np .all (b == a ))
42
+ except :
43
+ pass
44
+ # Final comparison - ensure we return a boolean
45
+ try :
46
+ result = a == b
47
+ if isinstance (result , np .ndarray ):
48
+ return bool (result .all ())
49
+ return bool (result )
50
+ except :
51
+ return False
52
+
53
+ """Compare two EEG-like structures, reporting differences to stderr."""
54
+ print ('\n Field analysis: (no entries means OK)' )
55
+ # Handle both dictionary-like objects and objects with __dict__
56
+ if hasattr (eeg1 , 'keys' ):
57
+ # Dictionary-like object
58
+ fields1 = eeg1 .keys ()
59
+ get_val1 = lambda f : eeg1 .get (f , None )
60
+ has_field2 = lambda f : f in eeg2
61
+ get_val2 = lambda f : eeg2 .get (f , None )
62
+ else :
63
+ # Object with __dict__
64
+ fields1 = getattr (eeg1 , '__dict__' , {}).keys ()
65
+ get_val1 = lambda f : getattr (eeg1 , f , None )
66
+ has_field2 = lambda f : hasattr (eeg2 , f )
67
+ get_val2 = lambda f : getattr (eeg2 , f , None )
68
+
69
+ for field in fields1 :
70
+ if not has_field2 (field ):
71
+ print (f' Field { field } missing in second dataset' , file = sys .stderr )
72
+ else :
73
+ v1 = get_val1 (field )
74
+ v2 = get_val2 (field )
75
+ if not isequaln (v1 , v2 ):
76
+ name = field .lower ()
77
+ if any (sub in name for sub in ('filename' , 'datfile' )):
78
+ print (f' Field { field } differs (ok, supposed to differ)' )
79
+ elif any (sub in name for sub in ('subject' , 'session' , 'run' , 'task' )):
80
+ print (f' Field { field } differs ("{ v1 } " vs "{ v2 } ")' , file = sys .stderr )
81
+ elif any (sub in name for sub in ('chanlocs' , 'event' , 'reject' )):
82
+ pass
83
+ # For complex nested structures, provide more detailed info
84
+ elif any (sub in name for sub in ('eventdescription' )):
85
+ n1 = len (v1 ) if isinstance (v1 , Sequence ) else 1
86
+ n2 = len (v2 ) if isinstance (v2 , Sequence ) else 1
87
+ print (f' Field { field } differs (n={ n1 } vs n={ n2 } )' , file = sys .stderr )
88
+ else :
89
+ print (f' Field { field } differs' , file = sys .stderr )
90
+ # compare xmin/xmax
91
+ for attr in ('xmin' , 'xmax' ):
92
+ x1 = get_val1 (attr )
93
+ x2 = get_val2 (attr )
94
+ if not isequaln (x1 , x2 ):
95
+ diff = (x1 or 0 ) - (x2 or 0 )
96
+ print (f' Difference between { attr } is { diff :1.6f} sec' , file = sys .stderr )
97
+
98
+ # channel locations
99
+ print ('Chanlocs analysis:' )
100
+ chans1 = eeg1 ['chanlocs' ] # need to fuse with chaninfo
101
+ chans2 = eeg2 ['chanlocs' ] # need to fuse with chaninfo
102
+ if len (chans1 ) == len (chans2 ):
103
+ coord_diff = label_diff = 0
104
+ for c1 , c2 in zip (chans1 , chans2 ):
105
+ c1_xyz = (c1 ['X' ], c1 ['Y' ], c1 ['Z' ])
106
+ c2_xyz = (c2 ['X' ], c2 ['Y' ], c2 ['Z' ])
107
+ if (any (v is None for v in c1_xyz ) and not any (v is None for v in c2_xyz )) \
108
+ or (any (v is None for v in c2_xyz ) and not any (v is None for v in c1_xyz )) \
109
+ or (all (v is not None for v in (* c1_xyz ,)) and
110
+ sum (abs (a - b ) for a , b in zip (c1_xyz , c2_xyz )) > 1e-12 ):
111
+ coord_diff += 1
112
+ if c1 ['labels' ] != c2 ['labels' ]:
113
+ label_diff += 1
114
+ if coord_diff :
115
+ print (f' { coord_diff } channel coordinates differ' , file = sys .stderr )
116
+ else :
117
+ print (' All channel coordinates are OK' )
118
+ if label_diff :
119
+ print (f' { label_diff } channel label(s) differ' , file = sys .stderr )
120
+ else :
121
+ print (' All channel labels are OK' )
122
+ else :
123
+ print (' Different numbers of channels' , file = sys .stderr )
124
+
125
+ # events
126
+ print ('Event analysis:' )
127
+ ev1 , ev2 = eeg1 ['event' ], eeg2 ['event' ]
128
+ if len (ev1 ) != len (ev2 ):
129
+ print (' Different numbers of events' , file = sys .stderr )
130
+ else :
131
+ f1 = set (ev1 [0 ].keys ())
132
+ f2 = set (ev2 [0 ].keys ())
133
+ if f1 != f2 :
134
+ print (' Not the same number of event fields' , file = sys .stderr )
135
+ for fld in f1 :
136
+ diffs = []
137
+ if fld .lower () == 'latency' :
138
+ diffs = [e1 ['latency' ] - e2 ['latency' ] for e1 , e2 in zip (ev1 , ev2 )]
139
+ nonzero = [d for d in diffs if d != 0 ]
140
+ if nonzero :
141
+ pct = len (nonzero ) / len (diffs ) * 100
142
+ avg = sum (abs (d ) for d in nonzero ) / len (nonzero )
143
+ print (f' Event latency ({ pct :2.1f} %) not OK (abs diff { avg :1.4f} samples)' , file = sys .stderr )
144
+ # print(' ******** (see plot)')
145
+ # import matplotlib.pyplot as plt
146
+ # plt.plot(diffs)
147
+ # plt.show()
148
+ else :
149
+ diffs = [not isequaln (getattr (e1 , fld , None ), getattr (e2 , fld , None )) for e1 , e2 in zip (ev1 , ev2 )]
150
+ if any (diffs ):
151
+ pct = sum (diffs ) / len (diffs ) * 100
152
+ print (f' Event fields "{ fld } " are NOT OK ({ pct :2.1f} % of them)' , file = sys .stderr )
153
+ print (' All other events OK' )
154
+
155
+ # epochs
156
+ # if 'epoch' in eeg1:
157
+ # print('Epoch analysis:')
158
+ # ep1, ep2 = eeg1['epoch'], eeg2['epoch']
159
+ # if len(ep1) != len(ep2):
160
+ # print(' Different numbers of epochs', file=sys.stderr)
161
+ # else:
162
+ # fields = ep1[0].keys()
163
+ # all_ok = True
164
+ # for fld in fields:
165
+ # diffs = [not isequaln(getattr(e1, fld, None), getattr(e2, fld, None)) for e1, e2 in zip(ep1, ep2)]
166
+ # if any(diffs):
167
+ # pct = sum(diffs) / len(diffs) * 100
168
+ # print(f' Epoch fields "{fld}" are NOT OK ({pct:2.1f} % of them)', file=sys.stderr)
169
+ # all_ok = False
170
+ # if all_ok:
171
+ # print(' All epoch and all epoch fields are OK')
172
+
173
+ return True
174
+
175
+ # add test data and compare with it
176
+
177
+ # load test data
178
+ if __name__ == '__main__' :
179
+ from eegprep import pop_loadset
180
+ eeg1 = pop_loadset ('../../data/eeglab_data_tmp.set' )
181
+ eeg2 = pop_loadset ('../../data/eeglab_data_tmp.set' )
182
+
183
+ # compare
184
+ eeg_compare (eeg1 , eeg2 )
0 commit comments