-
-
Notifications
You must be signed in to change notification settings - Fork 18.5k
[PERF] Get rid of MultiIndex conversion in IntervalIndex.intersection #26225
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
a5a1272
3cd095a
0486a4e
09c89f1
841a0b7
8b22623
32d4005
d502fcb
8ec6366
6000904
7cb7d2c
bcf36bb
745c0bb
ff8bb97
17d775f
03a989a
b36cbc8
1cdb170
18c2d37
d229677
35594b0
0834206
3cf5be8
9cf9b7e
ab67edd
402b09c
b4f130d
3ff4c64
3db3130
1f25adb
4a9cd29
1467e94
ea2550a
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -97,6 +97,42 @@ def _new_IntervalIndex(cls, d): | |
return cls.from_arrays(**d) | ||
|
||
|
||
class SetopCheck: | ||
""" | ||
This is called to decorate the set operations of IntervalIndex | ||
to perform the type check in advance. | ||
""" | ||
def __init__(self, op_name): | ||
self.op_name = op_name | ||
|
||
def __call__(self, setop): | ||
def func(intvidx_self, other, sort=False): | ||
intvidx_self._assert_can_do_setop(other) | ||
other = ensure_index(other) | ||
|
||
if not isinstance(other, IntervalIndex): | ||
result = getattr(intvidx_self.astype(object), | ||
self.op_name)(other) | ||
if self.op_name in ('difference',): | ||
result = result.astype(intvidx_self.dtype) | ||
return result | ||
elif intvidx_self.closed != other.closed: | ||
msg = ('can only do set operations between two IntervalIndex ' | ||
'objects that are closed on the same side') | ||
raise ValueError(msg) | ||
|
||
# GH 19016: ensure set op will not return a prohibited dtype | ||
subtypes = [intvidx_self.dtype.subtype, other.dtype.subtype] | ||
common_subtype = find_common_type(subtypes) | ||
if is_object_dtype(common_subtype): | ||
msg = ('can only do {op} between two IntervalIndex ' | ||
'objects that have compatible dtypes') | ||
raise TypeError(msg.format(op=self.op_name)) | ||
|
||
return setop(intvidx_self, other, sort) | ||
return func | ||
|
||
|
||
@Appender(_interval_shared_docs['class'] % dict( | ||
klass="IntervalIndex", | ||
summary="Immutable index of intervals that are closed on the same side.", | ||
|
@@ -1102,28 +1138,78 @@ def equals(self, other): | |
def overlaps(self, other): | ||
return self._data.overlaps(other) | ||
|
||
def _setop(op_name, sort=None): | ||
def func(self, other, sort=sort): | ||
self._assert_can_do_setop(other) | ||
other = ensure_index(other) | ||
if not isinstance(other, IntervalIndex): | ||
result = getattr(self.astype(object), op_name)(other) | ||
if op_name in ('difference',): | ||
result = result.astype(self.dtype) | ||
return result | ||
elif self.closed != other.closed: | ||
msg = ('can only do set operations between two IntervalIndex ' | ||
'objects that are closed on the same side') | ||
raise ValueError(msg) | ||
@Appender(_index_shared_docs['intersection']) | ||
@SetopCheck(op_name='intersection') | ||
def intersection(self, other, sort=False): | ||
if self.left.is_unique and self.right.is_unique: | ||
jschendel marked this conversation as resolved.
Show resolved
Hide resolved
|
||
taken = self._intersection_unique(other) | ||
else: | ||
# duplicates | ||
taken = self._intersection_non_unique(other) | ||
|
||
# GH 19016: ensure set op will not return a prohibited dtype | ||
subtypes = [self.dtype.subtype, other.dtype.subtype] | ||
common_subtype = find_common_type(subtypes) | ||
if is_object_dtype(common_subtype): | ||
msg = ('can only do {op} between two IntervalIndex ' | ||
'objects that have compatible dtypes') | ||
raise TypeError(msg.format(op=op_name)) | ||
if sort is None: | ||
taken = taken.sort_values() | ||
|
||
return taken | ||
|
||
def _intersection_unique(self, other): | ||
jschendel marked this conversation as resolved.
Show resolved
Hide resolved
|
||
""" | ||
Used when the IntervalIndex does not have any common endpoint, | ||
no mater left or right. | ||
Return the intersection with another IntervalIndex. | ||
|
||
Parameters | ||
---------- | ||
other : IntervalIndex | ||
|
||
Returns | ||
------- | ||
taken : IntervalIndex | ||
""" | ||
lindexer = self.left.get_indexer(other.left) | ||
TomAugspurger marked this conversation as resolved.
Show resolved
Hide resolved
|
||
rindexer = self.right.get_indexer(other.right) | ||
|
||
match = (lindexer == rindexer) & (lindexer != -1) | ||
indexer = lindexer.take(match.nonzero()[0]) | ||
|
||
return self.take(indexer) | ||
|
||
def _intersection_non_unique(self, other): | ||
""" | ||
Used when the IntervalIndex does have some common endpoints, | ||
on either sides. | ||
Return the intersection with another IntervalIndex. | ||
|
||
Parameters | ||
---------- | ||
other : IntervalIndex | ||
|
||
Returns | ||
------- | ||
taken : IntervalIndex | ||
""" | ||
mask = np.zeros(len(self), dtype=bool) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There might be an issue with this approach when dupes are present in Some examples of the buggy and inconsistent behavior with In [2]: idx2 = pd.Index(list('aa'))
...: idx3 = pd.Index(list('aaa'))
...: idx3b = pd.Index(list('baaa'))
In [3]: idx2.intersection(idx3)
Out[3]: Index(['a', 'a', 'a', 'a'], dtype='object')
In [4]: idx3.intersection(idx3)
Out[4]: Index(['a', 'a', 'a'], dtype='object')
In [5]: idx2.intersection(idx3)
Out[5]: Index(['a', 'a', 'a', 'a'], dtype='object')
In [6]: idx2.intersection(idx3b)
Out[6]: Index(['a', 'a', 'a'], dtype='object') It seems strange that @jreback : Do you know what the expected behavior for If we treat indexes like multisets, then the intersection should contain the minimum multiplicity of dupes, e.g. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yeah this is weird as these are set ops what happens (meaning how much breakage) if
prob need to do this for all set ops There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Haven't had time to extensively test this out but I made the two changes you suggested in |
||
|
||
if self.hasnans and other.hasnans: | ||
first_nan_loc = np.arange(len(self))[self.isna()][0] | ||
mask[first_nan_loc] = True | ||
|
||
lmiss = other.left.get_indexer_non_unique(self.left)[1] | ||
lmatch = np.setdiff1d(np.arange(len(self)), lmiss) | ||
|
||
for i in lmatch: | ||
potential = other.left.get_loc(self.left[i]) | ||
if is_scalar(potential): | ||
if self.right[i] == other.right[potential]: | ||
mask[i] = True | ||
elif self.right[i] in other.right[potential]: | ||
mask[i] = True | ||
|
||
return self[mask] | ||
|
||
def _setop(op_name, sort=None): | ||
@SetopCheck(op_name=op_name) | ||
def func(self, other, sort=sort): | ||
result = getattr(self._multiindex, op_name)(other._multiindex, | ||
sort=sort) | ||
result_name = get_op_result_name(self, other) | ||
|
@@ -1148,7 +1234,6 @@ def is_all_dates(self): | |
return False | ||
|
||
union = _setop('union') | ||
intersection = _setop('intersection', sort=False) | ||
difference = _setop('difference') | ||
symmetric_difference = _setop('symmetric_difference') | ||
|
||
|
Uh oh!
There was an error while loading. Please reload this page.