﻿
using System;
using System.Collections.Generic;
using Assert = Microsoft.VisualStudio.TestTools.UnitTesting.Assert;
using Orleans.Providers.Streams.Common;
using Orleans.Streams;
using UnitTests.StreamingTests;
using Xunit;

namespace UnitTests.OrleansRuntime.Streams
{
    public class CachedMessageBlockTests
    {
        private const int TestBlockSize = 100;
        private readonly Guid StreamGuid = Guid.NewGuid();

        private class TestQueueMessage
        {
            public Guid StreamGuid { get; set; }
            public EventSequenceToken SequenceToken { get; set; }
        }

        private struct TestCachedMessage
        {
            public Guid StreamGuid { get; set; }
            public long SequenceNumber { get; set; }
            public int EventIndex { get; set; }
        }

        private class TestBatchContainer : IBatchContainer
        {
            public Guid StreamGuid { get; set; }
            public string StreamNamespace { get { return null; }}
            public StreamSequenceToken SequenceToken { get; set; }

            public IEnumerable<Tuple<T, StreamSequenceToken>> GetEvents<T>()
            {
                throw new NotImplementedException();
            }

            public bool ImportRequestContext()
            {
                throw new NotImplementedException();
            }

            public bool ShouldDeliver(IStreamIdentity stream, object filterData, StreamFilterPredicate shouldReceiveFunc)
            {
                throw new NotImplementedException();
            }
        }

        private class TestCacheDataComparer : ICacheDataComparer<TestCachedMessage>
        {
            public static readonly ICacheDataComparer<TestCachedMessage> Instance = new TestCacheDataComparer();

            public int Compare(TestCachedMessage cachedMessage, StreamSequenceToken token)
            {
                var myToken = new EventSequenceToken(cachedMessage.SequenceNumber, cachedMessage.EventIndex);
                return myToken.CompareTo(token);
            }

            public int Compare(TestCachedMessage cachedMessage, IStreamIdentity streamIdentity)
            {
                return cachedMessage.StreamGuid.CompareTo(streamIdentity.Guid);
            }
        }


        private class TestCacheDataAdapter : ICacheDataAdapter<TestQueueMessage, TestCachedMessage>
        {
            public Action<IDisposable> PurgeAction { private get; set; }

            public void QueueMessageToCachedMessage(ref TestCachedMessage cachedMessage, TestQueueMessage queueMessage)
            {
                cachedMessage.StreamGuid = queueMessage.StreamGuid;
                cachedMessage.SequenceNumber = queueMessage.SequenceToken.SequenceNumber;
                cachedMessage.EventIndex = queueMessage.SequenceToken.EventIndex;
            }

            public IBatchContainer GetBatchContainer(ref TestCachedMessage cachedMessage)
            {
                return new TestBatchContainer()
                {
                    StreamGuid = cachedMessage.StreamGuid,
                    SequenceToken = new EventSequenceToken(cachedMessage.SequenceNumber, cachedMessage.EventIndex)
                };
            }

            public StreamSequenceToken GetSequenceToken(ref TestCachedMessage cachedMessage)
            {
                return new EventSequenceToken(cachedMessage.SequenceNumber, cachedMessage.EventIndex);
            }

            public bool IsInStream(ref TestCachedMessage cachedMessage, Guid streamGuid, string streamNamespace)
            {
                return cachedMessage.StreamGuid == streamGuid && streamNamespace == null;
            }

            public bool ShouldPurge(ref TestCachedMessage cachedMessage, IDisposable purgeRequest)
            {
                throw new NotImplementedException();
            }
        }

        private class MyTestPooled : IObjectPool<CachedMessageBlock<TestCachedMessage>>
        {
            public CachedMessageBlock<TestCachedMessage> Allocate()
            {
                return new CachedMessageBlock<TestCachedMessage>(this, TestBlockSize);
            }

            public void Free(CachedMessageBlock<TestCachedMessage> resource)
            {
            }
        }

