# Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES. ALL RIGHTS RESERVED.
#
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

import numba
import numpy as np
import pytest
from helpers import NUMBA_TYPES_TO_NP, random_int
from numba import cuda, types
from numba.core import cgutils
from numba.core.extending import (
    lower_builtin,
    make_attribute_wrapper,
    models,
    register_model,
    type_callable,
    typeof_impl,
)
from pynvjitlink import patch

import cuda.cooperative.experimental as cudax

numba.config.CUDA_LOW_OCCUPANCY_WARNINGS = 0


patch.patch_numba_linker(lto=True)


class Complex:
    def __init__(self, real, imag):
        self.real = real
        self.imag = imag

    def construct(this):
        default_value = numba.int32(0)
        this[0] = Complex(default_value, default_value)

    def assign(this, that):
        this[0] = Complex(that[0].real, that[0].imag)


class ComplexType(types.Type):
    def __init__(self):
        super().__init__(name="Complex")


complex_type = ComplexType()


@typeof_impl.register(Complex)
def typeof_complex(val, c):
    return complex_type


@type_callable(Complex)
def type__complex(context):
    def typer(real, imag):
        if isinstance(real, types.Integer) and isinstance(imag, types.Integer):
            return complex_type

    return typer


@register_model(ComplexType)
class ComplexModel(models.StructModel):
    def __init__(self, dmm, fe_type):
        members = [("real", types.int32), ("imag", types.int32)]
        models.StructModel.__init__(self, dmm, fe_type, members)


make_attribute_wrapper(ComplexType, "real", "real")
make_attribute_wrapper(ComplexType, "imag", "imag")


@lower_builtin(Complex, types.Integer, types.Integer)
def impl_complex(context, builder, sig, args):
    typ = sig.return_type
    real, imag = args
    state = cgutils.create_struct_proxy(typ)(context, builder)
    state.real = real
    state.imag = imag
    return state._getvalue()


@pytest.mark.parametrize("threads_per_block", [32, 64, 128, 256, 512, 1024])
def test_block_reduction_of_user_defined_type_without_temp_storage(threads_per_block):
    def op(result_ptr, lhs_ptr, rhs_ptr):
        real_value = numba.int32(lhs_ptr[0].real + rhs_ptr[0].real)
        imag_value = numba.int32(lhs_ptr[0].imag + rhs_ptr[0].imag)
        result_ptr[0] = Complex(real_value, imag_value)

    block_reduce = cudax.block.reduce(
        dtype=complex_type,
        binary_op=op,
        threads_per_block=threads_per_block,
        methods={
            "construct": Complex.construct,
            "assign": Complex.assign,
        },
    )

    @cuda.jit(link=block_reduce.files)
    def kernel(input, output):
        block_output = block_reduce(
            Complex(
                input[cuda.threadIdx.x], input[threads_per_block + cuda.threadIdx.x]
            )
        )

        if cuda.threadIdx.x == 0:
            output[0] = block_output.real
            output[1] = block_output.imag

    h_input = random_int(2 * threads_per_block, "int32")
    d_input = cuda.to_device(h_input)
    d_output = cuda.device_array(2, dtype="int32")
    kernel[1, threads_per_block](d_input, d_output)
    cuda.synchronize()
    h_output = d_output.copy_to_host()
    h_expected = (
        np.sum(h_input[:threads_per_block]),
        np.sum(h_input[threads_per_block:]),
    )

    assert h_output[0] == h_expected[0]
    assert h_output[1] == h_expected[1]

    sig = (numba.int32[::1], numba.int32[::1])
    sass = kernel.inspect_sass(sig)

    assert "LDL" not in sass
    assert "STL" not in sass


