package io.quarkus.undertow.runtime;

import java.io.IOException;
import java.math.BigInteger;
import java.net.SocketAddress;
import java.nio.file.Path;
import java.security.SecureRandom;
import java.util.ArrayList;
import java.util.EventListener;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.concurrent.CopyOnWriteArrayList;
import java.util.concurrent.ExecutorService;
import java.util.function.Supplier;

import javax.servlet.AsyncEvent;
import javax.servlet.AsyncListener;
import javax.servlet.DispatcherType;
import javax.servlet.Filter;
import javax.servlet.MultipartConfigElement;
import javax.servlet.Servlet;
import javax.servlet.ServletContainerInitializer;
import javax.servlet.ServletContext;
import javax.servlet.ServletException;

import org.jboss.logging.Logger;

import io.netty.buffer.ByteBuf;
import io.quarkus.arc.InjectableContext;
import io.quarkus.arc.ManagedContext;
import io.quarkus.arc.runtime.BeanContainer;
import io.quarkus.runtime.LaunchMode;
import io.quarkus.runtime.RuntimeValue;
import io.quarkus.runtime.ShutdownContext;
import io.quarkus.runtime.annotations.Recorder;
import io.quarkus.runtime.configuration.MemorySize;
import io.quarkus.vertx.http.runtime.HttpConfiguration;
import io.undertow.httpcore.BufferAllocator;
import io.undertow.httpcore.StatusCodes;
import io.undertow.server.DefaultExchangeHandler;
import io.undertow.server.HandlerWrapper;
import io.undertow.server.HttpHandler;
import io.undertow.server.HttpServerExchange;
import io.undertow.server.handlers.PathHandler;
import io.undertow.server.handlers.ResponseCodeHandler;
import io.undertow.server.handlers.resource.CachingResourceManager;
import io.undertow.server.handlers.resource.ClassPathResourceManager;
import io.undertow.server.handlers.resource.PathResourceManager;
import io.undertow.server.handlers.resource.ResourceManager;
import io.undertow.server.session.SessionIdGenerator;
import io.undertow.servlet.ServletExtension;
import io.undertow.servlet.Servlets;
import io.undertow.servlet.api.ClassIntrospecter;
import io.undertow.servlet.api.DeploymentInfo;
import io.undertow.servlet.api.DeploymentManager;
import io.undertow.servlet.api.ErrorPage;
import io.undertow.servlet.api.FilterInfo;
import io.undertow.servlet.api.InstanceFactory;
import io.undertow.servlet.api.InstanceHandle;
import io.undertow.servlet.api.ListenerInfo;
import io.undertow.servlet.api.ServletContainer;
import io.undertow.servlet.api.ServletContainerInitializerInfo;
import io.undertow.servlet.api.ServletInfo;
import io.undertow.servlet.api.ServletSecurityInfo;
import io.undertow.servlet.api.ThreadSetupHandler;
import io.undertow.servlet.handlers.DefaultServlet;
import io.undertow.servlet.handlers.ServletPathMatches;
import io.undertow.servlet.handlers.ServletRequestContext;
import io.undertow.servlet.spec.HttpServletRequestImpl;
import io.undertow.util.AttachmentKey;
import io.undertow.vertx.VertxHttpExchange;
import io.vertx.core.Handler;
import io.vertx.core.http.HttpServerRequest;
import io.vertx.core.net.impl.PartialPooledByteBufAllocator;

/**
 * Provides the runtime methods to bootstrap Undertow. This class is present in the final uber-jar,
 * and is invoked from generated bytecode
 */
@Recorder
public class UndertowDeploymentRecorder {

    private static final Logger log = Logger.getLogger("io.quarkus.undertow");

    public static final HttpHandler ROOT_HANDLER = new HttpHandler() {
        @Override
        public void handleRequest(HttpServerExchange exchange) throws Exception {
            currentRoot.handleRequest(exchange);
        }
    };

    private static final List<HandlerWrapper> hotDeploymentWrappers = new CopyOnWriteArrayList<>();
    private static volatile List<Path> hotDeploymentResourcePaths;
    private static volatile HttpHandler currentRoot = ResponseCodeHandler.HANDLE_404;
    private static volatile ServletContext servletContext;

