﻿using System.Collections.Concurrent;
using Bit.Core.AdminConsole.Entities;
using Bit.Core.AdminConsole.Entities.Provider;
using Bit.Core.AdminConsole.Models.Data.Provider;
using Bit.Core.AdminConsole.Repositories;
using Bit.Core.Models.Data.Organizations;
using Bit.Core.Repositories;

namespace Bit.Core.AdminConsole.AbilitiesCache;

public class VNextInMemoryApplicationCacheService(
    IOrganizationRepository organizationRepository,
    IProviderRepository providerRepository,
    TimeProvider timeProvider) : IVNextInMemoryApplicationCacheService
{
    private ConcurrentDictionary<Guid, OrganizationAbility> _orgAbilities = new();
    private readonly SemaphoreSlim _orgInitLock = new(1, 1);
    private DateTimeOffset _lastOrgAbilityRefresh = DateTimeOffset.MinValue;

    private ConcurrentDictionary<Guid, ProviderAbility> _providerAbilities = new();
    private readonly SemaphoreSlim _providerInitLock = new(1, 1);
    private DateTimeOffset _lastProviderAbilityRefresh = DateTimeOffset.MinValue;

    private readonly TimeSpan _refreshInterval = TimeSpan.FromMinutes(10);

    public virtual async Task<IDictionary<Guid, OrganizationAbility>> GetOrganizationAbilitiesAsync()
    {
        await InitOrganizationAbilitiesAsync();
        return _orgAbilities;
    }

    public async Task<OrganizationAbility?> GetOrganizationAbilityAsync(Guid organizationId)
    {
        (await GetOrganizationAbilitiesAsync())
            .TryGetValue(organizationId, out var organizationAbility);
        return organizationAbility;
    }

    public virtual async Task<IDictionary<Guid, ProviderAbility>> GetProviderAbilitiesAsync()
    {
        await InitProviderAbilitiesAsync();
        return _providerAbilities;
    }

    public virtual async Task UpsertProviderAbilityAsync(Provider provider)
    {
        await InitProviderAbilitiesAsync();
        _providerAbilities.AddOrUpdate(
            provider.Id,
            static (_, provider) => new ProviderAbility(provider),
            static (_, _, provider) => new ProviderAbility(provider),
            provider);
    }

    public virtual async Task UpsertOrganizationAbilityAsync(Organization organization)
    {
        await InitOrganizationAbilitiesAsync();

        _orgAbilities.AddOrUpdate(
            organization.Id,
            static (_, organization) => new OrganizationAbility(organization),
            static (_, _, organization) => new OrganizationAbility(organization),
            organization);
    }

    public virtual Task DeleteOrganizationAbilityAsync(Guid organizationId)
    {
        _orgAbilities.TryRemove(organizationId, out _);
        return Task.CompletedTask;
    }

    public virtual Task DeleteProviderAbilityAsync(Guid providerId)
    {
        _providerAbilities.TryRemove(providerId, out _);
        return Task.CompletedTask;
    }

    private async Task InitOrganizationAbilitiesAsync() =>
        await InitAbilitiesAsync<OrganizationAbility>(
            dict => _orgAbilities = dict,
            () => _lastOrgAbilityRefresh,
            dt => _lastOrgAbilityRefresh = dt,
            _orgInitLock,
            async () => await organizationRepository.GetManyAbilitiesAsync(),
            _refreshInterval,
            ability => ability.Id);

    private async Task InitProviderAbilitiesAsync() =>
       await InitAbilitiesAsync<ProviderAbility>(
            dict => _providerAbilities = dict,
            () => _lastProviderAbilityRefresh,
            dateTime => _lastProviderAbilityRefresh = dateTime,
            _providerInitLock,
            async () => await providerRepository.GetManyAbilitiesAsync(),
            _refreshInterval,
            ability => ability.Id);


    private async Task InitAbilitiesAsync<TAbility>(
        Action<ConcurrentDictionary<Guid, TAbility>> setCache,
        Func<DateTimeOffset> getLastRefresh,
        Action<DateTimeOffset> setLastRefresh,
        SemaphoreSlim @lock,
        Func<Task<IEnumerable<TAbility>>> fetchFunc,
        TimeSpan refreshInterval,
        Func<TAbility, Guid> getId)
    {
        if (SkipRefresh())
        {
            return;
        }

        await @lock.WaitAsync();
        try
        {
            if (SkipRefresh())
            {
                return;
            }

            var sources = await fetchFunc();
            var abilities = new ConcurrentDictionary<Guid, TAbility>(
                sources.ToDictionary(getId));
            setCache(abilities);
            setLastRefresh(timeProvider.GetUtcNow());
        }
        finally
        {
            @lock.Release();
        }

        bool SkipRefresh()
        {
            return timeProvider.GetUtcNow() - getLastRefresh() <= refreshInterval;
        }
    }
}
