#region License
/* 
 * Copyright (C) 1999-2024 John Källén.
 *
 * This program is free software; you can redistribute it and/or modify
 * it under the terms of the GNU General Public License as published by
 * the Free Software Foundation; either version 2, or (at your option)
 * any later version.
 *
 * This program is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU General Public License for more details.
 *
 * You should have received a copy of the GNU General Public License
 * along with this program; see the file COPYING.  If not, write to
 * the Free Software Foundation, 675 Mass Ave, Cambridge, MA 02139, USA.
 */
#endregion

using Reko.Arch.Arm.AArch32;
using Reko.Core;
using Reko.Core.Expressions;
using Reko.Core.Hll.C;
using Reko.Core.Loading;
using Reko.Core.Machine;
using Reko.Core.Memory;
using Reko.Core.Rtl;
using Reko.Core.Serialization;
using Reko.Core.Services;
using Reko.Core.Types;
using System;
using System.Collections.Generic;
using System.IO;
using System.Linq;

namespace Reko.Environments.Windows
{
    // https://stackoverflow.com/questions/16375355/what-is-the-windows-rt-on-arm-native-code-calling-convention
    public class Win32ThumbPlatform : Platform
    {
        private Dictionary<int, SystemService> systemServices;
        private readonly HashSet<RegisterStorage> implicitRegs;

        public Win32ThumbPlatform(IServiceProvider services, IProcessorArchitecture arch) : 
            base(services, arch, "winArm")
        {
            this.systemServices = new Dictionary<int, SystemService>
            {
                {
                    0x00FE,
                    new SystemService {
                        SyscallInfo = new SyscallInfo {
                            Vector = 0x00FE,
                            RegisterValues = Array.Empty<RegValue>()
                        },
                        Name = "__debugbreak",  // Breaks into the debugger. Used by ntdll!DbgUserBreakPoint(). 
                    }
                },
                { 
                    0x00FC,
                    new SystemService { 
                        SyscallInfo = new SyscallInfo
                        {
                            Vector = 0x00FC,
                            RegisterValues = Array.Empty<RegValue>(),
                        },
                        Name = "__assertfail",  // Used to indicate critical assertion failures in the kernel debugger. Used by KeAccumulateTicks() 
                    }
                },
                 {
                     0x00FB,
                    new SystemService {
                        SyscallInfo = new SyscallInfo
                        {
                            Vector = 0x00FB,
                            RegisterValues = Array.Empty<RegValue>(),
                        },
                        Name = "__fastfail",    // Indicates fast fail conditions resulting in KeBugCheckEx(KERNEL_SECURITY_CHECK_FAILURE). Called by functions like InsertTailList() upon detecting a corrupted list, as described in [9]. 
                        Characteristics = new ProcedureCharacteristics {
                            Terminates = true,
                        }
                    }
                },
                {
                    0x00FA,
                    new SystemService {
                        SyscallInfo = new SyscallInfo {
                            Vector = 0x00FA,
                            RegisterValues = Array.Empty<RegValue>()
                        },
                        Name = "__rdpmccntr64", // Reads the 64-bit performance counter co-processor register and returns the value in R0+R1. Used by ReadTimeStampCounter(), KiCacheFlushTrial() etc. 
                    }
                },
                { 
                    0x00FD,
                    new SystemService {
                        SyscallInfo = new SyscallInfo {
                            Vector = 0x00FD,
                            RegisterValues = Array.Empty<RegValue>(),
                        },
                        Name = "__debugservice", // Invoke debugger breakpoint. Used by DbgBreakPointWithStatusEnd(), DebugPrompt() etc. 
                    }
                },
                {
                    0x00F9,
                    new SystemService
                    {
                        SyscallInfo = new SyscallInfo {
                            Vector = 0x00F9,
                            RegisterValues = Array.Empty<RegValue>(),
                        },
                        Name = "__brkdiv0", //  Divide By Zero Exception, used by functions like nt!_rt_udiv and nt!_rt_udiv. Also generated by the compiler to check the divisor before division operations. 
                        Characteristics = new ProcedureCharacteristics {
                            Terminates = true,
                        }
                    }
                }
            };
            this.implicitRegs = new[] { "r11", "sp", "lr", "pc" }
                .Select(r => Architecture.GetRegister(r)!)
                .ToHashSet();
            this.StructureMemberAlignment = 8;
            this.TrashedRegisters = CreateTrashedRegisters();
        }

        public override string DefaultCallingConvention
        {
            get { return ""; } 
        }

        public override Address AdjustProcedureAddress(Address addr)
        {
            return Address.Ptr32((uint)addr.ToLinear() & ~1u);
        }

        public override CParser CreateCParser(TextReader rdr, ParserState? state)
        {
            state ??= new ParserState();
            var lexer = new CLexer(rdr, CLexer.MsvcCeKeywords);
            var parser = new CParser(state, lexer);
            return parser;
        }

        public override bool IsImplicitArgumentRegister(RegisterStorage reg)
        {
            return implicitRegs.Contains(reg);
        }

        private HashSet<RegisterStorage> CreateTrashedRegisters()
        {
            // https://msdn.microsoft.com/en-us/library/dn736986.aspx 
            return new[] { "r0", "r1", "r2", "r3", "ip" }
                .Select(r => Architecture.GetRegister(r)!).ToHashSet();
        }

        public override CallingConvention GetCallingConvention(string? ccName)
        {
            return new Arm32CallingConvention();
        }

