// SPDX-FileCopyrightText: Copyright (c) 2011-2023, NVIDIA CORPORATION. All rights reserved.
// SPDX-License-Identifier: BSD-3

#include <cub/device/device_partition.cuh>

#include <thrust/count.h>

#include <cuda/std/algorithm>
#include <cuda/std/type_traits>

#include <look_back_helper.cuh>
#include <nvbench_helper.cuh>

// %RANGE% TUNE_TRANSPOSE trp 0:1:1
// %RANGE% TUNE_LOAD ld 0:1:1
// %RANGE% TUNE_ITEMS_PER_THREAD ipt 7:24:1
// %RANGE% TUNE_THREADS_PER_BLOCK tpb 128:1024:32
// %RANGE% TUNE_MAGIC_NS ns 0:2048:4
// %RANGE% TUNE_DELAY_CONSTRUCTOR_ID dcid 0:7:1
// %RANGE% TUNE_L2_WRITE_LATENCY_NS l2w 0:1200:5

#if !TUNE_BASE
#  if TUNE_TRANSPOSE == 0
#    define TUNE_LOAD_ALGORITHM cub::BLOCK_LOAD_DIRECT
#  else // TUNE_TRANSPOSE == 1
#    define TUNE_LOAD_ALGORITHM cub::BLOCK_LOAD_WARP_TRANSPOSE
#  endif // TUNE_TRANSPOSE

#  if TUNE_LOAD == 0
#    define TUNE_LOAD_MODIFIER cub::LOAD_DEFAULT
#  else // TUNE_LOAD == 1
#    define TUNE_LOAD_MODIFIER cub::LOAD_CA
#  endif // TUNE_LOAD

template <typename InputT>
struct policy_hub_t
{
  struct policy_t : cub::ChainedPolicy<300, policy_t, policy_t>
  {
    using SelectIfPolicyT =
      cub::AgentSelectIfPolicy<TUNE_THREADS_PER_BLOCK,
                               TUNE_ITEMS_PER_THREAD,
                               TUNE_LOAD_ALGORITHM,
                               TUNE_LOAD_MODIFIER,
                               cub::BLOCK_SCAN_WARP_SCANS,
                               delay_constructor_t>;
  };

  using MaxPolicy = policy_t;
};
#endif // !TUNE_BASE

template <typename InItT, typename T, typename OffsetT, typename SelectOpT>
void init_output_partition_buffer(
  InItT d_in,
  OffsetT num_items,
  T* d_out,
  SelectOpT select_op,
  cub::detail::select::partition_distinct_output_t<T*, T*>& d_partition_out_buffer)
{
  const auto selected_elements = thrust::count_if(d_in, d_in + num_items, select_op);
  d_partition_out_buffer = cub::detail::select::partition_distinct_output_t<T*, T*>{d_out, d_out + selected_elements};
}

template <typename InItT, typename T, typename OffsetT, typename SelectOpT>
void init_output_partition_buffer(InItT, OffsetT, T* d_out, SelectOpT, T*& d_partition_out_buffer)
{
  d_partition_out_buffer = d_out;
}

template <typename T, typename OffsetT, typename UseDistinctPartitionT>
void partition(nvbench::state& state, nvbench::type_list<T, OffsetT, UseDistinctPartitionT>)
{
  using input_it_t                           = const T*;
  using flag_it_t                            = cub::NullType*;
  using num_selected_it_t                    = OffsetT*;
  using select_op_t                          = less_then_t<T>;
  using equality_op_t                        = cub::NullType;
  using offset_t                             = OffsetT;
  constexpr bool use_distinct_out_partitions = UseDistinctPartitionT::value;
  using output_it_t                          = typename ::cuda::std::
    conditional<use_distinct_out_partitions, cub::detail::select::partition_distinct_output_t<T*, T*>, T*>::type;

  using dispatch_t = cub::DispatchSelectIf<
    input_it_t,
    flag_it_t,
    output_it_t,
    num_selected_it_t,
    select_op_t,
    equality_op_t,
    offset_t,
    cub::SelectImpl::Partition
#if !TUNE_BASE
    ,
    policy_hub_t<T>
#endif // !TUNE_BASE
    >;

  // Retrieve axis parameters
  const auto elements       = static_cast<std::size_t>(state.get_int64("Elements{io}"));
  const bit_entropy entropy = str_to_entropy(state.get_string("Entropy"));

  const T val = lerp_min_max<T>(entropy_to_probability(entropy));
  select_op_t select_op{val};

  thrust::device_vector<T> in = generate(elements);
  thrust::device_vector<offset_t> num_selected(1);

  thrust::device_vector<T> out(elements);

  input_it_t d_in                  = thrust::raw_pointer_cast(in.data());
  flag_it_t d_flags                = nullptr;
  num_selected_it_t d_num_selected = thrust::raw_pointer_cast(num_selected.data());
  output_it_t d_out{};
  init_output_partition_buffer(in.cbegin(), elements, thrust::raw_pointer_cast(out.data()), select_op, d_out);

  state.add_element_count(elements);
  state.add_global_memory_reads<T>(elements);
  state.add_global_memory_writes<T>(elements);
  state.add_global_memory_writes<offset_t>(1);

  std::size_t temp_size{};
  dispatch_t::Dispatch(
    nullptr, temp_size, d_in, d_flags, d_out, d_num_selected, select_op, equality_op_t{}, elements, 0);

  thrust::device_vector<nvbench::uint8_t> temp(temp_size);
  auto* temp_storage = thrust::raw_pointer_cast(temp.data());

  state.exec(nvbench::exec_tag::gpu | nvbench::exec_tag::no_batch, [&](nvbench::launch& launch) {
    dispatch_t::Dispatch(
      temp_storage,
      temp_size,
      d_in,
      d_flags,
      d_out,
      d_num_selected,
      select_op,
      equality_op_t{},
      elements,
      launch.get_stream());
  });
}

using ::cuda::std::false_type;
using ::cuda::std::true_type;
#ifdef TUNE_DistinctPartitions
using distinct_partitions = nvbench::type_list<TUNE_DistinctPartitions>; // expands to "false_type" or "true_type"
#else // !defined(TUNE_DistinctPartitions)
using distinct_partitions = nvbench::type_list<false_type, true_type>;
#endif // TUNE_DistinctPartitions

NVBENCH_BENCH_TYPES(partition, NVBENCH_TYPE_AXES(fundamental_types, offset_types, distinct_partitions))
  .set_name("base")
  .set_type_axes_names({"T{ct}", "OffsetT{ct}", "DistinctPartitions{ct}"})
  .add_int64_power_of_two_axis("Elements{io}", nvbench::range(16, 28, 4))
  .add_string_axis("Entropy", {"1.000", "0.544", "0.000"});