    private static final AttachmentKey<InjectableContext.ContextState> REQUEST_CONTEXT = AttachmentKey
            .create(InjectableContext.ContextState.class);

    protected static final int DEFAULT_BUFFER_SIZE;
    protected static final boolean DEFAULT_DIRECT_BUFFERS;

    static {
        long maxMemory = Runtime.getRuntime().maxMemory();
        //smaller than 64mb of ram we use 512b buffers
        if (maxMemory < 64 * 1024 * 1024) {
            //use 512b buffers
            DEFAULT_DIRECT_BUFFERS = false;
            DEFAULT_BUFFER_SIZE = 512;
        } else if (maxMemory < 128 * 1024 * 1024) {
            //use 1k buffers
            DEFAULT_DIRECT_BUFFERS = true;
            DEFAULT_BUFFER_SIZE = 1024;
        } else {
            //use 16k buffers for best performance
            //as 16k is generally the max amount of data that can be sent in a single write() call
            DEFAULT_DIRECT_BUFFERS = true;
            DEFAULT_BUFFER_SIZE = 1024 * 16 - 20; //the 20 is to allow some space for protocol headers, see UNDERTOW-1209
        }

    }

    public static void setHotDeploymentResources(List<Path> resources) {
        hotDeploymentResourcePaths = resources;
    }

    public RuntimeValue<DeploymentInfo> createDeployment(String name, Set<String> knownFile, Set<String> knownDirectories,
            LaunchMode launchMode, ShutdownContext context, String contextPath) {
        DeploymentInfo d = new DeploymentInfo();
        d.setSessionIdGenerator(new QuarkusSessionIdGenerator());
        d.setClassLoader(getClass().getClassLoader());
        d.setDeploymentName(name);
        d.setContextPath(contextPath);
        d.setEagerFilterInit(true);
        ClassLoader cl = Thread.currentThread().getContextClassLoader();
        if (cl == null) {
            cl = new ClassLoader() {
            };
        }
        d.setClassLoader(cl);
        //TODO: we need better handling of static resources
        ResourceManager resourceManager;
        if (hotDeploymentResourcePaths == null) {
            resourceManager = new KnownPathResourceManager(knownFile, knownDirectories,
                    new ClassPathResourceManager(d.getClassLoader(), "META-INF/resources"));
        } else {
            List<ResourceManager> managers = new ArrayList<>();
            for (Path i : hotDeploymentResourcePaths) {
                managers.add(new PathResourceManager(i));
            }
            managers.add(new ClassPathResourceManager(d.getClassLoader(), "META-INF/resources"));
            resourceManager = new DelegatingResourceManager(managers.toArray(new ResourceManager[0]));
        }

        if (launchMode == LaunchMode.NORMAL) {
            //todo: cache configuration
            resourceManager = new CachingResourceManager(1000, 0, null, resourceManager, 2000);
        }
        d.setResourceManager(resourceManager);

        d.addWelcomePages("index.html", "index.htm");

        d.addServlet(new ServletInfo(ServletPathMatches.DEFAULT_SERVLET_NAME, DefaultServlet.class).setAsyncSupported(true));
        for (HandlerWrapper i : hotDeploymentWrappers) {
            d.addOuterHandlerChainWrapper(i);
        }
        context.addShutdownTask(new Runnable() {
            @Override
            public void run() {
                try {
                    d.getResourceManager().close();
                } catch (IOException e) {
                    log.error("Failed to close Servlet ResourceManager", e);
                }
            }
        });
        return new RuntimeValue<>(d);
    }

    public static SocketAddress getHttpAddress() {
        return null;
    }