        [Fact, TestCategory("BVT"), TestCategory("Streaming")]
        public void Add1Remove1Test()
        {
            IObjectPool<CachedMessageBlock<TestCachedMessage>> pool = new MyTestPooled();
            ICacheDataAdapter<TestQueueMessage, TestCachedMessage> dataAdapter = new TestCacheDataAdapter();
            CachedMessageBlock<TestCachedMessage> block = pool.Allocate();

            AddAndCheck(block, dataAdapter, 0, -1);
            RemoveAndCheck(block, 0, 0);
        }

        [Fact, TestCategory("BVT"), TestCategory("Streaming")]
        public void Add2Remove1UntilFull()
        {
            IObjectPool<CachedMessageBlock<TestCachedMessage>> pool = new MyTestPooled();
            ICacheDataAdapter<TestQueueMessage, TestCachedMessage> dataAdapter = new TestCacheDataAdapter();
            CachedMessageBlock<TestCachedMessage> block = pool.Allocate();
            int first = 0;
            int last = -1;

            while (block.HasCapacity)
            {
                // add message to end of block
                AddAndCheck(block, dataAdapter, first, last);
                last++;
                if (!block.HasCapacity)
                {
                    continue;
                }
                // add message to end of block
                AddAndCheck(block, dataAdapter, first, last);
                last++;
                // removed message from start of block
                RemoveAndCheck(block, first, last);
                first++;
            }
            Assert.AreEqual(TestBlockSize / 2, block.OldestMessageIndex);
            Assert.AreEqual(TestBlockSize - 1, block.NewestMessageIndex);
            Assert.IsFalse(block.IsEmpty);
            Assert.IsFalse(block.HasCapacity);
        }

        [Fact, TestCategory("BVT"), TestCategory("Streaming")]
        public void FirstMessageWithSequenceNumberTest()
        {
            IObjectPool<CachedMessageBlock<TestCachedMessage>> pool = new MyTestPooled();
            ICacheDataAdapter<TestQueueMessage, TestCachedMessage> dataAdapter = new TestCacheDataAdapter();
            CachedMessageBlock<TestCachedMessage> block = pool.Allocate();
            int last = -1;
            int sequenceNumber = 0;

            while (block.HasCapacity)
            {
                // add message to end of block
                AddAndCheck(block, dataAdapter, 0, last, sequenceNumber);
                last++;
                sequenceNumber += 2;
            }
            Assert.AreEqual(block.OldestMessageIndex, block.GetIndexOfFirstMessageLessThanOrEqualTo(new EventSequenceToken(0), TestCacheDataComparer.Instance));
            Assert.AreEqual(block.OldestMessageIndex, block.GetIndexOfFirstMessageLessThanOrEqualTo(new EventSequenceToken(1), TestCacheDataComparer.Instance));
            Assert.AreEqual(block.NewestMessageIndex, block.GetIndexOfFirstMessageLessThanOrEqualTo(new EventSequenceToken(sequenceNumber - 2), TestCacheDataComparer.Instance));
            Assert.AreEqual(block.NewestMessageIndex - 1, block.GetIndexOfFirstMessageLessThanOrEqualTo(new EventSequenceToken(sequenceNumber - 3), TestCacheDataComparer.Instance));
            Assert.AreEqual(50, block.GetIndexOfFirstMessageLessThanOrEqualTo(new EventSequenceToken(sequenceNumber / 2), TestCacheDataComparer.Instance));
            Assert.AreEqual(50, block.GetIndexOfFirstMessageLessThanOrEqualTo(new EventSequenceToken(sequenceNumber / 2 + 1), TestCacheDataComparer.Instance));
        }

