Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit 964fade

Browse files
committed
clean code. isinstance instead of type
1 parent 2a75b21 commit 964fade

File tree

5 files changed

+54
-56
lines changed

5 files changed

+54
-56
lines changed

wfdb/plot/plots.py

Lines changed: 20 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -57,12 +57,12 @@ def plotrec(record=None, title = None, annotation = None, timeunits='samples', s
5757
siglen, nsig = record.p_signals.shape
5858

5959
# Expand list styles
60-
if type(sigstyle) == str:
60+
if isinstance(sigstyle, str):
6161
sigstyle = [sigstyle]*record.nsig
6262
else:
6363
if len(sigstyle) < record.nsig:
6464
sigstyle = sigstyle+['']*(record.nsig-len(sigstyle))
65-
if type(annstyle) == str:
65+
if isinstance(annstyle, str):
6666
annstyle = [annstyle]*record.nsig
6767
else:
6868
if len(annstyle) < record.nsig:
@@ -182,9 +182,9 @@ def calc_ecg_grids(minsig, maxsig, units, fs, maxt, timeunits):
182182
def checkplotitems(record, title, annotation, timeunits, sigstyle, annstyle):
183183

184184
# signals
185-
if type(record) != records.Record:
185+
if not isinstance(record, records.Record):
186186
raise TypeError("The 'record' argument must be a valid wfdb.Record object")
187-
if type(record.p_signals) != np.ndarray or record.p_signals.ndim != 2:
187+
if not isinstance(record.p_signals, np.ndarray) or record.p_signals.ndim != 2:
188188
raise TypeError("The plotted signal 'record.p_signals' must be a 2d numpy array")
189189

190190
siglen, nsig = record.p_signals.shape
@@ -197,7 +197,7 @@ def checkplotitems(record, title, annotation, timeunits, sigstyle, annstyle):
197197
if timeunits == 'samples':
198198
t = np.linspace(0, siglen-1, siglen)
199199
else:
200-
if type(record.fs) not in _headers.floattypes:
200+
if not isinstance(record.fs, _headers.floattypes):
201201
raise TypeError("The 'fs' field must be a number")
202202

203203
if timeunits == 'seconds':
@@ -211,7 +211,7 @@ def checkplotitems(record, title, annotation, timeunits, sigstyle, annstyle):
211211
if record.units is None:
212212
record.units = ['NU']*nsig
213213
else:
214-
if type(record.units) != list or len(record.units)!= nsig:
214+
if not isinstance(record.units, list) or len(record.units)!= nsig:
215215
raise ValueError("The 'units' parameter must be a list of strings with length equal to the number of signal channels")
216216
for ch in range(nsig):
217217
if record.units[ch] is None:
@@ -221,26 +221,26 @@ def checkplotitems(record, title, annotation, timeunits, sigstyle, annstyle):
221221
if record.signame is None:
222222
record.signame = ['ch'+str(ch) for ch in range(1, nsig+1)]
223223
else:
224-
if type(record.signame) != list or len(record.signame)!= nsig:
224+
if not isinstance(record.signame, list) or len(record.signame)!= nsig:
225225
raise ValueError("The 'signame' parameter must be a list of strings, with length equal to the number of signal channels")
226226

227227
# title
228-
if title is not None and type(title) != str:
228+
if title is not None and not isinstance(title, str):
229229
raise TypeError("The 'title' field must be a string")
230230

231231
# signal line style
232-
if type(sigstyle) == str:
232+
if isinstance(sigstyle, str):
233233
pass
234-
elif type(sigstyle) == list:
234+
elif isinstance(sigstyle, list):
235235
if len(sigstyle) > record.nsig:
236236
raise ValueError("The 'sigstyle' list cannot have more elements than the number of record channels")
237237
else:
238238
raise TypeError("The 'sigstyle' field must be a string or a list of strings")
239239

240240
# annotation plot style
241-
if type(annstyle) == str:
241+
if isinstance(annstyle, str):
242242
pass
243-
elif type(annstyle) == list:
243+
elif isinstance(annstyle, list):
244244
if len(annstyle) > record.nsig:
245245
raise ValueError("The 'annstyle' list cannot have more elements than the number of record channels")
246246
else:
@@ -254,21 +254,21 @@ def checkplotitems(record, title, annotation, timeunits, sigstyle, annstyle):
254254
annplot = [None]*record.nsig
255255

256256
# Move single channel annotations to channel 0
257-
if type(annotation) == annotations.Annotation:
257+
if isinstance(annotation, annotations.Annotation):
258258
annplot[0] = annotation.sample
259-
elif type(annotation) == np.ndarray:
259+
elif isinstance(annotation, np.ndarray):
260260
annplot[0] = annotation
261261
# Ready list.
262-
elif type(annotation) == list:
262+
elif isinstance(annotation, list):
263263
if len(annotation) > record.nsig:
264264
raise ValueError("The number of annotation series to plot cannot be more than the number of channels")
265265
if len(annotation) < record.nsig:
266266
annotation = annotation+[None]*(record.nsig-len(annotation))
267267
# Check elements. Copy over to new list.
268268
for ch in range(record.nsig):
269-
if type(annotation[ch]) == annotations.Annotation:
269+
if isinstance(annotation[ch], annotations.Annotation):
270270
annplot[ch] = annotation[ch].sample
271-
elif type(annotation[ch]) == np.ndarray:
271+
elif isinstance(annotation[ch], np.ndarray):
272272
annplot[ch] = annotation[ch]
273273
elif annotation[ch] is None:
274274
pass
@@ -353,7 +353,7 @@ def plotann(annotation, title = None, timeunits = 'samples', returnfig = False):
353353
def checkannplotitems(annotation, title, timeunits):
354354

355355
# signals
356-
if type(annotation)!= annotations.Annotation:
356+
if not isinstance(annotation, annotations.Annotation):
357357
raise TypeError("The 'annotation' field must be a 'wfdb.Annotation' object")
358358

359359
# fs and timeunits
@@ -363,7 +363,7 @@ def checkannplotitems(annotation, title, timeunits):
363363

364364
# fs must be valid when plotting time
365365
if timeunits != 'samples':
366-
if type(annotation.fs) not in _headers.floattypes:
366+
if not isinstance(annotation.fs, _headers.floattypes):
367367
raise Exception("In order to plot time units, the Annotation object must have a valid 'fs' attribute")
368368

369369
# Get x axis values to plot
@@ -377,7 +377,7 @@ def checkannplotitems(annotation, title, timeunits):
377377
plotvals = annotation.sample/(annotation.fs*3600)
378378

379379
# title
380-
if title is not None and type(title) != str:
380+
if title is not None and not isinstance(title, str):
381381
raise TypeError("The 'title' field must be a string")
382382

383383
return plotvals

wfdb/readwrite/_headers.py

Lines changed: 14 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -228,7 +228,7 @@ def wrheaderfile(self, recwritefields, sigwritefields):
228228
if field in recwritefields:
229229
stringfield = str(getattr(self, field))
230230
# If fs is float, check whether it as an integer
231-
if field == 'fs' and type(self.fs) == float:
231+
if field == 'fs' and isinstance(self.fs, float):
232232
if round(self.fs, 8) == float(int(self.fs)):
233233
stringfield = str(int(self.fs))
234234
recordline = recordline + recfieldspecs[field].delimiter + stringfield
@@ -371,12 +371,12 @@ def getsigsegments(self, signame=None):
371371
if signame is None:
372372
signame = self.getsignames()
373373

374-
if type(signame) == list:
374+
if isinstance(signame, list):
375375
sigdict = {}
376376
for sig in signame:
377377
sigdict[sig] = self.getsigsegments(sig)
378378
return sigdict
379-
elif type(signame) == str:
379+
elif isinstance(signame, str):
380380
sigsegs = []
381381
for i in range(self.nseg):
382382
if self.segname[i] != '~' and signame in self.segments[i].signame:
@@ -594,38 +594,39 @@ def __init__(self, allowedtypes, delimiter, dependency, write_req, read_def, wri
594594
# so that the user doesn't need to. But when reading, it should
595595
# be clear that the fields are missing.
596596

597-
inttypes = [int, np.int64, np.int32, np.int16, np.int8]
598-
floattypes = inttypes + [float, np.float64, np.float32]
597+
inttypes = (int, np.int64, np.int32, np.int16, np.int8)
598+
floattypes = inttypes + (float, np.float64, np.float32)
599+
int_dtypes = ('int64', 'uint64', 'int32', 'uint32','int16','uint16')
599600

600601
# Record specification fields
601-
recfieldspecs = OrderedDict([('recordname', WFDBheaderspecs([str], '', None, True, None, None)),
602+
recfieldspecs = OrderedDict([('recordname', WFDBheaderspecs((str), '', None, True, None, None)),
602603
('nseg', WFDBheaderspecs(inttypes, '/', 'recordname', True, None, None)),
603604
('nsig', WFDBheaderspecs(inttypes, ' ', 'recordname', True, None, None)),
604605
('fs', WFDBheaderspecs(floattypes, ' ', 'nsig', True, 250, None)),
605606
('counterfreq', WFDBheaderspecs(floattypes, '/', 'fs', False, None, None)),
606607
('basecounter', WFDBheaderspecs(floattypes, '(', 'counterfreq', False, None, None)),
607608
('siglen', WFDBheaderspecs(inttypes, ' ', 'fs', True, None, None)),
608-
('basetime', WFDBheaderspecs([str], ' ', 'siglen', False, None, '00:00:00')),
609-
('basedate', WFDBheaderspecs([str], ' ', 'basetime', False, None, None))])
609+
('basetime', WFDBheaderspecs((str), ' ', 'siglen', False, None, '00:00:00')),
610+
('basedate', WFDBheaderspecs((str), ' ', 'basetime', False, None, None))])
610611

611612
# Signal specification fields.
612-
sigfieldspecs = OrderedDict([('filename', WFDBheaderspecs([str], '', None, True, None, None)),
613-
('fmt', WFDBheaderspecs([str], ' ', 'filename', True, None, None)),
613+
sigfieldspecs = OrderedDict([('filename', WFDBheaderspecs((str), '', None, True, None, None)),
614+
('fmt', WFDBheaderspecs((str), ' ', 'filename', True, None, None)),
614615
('sampsperframe', WFDBheaderspecs(inttypes, 'x', 'fmt', False, 1, None)),
615616
('skew', WFDBheaderspecs(inttypes, ':', 'fmt', False, None, None)),
616617
('byteoffset', WFDBheaderspecs(inttypes, '+', 'fmt', False, None, None)),
617618
('adcgain', WFDBheaderspecs(floattypes, ' ', 'fmt', True, 200., None)),
618619
('baseline', WFDBheaderspecs(inttypes, '(', 'adcgain', True, 0, None)),
619-
('units', WFDBheaderspecs([str], '/', 'adcgain', True, 'mV', None)),
620+
('units', WFDBheaderspecs((str), '/', 'adcgain', True, 'mV', None)),
620621
('adcres', WFDBheaderspecs(inttypes, ' ', 'adcgain', False, None, 0)),
621622
('adczero', WFDBheaderspecs(inttypes, ' ', 'adcres', False, None, 0)),
622623
('initvalue', WFDBheaderspecs(inttypes, ' ', 'adczero', False, None, None)),
623624
('checksum', WFDBheaderspecs(inttypes, ' ', 'initvalue', False, None, None)),
624625
('blocksize', WFDBheaderspecs(inttypes, ' ', 'checksum', False, None, 0)),
625-
('signame', WFDBheaderspecs([str], ' ', 'blocksize', False, None, None))])
626+
('signame', WFDBheaderspecs((str), ' ', 'blocksize', False, None, None))])
626627

627628
# Segment specification fields.
628-
segfieldspecs = OrderedDict([('segname', WFDBheaderspecs([str], '', None, True, None, None)),
629+
segfieldspecs = OrderedDict([('segname', WFDBheaderspecs((str), '', None, True, None, None)),
629630
('seglen', WFDBheaderspecs(inttypes, ' ', 'segname', True, None, None))])
630631

631632

wfdb/readwrite/_signals.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -967,7 +967,7 @@ def skewsig(sig, skew, nsig, readlen, fmt, nanreplace, sampsperframe=None):
967967
if max(skew)>0:
968968

969969
# Expanded frame samples. List of arrays.
970-
if type(sig) == list:
970+
if isinstance(sig, list):
971971
# Shift the channel samples
972972
for ch in range(nsig):
973973
if skew[ch]>0:
@@ -1000,7 +1000,7 @@ def skewsig(sig, skew, nsig, readlen, fmt, nanreplace, sampsperframe=None):
10001000

10011001
# Integrity check of signal shape after reading
10021002
def checksigdims(sig, readlen, nsig, sampsperframe):
1003-
if type(sig) == np.ndarray:
1003+
if isinstance(sig, np.ndarray):
10041004
if sig.shape != (readlen, nsig):
10051005
raise ValueError('Samples were not loaded correctly')
10061006
else:
@@ -1027,7 +1027,7 @@ def checksigdims(sig, readlen, nsig, sampsperframe):
10271027

10281028
# Return min and max digital values for each format type. Accepts lists.
10291029
def digi_bounds(fmt):
1030-
if type(fmt) == list:
1030+
if isinstance(fmt, list):
10311031
digibounds = []
10321032
for f in fmt:
10331033
digibounds.append(digi_bounds(f))
@@ -1046,7 +1046,7 @@ def digi_bounds(fmt):
10461046

10471047
# Return nan value for the format type(s).
10481048
def digi_nan(fmt):
1049-
if type(fmt) == list:
1049+
if isinstance(fmt, list):
10501050
diginans = []
10511051
for f in fmt:
10521052
diginans.append(digi_nan(f))
@@ -1085,7 +1085,7 @@ def estres(signals):
10851085
"""
10861086

10871087
# Expanded sample signals. List of numpy arrays
1088-
if type(signals) == list:
1088+
if isinstance(signals, list):
10891089
nsig = len(signals)
10901090
# Uniform numpy array
10911091
else:
@@ -1097,7 +1097,7 @@ def estres(signals):
10971097

10981098
for ch in range(nsig):
10991099
# Estimate the number of steps as the range divided by the minimum increment.
1100-
if type(signals) == list:
1100+
if isinstance(signals, list):
11011101
sortedsig = np.sort(signals[ch])
11021102
else:
11031103
sortedsig = np.sort(signals[:,ch])
@@ -1121,7 +1121,7 @@ def estres(signals):
11211121
# If singlefmt is True, the format for the maximum resolution will be returned.
11221122
def wfdbfmt(res, singlefmt = True):
11231123

1124-
if type(res) == list:
1124+
if isinstance(res, list):
11251125
# Return a single format
11261126
if singlefmt is True:
11271127
res = [max(res)]*len(res)
@@ -1145,7 +1145,7 @@ def wfdbfmt(res, singlefmt = True):
11451145
# Return the resolution of the WFDB format(s).
11461146
def wfdbfmtres(fmt, maxres=False):
11471147

1148-
if type(fmt)==list:
1148+
if isinstance(fmt, list):
11491149
res = [wfdbfmtres(f) for f in fmt]
11501150
if maxres is True:
11511151
res = np.max(res)

wfdb/readwrite/annotations.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -177,7 +177,7 @@ def checkfield(self, field):
177177
raise TypeError('The '+field+' field must be one of the following types:', ann_field_types[field])
178178

179179
if field in int_ann_fields:
180-
if item.dtype not in int_dtypes:
180+
if item.dtype not in _headers.int_dtypes:
181181
raise TypeError('The '+field+' field must have an integer-based dtype.')
182182

183183
# Field specific checks
@@ -485,7 +485,7 @@ def get_custom_label_attribute(self, attribute):
485485
if attribute not in ann_label_fields:
486486
raise ValueError('Invalid attribute specified')
487487

488-
if type(self.custom_labels) == pd.DataFrame:
488+
if isinstance(self.custom_labels, pd.DataFrame):
489489
if 'label_store' not in list(self.custom_labels):
490490
raise ValueError('label_store not defined in custom_labels')
491491
a = list(self.custom_labels[attribute].values)
@@ -1449,13 +1449,10 @@ def lists_to_arrays(*args):
14491449
# Allowed types of each Annotation object attribute.
14501450
ann_field_types = {'recordname': (str), 'extension': (str), 'sample': (np.ndarray),
14511451
'symbol': (list, np.ndarray), 'subtype': (np.ndarray), 'chan': (np.ndarray),
1452-
'num': (np.ndarray), 'aux_note': (list, np.ndarray), 'fs': tuple(_headers.floattypes),
1452+
'num': (np.ndarray), 'aux_note': (list, np.ndarray), 'fs': _headers.floattypes,
14531453
'label_store': (np.ndarray), 'description':(list, np.ndarray), 'custom_labels': (pd.DataFrame, list, tuple),
14541454
'contained_labels':(pd.DataFrame, list, tuple)}
14551455

1456-
# Acceptable numpy integer dtypes
1457-
int_dtypes = ('int64', 'uint64', 'int32', 'uint32','int16','uint16')
1458-
14591456
strtypes = (str, np.str_)
14601457

14611458
# Elements of the annotation label

wfdb/readwrite/records.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -237,7 +237,7 @@ def checkreadinputs(self, sampfrom, sampto, channels, physical, m2s, smoothframe
237237
if not hasattr(sampto, '__index__'):
238238
raise TypeError('sampto must be an integer')
239239

240-
if type(channels) != list:
240+
if not isinstance(channels, list):
241241
raise TypeError('channels must be a list of integers')
242242

243243
# Duration Ranges
@@ -265,7 +265,7 @@ def checkreadinputs(self, sampfrom, sampto, channels, physical, m2s, smoothframe
265265
raise ValueError("returnres must be one of the following when physical is True: 64, 32, 16")
266266

267267
# Cannot expand multiple samples/frame for multi-segment records
268-
if type(self) == MultiRecord:
268+
if isinstance(self, MultiRecord):
269269

270270
# If m2s == True, Physical must be true. There is no
271271
# meaningful representation of digital signals transferred
@@ -286,7 +286,7 @@ def checkitemtype(item, field, allowedtypes, channels=None):
286286
if channels is not None:
287287

288288
# First make sure the item is a list
289-
if type(item) != list:
289+
if not isinstance(item, list):
290290
raise TypeError("Field: '"+field+"' must be a list")
291291

292292
# Expand to make sure all channels must have present field
@@ -302,17 +302,17 @@ def checkitemtype(item, field, allowedtypes, channels=None):
302302
mustexist=channels[ch]
303303
# The field must exist for the channel
304304
if mustexist:
305-
if type(item[ch]) not in allowedtypes:
305+
if not isinstance(item[ch], allowedtypes):
306306
raise TypeError("Channel "+str(ch)+" of field: '"+field+"' must be one of the following types:", allowedtypes)
307307

308308
# The field may be None for the channel
309309
else:
310-
if type(item[ch]) not in allowedtypes and item[ch] is not None:
310+
if not isinstance(item[ch], allowedtypes) and item[ch] is not None:
311311
raise TypeError("Channel "+str(ch)+" of field: '"+field+"' must be a 'None', or one of the following types:", allowedtypes)
312312

313313
# Single scalar to check
314314
else:
315-
if type(item) not in allowedtypes:
315+
if not isinstance(item, allowedtypes):
316316
raise TypeError("Field: '"+field+"' must be one of the following types:", allowedtypes)
317317

318318

@@ -810,7 +810,7 @@ def rdsamp(recordname, sampfrom=0, sampto=None, channels = None, physical = True
810810
record.checkreadinputs(sampfrom, sampto, channels, physical, m2s, smoothframes, returnres)
811811

812812
# A single segment record
813-
if type(record) == Record:
813+
if isinstance(record, Record):
814814

815815
# Only 1 sample/frame, or frames are smoothed. Return uniform numpy array
816816
if smoothframes or max([record.sampsperframe[c] for c in channels])==1:
@@ -896,7 +896,7 @@ def rdsamp(recordname, sampfrom=0, sampto=None, channels = None, physical = True
896896
record = record.multi_to_single(returnres=returnres)
897897

898898
# Perform dtype conversion if necessary
899-
if type(record) == Record and record.nsig>0:
899+
if isinstance(record, Record) and record.nsig>0:
900900
record.convert_dtype(physical, returnres, smoothframes)
901901

902902
return record
@@ -1277,7 +1277,7 @@ def dldatabase(pbdb, dlbasedir, records = 'all', annotators = 'all' , keepsubdir
12771277
record = rdheader(baserecname, pbdir = posixpath.join(pbdb, dirname))
12781278

12791279
# Single segment record
1280-
if type(record) == Record:
1280+
if isinstance(record, Record):
12811281
# Add all dat files of the segment
12821282
for file in record.filename:
12831283
allfiles.append(posixpath.join(dirname, file))

0 commit comments

Comments
 (0)