    public RuntimeValue<ServletInfo> registerServlet(RuntimeValue<DeploymentInfo> deploymentInfo,
            String name,
            Class<?> servletClass,
            boolean asyncSupported,
            int loadOnStartup,
            BeanContainer beanContainer, Map<String, String> initParams,
            InstanceFactory<? extends Servlet> instanceFactory) throws Exception {

        InstanceFactory<? extends Servlet> factory = instanceFactory != null ? instanceFactory
                : new QuarkusInstanceFactory(beanContainer.instanceFactory(servletClass));
        ServletInfo servletInfo = new ServletInfo(name, (Class<? extends Servlet>) servletClass,
                factory);
        for (Map.Entry<String, String> e : initParams.entrySet()) {
            servletInfo.addInitParam(e.getKey(), e.getValue());
        }
        deploymentInfo.getValue().addServlet(servletInfo);
        servletInfo.setAsyncSupported(asyncSupported);
        if (loadOnStartup > 0) {
            servletInfo.setLoadOnStartup(loadOnStartup);
        }
        return new RuntimeValue<>(servletInfo);
    }

    public void addServletInitParam(RuntimeValue<ServletInfo> info, String name, String value) {
        info.getValue().addInitParam(name, value);
    }

    public void addServletMapping(RuntimeValue<DeploymentInfo> info, String name, String mapping) throws Exception {
        ServletInfo sv = info.getValue().getServlets().get(name);
        sv.addMapping(mapping);
    }

    public void setMultipartConfig(RuntimeValue<ServletInfo> sref, String location, long fileSize, long maxRequestSize,
            int fileSizeThreshold) {
        MultipartConfigElement mp = new MultipartConfigElement(location, fileSize, maxRequestSize, fileSizeThreshold);
        sref.getValue().setMultipartConfig(mp);
    }

    /**
     * @param sref
     * @param securityInfo
     */
    public void setSecurityInfo(RuntimeValue<ServletInfo> sref, ServletSecurityInfo securityInfo) {
        sref.getValue().setServletSecurityInfo(securityInfo);
    }

    /**
     * @param sref
     * @param roleName
     * @param roleLink
     */
    public void addSecurityRoleRef(RuntimeValue<ServletInfo> sref, String roleName, String roleLink) {
        sref.getValue().addSecurityRoleRef(roleName, roleLink);
    }

    public RuntimeValue<FilterInfo> registerFilter(RuntimeValue<DeploymentInfo> info,
            String name, Class<?> filterClass,
            boolean asyncSupported,
            BeanContainer beanContainer,
            Map<String, String> initParams,
            InstanceFactory<? extends Filter> instanceFactory) throws Exception {

        InstanceFactory<? extends Filter> factory = instanceFactory != null ? instanceFactory
                : new QuarkusInstanceFactory(beanContainer.instanceFactory(filterClass));
        FilterInfo filterInfo = new FilterInfo(name, (Class<? extends Filter>) filterClass, factory);

        for (Map.Entry<String, String> e : initParams.entrySet()) {
            filterInfo.addInitParam(e.getKey(), e.getValue());
        }
        info.getValue().addFilter(filterInfo);
        filterInfo.setAsyncSupported(asyncSupported);
        return new RuntimeValue<>(filterInfo);
    }

    public void addFilterInitParam(RuntimeValue<FilterInfo> info, String name, String value) {
        info.getValue().addInitParam(name, value);
    }

    public void addFilterURLMapping(RuntimeValue<DeploymentInfo> info, String name, String mapping,
            DispatcherType dispatcherType) throws Exception {
        info.getValue().addFilterUrlMapping(name, mapping, dispatcherType);
    }

    public void addFilterServletNameMapping(RuntimeValue<DeploymentInfo> info, String name, String mapping,
            DispatcherType dispatcherType) throws Exception {
        info.getValue().addFilterServletNameMapping(name, mapping, dispatcherType);
    }

    public void registerListener(RuntimeValue<DeploymentInfo> info, Class<?> listenerClass, BeanContainer factory) {
        info.getValue()
                .addListener(new ListenerInfo((Class<? extends EventListener>) listenerClass,
                        (InstanceFactory<? extends EventListener>) new QuarkusInstanceFactory<>(
                                factory.instanceFactory(listenerClass))));
    }

    public void addServletInitParameter(RuntimeValue<DeploymentInfo> info, String name, String value) {
        info.getValue().addInitParameter(name, value);
    }

