/*
 * Copyright 2020 LINE Corporation
 *
 * LINE Corporation licenses this file to you under the Apache License,
 * version 2.0 (the "License"); you may not use this file except in compliance
 * with the License. You may obtain a copy of the License at:
 *
 *   https://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
 * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
 * License for the specific language governing permissions and limitations
 * under the License.
 */

package com.linecorp.armeria.internal.common;

import static org.assertj.core.api.Assertions.assertThat;
import static org.awaitility.Awaitility.await;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.spy;
import static org.mockito.Mockito.when;

import java.time.Duration;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicLong;
import java.util.function.Consumer;

import org.hamcrest.Matchers;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.RegisterExtension;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.CsvSource;

import com.google.common.base.Stopwatch;

import com.linecorp.armeria.internal.common.KeepAliveHandler.PingState;
import com.linecorp.armeria.testing.junit.common.EventLoopExtension;

import io.netty.channel.ChannelFuture;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelPromise;
import io.netty.channel.embedded.EmbeddedChannel;

class KeepAliveHandlerTest {

    @RegisterExtension
    static EventLoopExtension eventLoop = new EventLoopExtension();

    private EmbeddedChannel channel;
    private ChannelHandlerContext ctx;

    @BeforeEach
    void setUp() {
        channel = spy(new EmbeddedChannel());
        when(channel.eventLoop()).thenReturn(eventLoop.get());
        ctx = mock(ChannelHandlerContext.class);
        when(ctx.channel()).thenReturn(channel);
    }

    @AfterEach
    void tearDown() {
        channel.finish();
    }

    @Test
    void testIdle() {
        final AtomicInteger counter = new AtomicInteger();

        final KeepAliveHandler idleTimeoutScheduler =
                new KeepAliveHandler(channel, "test", 1000, 0) {

                    @Override
                    protected boolean pingResetsPreviousPing() {
                        return true;
                    }

                    @Override
                    protected ChannelFuture writePing(ChannelHandlerContext ctx) {
                        return channel.newSucceededFuture();
                    }

                    @Override
                    protected boolean hasRequestsInProgress(ChannelHandlerContext ctx) {
                        counter.incrementAndGet();
                        return false;
                    }
                };

        idleTimeoutScheduler.initialize(ctx);
        await().timeout(20, TimeUnit.SECONDS).untilAtomic(counter, Matchers.is(10));
        idleTimeoutScheduler.destroy();
    }

    @Test
    void testPing() {
        final Stopwatch stopwatch = Stopwatch.createStarted();

        final KeepAliveHandler idleTimeoutScheduler =
                new KeepAliveHandler(channel, "test", 0, 1000) {

                    @Override
                    protected boolean pingResetsPreviousPing() {
                        return true;
                    }

                    @Override
                    protected ChannelFuture writePing(ChannelHandlerContext ctx) {
                        stopwatch.stop();
                        return channel.newSucceededFuture();
                    }

                    @Override
                    protected boolean hasRequestsInProgress(ChannelHandlerContext ctx) {
                        return false;
                    }
                };

        idleTimeoutScheduler.initialize(ctx);
        await().until(stopwatch::isRunning, Matchers.is(false));
        final Duration elapsed = stopwatch.elapsed();
        assertThat(elapsed.toMillis()).isBetween(1000L, 5000L);
        idleTimeoutScheduler.destroy();
    }

    @CsvSource({
            "2000, 0, CONNECTION_IDLE",
            "0, 1000, PING_IDLE",
    })
    @ParameterizedTest
    void testKeepAlive(long connectionIdleTimeout, long pingInterval, String mode)
            throws InterruptedException {
        final AtomicLong lastIdleEventTime = new AtomicLong();
        final AtomicInteger idleCounter = new AtomicInteger();
        final AtomicInteger pingCounter = new AtomicInteger();
        final long idleTime = "CONNECTION_IDLE".equals(mode) ? connectionIdleTimeout : pingInterval;
        final Consumer<KeepAliveHandler> activator =
                "CONNECTION_IDLE".equals(mode) ?
                KeepAliveHandler::onReadOrWrite : KeepAliveHandler::onPing;

        final KeepAliveHandler idleTimeoutScheduler =
                new KeepAliveHandler(channel, "test", connectionIdleTimeout, pingInterval) {

                    @Override
                    protected boolean pingResetsPreviousPing() {
                        return true;
                    }

                    @Override
                    protected ChannelFuture writePing(ChannelHandlerContext ctx) {
                        pingCounter.incrementAndGet();
                        return channel.newSucceededFuture();
                    }

                    @Override
                    protected boolean hasRequestsInProgress(ChannelHandlerContext ctx) {
                        idleCounter.incrementAndGet();
                        return false;
                    }
                };

        lastIdleEventTime.set(System.nanoTime());
        idleTimeoutScheduler.initialize(ctx);

        for (int i = 0; i < 10; i++) {
            activator.accept(idleTimeoutScheduler);
            Thread.sleep(idleTime - 1000);
        }
        assertThat(idleCounter).hasValue(0);

        if ("CONNECTION_IDLE".equals(mode)) {
            await().timeout(idleTime * 10, TimeUnit.SECONDS).untilAtomic(idleCounter, Matchers.is(5));
        } else {
            await().timeout(idleTime * 2, TimeUnit.SECONDS).untilAtomic(pingCounter, Matchers.is(1));
        }

        idleTimeoutScheduler.destroy();
    }

