Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit 2ce1c44

Browse files
committed
Modified rrulewraper to handle timezone-aware datetimes.
1 parent 4c33d97 commit 2ce1c44

File tree

2 files changed

+110
-6
lines changed

2 files changed

+110
-6
lines changed

lib/matplotlib/dates.py

Lines changed: 92 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,7 @@
118118
import time
119119
import math
120120
import datetime
121+
import functools
121122

122123
import warnings
123124

@@ -732,20 +733,105 @@ def __call__(self, x, pos=None):
732733

733734

734735
class rrulewrapper(object):
736+
def __init__(self, freq, tzinfo=None, **kwargs):
737+
kwargs['freq'] = freq
738+
self._base_tzinfo = tzinfo
735739

736-
def __init__(self, freq, **kwargs):
737-
self._construct = kwargs.copy()
738-
self._construct["freq"] = freq
739-
self._rrule = rrule(**self._construct)
740+
self._update_rrule(**kwargs)
740741

741742
def set(self, **kwargs):
742743
self._construct.update(kwargs)
744+
745+
self._update_rrule(**self._construct)
746+
747+
def _update_rrule(self, **kwargs):
748+
tzinfo = self._base_tzinfo
749+
750+
# rrule does not play nicely with time zones - especially pytz time
751+
# zones, it's best to use naive zones and attach timezones once the
752+
# datetimes are returned
753+
if 'dtstart' in kwargs:
754+
dtstart = kwargs['dtstart']
755+
if dtstart.tzinfo is not None:
756+
if tzinfo is None:
757+
tzinfo = dtstart.tzinfo
758+
else:
759+
dtstart = dtstart.astimezone(tzinfo)
760+
761+
kwargs['dtstart'] = dtstart.replace(tzinfo=None)
762+
763+
if 'until' in kwargs:
764+
until = kwargs['until']
765+
if until.tzinfo is not None:
766+
if tzinfo is not None:
767+
until = until.astimezone(tzinfo)
768+
else:
769+
raise ValueError('until cannot be aware if dtstart '
770+
'is naive and tzinfo is None')
771+
772+
kwargs['until'] = until.replace(tzinfo=None)
773+
774+
self._construct = kwargs.copy()
775+
self._tzinfo = tzinfo
743776
self._rrule = rrule(**self._construct)
744777

778+
def _attach_tzinfo(self, dt, tzinfo):
779+
# pytz zones are attached by "localizing" the datetime
780+
if hasattr(tzinfo, 'localize'):
781+
return tzinfo.localize(dt, is_dst=True)
782+
783+
return dt.replace(tzinfo=tzinfo)
784+
785+
def _aware_return_wrapper(self, f, returns_list=False):
786+
"""Decorator function that allows rrule methods to handle tzinfo."""
787+
# This is only necessary if we're actually attaching a tzinfo
788+
if self._tzinfo is None:
789+
return f
790+
791+
# All datetime arguments must be naive. If they are not naive, they are
792+
# converted to the _tzinfo zone before dropping the zone.
793+
def normalize_arg(arg):
794+
if isinstance(arg, datetime.datetime) and arg.tzinfo is not None:
795+
if arg.tzinfo is not self._tzinfo:
796+
arg = arg.astimezone(self._tzinfo)
797+
798+
return arg.replace(tzinfo=None)
799+
800+
return arg
801+
802+
def normalize_args(args, kwargs):
803+
args = tuple(normalize_arg(arg) for arg in args)
804+
kwargs = {kw: normalize_arg(arg) for kw, arg in kwargs.items()}
805+
806+
return args, kwargs
807+
808+
# There are two kinds of functions we care about - ones that return
809+
# dates and ones that return lists of dates.
810+
if not returns_list:
811+
def inner_func(*args, **kwargs):
812+
args, kwargs = normalize_args(args, kwargs)
813+
dt = f(*args, **kwargs)
814+
return self._attach_tzinfo(dt, self._tzinfo)
815+
else:
816+
def inner_func(*args, **kwargs):
817+
args, kwargs = normalize_args(args, kwargs)
818+
dts = f(*args, **kwargs)
819+
return [self._attach_tzinfo(dt, self._tzinfo) for dt in dts]
820+
821+
return functools.wraps(f)(inner_func)
822+
745823
def __getattr__(self, name):
746824
if name in self.__dict__:
747825
return self.__dict__[name]
748-
return getattr(self._rrule, name)
826+
827+
f = getattr(self._rrule, name)
828+
829+
if name in {'after', 'before'}:
830+
return self._aware_return_wrapper(f)
831+
elif name in {'xafter', 'xbefore', 'between'}:
832+
return self._aware_return_wrapper(f, returns_list=True)
833+
else:
834+
return f
749835

750836
def __setstate__(self, state):
751837
self.__dict__.update(state)
@@ -1226,7 +1312,7 @@ def __init__(self, bymonth=None, bymonthday=1, interval=1, tz=None):
12261312
bymonth = [x.item() for x in bymonth.astype(int)]
12271313

12281314
rule = rrulewrapper(MONTHLY, bymonth=bymonth, bymonthday=bymonthday,
1229-
interval=interval, **self.hms0d)
1315+
interval=interval, **self.hms0d)
12301316
RRuleLocator.__init__(self, rule, tz)
12311317

12321318

lib/matplotlib/tests/test_dates.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -442,6 +442,24 @@ def tz_convert(*args):
442442
_test_date2num_dst(pd.date_range, tz_convert)
443443

444444

445+
@pytest.mark.parametrize("attach_tz, get_tz", [
446+
(lambda dt, zi: zi.localize(dt), lambda n: pytz.timezone(n)),
447+
(lambda dt, zi: dt.replace(tzinfo=zi), lambda n: dateutil.tz.gettz(n))])
448+
def test_rrulewrapper(attach_tz, get_tz):
449+
SYD = get_tz('Australia/Sydney')
450+
451+
dtstart = attach_tz(datetime.datetime(2017, 4, 1, 0), SYD)
452+
dtend = attach_tz(datetime.datetime(2017, 4, 4, 0), SYD)
453+
454+
rule = dates.rrulewrapper(freq=dateutil.rrule.DAILY, dtstart=dtstart)
455+
456+
act = rule.between(dtstart, dtend)
457+
exp = [datetime.datetime(2017, 4, 1, 13, tzinfo=dateutil.tz.tzutc()),
458+
datetime.datetime(2017, 4, 2, 14, tzinfo=dateutil.tz.tzutc())]
459+
460+
assert act == exp
461+
462+
445463
def test_DayLocator():
446464
with pytest.raises(ValueError):
447465
mdates.DayLocator(interval=-1)

0 commit comments

Comments
 (0)