    public Handler<HttpServerRequest> startUndertow(ShutdownContext shutdown, ExecutorService executorService,
            DeploymentManager manager, List<HandlerWrapper> wrappers, HttpConfiguration httpConfiguration,
            ServletRuntimeConfig servletRuntimeConfig) throws Exception {

        shutdown.addShutdownTask(new Runnable() {
            @Override
            public void run() {
                try {
                    manager.stop();
                } catch (ServletException e) {
                    log.error("Failed to stop deployment", e);
                }
                manager.undeploy();
            }
        });
        HttpHandler main = manager.getDeployment().getHandler();
        for (HandlerWrapper i : wrappers) {
            main = i.wrap(main);
        }
        if (!manager.getDeployment().getDeploymentInfo().getContextPath().equals("/")) {
            PathHandler pathHandler = new PathHandler()
                    .addPrefixPath(manager.getDeployment().getDeploymentInfo().getContextPath(), main);
            main = pathHandler;
        }
        currentRoot = main;

        DefaultExchangeHandler defaultHandler = new DefaultExchangeHandler(ROOT_HANDLER);

        UndertowBufferAllocator allocator = new UndertowBufferAllocator(
                servletRuntimeConfig.directBuffers.orElse(DEFAULT_DIRECT_BUFFERS), (int) servletRuntimeConfig.bufferSize
                        .orElse(new MemorySize(BigInteger.valueOf(DEFAULT_BUFFER_SIZE))).asLongValue());
        return new Handler<HttpServerRequest>() {
            @Override
            public void handle(HttpServerRequest event) {
                VertxHttpExchange exchange = new VertxHttpExchange(event, allocator, executorService, null);
                Optional<MemorySize> maxBodySize = httpConfiguration.limits.maxBodySize;
                if (maxBodySize.isPresent()) {
                    exchange.setMaxEntitySize(maxBodySize.get().asLongValue());
                }
                defaultHandler.handle(exchange);
            }
        };
    }

    public static void addHotDeploymentWrapper(HandlerWrapper handlerWrapper) {
        hotDeploymentWrappers.add(handlerWrapper);
    }

    public Supplier<ServletContext> servletContextSupplier() {
        return new ServletContextSupplier();
    }

    public DeploymentManager bootServletContainer(RuntimeValue<DeploymentInfo> info, BeanContainer beanContainer,
            LaunchMode launchMode, ShutdownContext shutdownContext) {
        if (info.getValue().getExceptionHandler() == null) {
            //if a 500 error page has not been mapped we change the default to our more modern one, with a UID in the
            //log. If this is not production we also include the stack trace
            boolean alreadyMapped = false;
            for (ErrorPage i : info.getValue().getErrorPages()) {
                if (i.getErrorCode() != null && i.getErrorCode() == StatusCodes.INTERNAL_SERVER_ERROR) {
                    alreadyMapped = true;
                    break;
                }
            }
            if (!alreadyMapped || launchMode.isDevOrTest()) {
                info.getValue().setExceptionHandler(new QuarkusExceptionHandler());
                info.getValue().addErrorPage(new ErrorPage("/@QuarkusError", StatusCodes.INTERNAL_SERVER_ERROR));
                info.getValue().addServlet(new ServletInfo("@QuarkusError", QuarkusErrorServlet.class)
                        .addMapping("/@QuarkusError").setAsyncSupported(true)
                        .addInitParam(QuarkusErrorServlet.SHOW_STACK, Boolean.toString(launchMode.isDevOrTest())));
            }
        }

        try {
            ClassIntrospecter defaultVal = info.getValue().getClassIntrospecter();
            info.getValue().setClassIntrospecter(new ClassIntrospecter() {
                @Override
                public <T> InstanceFactory<T> createInstanceFactory(Class<T> clazz) throws NoSuchMethodException {
                    BeanContainer.Factory<T> res = beanContainer.instanceFactory(clazz);
                    if (res == null) {
                        return defaultVal.createInstanceFactory(clazz);
                    }
                    return new InstanceFactory<T>() {
                        @Override
                        public InstanceHandle<T> createInstance() throws InstantiationException {
                            BeanContainer.Instance<T> ih = res.create();
                            return new InstanceHandle<T>() {
                                @Override
                                public T getInstance() {
                                    return ih.get();
                                }

                                @Override
                                public void release() {
                                    ih.close();
                                }
                            };
                        }
                    };
                }
            });
            ServletContainer servletContainer = Servlets.defaultContainer();
            DeploymentManager manager = servletContainer.addDeployment(info.getValue());
            manager.deploy();
            manager.start();
            servletContext = manager.getDeployment().getServletContext();
            shutdownContext.addShutdownTask(new Runnable() {
                @Override
                public void run() {
                    servletContext = null;
                }
            });
            return manager;
        } catch (Exception e) {
            throw new RuntimeException(e);
        }
    }

