diff --git a/lib/mpl_toolkits/axisartist/grid_finder.py b/lib/mpl_toolkits/axisartist/grid_finder.py index 2276d28bae90..eb988b14b289 100644 --- a/lib/mpl_toolkits/axisartist/grid_finder.py +++ b/lib/mpl_toolkits/axisartist/grid_finder.py @@ -176,24 +176,26 @@ def _clip_grid_lines_and_find_ticks(self, lines, values, levs, bb): return gi def update_transform(self, aux_trans): - if isinstance(aux_trans, Transform): - def transform_xy(x, y): - ll1 = np.column_stack([x, y]) - ll2 = aux_trans.transform(ll1) - lon, lat = ll2[:, 0], ll2[:, 1] - return lon, lat - - def inv_transform_xy(x, y): - ll1 = np.column_stack([x, y]) - ll2 = aux_trans.inverted().transform(ll1) - lon, lat = ll2[:, 0], ll2[:, 1] - return lon, lat - + if not isinstance(aux_trans, Transform) and len(aux_trans) != 2: + raise TypeError("'aux_trans' must be either a Transform instance " + "or a pair of callables") + self._aux_transform = aux_trans + + def transform_xy(self, x, y): + aux_trf = self._aux_transform + if isinstance(aux_trf, Transform): + return aux_trf.transform(np.column_stack([x, y])).T else: - transform_xy, inv_transform_xy = aux_trans + transform_xy, inv_transform_xy = aux_trf + return transform_xy(x, y) - self.transform_xy = transform_xy - self.inv_transform_xy = inv_transform_xy + def inv_transform_xy(self, x, y): + aux_trf = self._aux_transform + if isinstance(aux_trf, Transform): + return aux_trf.inverted().transform(np.column_stack([x, y])).T + else: + transform_xy, inv_transform_xy = aux_trf + return inv_transform_xy(x, y) def update(self, **kw): for k in kw: