diff --git a/tools/pythonpkg/src/arrow/arrow_array_stream.cpp b/tools/pythonpkg/src/arrow/arrow_array_stream.cpp index 2b13ddfc2451..3fbb9420e9ef 100644 --- a/tools/pythonpkg/src/arrow/arrow_array_stream.cpp +++ b/tools/pythonpkg/src/arrow/arrow_array_stream.cpp @@ -302,6 +302,37 @@ py::object TransformFilterRecursive(TableFilter &filter, vector column_r auto &constant_filter = filter.Cast(); auto constant_field = field(py::tuple(py::cast(column_ref))); auto constant_value = GetScalar(constant_filter.constant, timezone_config, type); + + bool is_nan = false; + auto &constant = constant_filter.constant; + auto &constant_type = constant.type(); + if (constant_type.id() == LogicalTypeId::FLOAT) { + is_nan = Value::IsNan(constant.GetValue()); + } else if (constant_type.id() == LogicalTypeId::DOUBLE) { + is_nan = Value::IsNan(constant.GetValue()); + } + + // Special handling for NaN comparisons (to explicitly violate IEEE-754) + if (is_nan) { + switch (constant_filter.comparison_type) { + case ExpressionType::COMPARE_EQUAL: + case ExpressionType::COMPARE_GREATERTHANOREQUALTO: + return constant_field.attr("is_nan")(); + case ExpressionType::COMPARE_LESSTHAN: + case ExpressionType::COMPARE_NOTEQUAL: + return constant_field.attr("is_nan")().attr("__invert__")(); + case ExpressionType::COMPARE_GREATERTHAN: + // Nothing is greater than NaN + return import_cache.pyarrow.dataset().attr("scalar")(false); + case ExpressionType::COMPARE_LESSTHANOREQUALTO: + // Everything is less than or equal to NaN + return import_cache.pyarrow.dataset().attr("scalar")(true); + default: + throw NotImplementedException("Unsupported comparison type (%s) for NaN values", + EnumUtil::ToString(constant_filter.comparison_type)); + } + } + switch (constant_filter.comparison_type) { case ExpressionType::COMPARE_EQUAL: return constant_field.attr("__eq__")(constant_value); diff --git a/tools/pythonpkg/tests/fast/arrow/test_filter_pushdown.py b/tools/pythonpkg/tests/fast/arrow/test_filter_pushdown.py index 142d1dace103..82cb4414b8f6 100644 --- a/tools/pythonpkg/tests/fast/arrow/test_filter_pushdown.py +++ b/tools/pythonpkg/tests/fast/arrow/test_filter_pushdown.py @@ -986,3 +986,30 @@ def test_pushdown_of_optional_filter(self, duckdb_cursor): ('product_code', 100), ('price', 100), ] + + # DuckDB intentionally violates IEEE-754 when it comes to NaNs, ensuring a total ordering where NaN is the greatest value + def test_nan_filter_pushdown(self, duckdb_cursor): + duckdb_cursor.execute( + """ + create table test as select a::DOUBLE a from VALUES + ('inf'), + ('nan'), + ('0.34234'), + ('34234234.00005'), + ('-nan') + t(a); + """ + ) + + def assert_equal_results(con, arrow_table, query): + duckdb_res = con.sql(query.format(table='test')).fetchall() + arrow_res = con.sql(query.format(table='arrow_table')).fetchall() + assert len(duckdb_res) == len(arrow_res) + + arrow_table = duckdb_cursor.table('test').arrow() + assert_equal_results(duckdb_cursor, arrow_table, "select * from {table} where a > 'NaN'::FLOAT") + assert_equal_results(duckdb_cursor, arrow_table, "select * from {table} where a >= 'NaN'::FLOAT") + assert_equal_results(duckdb_cursor, arrow_table, "select * from {table} where a < 'NaN'::FLOAT") + assert_equal_results(duckdb_cursor, arrow_table, "select * from {table} where a <= 'NaN'::FLOAT") + assert_equal_results(duckdb_cursor, arrow_table, "select * from {table} where a = 'NaN'::FLOAT") + assert_equal_results(duckdb_cursor, arrow_table, "select * from {table} where a != 'NaN'::FLOAT")