diff --git a/libcxx/include/__ranges/join_view.h b/libcxx/include/__ranges/join_view.h index 327b349f476a7..6097754f403ec 100644 --- a/libcxx/include/__ranges/join_view.h +++ b/libcxx/include/__ranges/join_view.h @@ -410,8 +410,12 @@ struct __segmented_iterator_traits<_JoinViewIterator> { static constexpr _LIBCPP_HIDE_FROM_ABI _JoinViewIterator __compose(__segment_iterator __seg_iter, __local_iterator __local_iter) { - return _JoinViewIterator( - std::move(__seg_iter).__get_data(), std::move(__seg_iter).__get_iter(), std::move(__local_iter)); + auto&& __outer = std::move(__seg_iter).__get_iter(); + if (__local_iter == ranges::end(*__outer)) { + ++__outer; + return _JoinViewIterator(*std::move(__seg_iter).__get_data(), __outer); + } + return _JoinViewIterator(std::move(__seg_iter).__get_data(), __outer, std::move(__local_iter)); } }; diff --git a/libcxx/test/std/ranges/range.adaptors/range.join/range.join.iterator/find.pass.cpp b/libcxx/test/std/ranges/range.adaptors/range.join/range.join.iterator/find.pass.cpp new file mode 100644 index 0000000000000..f8bee2227d0b9 --- /dev/null +++ b/libcxx/test/std/ranges/range.adaptors/range.join/range.join.iterator/find.pass.cpp @@ -0,0 +1,37 @@ +//===----------------------------------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +// UNSUPPORTED: c++03, c++11, c++14, c++17 + +#include +#include +#include +#include + +#include "../types.h" + +constexpr bool test() { + // Test the segmented iterator implementation of join_view + // https://github.com/llvm/llvm-project/issues/158279 + { + int buffer1[2][1] = {{1}, {2}}; + auto joined = std::views::join(buffer1); + assert(std::ranges::find(joined, 1) == std::ranges::begin(joined)); + assert(std::ranges::find(joined, 2) == std::ranges::next(std::ranges::begin(joined))); + assert(std::ranges::find(joined, 3) == std::ranges::end(joined)); + } + + return true; +} + +int main(int, char**) { + test(); + static_assert(test()); + + return 0; +}