        public override ImageSymbol? FindMainProcedure(Program program, Address addrStart)
        {
            Services.RequireService<IEventListener>().Warn(new NullCodeLocation(program.Name),
                           "Win32 ARM main procedure finder not implemented yet.");
            return null;
        }

        public override SystemService? FindService(int vector, ProcessorState? state, IMemory? memory)
        {
            systemServices.TryGetValue(vector, out SystemService? svc);
            return svc;
        }

        public override int GetBitSizeFromCBasicType(CBasicType cb)
        {
            switch (cb)
            {
            case CBasicType.Bool: return 8;
            case CBasicType.Char: return 8;
            case CBasicType.Short: return 16;
            case CBasicType.Int: return 32;
            case CBasicType.Long: return 32;
            case CBasicType.LongLong: return 64;
            case CBasicType.Float: return 32;
            case CBasicType.Double: return 64;
            case CBasicType.LongDouble: return 64;
            case CBasicType.Int64: return 64;
            default: throw new NotImplementedException(string.Format("C basic type {0} not supported.", cb));
            }
        }

        public override Trampoline? GetTrampolineDestination(Address addrInstr, List<RtlInstructionCluster> instrs, IRewriterHost host)
        {
            //00011644 E59FC000 ldr ip,[0001164C]                                                           ;[pc]
            //00011648 E59CF000 ldr pc,[ip]
            //0001164C AC 50 01 00.P..
            //var instr = rdr.FirstOrDefault();
            //if (instr == null)
            //    return null;
            //if (!(instr is RtlGoto jump))
            //    return null;
            //if (jump.Target is ProcedureConstant pc)
            //    return pc.Procedure;
            //if (!(jump.Target is MemoryAccess access))
            //    return null;
            //var addrTarget = access.EffectiveAddress as Address;
            //if (addrTarget == null)
            //{
            //    if (!(access.EffectiveAddress is Constant wAddr))
            //    {
            //        return null;
            //    }
            //    addrTarget = MakeAddressFromConstant(wAddr, true);
            //}
            //ProcedureBase proc = host.GetImportedProcedure(this.Architecture, addrTarget, addrInstr);
            //if (proc != null)
            //    return proc;
            //return host.GetInterceptedCall(this.Architecture, addrTarget);
            var cl = instrs.Take(3).ToArray();
            return null;
        }

        public override ProcedureBase? GetTrampolineDestination(Address addrInstr, IEnumerable<RtlInstruction> instrs, IRewriterHost host)
        {
            var trampInstrs = instrs.Take(2).ToArray();
            if (trampInstrs.Length != 2)
                return null;
            //000116C8 E59FC000 ldr ip,[000116D0]; [pc]
            //000116CC E59CF000 ldr pc,[ip]
            //000116D0 7C 50 01 00
            if (trampInstrs[0] is RtlAssignment ass && 
                ass.Dst is Identifier ip &&
                ass.Src is MemoryAccess mem && 
                mem.DataType is PrimitiveType pt &&
                mem.EffectiveAddress is Address addrFnPtr &&

                trampInstrs[1] is RtlGoto g &&
                g.Target is MemoryAccess mem2 && 
                mem2.EffectiveAddress == ip)
            {
                if (!host.TryRead(this.Architecture, addrFnPtr, pt, out var iatEntry))
                    return null;
                var addrTarget = Architecture.MakeAddressFromConstant(iatEntry, true);
                ProcedureBase? proc = host.GetImportedProcedure(this.Architecture, addrTarget, addrInstr);
                if (proc is not null)
                    return proc;
                return host.GetInterceptedCall(this.Architecture, addrTarget);
            }
            return null;
        }


        public override ExternalProcedure? LookupProcedureByName(string? moduleName, string procName)
        {
            var metadata = EnsureTypeLibraries(PlatformIdentifier);
            if (moduleName != null && metadata.Modules.TryGetValue(moduleName.ToUpper(), out ModuleDescriptor? mod))
            {
                if (mod.ServicesByName.TryGetValue(procName, out SystemService? svc))
                {
                    var chr = LookupCharacteristicsByName(svc.Name!);
                    return new ExternalProcedure(svc.Name!, svc.Signature!, chr);
                }
                else
                {
                    return null;
                }
            }
            else
            {
                if (!metadata.Signatures.TryGetValue(procName, out FunctionType? sig))
                    return null;
                var chr = LookupCharacteristicsByName(procName);
                return new ExternalProcedure(procName, sig, chr);
            }
        }

        // http://codemachine.com/article_armasm.html

        //0xDEFE __debugbreak Breaks into the debugger. Used by ntdll!DbgUserBreakPoint(). 
        //0xDEFC __assertfail Used to indicate critical assertion failures in the kernel debugger. Used by KeAccumulateTicks() 
        //0xDEFB __fastfail Indicates fast fail conditions resulting in KeBugCheckEx(KERNEL_SECURITY_CHECK_FAILURE). Called by functions like InsertTailList() upon detecting a corrupted list, as described in [9]. 
        //0xDEFA __rdpmccntr64 Reads the 64-bit performance counter co-processor register and returns the value in R0+R1. Used by ReadTimeStampCounter(), KiCacheFlushTrial() etc. 
        //0xDEFD __debugservice Invoke debugger breakpoint. Used by DbgBreakPointWithStatusEnd(), DebugPrompt() etc. 
        //0xDEF9 __brkdiv0 Divide By Zero Exception, used by functions like nt!_rt_udiv and nt!_rt_udiv. Also generated by the compiler to check the divisor before division operations. 

    }
}
