Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit 6b3de8f

Browse files
Code for 3d scale transform tests
1 parent 2e19d87 commit 6b3de8f

1 file changed

Lines changed: 328 additions & 1 deletion

File tree

lib/mpl_toolkits/mplot3d/tests/test_axes3d.py

Lines changed: 328 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
from matplotlib.patches import Circle, PathPatch
1919
from matplotlib.path import Path
2020
from matplotlib.text import Text
21-
from matplotlib import _api
21+
from matplotlib import _api
2222

2323
import matplotlib.pyplot as plt
2424
import numpy as np
@@ -2844,3 +2844,330 @@ def test_ctrl_rotation_snaps_to_5deg():
28442844
assert ax.roll == pytest.approx(expected_roll)
28452845

28462846
plt.close(fig)
2847+
2848+
2849+
# =============================================================================
2850+
# Tests for 3D scale transforms (log, symlog, logit, etc.)
2851+
# =============================================================================
2852+
2853+
def _make_log_data():
2854+
"""Data spanning 1 to ~1000 for log scale."""
2855+
t = np.linspace(0, 2 * np.pi, 50)
2856+
x = 10 ** (t / 2)
2857+
y = 10 ** (1 + np.sin(t))
2858+
z = 10 ** (2 * (1 + np.cos(t) / 2))
2859+
return x, y, z
2860+
2861+
2862+
def _make_surface_log_data():
2863+
"""Grid data for surface with positive Z."""
2864+
x = np.linspace(1, 10, 20)
2865+
y = np.linspace(1, 10, 20)
2866+
X, Y = np.meshgrid(x, y)
2867+
Z = X * Y
2868+
return X, Y, Z
2869+
2870+
2871+
def _make_triangulation_data():
2872+
"""Data for trisurf with positive values."""
2873+
np.random.seed(42)
2874+
x = np.random.uniform(1, 100, 100)
2875+
y = np.random.uniform(1, 100, 100)
2876+
z = x * y / 10
2877+
return x, y, z
2878+
2879+
2880+
@mpl3d_image_comparison(['scale3d_lines_log.png'], style='mpl20')
2881+
def test_scale3d_lines_log():
2882+
"""Test Line3D and Line3DCollection with log scale (plot, wireframe)."""
2883+
fig = plt.figure()
2884+
2885+
# Left: regular plot (Line3D)
2886+
ax1 = fig.add_subplot(1, 2, 1, projection='3d')
2887+
x, y, z = _make_log_data()
2888+
ax1.plot(x, y, z)
2889+
ax1.set(xscale='log', yscale='log', zscale='log')
2890+
2891+
# Right: wireframe (Line3DCollection)
2892+
ax2 = fig.add_subplot(1, 2, 2, projection='3d')
2893+
X, Y, Z = _make_surface_log_data()
2894+
ax2.plot_wireframe(X, Y, Z, rstride=5, cstride=5)
2895+
ax2.set(xscale='log', yscale='log', zscale='log')
2896+
2897+
2898+
@mpl3d_image_comparison(['scale3d_scatter_log.png'], style='mpl20')
2899+
def test_scale3d_scatter_log():
2900+
"""Test Path3DCollection with log scale (scatter)."""
2901+
fig = plt.figure()
2902+
ax = fig.add_subplot(projection='3d')
2903+
x, y, z = _make_log_data()
2904+
ax.scatter(x, y, z, c=z, cmap='viridis')
2905+
ax.set(xscale='log', yscale='log', zscale='log')
2906+
2907+
2908+
@mpl3d_image_comparison(['scale3d_surface_log.png'], style='mpl20')
2909+
def test_scale3d_surface_log():
2910+
"""Test Poly3DCollection with log scale (surface, trisurf)."""
2911+
fig = plt.figure()
2912+
2913+
# Left: plot_surface
2914+
ax1 = fig.add_subplot(1, 2, 1, projection='3d')
2915+
X, Y, Z = _make_surface_log_data()
2916+
ax1.plot_surface(X, Y, Z, cmap='viridis', alpha=0.8)
2917+
ax1.set(xscale='log', yscale='log', zscale='log')
2918+
2919+
# Right: plot_trisurf
2920+
ax2 = fig.add_subplot(1, 2, 2, projection='3d')
2921+
x, y, z = _make_triangulation_data()
2922+
ax2.plot_trisurf(x, y, z, cmap='viridis', alpha=0.8)
2923+
ax2.set(xscale='log', yscale='log', zscale='log')
2924+
2925+
2926+
@mpl3d_image_comparison(['scale3d_bar3d_log.png'], style='mpl20')
2927+
def test_scale3d_bar3d_log():
2928+
"""Test bar3d with log scale."""
2929+
fig = plt.figure()
2930+
ax = fig.add_subplot(projection='3d')
2931+
2932+
# Bar positions (in log space, use positive values)
2933+
x, y = np.meshgrid([1, 10, 100], [1, 10, 100])
2934+
x, y = x.flatten(), y.flatten()
2935+
z = np.ones_like(x, dtype=float)
2936+
ax.bar3d(x, y, z, x * 0.3, y * 0.3, x * y / 10, alpha=0.8)
2937+
ax.set(xscale='log', yscale='log', zscale='log')
2938+
2939+
2940+
@mpl3d_image_comparison(['scale3d_contour_log.png'], style='mpl20')
2941+
def test_scale3d_contour_log():
2942+
"""Test contour and contourf with log scale."""
2943+
fig = plt.figure()
2944+
X, Y, Z = _make_surface_log_data()
2945+
2946+
# Left: contour (Line3DCollection)
2947+
ax1 = fig.add_subplot(1, 2, 1, projection='3d')
2948+
ax1.contour(X, Y, Z, levels=10)
2949+
ax1.set(xscale='log', yscale='log', zscale='log')
2950+
2951+
# Right: contourf (Poly3DCollection)
2952+
ax2 = fig.add_subplot(1, 2, 2, projection='3d')
2953+
ax2.contourf(X, Y, Z, levels=10, alpha=0.8)
2954+
ax2.set(xscale='log', yscale='log', zscale='log')
2955+
2956+
2957+
@mpl3d_image_comparison(['scale3d_stem_quiver_log.png'], style='mpl20')
2958+
def test_scale3d_stem_quiver_log():
2959+
"""Test stem and quiver with log scale."""
2960+
fig = plt.figure()
2961+
2962+
# Left: stem
2963+
ax1 = fig.add_subplot(1, 2, 1, projection='3d')
2964+
x, y, z = [1, 10, 100], [1, 10, 100], [10, 100, 1000]
2965+
ax1.stem(x, y, z, bottom=1)
2966+
ax1.set(xscale='log', yscale='log', zscale='log')
2967+
2968+
# Right: quiver
2969+
ax2 = fig.add_subplot(1, 2, 2, projection='3d')
2970+
x, y, z = np.array([1, 10, 100]), np.array([1, 10, 100]), np.array([1, 10, 100])
2971+
ax2.quiver(x, y, z, x * 0.5, y * 0.5, z * 0.5)
2972+
ax2.set(xscale='log', yscale='log', zscale='log')
2973+
2974+
2975+
@mpl3d_image_comparison(['scale3d_text_log.png'], style='mpl20', remove_text=False)
2976+
def test_scale3d_text_log():
2977+
"""Test Text3D with log scale."""
2978+
fig = plt.figure()
2979+
ax = fig.add_subplot(projection='3d')
2980+
ax.text(1, 1, 1, "Point A")
2981+
ax.text(10, 10, 10, "Point B")
2982+
ax.text(100, 100, 100, "Point C")
2983+
ax.set(xscale='log', yscale='log', zscale='log',
2984+
xlim=(0.5, 200), ylim=(0.5, 200), zlim=(0.5, 200))
2985+
2986+
2987+
@mpl3d_image_comparison(['scale3d_all_scales.png'], style='mpl20', remove_text=False)
2988+
def test_scale3d_all_scales():
2989+
"""Test all scale types with mixed scales on each axis."""
2990+
fig, axs = plt.subplots(1, 2, subplot_kw={'projection': '3d'}, figsize=(10, 6))
2991+
2992+
# Data that works across all scale types
2993+
t = np.linspace(0.1, 0.9, 30)
2994+
# x: positive for log/asinh, y: spans neg/pos for symlog, z: (0,1) for logit
2995+
x = t * 100 # 10 to 90
2996+
y = (t - 0.5) * 20 # -10 to 10
2997+
z = t # 0.1 to 0.9
2998+
2999+
# Subplot 1: x=log, y=symlog, z=logit
3000+
axs[0].scatter(x, y, z)
3001+
axs[0].set(xscale='log', yscale='symlog', zscale='logit',
3002+
xlabel='log', ylabel='symlog', zlabel='logit')
3003+
3004+
# Subplot 2: x=asinh, y=linear, z=function (square root)
3005+
axs[1].scatter(x, y, z)
3006+
axs[1].set_xscale('asinh')
3007+
axs[1].set_zscale('function', functions=(lambda v: v**0.5, lambda v: v**2))
3008+
axs[1].set(xlabel='asinh', ylabel='linear', zlabel='function')
3009+
3010+
3011+
@mpl3d_image_comparison(['scale3d_log_bases.png'], style='mpl20', remove_text=False)
3012+
def test_scale3d_log_bases():
3013+
"""Test log scale with different bases and subs."""
3014+
fig, axs = plt.subplots(2, 2, subplot_kw={'projection': '3d'}, figsize=(10, 8))
3015+
x, y, z = _make_log_data()
3016+
3017+
for ax, base, title in [(axs[0, 0], 10, 'base=10'),
3018+
(axs[0, 1], 2, 'base=2'),
3019+
(axs[1, 0], np.e, 'base=e')]:
3020+
ax.scatter(x, y, z, s=10)
3021+
ax.set_xscale('log', base=base)
3022+
ax.set_yscale('log', base=base)
3023+
ax.set_zscale('log', base=base)
3024+
ax.set_title(title)
3025+
if base == np.e:
3026+
# Format tick labels as e^n instead of 2.718...^n
3027+
def fmt_e(x, pos=None):
3028+
if x <= 0:
3029+
return ''
3030+
exp = np.log(x)
3031+
if np.isclose(exp, round(exp)):
3032+
return r'$e^{%d}$' % round(exp)
3033+
return ''
3034+
ax.xaxis.set_major_formatter(fmt_e)
3035+
ax.yaxis.set_major_formatter(fmt_e)
3036+
ax.zaxis.set_major_formatter(fmt_e)
3037+
3038+
# subs
3039+
axs[1, 1].scatter(x, y, z, s=10)
3040+
axs[1, 1].set_xscale('log', subs=[2, 5])
3041+
axs[1, 1].set_yscale('log', subs=[2, 5])
3042+
axs[1, 1].set_zscale('log', subs=[2, 5])
3043+
axs[1, 1].set_title('subs=[2,5]')
3044+
3045+
3046+
@mpl3d_image_comparison(['scale3d_symlog_params.png'], style='mpl20',
3047+
remove_text=False)
3048+
def test_scale3d_symlog_params():
3049+
"""Test symlog scale with different linthresh values."""
3050+
fig, axs = plt.subplots(1, 2, subplot_kw={'projection': '3d'})
3051+
3052+
# Data spanning negative, zero, and positive
3053+
t = np.linspace(-3, 3, 50)
3054+
x = np.sinh(t) * 10
3055+
y = t ** 3
3056+
z = np.sign(t) * np.abs(t) ** 2
3057+
3058+
for ax, linthresh in [(axs[0], 0.1), (axs[1], 10)]:
3059+
ax.scatter(x, y, z, c=np.abs(z), cmap='viridis', s=10)
3060+
ax.set_xscale('symlog', linthresh=linthresh)
3061+
ax.set_yscale('symlog', linthresh=linthresh)
3062+
ax.set_zscale('symlog', linthresh=linthresh)
3063+
ax.set_title(f'linthresh={linthresh}')
3064+
3065+
3066+
@pytest.mark.parametrize('scale_type,kwargs', [
3067+
('log', {'base': 10}),
3068+
('log', {'base': 2}),
3069+
('log', {'subs': [2, 5]}),
3070+
('log', {'nonpositive': 'mask'}),
3071+
('symlog', {'base': 2}),
3072+
('symlog', {'linthresh': 1}),
3073+
('symlog', {'linscale': 0.5}),
3074+
('symlog', {'subs': [2, 5]}),
3075+
('asinh', {'linear_width': 0.5}),
3076+
('asinh', {'base': 2}),
3077+
('logit', {'nonpositive': 'clip'}),
3078+
])
3079+
def test_scale3d_keywords_accepted(scale_type, kwargs):
3080+
"""Verify that scale keywords are accepted on all 3 axes."""
3081+
fig = plt.figure()
3082+
ax = fig.add_subplot(projection='3d')
3083+
for setter in [ax.set_xscale, ax.set_yscale, ax.set_zscale]:
3084+
setter(scale_type, **kwargs)
3085+
assert (ax.get_xscale(), ax.get_yscale(), ax.get_zscale()) == (scale_type,) * 3
3086+
3087+
3088+
@pytest.mark.parametrize('axis', ['x', 'y', 'z'])
3089+
def test_scale3d_limit_range_log(axis):
3090+
"""Log scale should warn when setting non-positive limits."""
3091+
fig = plt.figure()
3092+
ax = fig.add_subplot(projection='3d')
3093+
getattr(ax, f'set_{axis}scale')('log')
3094+
3095+
# Setting non-positive limits should warn
3096+
with pytest.warns(UserWarning, match="non-positive"):
3097+
getattr(ax, f'set_{axis}lim')(-10, 100)
3098+
3099+
3100+
def test_scale3d_limit_range_logit():
3101+
"""Logit scale should constrain axis to (0, 1)."""
3102+
fig = plt.figure()
3103+
ax = fig.add_subplot(projection='3d')
3104+
ax.set(xscale='logit', yscale='logit', zscale='logit',
3105+
xlim=(-0.5, 1.5), ylim=(-0.5, 1.5), zlim=(-0.5, 1.5))
3106+
3107+
# Limits should be constrained to (0, 1)
3108+
for name, lim in [('x', ax.get_xlim()), ('y', ax.get_ylim()),
3109+
('z', ax.get_zlim())]:
3110+
assert lim[0] > 0, f"{name} lower limit should be > 0 for logit"
3111+
assert lim[1] < 1, f"{name} upper limit should be < 1 for logit"
3112+
3113+
3114+
@pytest.mark.parametrize('scale_type', ['log', 'symlog', 'logit', 'asinh'])
3115+
def test_scale3d_transform_roundtrip(scale_type):
3116+
"""Forward/inverse transform should preserve values."""
3117+
fig = plt.figure()
3118+
ax = fig.add_subplot(projection='3d')
3119+
ax.set(xscale=scale_type, yscale=scale_type, zscale=scale_type)
3120+
3121+
# Use appropriate test values for each scale type
3122+
test_values = {
3123+
'log': [1, 10, 100, 1000],
3124+
'symlog': [-100, -1, 0, 1, 100],
3125+
'asinh': [-100, -1, 0, 1, 100],
3126+
'logit': [0.01, 0.1, 0.5, 0.9, 0.99],
3127+
}[scale_type]
3128+
test_values = np.array(test_values)
3129+
3130+
# Test round-trip for each axis
3131+
for axis in [ax.xaxis, ax.yaxis, ax.zaxis]:
3132+
trans = axis.get_transform()
3133+
forward = trans.transform(test_values.reshape(-1, 1))
3134+
inverse = trans.inverted().transform(forward)
3135+
np.testing.assert_allclose(inverse.flatten(), test_values, rtol=1e-10)
3136+
3137+
3138+
def test_scale3d_invalid_keywords_raise():
3139+
"""Invalid kwargs should raise TypeError."""
3140+
fig = plt.figure()
3141+
ax = fig.add_subplot(projection='3d')
3142+
3143+
with pytest.raises(TypeError):
3144+
ax.set_xscale('log', invalid_kwarg=True)
3145+
3146+
with pytest.raises(TypeError):
3147+
ax.set_yscale('symlog', invalid_kwarg=True)
3148+
3149+
with pytest.raises(TypeError):
3150+
ax.set_zscale('logit', invalid_kwarg=True)
3151+
3152+
3153+
def test_scale3d_persists_after_plot():
3154+
"""Scale should persist after adding plot data."""
3155+
fig = plt.figure()
3156+
ax = fig.add_subplot(projection='3d')
3157+
ax.set(xscale='log', yscale='log', zscale='log')
3158+
ax.plot(*_make_log_data())
3159+
assert (ax.get_xscale(), ax.get_yscale(), ax.get_zscale()) == ('log',) * 3
3160+
3161+
3162+
def test_scale3d_autoscale_with_log():
3163+
"""Autoscale should work correctly with log scale."""
3164+
fig = plt.figure()
3165+
ax = fig.add_subplot(projection='3d')
3166+
ax.set(xscale='log', yscale='log', zscale='log')
3167+
ax.scatter([1, 10, 100], [1, 10, 100], [1, 10, 100])
3168+
3169+
# All limits should be positive
3170+
for name, lim in [('x', ax.get_xlim()), ('y', ax.get_ylim()),
3171+
('z', ax.get_zlim())]:
3172+
assert lim[0] > 0, f"{name} lower limit should be positive"
3173+
assert lim[1] > 0, f"{name} upper limit should be positive"

0 commit comments

Comments
 (0)