    public void addServletContextAttribute(RuntimeValue<DeploymentInfo> deployment, String key, Object value1) {
        deployment.getValue().addServletContextAttribute(key, value1);
    }

    public void addServletExtension(RuntimeValue<DeploymentInfo> deployment, ServletExtension extension) {
        deployment.getValue().addServletExtension(extension);
    }

    public ServletExtension setupRequestScope(BeanContainer beanContainer) {
        return new ServletExtension() {
            @Override
            public void handleDeployment(DeploymentInfo deploymentInfo, ServletContext servletContext) {
                deploymentInfo.addThreadSetupAction(new ThreadSetupHandler() {
                    @Override
                    public <T, C> ThreadSetupHandler.Action<T, C> create(Action<T, C> action) {
                        return new Action<T, C>() {
                            @Override
                            public T call(HttpServerExchange exchange, C context) throws Exception {
                                // Not sure what to do here
                                if (exchange == null) {
                                    return action.call(exchange, context);
                                }
                                ManagedContext requestContext = beanContainer.requestContext();
                                if (requestContext.isActive()) {
                                    return action.call(exchange, context);
                                } else {
                                    InjectableContext.ContextState existingRequestContext = exchange
                                            .getAttachment(REQUEST_CONTEXT);
                                    try {
                                        requestContext.activate(existingRequestContext);
                                        return action.call(exchange, context);
                                    } finally {
                                        ServletRequestContext src = exchange
                                                .getAttachment(ServletRequestContext.ATTACHMENT_KEY);
                                        HttpServletRequestImpl req = src.getOriginalRequest();
                                        if (req.isAsyncStarted()) {
                                            exchange.putAttachment(REQUEST_CONTEXT, requestContext.getState());
                                            requestContext.deactivate();
                                            if (existingRequestContext == null) {
                                                req.getAsyncContextInternal().addListener(new AsyncListener() {
                                                    @Override
                                                    public void onComplete(AsyncEvent event) throws IOException {
                                                        requestContext.activate(exchange
                                                                .getAttachment(REQUEST_CONTEXT));
                                                        requestContext.terminate();
                                                    }

                                                    @Override
                                                    public void onTimeout(AsyncEvent event) throws IOException {
                                                        onComplete(event);
                                                    }

                                                    @Override
                                                    public void onError(AsyncEvent event) throws IOException {
                                                        onComplete(event);
                                                    }

                                                    @Override
                                                    public void onStartAsync(AsyncEvent event) throws IOException {

                                                    }
                                                });
                                            }
                                        } else {
                                            requestContext.terminate();
                                        }
                                    }
                                }
                            }
                        };
                    }
                });
            }
        };
    }

    public void addServletContainerInitializer(RuntimeValue<DeploymentInfo> deployment,
            Class<? extends ServletContainerInitializer> sciClass, Set<Class<?>> handlesTypes) {
        deployment.getValue().addServletContainerInitializer(new ServletContainerInitializerInfo(sciClass, handlesTypes));
    }

    public void addContextParam(RuntimeValue<DeploymentInfo> deployment, String paramName, String paramValue) {
        deployment.getValue().addInitParameter(paramName, paramValue);
    }

