Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit 725d8d6

Browse files
authored
Merge pull request #9028 from pganssle/fix_rrulewrapper_tzinfo
Modified rrulewraper to handle timezone-aware datetimes.
2 parents d93a2a9 + 68d43e5 commit 725d8d6

File tree

2 files changed

+110
-6
lines changed

2 files changed

+110
-6
lines changed

lib/matplotlib/dates.py

Lines changed: 92 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,7 @@
125125
import time
126126
import math
127127
import datetime
128+
import functools
128129

129130
import warnings
130131

@@ -806,20 +807,105 @@ def __call__(self, x, pos=None):
806807

807808

808809
class rrulewrapper(object):
810+
def __init__(self, freq, tzinfo=None, **kwargs):
811+
kwargs['freq'] = freq
812+
self._base_tzinfo = tzinfo
809813

810-
def __init__(self, freq, **kwargs):
811-
self._construct = kwargs.copy()
812-
self._construct["freq"] = freq
813-
self._rrule = rrule(**self._construct)
814+
self._update_rrule(**kwargs)
814815

815816
def set(self, **kwargs):
816817
self._construct.update(kwargs)
818+
819+
self._update_rrule(**self._construct)
820+
821+
def _update_rrule(self, **kwargs):
822+
tzinfo = self._base_tzinfo
823+
824+
# rrule does not play nicely with time zones - especially pytz time
825+
# zones, it's best to use naive zones and attach timezones once the
826+
# datetimes are returned
827+
if 'dtstart' in kwargs:
828+
dtstart = kwargs['dtstart']
829+
if dtstart.tzinfo is not None:
830+
if tzinfo is None:
831+
tzinfo = dtstart.tzinfo
832+
else:
833+
dtstart = dtstart.astimezone(tzinfo)
834+
835+
kwargs['dtstart'] = dtstart.replace(tzinfo=None)
836+
837+
if 'until' in kwargs:
838+
until = kwargs['until']
839+
if until.tzinfo is not None:
840+
if tzinfo is not None:
841+
until = until.astimezone(tzinfo)
842+
else:
843+
raise ValueError('until cannot be aware if dtstart '
844+
'is naive and tzinfo is None')
845+
846+
kwargs['until'] = until.replace(tzinfo=None)
847+
848+
self._construct = kwargs.copy()
849+
self._tzinfo = tzinfo
817850
self._rrule = rrule(**self._construct)
818851

852+
def _attach_tzinfo(self, dt, tzinfo):
853+
# pytz zones are attached by "localizing" the datetime
854+
if hasattr(tzinfo, 'localize'):
855+
return tzinfo.localize(dt, is_dst=True)
856+
857+
return dt.replace(tzinfo=tzinfo)
858+
859+
def _aware_return_wrapper(self, f, returns_list=False):
860+
"""Decorator function that allows rrule methods to handle tzinfo."""
861+
# This is only necessary if we're actually attaching a tzinfo
862+
if self._tzinfo is None:
863+
return f
864+
865+
# All datetime arguments must be naive. If they are not naive, they are
866+
# converted to the _tzinfo zone before dropping the zone.
867+
def normalize_arg(arg):
868+
if isinstance(arg, datetime.datetime) and arg.tzinfo is not None:
869+
if arg.tzinfo is not self._tzinfo:
870+
arg = arg.astimezone(self._tzinfo)
871+
872+
return arg.replace(tzinfo=None)
873+
874+
return arg
875+
876+
def normalize_args(args, kwargs):
877+
args = tuple(normalize_arg(arg) for arg in args)
878+
kwargs = {kw: normalize_arg(arg) for kw, arg in kwargs.items()}
879+
880+
return args, kwargs
881+
882+
# There are two kinds of functions we care about - ones that return
883+
# dates and ones that return lists of dates.
884+
if not returns_list:
885+
def inner_func(*args, **kwargs):
886+
args, kwargs = normalize_args(args, kwargs)
887+
dt = f(*args, **kwargs)
888+
return self._attach_tzinfo(dt, self._tzinfo)
889+
else:
890+
def inner_func(*args, **kwargs):
891+
args, kwargs = normalize_args(args, kwargs)
892+
dts = f(*args, **kwargs)
893+
return [self._attach_tzinfo(dt, self._tzinfo) for dt in dts]
894+
895+
return functools.wraps(f)(inner_func)
896+
819897
def __getattr__(self, name):
820898
if name in self.__dict__:
821899
return self.__dict__[name]
822-
return getattr(self._rrule, name)
900+
901+
f = getattr(self._rrule, name)
902+
903+
if name in {'after', 'before'}:
904+
return self._aware_return_wrapper(f)
905+
elif name in {'xafter', 'xbefore', 'between'}:
906+
return self._aware_return_wrapper(f, returns_list=True)
907+
else:
908+
return f
823909

824910
def __setstate__(self, state):
825911
self.__dict__.update(state)
@@ -1304,7 +1390,7 @@ def __init__(self, bymonth=None, bymonthday=1, interval=1, tz=None):
13041390
bymonth = [x.item() for x in bymonth.astype(int)]
13051391

13061392
rule = rrulewrapper(MONTHLY, bymonth=bymonth, bymonthday=bymonthday,
1307-
interval=interval, **self.hms0d)
1393+
interval=interval, **self.hms0d)
13081394
RRuleLocator.__init__(self, rule, tz)
13091395

13101396

lib/matplotlib/tests/test_dates.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -566,6 +566,24 @@ def tz_convert(*args):
566566
_test_date2num_dst(pd.date_range, tz_convert)
567567

568568

569+
@pytest.mark.parametrize("attach_tz, get_tz", [
570+
(lambda dt, zi: zi.localize(dt), lambda n: pytz.timezone(n)),
571+
(lambda dt, zi: dt.replace(tzinfo=zi), lambda n: dateutil.tz.gettz(n))])
572+
def test_rrulewrapper(attach_tz, get_tz):
573+
SYD = get_tz('Australia/Sydney')
574+
575+
dtstart = attach_tz(datetime.datetime(2017, 4, 1, 0), SYD)
576+
dtend = attach_tz(datetime.datetime(2017, 4, 4, 0), SYD)
577+
578+
rule = mdates.rrulewrapper(freq=dateutil.rrule.DAILY, dtstart=dtstart)
579+
580+
act = rule.between(dtstart, dtend)
581+
exp = [datetime.datetime(2017, 4, 1, 13, tzinfo=dateutil.tz.tzutc()),
582+
datetime.datetime(2017, 4, 2, 14, tzinfo=dateutil.tz.tzutc())]
583+
584+
assert act == exp
585+
586+
569587
def test_DayLocator():
570588
with pytest.raises(ValueError):
571589
mdates.DayLocator(interval=-1)

0 commit comments

Comments
 (0)