@pytest.mark.parametrize("threads_per_block", [32, 64, 128, 256, 512, 1024])
def test_block_reduction_of_user_defined_type(threads_per_block):
    def op(result_ptr, lhs_ptr, rhs_ptr):
        real_value = numba.int32(lhs_ptr[0].real + rhs_ptr[0].real)
        imag_value = numba.int32(lhs_ptr[0].imag + rhs_ptr[0].imag)
        result_ptr[0] = Complex(real_value, imag_value)

    block_reduce = cudax.block.reduce(
        dtype=complex_type,
        binary_op=op,
        threads_per_block=threads_per_block,
        methods={
            "construct": Complex.construct,
            "assign": Complex.assign,
        },
    )
    temp_storage_bytes = block_reduce.temp_storage_bytes

    @cuda.jit(link=block_reduce.files)
    def kernel(input, output):
        temp_storage = cuda.shared.array(shape=temp_storage_bytes, dtype="uint8")
        block_output = block_reduce(
            temp_storage,
            Complex(
                input[cuda.threadIdx.x], input[threads_per_block + cuda.threadIdx.x]
            ),
        )

        if cuda.threadIdx.x == 0:
            output[0] = block_output.real
            output[1] = block_output.imag

    h_input = random_int(2 * threads_per_block, "int32")
    d_input = cuda.to_device(h_input)
    d_output = cuda.device_array(2, dtype="int32")
    kernel[1, threads_per_block](d_input, d_output)
    cuda.synchronize()
    h_output = d_output.copy_to_host()
    h_expected = (
        np.sum(h_input[:threads_per_block]),
        np.sum(h_input[threads_per_block:]),
    )

    assert h_output[0] == h_expected[0]
    assert h_output[1] == h_expected[1]

    sig = (numba.int32[::1], numba.int32[::1])
    sass = kernel.inspect_sass(sig)

    assert "LDL" not in sass
    assert "STL" not in sass


@pytest.mark.parametrize("T", [types.uint32, types.uint64])
@pytest.mark.parametrize("threads_per_block", [32, 64, 128, 256, 512, 1024])
def test_block_reduction_of_integral_type(T, threads_per_block):
    def op(a, b):
        return a if a < b else b

    block_reduce = cudax.block.reduce(
        dtype=T, binary_op=op, threads_per_block=threads_per_block
    )
    temp_storage_bytes = block_reduce.temp_storage_bytes

    @cuda.jit(link=block_reduce.files)
    def kernel(input, output):
        temp_storage = cuda.shared.array(shape=temp_storage_bytes, dtype="uint8")
        block_output = block_reduce(temp_storage, input[cuda.threadIdx.x])

        if cuda.threadIdx.x == 0:
            output[0] = block_output

    dtype = NUMBA_TYPES_TO_NP[T]
    h_input = random_int(threads_per_block, dtype)
    d_input = cuda.to_device(h_input)
    d_output = cuda.device_array(1, dtype=dtype)
    kernel[1, threads_per_block](d_input, d_output)
    cuda.synchronize()
    h_output = d_output.copy_to_host()
    h_expected = np.min(h_input)

    assert h_output[0] == h_expected

    sig = (T[::1], T[::1])
    sass = kernel.inspect_sass(sig)

    assert "LDL" not in sass
    assert "STL" not in sass


