Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit b02182c

Browse files
Batch transform 3d ticks
1 parent 59f7cd0 commit b02182c

1 file changed

Lines changed: 71 additions & 49 deletions

File tree

lib/mpl_toolkits/mplot3d/axis3d.py

Lines changed: 71 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -406,6 +406,10 @@ def _axmask(self):
406406
def _draw_ticks(self, renderer, edgep1, centers, deltas, highs,
407407
deltas_per_point, pos):
408408
ticks = self._ticks_to_draw
409+
n_ticks = len(ticks)
410+
if n_ticks == 0:
411+
return
412+
409413
info = self._axinfo
410414
index = info["i"]
411415
juggled = info["juggled"]
@@ -431,23 +435,39 @@ def _draw_ticks(self, renderer, edgep1, centers, deltas, highs,
431435

432436
default_label_offset = 8. # A rough estimate
433437
points = deltas_per_point * deltas
434-
# All coordinates below are in transformed coordinates for proper projection
435-
for tick in ticks:
436-
# Get tick line positions
437-
pos = edgep1.copy()
438-
pos[index] = axis_trans.transform([tick.get_loc()])[0]
439-
pos[tickdir] = out_tickdir
440-
x1, y1, z1 = proj3d.proj_transform(*pos, self.axes.M)
441-
pos[tickdir] = in_tickdir
442-
x2, y2, z2 = proj3d.proj_transform(*pos, self.axes.M)
443-
444-
# Get position of label
445-
labeldeltas = (tick.get_pad() + default_label_offset) * points
446-
pos[tickdir] = edgep1_tickdir
447-
pos = _move_from_center(pos, centers, labeldeltas, self._axmask())
448-
lx, ly, lz = proj3d.proj_transform(*pos, self.axes.M)
449-
450-
_tick_update_position(tick, (x1, x2), (y1, y2), (lx, ly))
438+
439+
# Collect tick data and batch transform tick locations
440+
tick_locs = np.array([tick.get_loc() for tick in ticks])
441+
tick_pads = np.array([tick.get_pad() for tick in ticks])
442+
transformed_locs = axis_trans.transform(tick_locs)
443+
444+
# Build position arrays for tick line endpoints (shape: n_ticks x 3)
445+
pos1 = np.tile(edgep1, (n_ticks, 1))
446+
pos1[:, index] = transformed_locs
447+
pos1[:, tickdir] = out_tickdir
448+
449+
pos2 = pos1.copy()
450+
pos2[:, tickdir] = in_tickdir
451+
452+
# Batch proj_transform for tick lines
453+
x1, y1, _ = proj3d.proj_transform(pos1[:, 0], pos1[:, 1], pos1[:, 2],
454+
self.axes.M)
455+
x2, y2, _ = proj3d.proj_transform(pos2[:, 0], pos2[:, 1], pos2[:, 2],
456+
self.axes.M)
457+
458+
# Build label positions
459+
labeldeltas = (tick_pads + default_label_offset)[:, np.newaxis] * points
460+
pos_label = pos1.copy()
461+
pos_label[:, tickdir] = edgep1_tickdir
462+
axmask = self._axmask()
463+
pos_label = _move_from_center(pos_label, centers, labeldeltas, axmask)
464+
lx, ly, _ = proj3d.proj_transform(pos_label[:, 0], pos_label[:, 1],
465+
pos_label[:, 2], self.axes.M)
466+
467+
# Update and draw each tick
468+
for i, tick in enumerate(ticks):
469+
_tick_update_position(tick, (x1[i], x2[i]), (y1[i], y2[i]),
470+
(lx[i], ly[i]))
451471
tick.tick1line.set_linewidth(tick_lw[tick._major])
452472
tick.draw(renderer)
453473

@@ -611,38 +631,40 @@ def draw_grid(self, renderer):
611631
renderer.open_group("grid3d", gid=self.get_gid())
612632

613633
ticks = self._ticks_to_draw
614-
if len(ticks):
615-
# Get general axis information:
616-
info = self._axinfo
617-
index = info["i"]
618-
619-
# Grid lines use data-space bounds (Line3DCollection applies transforms)
620-
mins, maxs, tc, highs = self.axes._get_coord_info()
621-
bounds = self.axes._get_bounds()
622-
xlim, ylim, zlim = bounds[0:2], bounds[2:4], bounds[4:6]
623-
data_mins = np.array([xlim[0], ylim[0], zlim[0]])
624-
data_maxs = np.array([xlim[1], ylim[1], zlim[1]])
625-
minmax = np.where(highs, data_maxs, data_mins)
626-
maxmin = np.where(~highs, data_maxs, data_mins)
627-
628-
# Grid points where the planes meet
629-
xyz0 = np.tile(minmax, (len(ticks), 1))
630-
xyz0[:, index] = [tick.get_loc() for tick in ticks]
631-
632-
# Grid lines go from the end of one plane through the plane
633-
# intersection (at xyz0) to the end of the other plane. The first
634-
# point (0) differs along dimension index-2 and the last (2) along
635-
# dimension index-1.
636-
lines = np.stack([xyz0, xyz0, xyz0], axis=1)
637-
lines[:, 0, index - 2] = maxmin[index - 2]
638-
lines[:, 2, index - 1] = maxmin[index - 1]
639-
self.gridlines.set_segments(lines)
640-
gridinfo = info['grid']
641-
self.gridlines.set_color(gridinfo['color'])
642-
self.gridlines.set_linewidth(gridinfo['linewidth'])
643-
self.gridlines.set_linestyle(gridinfo['linestyle'])
644-
self.gridlines.do_3d_projection()
645-
self.gridlines.draw(renderer)
634+
if len(ticks) == 0:
635+
return
636+
637+
# Get general axis information:
638+
info = self._axinfo
639+
index = info["i"]
640+
641+
# Grid lines use data-space bounds (Line3DCollection applies transforms)
642+
mins, maxs, tc, highs = self.axes._get_coord_info()
643+
bounds = self.axes._get_bounds()
644+
xlim, ylim, zlim = bounds[0:2], bounds[2:4], bounds[4:6]
645+
data_mins = np.array([xlim[0], ylim[0], zlim[0]])
646+
data_maxs = np.array([xlim[1], ylim[1], zlim[1]])
647+
minmax = np.where(highs, data_maxs, data_mins)
648+
maxmin = np.where(~highs, data_maxs, data_mins)
649+
650+
# Grid points where the planes meet
651+
xyz0 = np.tile(minmax, (len(ticks), 1))
652+
xyz0[:, index] = [tick.get_loc() for tick in ticks]
653+
654+
# Grid lines go from the end of one plane through the plane
655+
# intersection (at xyz0) to the end of the other plane. The first
656+
# point (0) differs along dimension index-2 and the last (2) along
657+
# dimension index-1.
658+
lines = np.stack([xyz0, xyz0, xyz0], axis=1)
659+
lines[:, 0, index - 2] = maxmin[index - 2]
660+
lines[:, 2, index - 1] = maxmin[index - 1]
661+
self.gridlines.set_segments(lines)
662+
gridinfo = info['grid']
663+
self.gridlines.set_color(gridinfo['color'])
664+
self.gridlines.set_linewidth(gridinfo['linewidth'])
665+
self.gridlines.set_linestyle(gridinfo['linestyle'])
666+
self.gridlines.do_3d_projection()
667+
self.gridlines.draw(renderer)
646668

647669
renderer.close_group('grid3d')
648670

0 commit comments

Comments
 (0)