Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit 24d662e

Browse files
committed
rec_join now handles two record arrays with the same column names with "*fixes"
svn path=/trunk/matplotlib/; revision=6170
1 parent 1c45457 commit 24d662e

1 file changed

Lines changed: 33 additions & 12 deletions

File tree

lib/matplotlib/mlab.py

Lines changed: 33 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1665,6 +1665,14 @@ def safe_isinf(x):
16651665
except TypeError: return False
16661666
else: return b
16671667

1668+
def rec_view(rec):
1669+
""" Return a view of an ndarray as a recarray
1670+
http://projects.scipy.org/pipermail/numpy-discussion/2008-August/036429.html
1671+
Reverting Travis' fix because it doesn't work for object arrays
1672+
"""
1673+
return rec.view(np.recarray)
1674+
#return rec.view(dtype=(np.record, rec.dtype), type=np.recarray)
1675+
16681676
def rec_append_field(rec, name, arr, dtype=None):
16691677
"""
16701678
return a new record array with field name populated with data from array arr.
@@ -1703,7 +1711,7 @@ def rec_append_fields(rec, names, arrs, dtypes=None):
17031711
newrec[field] = rec[field]
17041712
for name, arr in zip(names, arrs):
17051713
newrec[name] = arr
1706-
return newrec.view(np.recarray)
1714+
return rec_view(newrec)
17071715

17081716

17091717
def rec_drop_fields(rec, names):
@@ -1719,7 +1727,7 @@ def rec_drop_fields(rec, names):
17191727
for field in newdtype.names:
17201728
newrec[field] = rec[field]
17211729

1722-
return newrec.view(np.recarray)
1730+
return rec_view(newrec)
17231731

17241732

17251733

@@ -1789,7 +1797,7 @@ def rec_summarize(r, summaryfuncs):
17891797
return np.rec.fromarrays(arrays, names=names)
17901798

17911799

1792-
def rec_join(key, r1, r2, jointype='inner', defaults=None):
1800+
def rec_join(key, r1, r2, jointype='inner', defaults=None, r1postfix='1', r2postfix='2'):
17931801
"""
17941802
join record arrays r1 and r2 on key; key is a tuple of field
17951803
names. If r1 and r2 have equal values on all the keys in the key
@@ -1803,6 +1811,9 @@ def rec_join(key, r1, r2, jointype='inner', defaults=None):
18031811
18041812
The defaults keyword is a dictionary filled with
18051813
{column_name:default_value} pairs.
1814+
1815+
The keywords r1postfix and r2postfix are postfixed to column names
1816+
(other than keys) that are both in r1 and r2.
18061817
"""
18071818

18081819
for name in key:
@@ -1850,13 +1861,21 @@ def key_desc(name):
18501861
return (name, dt2.descr[0][1])
18511862

18521863

1853-
18541864
keydesc = [key_desc(name) for name in key]
1865+
1866+
def mapped_r1field(name):
1867+
""" the column name in newrec that corresponds to the colmn in r1 """
1868+
if name in key or name not in r2.dtype.names: return name
1869+
else: return name + r1postfix
18551870

1856-
newdtype = np.dtype(keydesc +
1857-
[desc for desc in r1.dtype.descr if desc[0] not in key ] +
1858-
[desc for desc in r2.dtype.descr if desc[0] not in key ] )
1871+
def mapped_r2field(name):
1872+
""" the column name in newrec that corresponds to the colmn in r2 """
1873+
if name in key or name not in r1.dtype.names: return name
1874+
else: return name + r2postfix
18591875

1876+
r1desc = [(mapped_r1field(desc[0]), desc[1]) for desc in r1.dtype.descr if desc[0] not in key]
1877+
r2desc = [(mapped_r2field(desc[0]), desc[1]) for desc in r2.dtype.descr if desc[0] not in key]
1878+
newdtype = np.dtype(keydesc + r1desc + r2desc)
18601879

18611880
newrec = np.empty(common_len + left_len + right_len, dtype=newdtype)
18621881

@@ -1867,20 +1886,22 @@ def key_desc(name):
18671886
newrec[k] = v
18681887

18691888
for field in r1.dtype.names:
1889+
newfield = mapped_r1field(field)
18701890
if common_len:
1871-
newrec[field][:common_len] = r1[field][r1ind]
1891+
newrec[newfield][:common_len] = r1[field][r1ind]
18721892
if (jointype == "outer" or jointype == "leftouter") and left_len:
1873-
newrec[field][common_len:(common_len+left_len)] = r1[field][left_ind]
1893+
newrec[newfield][common_len:(common_len+left_len)] = r1[field][left_ind]
18741894

18751895
for field in r2.dtype.names:
1896+
newfield = mapped_r2field(field)
18761897
if field not in key and common_len:
1877-
newrec[field][:common_len] = r2[field][r2ind]
1898+
newrec[newfield][:common_len] = r2[field][r2ind]
18781899
if jointype == "outer" and right_len:
1879-
newrec[field][-right_len:] = r2[field][right_ind]
1900+
newrec[newfield][-right_len:] = r2[field][right_ind]
18801901

18811902
newrec.sort(order=key)
18821903

1883-
return newrec.view(np.recarray)
1904+
return rec_view(newrec)
18841905

18851906

18861907
def csv2rec(fname, comments='#', skiprows=0, checkrows=0, delimiter=',',

0 commit comments

Comments
 (0)