Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit 2c2a4ae

Browse files
author
shawnchen1996
committed
add tests for cbar anchor
1 parent a55b35c commit 2c2a4ae

1 file changed

Lines changed: 48 additions & 0 deletions

File tree

lib/matplotlib/tests/test_colorbar.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -622,3 +622,51 @@ def test_colorbar_int(clim):
622622
im = ax.imshow([[*map(np.int16, clim)]])
623623
fig.colorbar(im)
624624
assert (im.norm.vmin, im.norm.vmax) == clim
625+
626+
627+
def test_anchored_cbar_position_using_specgrid():
628+
data = np.arange(1200).reshape(30, 40)
629+
levels = [0, 200, 400, 600, 800, 1000, 1200]
630+
shrink = 0.5
631+
anchor_y = 0.3
632+
# vertival
633+
fig, ax = plt.subplots()
634+
cs = ax.contourf(data, levels=levels)
635+
cbar = plt.colorbar(
636+
cs, ax=ax, use_gridspec=True,
637+
orientation='vertical', anchor=(1, anchor_y), shrink=shrink)
638+
639+
# y1: the top of ax, y0: the bottom of ax, p0: the y postion of anchor
640+
# cy1 : the top of colorbar ax, cy0: the bottom of colorbar ax
641+
y1 = ax.get_position().y1
642+
y0 = ax.get_position().y0
643+
p0 = (y1 - y0) * anchor_y + y0
644+
cy1 = cbar.ax.get_position().y1
645+
cy0 = cbar.ax.get_position().y0
646+
647+
assert np.isclose(
648+
[cy1, cy0],
649+
[y1 * shrink + (1 - shrink) * p0, p0 * (1 - shrink) + y0 * shrink]
650+
).all()
651+
652+
# horizontal
653+
shrink = 0.5
654+
anchor_x = 0.3
655+
fig, ax = plt.subplots()
656+
cs = ax.contourf(data, levels=levels)
657+
cbar = plt.colorbar(
658+
cs, ax=ax, use_gridspec=True,
659+
orientation='horizontal', anchor=(anchor_x, 1), shrink=shrink)
660+
661+
# x1: the right of ax, x0: the left of ax, p0: the x postion of anchor
662+
# cx1 : the right of colorbar ax, cx0: the left of colorbar ax
663+
x1 = ax.get_position().x1
664+
x0 = ax.get_position().x0
665+
p0 = (x1 - x0) * anchor_x + x0
666+
cx1 = cbar.ax.get_position().x1
667+
cx0 = cbar.ax.get_position().x0
668+
669+
assert np.isclose(
670+
[cx1, cx0],
671+
[x1 * shrink + (1 - shrink) * p0, p0 * (1 - shrink) + x0 * shrink]
672+
).all()

0 commit comments

Comments
 (0)