    @ParameterizedTest
    @CsvSource({ "true", "false" })
    void checkReadOrWrite(boolean hasRequests) throws InterruptedException {
        final long idleTimeout = 10000;
        final long pingInterval = 0;
        final ChannelFuture channelFuture = channel.newPromise();
        final KeepAliveHandler keepAliveHandler =
                new KeepAliveHandler(channel, "test", idleTimeout, pingInterval) {
                    @Override
                    protected ChannelFuture writePing(ChannelHandlerContext ctx) {
                        return channelFuture;
                    }

                    @Override
                    protected boolean pingResetsPreviousPing() {
                        return true;
                    }

                    @Override
                    protected boolean hasRequestsInProgress(ChannelHandlerContext ctx) {
                        return hasRequests;
                    }
                };

        assertThat(keepAliveHandler.state()).isEqualTo(PingState.IDLE);

        keepAliveHandler.initialize(ctx);
        assertThat(keepAliveHandler.state()).isEqualTo(PingState.IDLE);

        keepAliveHandler.onReadOrWrite();
        assertThat(keepAliveHandler.state()).isEqualTo(PingState.IDLE);

        Thread.sleep(idleTimeout / 2);
        assertThat(keepAliveHandler.state()).isEqualTo(PingState.IDLE);

        Thread.sleep(idleTimeout);
        if (hasRequests) {
            assertThat(keepAliveHandler.state()).isEqualTo(PingState.IDLE);
        } else {
            assertThat(keepAliveHandler.state()).isEqualTo(PingState.SHUTDOWN);
        }
    }

    @ParameterizedTest
    @CsvSource({ "true", "false" })
    void checkPing(boolean hasRequests) throws InterruptedException {
        final long idleTimeout = 10000;
        final long pingInterval = 1000;
        final ChannelPromise promise = channel.newPromise();
        final KeepAliveHandler keepAliveHandler =
                new KeepAliveHandler(channel, "test", idleTimeout, pingInterval) {
                    @Override
                    protected ChannelFuture writePing(ChannelHandlerContext ctx) {
                        return promise;
                    }

                    @Override
                    protected boolean pingResetsPreviousPing() {
                        return true;
                    }

                    @Override
                    protected boolean hasRequestsInProgress(ChannelHandlerContext ctx) {
                        return hasRequests;
                    }
                };

        assertThat(keepAliveHandler.state()).isEqualTo(PingState.IDLE);

        keepAliveHandler.initialize(ctx);
        assertThat(keepAliveHandler.state()).isEqualTo(PingState.IDLE);

        keepAliveHandler.onReadOrWrite();
        assertThat(keepAliveHandler.state()).isEqualTo(PingState.IDLE);

        keepAliveHandler.writePing(ctx);
        await().untilAsserted(() -> assertThat(keepAliveHandler.state()).isEqualTo(PingState.PING_SCHEDULED));

        promise.setSuccess();
        await().untilAsserted(() -> assertThat(keepAliveHandler.state()).isEqualTo(PingState.PENDING_PING_ACK));

        keepAliveHandler.onPing();
        assertThat(keepAliveHandler.state()).isEqualTo(PingState.IDLE);
        final Stopwatch stopwatch = Stopwatch.createStarted();
        await().until(keepAliveHandler::state, Matchers.is(PingState.SHUTDOWN));
        final Duration elapsed = stopwatch.elapsed();
        assertThat(elapsed.toMillis()).isBetween(pingInterval, idleTimeout - 1000);
    }

    @ParameterizedTest
    @CsvSource({ "true", "false" })
    void resetPing(boolean resetPing) throws InterruptedException {
        final long idleTimeout = 10000;
        final long pingInterval = 1000;
        final ChannelPromise promise = channel.newPromise();
        final KeepAliveHandler keepAliveHandler =
                new KeepAliveHandler(channel, "test", idleTimeout, pingInterval) {
                    @Override
                    protected ChannelFuture writePing(ChannelHandlerContext ctx) {
                        return promise;
                    }

                    @Override
                    protected boolean pingResetsPreviousPing() {
                        return resetPing;
                    }

                    @Override
                    protected boolean hasRequestsInProgress(ChannelHandlerContext ctx) {
                        return true;
                    }
                };

        keepAliveHandler.initialize(ctx);
        assertThat(keepAliveHandler.state()).isEqualTo(PingState.IDLE);

        keepAliveHandler.writePing(ctx);
        await().untilAsserted(() -> assertThat(keepAliveHandler.state()).isEqualTo(PingState.PING_SCHEDULED));

        if (resetPing) {
            keepAliveHandler.onReadOrWrite();
            assertThat(keepAliveHandler.state()).isEqualTo(PingState.IDLE);
        } else {
            keepAliveHandler.onReadOrWrite();
            assertThat(keepAliveHandler.state()).isEqualTo(PingState.PING_SCHEDULED);
        }
    }
}
