|
18 | 18 | from matplotlib.patches import Circle, PathPatch |
19 | 19 | from matplotlib.path import Path |
20 | 20 | from matplotlib.text import Text |
21 | | -from matplotlib import _api |
| 21 | +from matplotlib import _api |
22 | 22 |
|
23 | 23 | import matplotlib.pyplot as plt |
24 | 24 | import numpy as np |
@@ -2844,3 +2844,330 @@ def test_ctrl_rotation_snaps_to_5deg(): |
2844 | 2844 | assert ax.roll == pytest.approx(expected_roll) |
2845 | 2845 |
|
2846 | 2846 | plt.close(fig) |
| 2847 | + |
| 2848 | + |
| 2849 | +# ============================================================================= |
| 2850 | +# Tests for 3D scale transforms (log, symlog, logit, etc.) |
| 2851 | +# ============================================================================= |
| 2852 | + |
| 2853 | +def _make_log_data(): |
| 2854 | + """Data spanning 1 to ~1000 for log scale.""" |
| 2855 | + t = np.linspace(0, 2 * np.pi, 50) |
| 2856 | + x = 10 ** (t / 2) |
| 2857 | + y = 10 ** (1 + np.sin(t)) |
| 2858 | + z = 10 ** (2 * (1 + np.cos(t) / 2)) |
| 2859 | + return x, y, z |
| 2860 | + |
| 2861 | + |
| 2862 | +def _make_surface_log_data(): |
| 2863 | + """Grid data for surface with positive Z.""" |
| 2864 | + x = np.linspace(1, 10, 20) |
| 2865 | + y = np.linspace(1, 10, 20) |
| 2866 | + X, Y = np.meshgrid(x, y) |
| 2867 | + Z = X * Y |
| 2868 | + return X, Y, Z |
| 2869 | + |
| 2870 | + |
| 2871 | +def _make_triangulation_data(): |
| 2872 | + """Data for trisurf with positive values.""" |
| 2873 | + np.random.seed(42) |
| 2874 | + x = np.random.uniform(1, 100, 100) |
| 2875 | + y = np.random.uniform(1, 100, 100) |
| 2876 | + z = x * y / 10 |
| 2877 | + return x, y, z |
| 2878 | + |
| 2879 | + |
| 2880 | +@mpl3d_image_comparison(['scale3d_lines_log.png'], style='mpl20') |
| 2881 | +def test_scale3d_lines_log(): |
| 2882 | + """Test Line3D and Line3DCollection with log scale (plot, wireframe).""" |
| 2883 | + fig = plt.figure() |
| 2884 | + |
| 2885 | + # Left: regular plot (Line3D) |
| 2886 | + ax1 = fig.add_subplot(1, 2, 1, projection='3d') |
| 2887 | + x, y, z = _make_log_data() |
| 2888 | + ax1.plot(x, y, z) |
| 2889 | + ax1.set(xscale='log', yscale='log', zscale='log') |
| 2890 | + |
| 2891 | + # Right: wireframe (Line3DCollection) |
| 2892 | + ax2 = fig.add_subplot(1, 2, 2, projection='3d') |
| 2893 | + X, Y, Z = _make_surface_log_data() |
| 2894 | + ax2.plot_wireframe(X, Y, Z, rstride=5, cstride=5) |
| 2895 | + ax2.set(xscale='log', yscale='log', zscale='log') |
| 2896 | + |
| 2897 | + |
| 2898 | +@mpl3d_image_comparison(['scale3d_scatter_log.png'], style='mpl20') |
| 2899 | +def test_scale3d_scatter_log(): |
| 2900 | + """Test Path3DCollection with log scale (scatter).""" |
| 2901 | + fig = plt.figure() |
| 2902 | + ax = fig.add_subplot(projection='3d') |
| 2903 | + x, y, z = _make_log_data() |
| 2904 | + ax.scatter(x, y, z, c=z, cmap='viridis') |
| 2905 | + ax.set(xscale='log', yscale='log', zscale='log') |
| 2906 | + |
| 2907 | + |
| 2908 | +@mpl3d_image_comparison(['scale3d_surface_log.png'], style='mpl20') |
| 2909 | +def test_scale3d_surface_log(): |
| 2910 | + """Test Poly3DCollection with log scale (surface, trisurf).""" |
| 2911 | + fig = plt.figure() |
| 2912 | + |
| 2913 | + # Left: plot_surface |
| 2914 | + ax1 = fig.add_subplot(1, 2, 1, projection='3d') |
| 2915 | + X, Y, Z = _make_surface_log_data() |
| 2916 | + ax1.plot_surface(X, Y, Z, cmap='viridis', alpha=0.8) |
| 2917 | + ax1.set(xscale='log', yscale='log', zscale='log') |
| 2918 | + |
| 2919 | + # Right: plot_trisurf |
| 2920 | + ax2 = fig.add_subplot(1, 2, 2, projection='3d') |
| 2921 | + x, y, z = _make_triangulation_data() |
| 2922 | + ax2.plot_trisurf(x, y, z, cmap='viridis', alpha=0.8) |
| 2923 | + ax2.set(xscale='log', yscale='log', zscale='log') |
| 2924 | + |
| 2925 | + |
| 2926 | +@mpl3d_image_comparison(['scale3d_bar3d_log.png'], style='mpl20') |
| 2927 | +def test_scale3d_bar3d_log(): |
| 2928 | + """Test bar3d with log scale.""" |
| 2929 | + fig = plt.figure() |
| 2930 | + ax = fig.add_subplot(projection='3d') |
| 2931 | + |
| 2932 | + # Bar positions (in log space, use positive values) |
| 2933 | + x, y = np.meshgrid([1, 10, 100], [1, 10, 100]) |
| 2934 | + x, y = x.flatten(), y.flatten() |
| 2935 | + z = np.ones_like(x, dtype=float) |
| 2936 | + ax.bar3d(x, y, z, x * 0.3, y * 0.3, x * y / 10, alpha=0.8) |
| 2937 | + ax.set(xscale='log', yscale='log', zscale='log') |
| 2938 | + |
| 2939 | + |
| 2940 | +@mpl3d_image_comparison(['scale3d_contour_log.png'], style='mpl20') |
| 2941 | +def test_scale3d_contour_log(): |
| 2942 | + """Test contour and contourf with log scale.""" |
| 2943 | + fig = plt.figure() |
| 2944 | + X, Y, Z = _make_surface_log_data() |
| 2945 | + |
| 2946 | + # Left: contour (Line3DCollection) |
| 2947 | + ax1 = fig.add_subplot(1, 2, 1, projection='3d') |
| 2948 | + ax1.contour(X, Y, Z, levels=10) |
| 2949 | + ax1.set(xscale='log', yscale='log', zscale='log') |
| 2950 | + |
| 2951 | + # Right: contourf (Poly3DCollection) |
| 2952 | + ax2 = fig.add_subplot(1, 2, 2, projection='3d') |
| 2953 | + ax2.contourf(X, Y, Z, levels=10, alpha=0.8) |
| 2954 | + ax2.set(xscale='log', yscale='log', zscale='log') |
| 2955 | + |
| 2956 | + |
| 2957 | +@mpl3d_image_comparison(['scale3d_stem_quiver_log.png'], style='mpl20') |
| 2958 | +def test_scale3d_stem_quiver_log(): |
| 2959 | + """Test stem and quiver with log scale.""" |
| 2960 | + fig = plt.figure() |
| 2961 | + |
| 2962 | + # Left: stem |
| 2963 | + ax1 = fig.add_subplot(1, 2, 1, projection='3d') |
| 2964 | + x, y, z = [1, 10, 100], [1, 10, 100], [10, 100, 1000] |
| 2965 | + ax1.stem(x, y, z, bottom=1) |
| 2966 | + ax1.set(xscale='log', yscale='log', zscale='log') |
| 2967 | + |
| 2968 | + # Right: quiver |
| 2969 | + ax2 = fig.add_subplot(1, 2, 2, projection='3d') |
| 2970 | + x, y, z = np.array([1, 10, 100]), np.array([1, 10, 100]), np.array([1, 10, 100]) |
| 2971 | + ax2.quiver(x, y, z, x * 0.5, y * 0.5, z * 0.5) |
| 2972 | + ax2.set(xscale='log', yscale='log', zscale='log') |
| 2973 | + |
| 2974 | + |
| 2975 | +@mpl3d_image_comparison(['scale3d_text_log.png'], style='mpl20', remove_text=False) |
| 2976 | +def test_scale3d_text_log(): |
| 2977 | + """Test Text3D with log scale.""" |
| 2978 | + fig = plt.figure() |
| 2979 | + ax = fig.add_subplot(projection='3d') |
| 2980 | + ax.text(1, 1, 1, "Point A") |
| 2981 | + ax.text(10, 10, 10, "Point B") |
| 2982 | + ax.text(100, 100, 100, "Point C") |
| 2983 | + ax.set(xscale='log', yscale='log', zscale='log', |
| 2984 | + xlim=(0.5, 200), ylim=(0.5, 200), zlim=(0.5, 200)) |
| 2985 | + |
| 2986 | + |
| 2987 | +@mpl3d_image_comparison(['scale3d_all_scales.png'], style='mpl20', remove_text=False) |
| 2988 | +def test_scale3d_all_scales(): |
| 2989 | + """Test all scale types with mixed scales on each axis.""" |
| 2990 | + fig, axs = plt.subplots(1, 2, subplot_kw={'projection': '3d'}, figsize=(10, 6)) |
| 2991 | + |
| 2992 | + # Data that works across all scale types |
| 2993 | + t = np.linspace(0.1, 0.9, 30) |
| 2994 | + # x: positive for log/asinh, y: spans neg/pos for symlog, z: (0,1) for logit |
| 2995 | + x = t * 100 # 10 to 90 |
| 2996 | + y = (t - 0.5) * 20 # -10 to 10 |
| 2997 | + z = t # 0.1 to 0.9 |
| 2998 | + |
| 2999 | + # Subplot 1: x=log, y=symlog, z=logit |
| 3000 | + axs[0].scatter(x, y, z) |
| 3001 | + axs[0].set(xscale='log', yscale='symlog', zscale='logit', |
| 3002 | + xlabel='log', ylabel='symlog', zlabel='logit') |
| 3003 | + |
| 3004 | + # Subplot 2: x=asinh, y=linear, z=function (square root) |
| 3005 | + axs[1].scatter(x, y, z) |
| 3006 | + axs[1].set_xscale('asinh') |
| 3007 | + axs[1].set_zscale('function', functions=(lambda v: v**0.5, lambda v: v**2)) |
| 3008 | + axs[1].set(xlabel='asinh', ylabel='linear', zlabel='function') |
| 3009 | + |
| 3010 | + |
| 3011 | +@mpl3d_image_comparison(['scale3d_log_bases.png'], style='mpl20', remove_text=False) |
| 3012 | +def test_scale3d_log_bases(): |
| 3013 | + """Test log scale with different bases and subs.""" |
| 3014 | + fig, axs = plt.subplots(2, 2, subplot_kw={'projection': '3d'}, figsize=(10, 8)) |
| 3015 | + x, y, z = _make_log_data() |
| 3016 | + |
| 3017 | + for ax, base, title in [(axs[0, 0], 10, 'base=10'), |
| 3018 | + (axs[0, 1], 2, 'base=2'), |
| 3019 | + (axs[1, 0], np.e, 'base=e')]: |
| 3020 | + ax.scatter(x, y, z, s=10) |
| 3021 | + ax.set_xscale('log', base=base) |
| 3022 | + ax.set_yscale('log', base=base) |
| 3023 | + ax.set_zscale('log', base=base) |
| 3024 | + ax.set_title(title) |
| 3025 | + if base == np.e: |
| 3026 | + # Format tick labels as e^n instead of 2.718...^n |
| 3027 | + def fmt_e(x, pos=None): |
| 3028 | + if x <= 0: |
| 3029 | + return '' |
| 3030 | + exp = np.log(x) |
| 3031 | + if np.isclose(exp, round(exp)): |
| 3032 | + return r'$e^{%d}$' % round(exp) |
| 3033 | + return '' |
| 3034 | + ax.xaxis.set_major_formatter(fmt_e) |
| 3035 | + ax.yaxis.set_major_formatter(fmt_e) |
| 3036 | + ax.zaxis.set_major_formatter(fmt_e) |
| 3037 | + |
| 3038 | + # subs |
| 3039 | + axs[1, 1].scatter(x, y, z, s=10) |
| 3040 | + axs[1, 1].set_xscale('log', subs=[2, 5]) |
| 3041 | + axs[1, 1].set_yscale('log', subs=[2, 5]) |
| 3042 | + axs[1, 1].set_zscale('log', subs=[2, 5]) |
| 3043 | + axs[1, 1].set_title('subs=[2,5]') |
| 3044 | + |
| 3045 | + |
| 3046 | +@mpl3d_image_comparison(['scale3d_symlog_params.png'], style='mpl20', |
| 3047 | + remove_text=False) |
| 3048 | +def test_scale3d_symlog_params(): |
| 3049 | + """Test symlog scale with different linthresh values.""" |
| 3050 | + fig, axs = plt.subplots(1, 2, subplot_kw={'projection': '3d'}) |
| 3051 | + |
| 3052 | + # Data spanning negative, zero, and positive |
| 3053 | + t = np.linspace(-3, 3, 50) |
| 3054 | + x = np.sinh(t) * 10 |
| 3055 | + y = t ** 3 |
| 3056 | + z = np.sign(t) * np.abs(t) ** 2 |
| 3057 | + |
| 3058 | + for ax, linthresh in [(axs[0], 0.1), (axs[1], 10)]: |
| 3059 | + ax.scatter(x, y, z, c=np.abs(z), cmap='viridis', s=10) |
| 3060 | + ax.set_xscale('symlog', linthresh=linthresh) |
| 3061 | + ax.set_yscale('symlog', linthresh=linthresh) |
| 3062 | + ax.set_zscale('symlog', linthresh=linthresh) |
| 3063 | + ax.set_title(f'linthresh={linthresh}') |
| 3064 | + |
| 3065 | + |
| 3066 | +@pytest.mark.parametrize('scale_type,kwargs', [ |
| 3067 | + ('log', {'base': 10}), |
| 3068 | + ('log', {'base': 2}), |
| 3069 | + ('log', {'subs': [2, 5]}), |
| 3070 | + ('log', {'nonpositive': 'mask'}), |
| 3071 | + ('symlog', {'base': 2}), |
| 3072 | + ('symlog', {'linthresh': 1}), |
| 3073 | + ('symlog', {'linscale': 0.5}), |
| 3074 | + ('symlog', {'subs': [2, 5]}), |
| 3075 | + ('asinh', {'linear_width': 0.5}), |
| 3076 | + ('asinh', {'base': 2}), |
| 3077 | + ('logit', {'nonpositive': 'clip'}), |
| 3078 | +]) |
| 3079 | +def test_scale3d_keywords_accepted(scale_type, kwargs): |
| 3080 | + """Verify that scale keywords are accepted on all 3 axes.""" |
| 3081 | + fig = plt.figure() |
| 3082 | + ax = fig.add_subplot(projection='3d') |
| 3083 | + for setter in [ax.set_xscale, ax.set_yscale, ax.set_zscale]: |
| 3084 | + setter(scale_type, **kwargs) |
| 3085 | + assert (ax.get_xscale(), ax.get_yscale(), ax.get_zscale()) == (scale_type,) * 3 |
| 3086 | + |
| 3087 | + |
| 3088 | +@pytest.mark.parametrize('axis', ['x', 'y', 'z']) |
| 3089 | +def test_scale3d_limit_range_log(axis): |
| 3090 | + """Log scale should warn when setting non-positive limits.""" |
| 3091 | + fig = plt.figure() |
| 3092 | + ax = fig.add_subplot(projection='3d') |
| 3093 | + getattr(ax, f'set_{axis}scale')('log') |
| 3094 | + |
| 3095 | + # Setting non-positive limits should warn |
| 3096 | + with pytest.warns(UserWarning, match="non-positive"): |
| 3097 | + getattr(ax, f'set_{axis}lim')(-10, 100) |
| 3098 | + |
| 3099 | + |
| 3100 | +def test_scale3d_limit_range_logit(): |
| 3101 | + """Logit scale should constrain axis to (0, 1).""" |
| 3102 | + fig = plt.figure() |
| 3103 | + ax = fig.add_subplot(projection='3d') |
| 3104 | + ax.set(xscale='logit', yscale='logit', zscale='logit', |
| 3105 | + xlim=(-0.5, 1.5), ylim=(-0.5, 1.5), zlim=(-0.5, 1.5)) |
| 3106 | + |
| 3107 | + # Limits should be constrained to (0, 1) |
| 3108 | + for name, lim in [('x', ax.get_xlim()), ('y', ax.get_ylim()), |
| 3109 | + ('z', ax.get_zlim())]: |
| 3110 | + assert lim[0] > 0, f"{name} lower limit should be > 0 for logit" |
| 3111 | + assert lim[1] < 1, f"{name} upper limit should be < 1 for logit" |
| 3112 | + |
| 3113 | + |
| 3114 | +@pytest.mark.parametrize('scale_type', ['log', 'symlog', 'logit', 'asinh']) |
| 3115 | +def test_scale3d_transform_roundtrip(scale_type): |
| 3116 | + """Forward/inverse transform should preserve values.""" |
| 3117 | + fig = plt.figure() |
| 3118 | + ax = fig.add_subplot(projection='3d') |
| 3119 | + ax.set(xscale=scale_type, yscale=scale_type, zscale=scale_type) |
| 3120 | + |
| 3121 | + # Use appropriate test values for each scale type |
| 3122 | + test_values = { |
| 3123 | + 'log': [1, 10, 100, 1000], |
| 3124 | + 'symlog': [-100, -1, 0, 1, 100], |
| 3125 | + 'asinh': [-100, -1, 0, 1, 100], |
| 3126 | + 'logit': [0.01, 0.1, 0.5, 0.9, 0.99], |
| 3127 | + }[scale_type] |
| 3128 | + test_values = np.array(test_values) |
| 3129 | + |
| 3130 | + # Test round-trip for each axis |
| 3131 | + for axis in [ax.xaxis, ax.yaxis, ax.zaxis]: |
| 3132 | + trans = axis.get_transform() |
| 3133 | + forward = trans.transform(test_values.reshape(-1, 1)) |
| 3134 | + inverse = trans.inverted().transform(forward) |
| 3135 | + np.testing.assert_allclose(inverse.flatten(), test_values, rtol=1e-10) |
| 3136 | + |
| 3137 | + |
| 3138 | +def test_scale3d_invalid_keywords_raise(): |
| 3139 | + """Invalid kwargs should raise TypeError.""" |
| 3140 | + fig = plt.figure() |
| 3141 | + ax = fig.add_subplot(projection='3d') |
| 3142 | + |
| 3143 | + with pytest.raises(TypeError): |
| 3144 | + ax.set_xscale('log', invalid_kwarg=True) |
| 3145 | + |
| 3146 | + with pytest.raises(TypeError): |
| 3147 | + ax.set_yscale('symlog', invalid_kwarg=True) |
| 3148 | + |
| 3149 | + with pytest.raises(TypeError): |
| 3150 | + ax.set_zscale('logit', invalid_kwarg=True) |
| 3151 | + |
| 3152 | + |
| 3153 | +def test_scale3d_persists_after_plot(): |
| 3154 | + """Scale should persist after adding plot data.""" |
| 3155 | + fig = plt.figure() |
| 3156 | + ax = fig.add_subplot(projection='3d') |
| 3157 | + ax.set(xscale='log', yscale='log', zscale='log') |
| 3158 | + ax.plot(*_make_log_data()) |
| 3159 | + assert (ax.get_xscale(), ax.get_yscale(), ax.get_zscale()) == ('log',) * 3 |
| 3160 | + |
| 3161 | + |
| 3162 | +def test_scale3d_autoscale_with_log(): |
| 3163 | + """Autoscale should work correctly with log scale.""" |
| 3164 | + fig = plt.figure() |
| 3165 | + ax = fig.add_subplot(projection='3d') |
| 3166 | + ax.set(xscale='log', yscale='log', zscale='log') |
| 3167 | + ax.scatter([1, 10, 100], [1, 10, 100], [1, 10, 100]) |
| 3168 | + |
| 3169 | + # All limits should be positive |
| 3170 | + for name, lim in [('x', ax.get_xlim()), ('y', ax.get_ylim()), |
| 3171 | + ('z', ax.get_zlim())]: |
| 3172 | + assert lim[0] > 0, f"{name} lower limit should be positive" |
| 3173 | + assert lim[1] > 0, f"{name} upper limit should be positive" |
0 commit comments