    /**
     * we can't have SecureRandom in the native image heap, so we need to lazy init
     */
    private static class QuarkusSessionIdGenerator implements SessionIdGenerator {

        private volatile SecureRandom random;

        private volatile int length = 30;

        private static final char[] SESSION_ID_ALPHABET;

        private static final String ALPHABET_PROPERTY = "io.undertow.server.session.SecureRandomSessionIdGenerator.ALPHABET";

        static {
            String alphabet = System.getProperty(ALPHABET_PROPERTY,
                    "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-_");
            if (alphabet.length() != 64) {
                throw new RuntimeException(
                        "io.undertow.server.session.SecureRandomSessionIdGenerator must be exactly 64 characters long");
            }
            SESSION_ID_ALPHABET = alphabet.toCharArray();
        }

        @Override
        public String createSessionId() {
            if (random == null) {
                random = new SecureRandom();
            }
            final byte[] bytes = new byte[length];
            random.nextBytes(bytes);
            return new String(encode(bytes));
        }

        public int getLength() {
            return length;
        }

        public void setLength(final int length) {
            this.length = length;
        }

        /**
         * Encode the bytes into a String with a slightly modified Base64-algorithm
         * This code was written by Kevin Kelley <kelley@ruralnet.net>
         * and adapted by Thomas Peuss <jboss@peuss.de>
         *
         * @param data The bytes you want to encode
         * @return the encoded String
         */
        private char[] encode(byte[] data) {
            char[] out = new char[((data.length + 2) / 3) * 4];
            char[] alphabet = SESSION_ID_ALPHABET;
            //
            // 3 bytes encode to 4 chars.  Output is always an even
            // multiple of 4 characters.
            //
            for (int i = 0, index = 0; i < data.length; i += 3, index += 4) {
                boolean quad = false;
                boolean trip = false;

                int val = (0xFF & (int) data[i]);
                val <<= 8;
                if ((i + 1) < data.length) {
                    val |= (0xFF & (int) data[i + 1]);
                    trip = true;
                }
                val <<= 8;
                if ((i + 2) < data.length) {
                    val |= (0xFF & (int) data[i + 2]);
                    quad = true;
                }
                out[index + 3] = alphabet[(quad ? (val & 0x3F) : 63)];
                val >>= 6;
                out[index + 2] = alphabet[(trip ? (val & 0x3F) : 63)];
                val >>= 6;
                out[index + 1] = alphabet[val & 0x3F];
                val >>= 6;
                out[index] = alphabet[val & 0x3F];
            }
            return out;
        }
    }

    public static class ServletContextSupplier implements Supplier<ServletContext> {

        @Override
        public ServletContext get() {
            return servletContext;
        }
    }

    private static class UndertowBufferAllocator implements BufferAllocator {

        private final boolean defaultDirectBuffers;
        private final int defaultBufferSize;

        private UndertowBufferAllocator(boolean defaultDirectBuffers, int defaultBufferSize) {
            this.defaultDirectBuffers = defaultDirectBuffers;
            this.defaultBufferSize = defaultBufferSize;
        }

        @Override
        public ByteBuf allocateBuffer() {
            return allocateBuffer(defaultDirectBuffers);
        }

        @Override
        public ByteBuf allocateBuffer(boolean direct) {
            if (direct) {
                return PartialPooledByteBufAllocator.DEFAULT.directBuffer(defaultBufferSize);
            } else {
                return PartialPooledByteBufAllocator.DEFAULT.heapBuffer(defaultBufferSize);
            }
        }

        @Override
        public ByteBuf allocateBuffer(int bufferSize) {
            return allocateBuffer(defaultDirectBuffers, bufferSize);
        }

        @Override
        public ByteBuf allocateBuffer(boolean direct, int bufferSize) {
            if (direct) {
                return PartialPooledByteBufAllocator.DEFAULT.directBuffer(bufferSize);
            } else {
                return PartialPooledByteBufAllocator.DEFAULT.heapBuffer(bufferSize);
            }
        }

        @Override
        public int getBufferSize() {
            return defaultBufferSize;
        }
    }
}
