diff --git a/lib/mpl_toolkits/mplot3d/axes3d.py b/lib/mpl_toolkits/mplot3d/axes3d.py index accc1650093b..1f6637d206ad 100644 --- a/lib/mpl_toolkits/mplot3d/axes3d.py +++ b/lib/mpl_toolkits/mplot3d/axes3d.py @@ -587,6 +587,18 @@ def _determine_lims(self, xmin=None, xmax=None, *args, **kwargs): xmax += 0.05 return (xmin, xmax) + def _validate_axis_limits(self, limit, convert): + """ + Raise ValueError if specified axis limits are infinite. + + """ + if limit is not None: + converted_limits = convert(limit) + if (isinstance(limit, float) and + (not np.isreal(limit) or not np.isfinite(limit))): + raise ValueError("Axis limits cannot be NaN or Inf") + return converted_limits + def set_xlim3d(self, left=None, right=None, emit=True, auto=False, **kw): """ Set 3D x limits. @@ -605,10 +617,8 @@ def set_xlim3d(self, left=None, right=None, emit=True, auto=False, **kw): left, right = left self._process_unit_info(xdata=(left, right)) - if left is not None: - left = self.convert_xunits(left) - if right is not None: - right = self.convert_xunits(right) + left = self._validate_axis_limits(left, self.convert_xunits) + right = self._validate_axis_limits(right, self.convert_xunits) old_left, old_right = self.get_xlim() if left is None: @@ -665,10 +675,8 @@ def set_ylim3d(self, bottom=None, top=None, emit=True, auto=False, **kw): top = self.convert_yunits(top) old_bottom, old_top = self.get_ylim() - if bottom is None: - bottom = old_bottom - if top is None: - top = old_top + bottom = self._validate_axis_limits(bottom, self.convert_yunits) + top = self._validate_axis_limits(top, self.convert_yunits) if top == bottom: warnings.warn(('Attempting to set identical bottom==top results\n' @@ -713,10 +721,8 @@ def set_zlim3d(self, bottom=None, top=None, emit=True, auto=False, **kw): bottom, top = bottom self._process_unit_info(zdata=(bottom, top)) - if bottom is not None: - bottom = self.convert_zunits(bottom) - if top is not None: - top = self.convert_zunits(top) + bottom = self._validate_axis_limits(bottom, self.convert_yunits) + top = self._validate_axis_limits(top, self.convert_yunits) old_bottom, old_top = self.get_zlim() if bottom is None: diff --git a/lib/mpl_toolkits/tests/test_mplot3d.py b/lib/mpl_toolkits/tests/test_mplot3d.py index b6a0692491c4..7ba561c6d866 100644 --- a/lib/mpl_toolkits/tests/test_mplot3d.py +++ b/lib/mpl_toolkits/tests/test_mplot3d.py @@ -547,3 +547,20 @@ def test_axes3d_ortho(): fig = plt.figure() ax = fig.gca(projection='3d') ax.set_proj_type('ortho') + + +@pytest.mark.parametrize('value', [np.inf, np.nan]) +@pytest.mark.parametrize(('setter', 'side'), [ + ('set_xlim3d', 'left'), + ('set_xlim3d', 'right'), + ('set_ylim3d', 'bottom'), + ('set_ylim3d', 'top'), + ('set_zlim3d', 'bottom'), + ('set_zlim3d', 'top'), +]) +def test_invalid_axes_limits(setter, side, value): + limit = {side: value} + fig = plt.figure() + obj = fig.add_subplot(111, projection='3d') + with pytest.raises(ValueError): + getattr(obj, setter)(**limit)