|
33 | 33 | # `np.minimum` instead of the builtin `min`, and likewise for `max`. This is |
34 | 34 | # done so that `nan`s are propagated, instead of being silently dropped. |
35 | 35 |
|
| 36 | +import copy |
36 | 37 | import functools |
37 | 38 | import textwrap |
38 | 39 | import weakref |
@@ -139,11 +140,33 @@ def __setstate__(self, data_dict): |
139 | 140 | k: weakref.ref(v, lambda _, pop=self._parents.pop, k=k: pop(k)) |
140 | 141 | for k, v in self._parents.items() if v is not None} |
141 | 142 |
|
142 | | - def __copy__(self, *args): |
143 | | - raise NotImplementedError( |
144 | | - "TransformNode instances can not be copied. " |
145 | | - "Consider using frozen() instead.") |
146 | | - __deepcopy__ = __copy__ |
| 143 | + def __copy__(self): |
| 144 | + other = copy.copy(super()) |
| 145 | + # If `c = a + b; a1 = copy(a)`, then modifications to `a1` do not |
| 146 | + # propagate back to `c`, i.e. we need to clear the parents of `a1`. |
| 147 | + other._parents = {} |
| 148 | + # If `c = a + b; c1 = copy(c)`, then modifications to `a` also need to |
| 149 | + # be propagated to `c1`. |
| 150 | + for key, val in vars(self).items(): |
| 151 | + if isinstance(val, TransformNode) and id(self) in val._parents: |
| 152 | + other.set_children(val) # val == getattr(other, key) |
| 153 | + return other |
| 154 | + |
| 155 | + def __deepcopy__(self, memo): |
| 156 | + # We could deepcopy the entire transform tree, but nothing except |
| 157 | + # `self` is accessible publicly, so we may as well just freeze `self`. |
| 158 | + other = self.frozen() |
| 159 | + if other is not self: |
| 160 | + return other |
| 161 | + # Some classes implement frozen() as returning self, which is not |
| 162 | + # acceptable for deepcopying, so we need to handle them separately. |
| 163 | + other = copy.deepcopy(super(), memo) |
| 164 | + # If `c = a + b; a1 = copy(a)`, then modifications to `a1` do not |
| 165 | + # propagate back to `c`, i.e. we need to clear the parents of `a1`. |
| 166 | + other._parents = {} |
| 167 | + # If `c = a + b; c1 = copy(c)`, this creates a separate tree |
| 168 | + # (`c1 = a1 + b1`) so nothing needs to be done. |
| 169 | + return other |
147 | 170 |
|
148 | 171 | def invalidate(self): |
149 | 172 | """ |
|
0 commit comments