|
125 | 125 | import time
|
126 | 126 | import math
|
127 | 127 | import datetime
|
| 128 | +import functools |
128 | 129 |
|
129 | 130 | import warnings
|
130 | 131 |
|
@@ -806,20 +807,105 @@ def __call__(self, x, pos=None):
|
806 | 807 |
|
807 | 808 |
|
808 | 809 | class rrulewrapper(object):
|
| 810 | + def __init__(self, freq, tzinfo=None, **kwargs): |
| 811 | + kwargs['freq'] = freq |
| 812 | + self._base_tzinfo = tzinfo |
809 | 813 |
|
810 |
| - def __init__(self, freq, **kwargs): |
811 |
| - self._construct = kwargs.copy() |
812 |
| - self._construct["freq"] = freq |
813 |
| - self._rrule = rrule(**self._construct) |
| 814 | + self._update_rrule(**kwargs) |
814 | 815 |
|
815 | 816 | def set(self, **kwargs):
|
816 | 817 | self._construct.update(kwargs)
|
| 818 | + |
| 819 | + self._update_rrule(**self._construct) |
| 820 | + |
| 821 | + def _update_rrule(self, **kwargs): |
| 822 | + tzinfo = self._base_tzinfo |
| 823 | + |
| 824 | + # rrule does not play nicely with time zones - especially pytz time |
| 825 | + # zones, it's best to use naive zones and attach timezones once the |
| 826 | + # datetimes are returned |
| 827 | + if 'dtstart' in kwargs: |
| 828 | + dtstart = kwargs['dtstart'] |
| 829 | + if dtstart.tzinfo is not None: |
| 830 | + if tzinfo is None: |
| 831 | + tzinfo = dtstart.tzinfo |
| 832 | + else: |
| 833 | + dtstart = dtstart.astimezone(tzinfo) |
| 834 | + |
| 835 | + kwargs['dtstart'] = dtstart.replace(tzinfo=None) |
| 836 | + |
| 837 | + if 'until' in kwargs: |
| 838 | + until = kwargs['until'] |
| 839 | + if until.tzinfo is not None: |
| 840 | + if tzinfo is not None: |
| 841 | + until = until.astimezone(tzinfo) |
| 842 | + else: |
| 843 | + raise ValueError('until cannot be aware if dtstart ' |
| 844 | + 'is naive and tzinfo is None') |
| 845 | + |
| 846 | + kwargs['until'] = until.replace(tzinfo=None) |
| 847 | + |
| 848 | + self._construct = kwargs.copy() |
| 849 | + self._tzinfo = tzinfo |
817 | 850 | self._rrule = rrule(**self._construct)
|
818 | 851 |
|
| 852 | + def _attach_tzinfo(self, dt, tzinfo): |
| 853 | + # pytz zones are attached by "localizing" the datetime |
| 854 | + if hasattr(tzinfo, 'localize'): |
| 855 | + return tzinfo.localize(dt, is_dst=True) |
| 856 | + |
| 857 | + return dt.replace(tzinfo=tzinfo) |
| 858 | + |
| 859 | + def _aware_return_wrapper(self, f, returns_list=False): |
| 860 | + """Decorator function that allows rrule methods to handle tzinfo.""" |
| 861 | + # This is only necessary if we're actually attaching a tzinfo |
| 862 | + if self._tzinfo is None: |
| 863 | + return f |
| 864 | + |
| 865 | + # All datetime arguments must be naive. If they are not naive, they are |
| 866 | + # converted to the _tzinfo zone before dropping the zone. |
| 867 | + def normalize_arg(arg): |
| 868 | + if isinstance(arg, datetime.datetime) and arg.tzinfo is not None: |
| 869 | + if arg.tzinfo is not self._tzinfo: |
| 870 | + arg = arg.astimezone(self._tzinfo) |
| 871 | + |
| 872 | + return arg.replace(tzinfo=None) |
| 873 | + |
| 874 | + return arg |
| 875 | + |
| 876 | + def normalize_args(args, kwargs): |
| 877 | + args = tuple(normalize_arg(arg) for arg in args) |
| 878 | + kwargs = {kw: normalize_arg(arg) for kw, arg in kwargs.items()} |
| 879 | + |
| 880 | + return args, kwargs |
| 881 | + |
| 882 | + # There are two kinds of functions we care about - ones that return |
| 883 | + # dates and ones that return lists of dates. |
| 884 | + if not returns_list: |
| 885 | + def inner_func(*args, **kwargs): |
| 886 | + args, kwargs = normalize_args(args, kwargs) |
| 887 | + dt = f(*args, **kwargs) |
| 888 | + return self._attach_tzinfo(dt, self._tzinfo) |
| 889 | + else: |
| 890 | + def inner_func(*args, **kwargs): |
| 891 | + args, kwargs = normalize_args(args, kwargs) |
| 892 | + dts = f(*args, **kwargs) |
| 893 | + return [self._attach_tzinfo(dt, self._tzinfo) for dt in dts] |
| 894 | + |
| 895 | + return functools.wraps(f)(inner_func) |
| 896 | + |
819 | 897 | def __getattr__(self, name):
|
820 | 898 | if name in self.__dict__:
|
821 | 899 | return self.__dict__[name]
|
822 |
| - return getattr(self._rrule, name) |
| 900 | + |
| 901 | + f = getattr(self._rrule, name) |
| 902 | + |
| 903 | + if name in {'after', 'before'}: |
| 904 | + return self._aware_return_wrapper(f) |
| 905 | + elif name in {'xafter', 'xbefore', 'between'}: |
| 906 | + return self._aware_return_wrapper(f, returns_list=True) |
| 907 | + else: |
| 908 | + return f |
823 | 909 |
|
824 | 910 | def __setstate__(self, state):
|
825 | 911 | self.__dict__.update(state)
|
@@ -1304,7 +1390,7 @@ def __init__(self, bymonth=None, bymonthday=1, interval=1, tz=None):
|
1304 | 1390 | bymonth = [x.item() for x in bymonth.astype(int)]
|
1305 | 1391 |
|
1306 | 1392 | rule = rrulewrapper(MONTHLY, bymonth=bymonth, bymonthday=bymonthday,
|
1307 |
| - interval=interval, **self.hms0d) |
| 1393 | + interval=interval, **self.hms0d) |
1308 | 1394 | RRuleLocator.__init__(self, rule, tz)
|
1309 | 1395 |
|
1310 | 1396 |
|
|
0 commit comments