@pytest.mark.parametrize("T", [types.uint32, types.uint64])
@pytest.mark.parametrize("threads_per_block", [32, 64, 128, 256, 512, 1024])
def test_block_reduction_valid(T, threads_per_block):
    def op(a, b):
        return a if a < b else b

    block_reduce = cudax.block.reduce(
        dtype=T, binary_op=op, threads_per_block=threads_per_block
    )
    temp_storage_bytes = block_reduce.temp_storage_bytes

    @cuda.jit(link=block_reduce.files)
    def kernel(input, output):
        temp_storage = cuda.shared.array(shape=temp_storage_bytes, dtype="uint8")
        block_output = block_reduce(
            temp_storage, input[cuda.threadIdx.x], threads_per_block / 2
        )

        if cuda.threadIdx.x == 0:
            output[0] = block_output

    dtype = NUMBA_TYPES_TO_NP[T]
    h_input = random_int(threads_per_block, dtype)
    h_input[-1] = 0
    d_input = cuda.to_device(h_input)
    d_output = cuda.device_array(1, dtype=dtype)
    kernel[1, threads_per_block](d_input, d_output)
    cuda.synchronize()
    h_output = d_output.copy_to_host()
    h_expected = np.min(h_input[: threads_per_block // 2])

    assert h_output[0] == h_expected

    sig = (T[::1], T[::1])
    sass = kernel.inspect_sass(sig)

    assert "LDL" not in sass
    assert "STL" not in sass


@pytest.mark.parametrize("T", [types.uint32, types.uint64])
@pytest.mark.parametrize("threads_per_block", [32, 128, 512, 1024])
@pytest.mark.parametrize("items_per_thread", [1, 2, 4])
def test_block_reduction_array_local(T, threads_per_block, items_per_thread):
    def op(a, b):
        return a if a < b else b

    block_reduce = cudax.block.reduce(
        dtype=T,
        binary_op=op,
        threads_per_block=threads_per_block,
        items_per_thread=items_per_thread,
    )
    temp_storage_bytes = block_reduce.temp_storage_bytes

    @cuda.jit(link=block_reduce.files)
    def kernel(input, output):
        temp_storage = cuda.shared.array(shape=temp_storage_bytes, dtype="uint8")
        thread_items = cuda.local.array(shape=items_per_thread, dtype=T)

        for i in range(items_per_thread):
            thread_items[i] = input[i * threads_per_block + cuda.threadIdx.x]

        block_output = block_reduce(temp_storage, thread_items)

        if cuda.threadIdx.x == 0:
            output[0] = block_output

    dtype = NUMBA_TYPES_TO_NP[T]
    h_input = random_int(items_per_thread * threads_per_block, dtype)
    d_input = cuda.to_device(h_input)
    d_output = cuda.device_array(1, dtype=dtype)
    kernel[1, threads_per_block](d_input, d_output)
    cuda.synchronize()
    h_output = d_output.copy_to_host()
    h_expected = np.min(h_input)

    assert h_output[0] == h_expected

    sig = (T[::1], T[::1])
    sass = kernel.inspect_sass(sig)

    assert "LDL" not in sass
    assert "STL" not in sass


@pytest.mark.parametrize("T", [types.uint32, types.uint64])
@pytest.mark.parametrize("threads_per_block", [32, 128, 512, 1024])
@pytest.mark.parametrize("items_per_thread", [1, 2, 4])
def test_block_reduction_array_global(T, threads_per_block, items_per_thread):
    def op(a, b):
        return a if a < b else b

    block_reduce = cudax.block.reduce(
        dtype=T,
        binary_op=op,
        threads_per_block=threads_per_block,
        items_per_thread=items_per_thread,
    )
    temp_storage_bytes = block_reduce.temp_storage_bytes

    @cuda.jit(link=block_reduce.files)
    def kernel(input, output):
        temp_storage = cuda.shared.array(shape=temp_storage_bytes, dtype="uint8")

        # If a scalar (e.g. input[x]) is passed here, the reference overload
        # will be selected and only one element will be loaded, instead of the
        # statically sized array overload. This is very subtle and should be
        # fixed.
        block_output = block_reduce(
            temp_storage, input[items_per_thread * cuda.threadIdx.x :]
        )

        if cuda.threadIdx.x == 0:
            output[0] = block_output

    dtype = NUMBA_TYPES_TO_NP[T]
    h_input = random_int(items_per_thread * threads_per_block, dtype)
    d_input = cuda.to_device(h_input)
    d_output = cuda.device_array(1, dtype=dtype)
    kernel[1, threads_per_block](d_input, d_output)
    cuda.synchronize()
    h_output = d_output.copy_to_host()
    h_expected = np.min(h_input)

    assert h_output[0] == h_expected

    sig = (T[::1], T[::1])
    sass = kernel.inspect_sass(sig)

    assert "LDL" not in sass
    assert "STL" not in sass


@pytest.mark.parametrize("T", [types.uint32, types.uint64])
@pytest.mark.parametrize("threads_per_block", [32, 64, 128, 256, 512, 1024])
def test_block_sum(T, threads_per_block):
    block_reduce = cudax.block.sum(dtype=T, threads_per_block=threads_per_block)
    temp_storage_bytes = block_reduce.temp_storage_bytes

    @cuda.jit(link=block_reduce.files)
    def kernel(input, output):
        temp_storage = cuda.shared.array(shape=temp_storage_bytes, dtype="uint8")
        block_output = block_reduce(temp_storage, input[cuda.threadIdx.x])

        if cuda.threadIdx.x == 0:
            output[0] = block_output

    dtype = NUMBA_TYPES_TO_NP[T]
    h_input = random_int(threads_per_block, dtype)
    d_input = cuda.to_device(h_input)
    d_output = cuda.device_array(1, dtype=dtype)
    kernel[1, threads_per_block](d_input, d_output)
    cuda.synchronize()
    h_output = d_output.copy_to_host()
    h_expected = np.sum(h_input)

    assert h_output[0] == h_expected

    sig = (T[::1], T[::1])
    sass = kernel.inspect_sass(sig)

    assert "LDL" not in sass
    assert "STL" not in sass


@pytest.mark.parametrize("T", [types.uint32, types.uint64])
@pytest.mark.parametrize("threads_per_block", [32, 64, 128, 256, 512, 1024])
def test_block_sum_valid(T, threads_per_block):
    block_reduce = cudax.block.sum(dtype=T, threads_per_block=threads_per_block)
    temp_storage_bytes = block_reduce.temp_storage_bytes

    @cuda.jit(link=block_reduce.files)
    def kernel(input, output):
        temp_storage = cuda.shared.array(shape=temp_storage_bytes, dtype="uint8")
        block_output = block_reduce(
            temp_storage, input[cuda.threadIdx.x], numba.int32(threads_per_block / 2)
        )

        if cuda.threadIdx.x == 0:
            output[0] = block_output

    dtype = NUMBA_TYPES_TO_NP[T]
    h_input = random_int(threads_per_block, dtype)
    h_input[-1] = 0
    d_input = cuda.to_device(h_input)
    d_output = cuda.device_array(1, dtype=dtype)
    kernel[1, threads_per_block](d_input, d_output)
    cuda.synchronize()
    h_output = d_output.copy_to_host()
    h_expected = np.sum(h_input[: threads_per_block // 2])

    assert h_output[0] == h_expected

    sig = (T[::1], T[::1])
    sass = kernel.inspect_sass(sig)

    assert "LDL" not in sass
    assert "STL" not in sass


@pytest.mark.parametrize("T", [types.uint32, types.uint64])
@pytest.mark.parametrize("threads_per_block", [32, 128, 512, 1024])
@pytest.mark.parametrize("items_per_thread", [1, 2, 4])
def test_block_sum_array_local(T, threads_per_block, items_per_thread):
    block_reduce = cudax.block.sum(
        dtype=T, threads_per_block=threads_per_block, items_per_thread=items_per_thread
    )
    temp_storage_bytes = block_reduce.temp_storage_bytes

    @cuda.jit(link=block_reduce.files)
    def kernel(input, output):
        temp_storage = cuda.shared.array(shape=temp_storage_bytes, dtype="uint8")
        thread_items = cuda.local.array(shape=items_per_thread, dtype=T)

        for i in range(items_per_thread):
            thread_items[i] = input[i * threads_per_block + cuda.threadIdx.x]

        block_output = block_reduce(temp_storage, thread_items)

        if cuda.threadIdx.x == 0:
            output[0] = block_output

    dtype = NUMBA_TYPES_TO_NP[T]
    h_input = random_int(items_per_thread * threads_per_block, dtype)
    d_input = cuda.to_device(h_input)
    d_output = cuda.device_array(1, dtype=dtype)
    kernel[1, threads_per_block](d_input, d_output)
    cuda.synchronize()
    h_output = d_output.copy_to_host()
    h_expected = np.sum(h_input)

    assert h_output[0] == h_expected

    sig = (T[::1], T[::1])
    sass = kernel.inspect_sass(sig)

    assert "LDL" not in sass
    assert "STL" not in sass


@pytest.mark.parametrize("T", [types.uint32, types.uint64])
@pytest.mark.parametrize("threads_per_block", [32, 128, 512, 1024])
@pytest.mark.parametrize("items_per_thread", [1, 2, 4])
def test_block_sum_array_global(T, threads_per_block, items_per_thread):
    block_reduce = cudax.block.sum(
        dtype=T, threads_per_block=threads_per_block, items_per_thread=items_per_thread
    )
    temp_storage_bytes = block_reduce.temp_storage_bytes

    @cuda.jit(link=block_reduce.files)
    def kernel(input, output):
        temp_storage = cuda.shared.array(shape=temp_storage_bytes, dtype="uint8")

        # If a scalar (e.g. input[x]) is passed here, the reference overload
        # will be selected and only one element will be loaded, instead of the
        # statically sized array overload. This is very subtle and should be
        # fixed.
        block_output = block_reduce(
            temp_storage, input[items_per_thread * cuda.threadIdx.x :]
        )

        if cuda.threadIdx.x == 0:
            output[0] = block_output

    dtype = NUMBA_TYPES_TO_NP[T]
    h_input = random_int(items_per_thread * threads_per_block, dtype)
    d_input = cuda.to_device(h_input)
    d_output = cuda.device_array(1, dtype=dtype)
    kernel[1, threads_per_block](d_input, d_output)
    cuda.synchronize()
    h_output = d_output.copy_to_host()
    h_expected = np.sum(h_input)

    assert h_output[0] == h_expected

    sig = (T[::1], T[::1])
    sass = kernel.inspect_sass(sig)

    assert "LDL" not in sass
    assert "STL" not in sass
