|
118 | 118 | import time
|
119 | 119 | import math
|
120 | 120 | import datetime
|
| 121 | +import functools |
121 | 122 |
|
122 | 123 | import warnings
|
123 | 124 |
|
@@ -732,20 +733,105 @@ def __call__(self, x, pos=None):
|
732 | 733 |
|
733 | 734 |
|
734 | 735 | class rrulewrapper(object):
|
| 736 | + def __init__(self, freq, tzinfo=None, **kwargs): |
| 737 | + kwargs['freq'] = freq |
| 738 | + self._base_tzinfo = tzinfo |
735 | 739 |
|
736 |
| - def __init__(self, freq, **kwargs): |
737 |
| - self._construct = kwargs.copy() |
738 |
| - self._construct["freq"] = freq |
739 |
| - self._rrule = rrule(**self._construct) |
| 740 | + self._update_rrule(**kwargs) |
740 | 741 |
|
741 | 742 | def set(self, **kwargs):
|
742 | 743 | self._construct.update(kwargs)
|
| 744 | + |
| 745 | + self._update_rrule(**self._construct) |
| 746 | + |
| 747 | + def _update_rrule(self, **kwargs): |
| 748 | + tzinfo = self._base_tzinfo |
| 749 | + |
| 750 | + # rrule does not play nicely with time zones - especially pytz time |
| 751 | + # zones, it's best to use naive zones and attach timezones once the |
| 752 | + # datetimes are returned |
| 753 | + if 'dtstart' in kwargs: |
| 754 | + dtstart = kwargs['dtstart'] |
| 755 | + if dtstart.tzinfo is not None: |
| 756 | + if tzinfo is None: |
| 757 | + tzinfo = dtstart.tzinfo |
| 758 | + else: |
| 759 | + dtstart = dtstart.astimezone(tzinfo) |
| 760 | + |
| 761 | + kwargs['dtstart'] = dtstart.replace(tzinfo=None) |
| 762 | + |
| 763 | + if 'until' in kwargs: |
| 764 | + until = kwargs['until'] |
| 765 | + if until.tzinfo is not None: |
| 766 | + if tzinfo is not None: |
| 767 | + until = until.astimezone(tzinfo) |
| 768 | + else: |
| 769 | + raise ValueError('until cannot be aware if dtstart ' |
| 770 | + 'is naive and tzinfo is None') |
| 771 | + |
| 772 | + kwargs['until'] = until.replace(tzinfo=None) |
| 773 | + |
| 774 | + self._construct = kwargs.copy() |
| 775 | + self._tzinfo = tzinfo |
743 | 776 | self._rrule = rrule(**self._construct)
|
744 | 777 |
|
| 778 | + def _attach_tzinfo(self, dt, tzinfo): |
| 779 | + # pytz zones are attached by "localizing" the datetime |
| 780 | + if hasattr(tzinfo, 'localize'): |
| 781 | + return tzinfo.localize(dt, is_dst=True) |
| 782 | + |
| 783 | + return dt.replace(tzinfo=tzinfo) |
| 784 | + |
| 785 | + def _aware_return_wrapper(self, f, returns_list=False): |
| 786 | + """Decorator function that allows rrule methods to handle tzinfo.""" |
| 787 | + # This is only necessary if we're actually attaching a tzinfo |
| 788 | + if self._tzinfo is None: |
| 789 | + return f |
| 790 | + |
| 791 | + # All datetime arguments must be naive. If they are not naive, they are |
| 792 | + # converted to the _tzinfo zone before dropping the zone. |
| 793 | + def normalize_arg(arg): |
| 794 | + if isinstance(arg, datetime.datetime) and arg.tzinfo is not None: |
| 795 | + if arg.tzinfo is not self._tzinfo: |
| 796 | + arg = arg.astimezone(self._tzinfo) |
| 797 | + |
| 798 | + return arg.replace(tzinfo=None) |
| 799 | + |
| 800 | + return arg |
| 801 | + |
| 802 | + def normalize_args(args, kwargs): |
| 803 | + args = tuple(normalize_arg(arg) for arg in args) |
| 804 | + kwargs = {kw: normalize_arg(arg) for kw, arg in kwargs.items()} |
| 805 | + |
| 806 | + return args, kwargs |
| 807 | + |
| 808 | + # There are two kinds of functions we care about - ones that return |
| 809 | + # dates and ones that return lists of dates. |
| 810 | + if not returns_list: |
| 811 | + def inner_func(*args, **kwargs): |
| 812 | + args, kwargs = normalize_args(args, kwargs) |
| 813 | + dt = f(*args, **kwargs) |
| 814 | + return self._attach_tzinfo(dt, self._tzinfo) |
| 815 | + else: |
| 816 | + def inner_func(*args, **kwargs): |
| 817 | + args, kwargs = normalize_args(args, kwargs) |
| 818 | + dts = f(*args, **kwargs) |
| 819 | + return [self._attach_tzinfo(dt, self._tzinfo) for dt in dts] |
| 820 | + |
| 821 | + return functools.wraps(f)(inner_func) |
| 822 | + |
745 | 823 | def __getattr__(self, name):
|
746 | 824 | if name in self.__dict__:
|
747 | 825 | return self.__dict__[name]
|
748 |
| - return getattr(self._rrule, name) |
| 826 | + |
| 827 | + f = getattr(self._rrule, name) |
| 828 | + |
| 829 | + if name in {'after', 'before'}: |
| 830 | + return self._aware_return_wrapper(f) |
| 831 | + elif name in {'xafter', 'xbefore', 'between'}: |
| 832 | + return self._aware_return_wrapper(f, returns_list=True) |
| 833 | + else: |
| 834 | + return f |
749 | 835 |
|
750 | 836 | def __setstate__(self, state):
|
751 | 837 | self.__dict__.update(state)
|
@@ -1226,7 +1312,7 @@ def __init__(self, bymonth=None, bymonthday=1, interval=1, tz=None):
|
1226 | 1312 | bymonth = [x.item() for x in bymonth.astype(int)]
|
1227 | 1313 |
|
1228 | 1314 | rule = rrulewrapper(MONTHLY, bymonth=bymonth, bymonthday=bymonthday,
|
1229 |
| - interval=interval, **self.hms0d) |
| 1315 | + interval=interval, **self.hms0d) |
1230 | 1316 | RRuleLocator.__init__(self, rule, tz)
|
1231 | 1317 |
|
1232 | 1318 |
|
|
0 commit comments