diff --git a/vm/src/convert/try_from.rs b/vm/src/convert/try_from.rs index 7eb1c9c00d..d2d83b36e7 100644 --- a/vm/src/convert/try_from.rs +++ b/vm/src/convert/try_from.rs @@ -3,6 +3,7 @@ use crate::{ builtins::PyFloat, object::{AsObject, PyObject, PyObjectRef, PyPayload, PyRef, PyResult}, }; +use malachite_bigint::Sign; use num_traits::ToPrimitive; /// Implemented by any type that can be created from a Python object. @@ -124,10 +125,19 @@ impl<'a, T: PyPayload> TryFromBorrowedObject<'a> for &'a Py { impl TryFromObject for std::time::Duration { fn try_from_object(vm: &VirtualMachine, obj: PyObjectRef) -> PyResult { if let Some(float) = obj.payload::() { - Ok(Self::from_secs_f64(float.to_f64())) + let f = float.to_f64(); + if f < 0.0 { + return Err(vm.new_value_error("negative duration")); + } + Ok(Self::from_secs_f64(f)) } else if let Some(int) = obj.try_index_opt(vm) { - let sec = int? - .as_bigint() + let int = int?; + let bigint = int.as_bigint(); + if bigint.sign() == Sign::Minus { + return Err(vm.new_value_error("negative duration")); + } + + let sec = bigint .to_u64() .ok_or_else(|| vm.new_value_error("value out of range"))?; Ok(Self::from_secs(sec)) diff --git a/vm/src/stdlib/time.rs b/vm/src/stdlib/time.rs index 12a47fca87..0de8648c12 100644 --- a/vm/src/stdlib/time.rs +++ b/vm/src/stdlib/time.rs @@ -34,7 +34,7 @@ unsafe extern "C" { #[pymodule(name = "time", with(platform))] mod decl { use crate::{ - PyObjectRef, PyResult, TryFromObject, VirtualMachine, + AsObject, PyObjectRef, PyResult, TryFromObject, VirtualMachine, builtins::{PyStrRef, PyTypeRef}, function::{Either, FuncArgs, OptionalArg}, types::PyStructSequence, @@ -88,10 +88,37 @@ mod decl { duration_since_system_now(vm) } - #[cfg(not(unix))] #[pyfunction] - fn sleep(dur: Duration) { - std::thread::sleep(dur); + fn sleep(seconds: PyObjectRef, vm: &VirtualMachine) -> PyResult<()> { + let dur = seconds.try_into_value::(vm).map_err(|e| { + if e.class().is(vm.ctx.exceptions.value_error) { + if let Some(s) = e.args().first().and_then(|arg| arg.str(vm).ok()) { + if s.as_str() == "negative duration" { + return vm.new_value_error("sleep length must be non-negative"); + } + } + } + e + })?; + + #[cfg(unix)] + { + // this is basically std::thread::sleep, but that catches interrupts and we don't want to; + let ts = nix::sys::time::TimeSpec::from(dur); + let res = unsafe { libc::nanosleep(ts.as_ref(), std::ptr::null_mut()) }; + let interrupted = res == -1 && nix::Error::last_raw() == libc::EINTR; + + if interrupted { + vm.check_signals()?; + } + } + + #[cfg(not(unix))] + { + std::thread::sleep(dur); + } + + Ok(()) } #[cfg(not(target_os = "wasi"))] @@ -690,21 +717,6 @@ mod platform { get_clock_time(ClockId::CLOCK_MONOTONIC, vm) } - #[pyfunction] - fn sleep(dur: Duration, vm: &VirtualMachine) -> PyResult<()> { - // this is basically std::thread::sleep, but that catches interrupts and we don't want to; - - let ts = TimeSpec::from(dur); - let res = unsafe { libc::nanosleep(ts.as_ref(), std::ptr::null_mut()) }; - let interrupted = res == -1 && nix::Error::last_raw() == libc::EINTR; - - if interrupted { - vm.check_signals()?; - } - - Ok(()) - } - #[cfg(not(any( target_os = "illumos", target_os = "netbsd",