        [Fact, TestCategory("BVT"), TestCategory("Streaming")]
        public void NextInStreamTest()
        {
            IObjectPool<CachedMessageBlock<TestCachedMessage>> pool = new MyTestPooled();
            ICacheDataAdapter<TestQueueMessage, TestCachedMessage> dataAdapter = new TestCacheDataAdapter();
            CachedMessageBlock<TestCachedMessage> block = pool.Allocate();
            int last = 0;
            int sequenceNumber = 0;
            // define 2 streams
            var streams = new[] { new TestStreamIdentity { Guid = Guid.NewGuid() }, new TestStreamIdentity { Guid = Guid.NewGuid() } };

            // add both streams interleaved, until lock is full
            while (block.HasCapacity)
            {
                var stream = streams[last%2];
                var message = new TestQueueMessage
                {
                    StreamGuid = stream.Guid,
                    SequenceToken = new EventSequenceToken(sequenceNumber)
                };

                // add message to end of block
                AddAndCheck(block, dataAdapter, message, 0, last - 1);
                last++;
                sequenceNumber += 2;
            }

            // get index of first stream
            int streamIndex;
            Assert.IsTrue(block.TryFindFirstMessage(streams[0], TestCacheDataComparer.Instance, out streamIndex));
            Assert.AreEqual(0, streamIndex);
            Assert.AreEqual(0, (block.GetSequenceToken(streamIndex, dataAdapter) as EventSequenceToken).SequenceNumber);

            // find stream1 messages
            int iteration = 1;
            while (block.TryFindNextMessage(streamIndex + 1, streams[0], TestCacheDataComparer.Instance, out streamIndex))
            {
                Assert.AreEqual(iteration * 2, streamIndex);
                Assert.AreEqual(iteration * 4, (block.GetSequenceToken(streamIndex, dataAdapter) as EventSequenceToken).SequenceNumber);
                iteration++;
            }
            Assert.AreEqual(iteration, TestBlockSize / 2);

            // get index of first stream
            Assert.IsTrue(block.TryFindFirstMessage(streams[1], TestCacheDataComparer.Instance, out streamIndex));
            Assert.AreEqual(1, streamIndex);
            Assert.AreEqual(2, (block.GetSequenceToken(streamIndex, dataAdapter) as EventSequenceToken).SequenceNumber);

            // find stream1 messages
            iteration = 1;
            while (block.TryFindNextMessage(streamIndex + 1, streams[1], TestCacheDataComparer.Instance, out streamIndex))
            {
                Assert.AreEqual(iteration * 2 + 1, streamIndex);
                Assert.AreEqual(iteration * 4 + 2, (block.GetSequenceToken(streamIndex, dataAdapter) as EventSequenceToken).SequenceNumber);
                iteration++;
            }
            Assert.AreEqual(iteration, TestBlockSize / 2);
        }

        private void AddAndCheck(CachedMessageBlock<TestCachedMessage> block, ICacheDataAdapter<TestQueueMessage, TestCachedMessage> dataAdapter, int first, int last, int sequenceNumber = 1)
        {
            var message = new TestQueueMessage
            {
                StreamGuid = StreamGuid,
                SequenceToken = new EventSequenceToken(sequenceNumber)
            };
            AddAndCheck(block, dataAdapter, message, first, last);
        }

        private void AddAndCheck(CachedMessageBlock<TestCachedMessage> block, ICacheDataAdapter<TestQueueMessage, TestCachedMessage> dataAdapter, TestQueueMessage message, int first, int last)
        {
            Assert.AreEqual(first, block.OldestMessageIndex);
            Assert.AreEqual(last, block.NewestMessageIndex);
            Assert.IsTrue(block.HasCapacity);

            block.Add(message, dataAdapter);
            last++;

            Assert.AreEqual(first > last, block.IsEmpty);
            Assert.AreEqual(last + 1 < TestBlockSize, block.HasCapacity);
            Assert.AreEqual(first, block.OldestMessageIndex);
            Assert.AreEqual(last, block.NewestMessageIndex);

            Assert.IsTrue(block.GetSequenceToken(last, dataAdapter).Equals(message.SequenceToken));
        }

        private void RemoveAndCheck(CachedMessageBlock<TestCachedMessage> block, int first, int last)
        {
            Assert.AreEqual(first, block.OldestMessageIndex);
            Assert.AreEqual(last, block.NewestMessageIndex);
            Assert.IsFalse(block.IsEmpty);
            Assert.AreEqual(last + 1 < TestBlockSize, block.HasCapacity);

            Assert.IsTrue(block.Remove());

            first++;
            Assert.AreEqual(first > last, block.IsEmpty);
            Assert.AreEqual(last + 1 < TestBlockSize, block.HasCapacity);
            Assert.AreEqual(first, block.OldestMessageIndex);
            Assert.AreEqual(last, block.NewestMessageIndex);
        }
    }
}
