diff --git a/lib/mpl_toolkits/axisartist/grid_finder.py b/lib/mpl_toolkits/axisartist/grid_finder.py index ff61887f9ef2..897a4152b0e7 100644 --- a/lib/mpl_toolkits/axisartist/grid_finder.py +++ b/lib/mpl_toolkits/axisartist/grid_finder.py @@ -192,25 +192,28 @@ def get_grid_info(self, x1, y1, x2, y2): grid_info = { "extremes": extremes, - "lon_lines": lon_lines, - "lat_lines": lat_lines, - "lon": self._clip_grid_lines_and_find_ticks( - lon_lines, lon_values, lon_levs, bb), - "lat": self._clip_grid_lines_and_find_ticks( - lat_lines, lat_values, lat_levs, bb), + # "lon", "lat", filled below. } - tck_labels = grid_info["lon"]["tick_labels"] = {} - for direction in ["left", "bottom", "right", "top"]: - levs = grid_info["lon"]["tick_levels"][direction] - tck_labels[direction] = self._format_ticks( - 1, direction, lon_factor, levs) - - tck_labels = grid_info["lat"]["tick_labels"] = {} - for direction in ["left", "bottom", "right", "top"]: - levs = grid_info["lat"]["tick_levels"][direction] - tck_labels[direction] = self._format_ticks( - 2, direction, lat_factor, levs) + for idx, lon_or_lat, levs, factor, values, lines in [ + (1, "lon", lon_levs, lon_factor, lon_values, lon_lines), + (2, "lat", lat_levs, lat_factor, lat_values, lat_lines), + ]: + grid_info[lon_or_lat] = gi = { + "lines": [[l] for l in lines], + "ticks": {"left": [], "right": [], "bottom": [], "top": []}, + } + for (lx, ly), v, level in zip(lines, values, levs): + all_crossings = _find_line_box_crossings(np.column_stack([lx, ly]), bb) + for side, crossings in zip( + ["left", "right", "bottom", "top"], all_crossings): + for crossing in crossings: + gi["ticks"][side].append({"level": level, "loc": crossing}) + for side in gi["ticks"]: + levs = [tick["level"] for tick in gi["ticks"][side]] + labels = self._format_ticks(idx, side, factor, levs) + for tick, label in zip(gi["ticks"][side], labels): + tick["label"] = label return grid_info @@ -228,30 +231,6 @@ def _get_raw_grid_lines(self, return lon_lines, lat_lines - def _clip_grid_lines_and_find_ticks(self, lines, values, levs, bb): - gi = { - "values": [], - "levels": [], - "tick_levels": dict(left=[], bottom=[], right=[], top=[]), - "tick_locs": dict(left=[], bottom=[], right=[], top=[]), - "lines": [], - } - - tck_levels = gi["tick_levels"] - tck_locs = gi["tick_locs"] - for (lx, ly), v, lev in zip(lines, values, levs): - tcks = _find_line_box_crossings(np.column_stack([lx, ly]), bb) - gi["levels"].append(v) - gi["lines"].append([(lx, ly)]) - - for tck, direction in zip(tcks, - ["left", "right", "bottom", "top"]): - for t in tck: - tck_levels[direction].append(lev) - tck_locs[direction].append(t) - - return gi - def set_transform(self, aux_trans): if isinstance(aux_trans, Transform): self._aux_transform = aux_trans diff --git a/lib/mpl_toolkits/axisartist/grid_helper_curvelinear.py b/lib/mpl_toolkits/axisartist/grid_helper_curvelinear.py index 8dbfa6adf90f..a7eb9d5cfe21 100644 --- a/lib/mpl_toolkits/axisartist/grid_helper_curvelinear.py +++ b/lib/mpl_toolkits/axisartist/grid_helper_curvelinear.py @@ -82,9 +82,9 @@ def iter_major(): for nth_coord, show_labels in [ (self.nth_coord_ticks, True), (1 - self.nth_coord_ticks, False)]: gi = self.grid_helper._grid_info[["lon", "lat"][nth_coord]] - for (xy, angle_normal), l in zip( - gi["tick_locs"][side], gi["tick_labels"][side]): - yield xy, angle_normal, angle_tangent, (l if show_labels else "") + for tick in gi["ticks"][side]: + yield (*tick["loc"], angle_tangent, + (tick["label"] if show_labels else "")) return iter_major(), iter([]) @@ -321,12 +321,8 @@ def get_tick_iterator(self, nth_coord, axis_side, minor=False): angle_tangent = dict(left=90, right=90, bottom=0, top=0)[axis_side] lon_or_lat = ["lon", "lat"][nth_coord] if not minor: # major ticks - for (xy, angle_normal), l in zip( - self._grid_info[lon_or_lat]["tick_locs"][axis_side], - self._grid_info[lon_or_lat]["tick_labels"][axis_side]): - yield xy, angle_normal, angle_tangent, l + for tick in self._grid_info[lon_or_lat]["ticks"][axis_side]: + yield *tick["loc"], angle_tangent, tick["label"] else: - for (xy, angle_normal), l in zip( - self._grid_info[lon_or_lat]["tick_locs"][axis_side], - self._grid_info[lon_or_lat]["tick_labels"][axis_side]): - yield xy, angle_normal, angle_tangent, "" + for tick in self._grid_info[lon_or_lat]["ticks"][axis_side]: + yield *tick["loc"], angle_tangent, ""