// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements.  See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership.  The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License.  You may obtain a copy of the License at
//
//   http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied.  See the License for the
// specific language governing permissions and limitations
// under the License.

#include "gizmosql_security.h"
#include "gizmosql_logging.h"
#include "request_ctx.h"

namespace fs = std::filesystem;

using arrow::Status;

namespace gizmosql {
const std::string kServerJWTIssuer = "gizmosql";
const int kJWTExpiration = 24 * 3600;
const std::string kValidUsername = "gizmosql_username";
const std::string kTokenUsername = "token";
const std::string kBasicPrefix = "Basic ";
const std::string kBearerPrefix = "Bearer ";
const std::string kAuthHeader = "authorization";

// ----------------------------------------
Status SecurityUtilities::FlightServerTlsCertificates(
    const fs::path& cert_path, const fs::path& key_path,
    std::vector<flight::CertKeyPair>* out) {
  GIZMOSQL_LOG(INFO) << "Using TLS Cert file: " << cert_path;
  GIZMOSQL_LOG(INFO) << "Using TLS Key file: " << key_path;

  *out = std::vector<flight::CertKeyPair>();
  try {
    std::ifstream cert_file(cert_path);
    if (!cert_file) {
      return Status::IOError("Could not open certificate: " + cert_path.string());
    }
    std::stringstream cert;
    cert << cert_file.rdbuf();

    std::ifstream key_file(key_path);
    if (!key_file) {
      return Status::IOError("Could not open key: " + key_path.string());
    }
    std::stringstream key;
    key << key_file.rdbuf();

    out->push_back(flight::CertKeyPair{cert.str(), key.str()});
  } catch (const std::ifstream::failure& e) {
    return Status::IOError(e.what());
  }
  return Status::OK();
}

Status SecurityUtilities::FlightServerMtlsCACertificate(const std::string& cert_path,
                                                        std::string* out) {
  try {
    std::ifstream cert_file(cert_path);
    if (!cert_file) {
      return Status::IOError("Could not open MTLS CA certificate: " + cert_path);
    }
    std::stringstream cert;
    cert << cert_file.rdbuf();

    *out = cert.str();
  } catch (const std::ifstream::failure& e) {
    return Status::IOError(e.what());
  }
  return Status::OK();
}

// Function to look in CallHeaders for a key that has a value starting with prefix and
// return the rest of the value after the prefix.
std::string SecurityUtilities::FindKeyValPrefixInCallHeaders(
    const flight::CallHeaders& incoming_headers, const std::string& key,
    const std::string& prefix) {
  // Lambda function to compare characters without case sensitivity.
  auto char_compare = [](const char& char1, const char& char2) {
    return (::toupper(char1) == ::toupper(char2));
  };

  auto iter = incoming_headers.find(key);
  if (iter == incoming_headers.end()) {
    return "";
  }
  const std::string val(iter->second);
  if (val.size() > prefix.length()) {
    if (std::equal(val.begin(), val.begin() + prefix.length(), prefix.begin(),
                   char_compare)) {
      return val.substr(prefix.length());
    }
  }
  return "";
}

Status SecurityUtilities::GetAuthHeaderType(const flight::CallHeaders& incoming_headers,
                                            std::string* out) {
  if (!FindKeyValPrefixInCallHeaders(incoming_headers, kAuthHeader, kBasicPrefix)
           .empty()) {
    *out = "Basic";
  } else if (!FindKeyValPrefixInCallHeaders(incoming_headers, kAuthHeader, kBearerPrefix)
                  .empty()) {
    *out = "Bearer";
  } else {
    return Status::IOError("Invalid Authorization Header type!");
  }
  return Status::OK();
}

void SecurityUtilities::ParseBasicHeader(const flight::CallHeaders& incoming_headers,
                                         std::string& username, std::string& password) {
  std::string encoded_credentials =
      FindKeyValPrefixInCallHeaders(incoming_headers, kAuthHeader, kBasicPrefix);
  std::stringstream decoded_stream(arrow::util::base64_decode(encoded_credentials));
  std::getline(decoded_stream, username, ':');
  std::getline(decoded_stream, password, ':');
}

// ----------------------------------------
BasicAuthServerMiddleware::BasicAuthServerMiddleware(const std::string& username,
                                                     const std::string& role,
                                                     const std::string& auth_method,
                                                     const std::string& secret_key)
    : username_(username),
      role_(role),
      auth_method_(auth_method),
      secret_key_(secret_key) {}

void BasicAuthServerMiddleware::SendingHeaders(flight::AddCallHeaders* outgoing_headers) {
  auto token = CreateJWTToken();
  outgoing_headers->AddHeader(kAuthHeader, std::string(kBearerPrefix) + token);
}

void BasicAuthServerMiddleware::CallCompleted(const Status& status) {}

std::string BasicAuthServerMiddleware::name() const {
  return "BasicAuthServerMiddleware";
}

std::string BasicAuthServerMiddleware::CreateJWTToken() const {
  auto token =
      jwt::create()
          .set_issuer(std::string(kServerJWTIssuer))
          .set_type("JWT")
          .set_id("gizmosql-server-" +
                  boost::uuids::to_string(boost::uuids::random_generator()()))
          .set_issued_at(std::chrono::system_clock::now())
          .set_expires_at(std::chrono::system_clock::now() +
                          std::chrono::seconds{kJWTExpiration})
          .set_payload_claim("sub", jwt::claim(username_))
          .set_payload_claim("role", jwt::claim(role_))
          .set_payload_claim("auth_method", jwt::claim(auth_method_))
          .set_payload_claim(
              "session_id",
              jwt::claim(boost::uuids::to_string(boost::uuids::random_generator()())))
          .sign(jwt::algorithm::hs256{secret_key_});

  return token;
}

// ----------------------------------------
BasicAuthServerMiddlewareFactory::BasicAuthServerMiddlewareFactory(
    const std::string& username, const std::string& password,
    const std::string& secret_key, const std::string& token_allowed_issuer,
    const std::string& token_allowed_audience,
    const std::filesystem::path& token_signature_verify_cert_path)
    : username_(username),
      password_(password),
      secret_key_(secret_key),
      token_allowed_issuer_(token_allowed_issuer),
      token_allowed_audience_(token_allowed_audience),
      token_signature_verify_cert_path_(token_signature_verify_cert_path) {
  if (username_ == kTokenUsername) {
    throw std::runtime_error("You cannot use username: '" + kTokenUsername +
                             "' for basic authentication, because it is reserved for JWT "
                             "token-based authentication");
  }

  if (!token_allowed_issuer_.empty() && !token_allowed_audience_.empty() &&
      !token_signature_verify_cert_path_.empty()) {
    // Load the cert file into a private string member
    if (!token_signature_verify_cert_path_.empty()) {
      std::ifstream cert_file(token_signature_verify_cert_path_);
      if (!cert_file) {
        // Raise error here, can't return from constructor
        throw std::runtime_error("Could not open certificate file: " +
                                 token_signature_verify_cert_path_.string());
      } else {
        std::stringstream cert;
        cert << cert_file.rdbuf();
        token_signature_verify_cert_file_contents_ = cert.str();
      }
    }
    token_auth_enabled_ = true;
    GIZMOSQL_LOG(INFO) << "Token auth is enabled on the server - Allowed Issuer: '"
                       << token_allowed_issuer_ << "' - Allowed Audience: '"
                       << token_allowed_audience_ << "' - Signature Verify Cert Path: '"
                       << token_signature_verify_cert_path_.string() << "'";
  }
}

Status BasicAuthServerMiddlewareFactory::StartCall(
    const flight::CallInfo& info, const flight::ServerCallContext& context,
    std::shared_ptr<flight::ServerMiddleware>* middleware) {
  std::string auth_header_type;

  auto incoming_headers = context.incoming_headers();

  ARROW_RETURN_NOT_OK(
      SecurityUtilities::GetAuthHeaderType(incoming_headers, &auth_header_type));
  if (auth_header_type == "Basic") {
    std::string username;
    std::string password;

    SecurityUtilities::ParseBasicHeader(incoming_headers, username, password);

    // If the username has "};PWD={" in it, it is from the Flight SQL ODBC driver -
    // we need to split it into username and password.
    if (username.find("};PWD={") != std::string::npos) {
      std::string username_pwd = username;
      std::string delimiter = "};PWD={";
      size_t pos = 0;
      while ((pos = username_pwd.find(delimiter)) != std::string::npos) {
        username = username_pwd.substr(0, pos);
        username_pwd.erase(0, pos + delimiter.length());
      }
      password = username_pwd;
    }

    if (username.empty() or password.empty()) {
      return MakeFlightError(flight::FlightStatusCode::Unauthenticated,
                             "No Username and/or Password supplied");
    }

    if (username != kTokenUsername) {
      if ((username == username_) && (password == password_)) {
        *middleware = std::make_shared<BasicAuthServerMiddleware>(username, "admin",
                                                                  "Basic", secret_key_);
        GIZMOSQL_LOGKV(INFO,
                       "User: " + username + " (peer " + context.peer() +
                           ") - Successfully Basic authenticated via Username / Password",
                       {"user", username}, {"peer", context.peer()},
                       {"kind", "authentication"}, {"authentication_type", "basic"},
                       {"authentication_method", "username/password"},
                       {"result", "success"});
      } else {
        GIZMOSQL_LOGKV(WARNING,
                       "User: " + username + " (peer " + context.peer() +
                           ") - Failed Basic authentication via Username / Password - "
                           "reason: user provided invalid credentials",
                       {"user", username}, {"peer", context.peer()},
                       {"kind", "authentication"}, {"authentication_type", "basic"},
                       {"result", "failure"}, {"reason", "invalid_credentials"});
        return MakeFlightError(flight::FlightStatusCode::Unauthenticated,
                               "Invalid credentials");
      }
    }
    // If the username is "token" - it is assumed that the user is using token auth - use the password field as the bootstrap token
    else {
      if (!token_auth_enabled_) {
        return MakeFlightError(flight::FlightStatusCode::Unauthenticated,
                               "Token auth is not enabled on the server");
      }
      ARROW_ASSIGN_OR_RAISE(auto bootstrap_decoded_token,
                            VerifyAndDecodeBootstrapToken(password, context));
      *middleware = std::make_shared<BasicAuthServerMiddleware>(
          bootstrap_decoded_token.get_subject(),
          bootstrap_decoded_token.get_payload_claim("role").as_string(), "BootstrapToken",
          secret_key_);
    }
  }
  return Status::OK();
}

arrow::Result<jwt::decoded_jwt<jwt::traits::kazuho_picojson>>
BasicAuthServerMiddlewareFactory::VerifyAndDecodeBootstrapToken(
    const std::string& token, const flight::ServerCallContext& context) const {
  if (token.empty()) {
    return Status::Invalid("Bearer Token is empty");
  }

  try {
    auto decoded = jwt::decode(token);

    const auto iss = decoded.get_issuer();

    auto verifier = jwt::verify();
    if (iss == token_allowed_issuer_) {
      verifier = verifier
                     .allow_algorithm(jwt::algorithm::rs256(
                         token_signature_verify_cert_file_contents_, "", "", ""))
                     .with_issuer(std::string(token_allowed_issuer_))
                     .with_audience(token_allowed_audience_);
    } else {
      GIZMOSQL_LOGKV(
          WARNING,
          "peer=" + context.peer() +
              " - Bootstrap Bearer Token has an invalid 'iss' claim value of: " + iss +
              " - token_claims=(id=" + decoded.get_id() +
              " sub=" + decoded.get_subject() + " iss=" + decoded.get_issuer() + ")",
          {"peer", context.peer()}, {"kind", "authentication"},
          {"authentication_type", "bearer"}, {"result", "failure"},
          {"reason", "invalid_issuer"}, {"token_id", decoded.get_id()},
          {"token_sub", decoded.get_subject()}, {"token_iss", decoded.get_issuer()});
      return Status::Invalid("Invalid token issuer");
    }

    verifier.verify(decoded);

    if (!decoded.has_payload_claim("role")) {
      return Status::Invalid("Bootstrap Bearer Token MUST have a 'role' claim");
    }
    GIZMOSQL_LOGKV(INFO,
                   "peer=" + context.peer() +
                       " - Bootstrap Bearer Token was validated successfully" +
                       " - token_claims=(id=" + decoded.get_id() +
                       " sub=" + decoded.get_subject() + " iss=" + decoded.get_issuer() +
                       " role=" + decoded.get_payload_claim("role").as_string() + ")",
                   {"peer", context.peer()}, {"kind", "authentication"},
                   {"authentication_type", "bearer"}, {"result", "success"},
                   {"token_id", decoded.get_id()}, {"token_sub", decoded.get_subject()},
                   {"token_role", decoded.get_payload_claim("role").as_string()},
                   {"token_iss", decoded.get_issuer()});

    return decoded;
  } catch (const std::exception& e) {
    auto error_message = e.what();
    GIZMOSQL_LOGKV(WARNING,
                   "peer=" + context.peer() +
                       " - Bootstrap Bearer Token verification failed with exception: " +
                       error_message,
                   {"peer", context.peer()}, {"kind", "authentication"},
                   {"authentication_type", "bearer"}, {"result", "failure"},
                   {"reason", error_message});

    return Status::Invalid("Bootstrap Token verification failed with error: " +
                           std::string(error_message));
  }
}

// ----------------------------------------
BearerAuthServerMiddleware::BearerAuthServerMiddleware(
    const jwt::decoded_jwt<jwt::traits::kazuho_picojson> decoded_jwt)
    : decoded_jwt_(decoded_jwt) {}

const jwt::decoded_jwt<jwt::traits::kazuho_picojson> BearerAuthServerMiddleware::GetJWT()
    const {
  return decoded_jwt_;
}

const std::string BearerAuthServerMiddleware::GetUsername() const {
  return decoded_jwt_.get_subject();
}

const std::string BearerAuthServerMiddleware::GetRole() const {
  return decoded_jwt_.get_payload_claim("role").as_string();
}

void BearerAuthServerMiddleware::SendingHeaders(
    flight::AddCallHeaders* outgoing_headers) {
  outgoing_headers->AddHeader("x-username", GetUsername());
  outgoing_headers->AddHeader("x-role", GetRole());
}

void BearerAuthServerMiddleware::CallCompleted(const Status& status) {
  // Clear on completion to avoid leakage across threads
  tl_request_ctx = {};
}

std::string BearerAuthServerMiddleware::name() const {
  return "BearerAuthServerMiddleware";
}

// ----------------------------------------
BearerAuthServerMiddlewareFactory::BearerAuthServerMiddlewareFactory(
    const std::string& secret_key)
    : secret_key_(secret_key) {}

arrow::Result<jwt::decoded_jwt<jwt::traits::kazuho_picojson>>
BearerAuthServerMiddlewareFactory::VerifyAndDecodeToken(
    const std::string& token, const flight::ServerCallContext& context) const {
  if (token.empty()) {
    return Status::Invalid("Bearer Token is empty");
  }

  try {
    auto decoded = jwt::decode(token);

    const auto iss = decoded.get_issuer();

    auto verifier = jwt::verify();
    if (iss == kServerJWTIssuer) {
      verifier = verifier.allow_algorithm(jwt::algorithm::hs256{secret_key_})
                     .with_issuer(std::string(kServerJWTIssuer));
    } else {
      GIZMOSQL_LOGKV(
          WARNING,
          "peer=" + context.peer() +
              " - Bearer Token has an invalid 'iss' claim value of: " + iss +
              " - token_claims=(id=" + decoded.get_id() +
              " sub=" + decoded.get_subject() + " iss=" + decoded.get_issuer() + ")",
          {"peer", context.peer()}, {"kind", "authentication"},
          {"authentication_type", "bearer"}, {"result", "failure"},
          {"reason", "invalid_issuer"}, {"token_id", decoded.get_id()},
          {"token_sub", decoded.get_subject()}, {"token_iss", decoded.get_issuer()});
      return Status::Invalid("Invalid token issuer");
    }

    verifier.verify(decoded);

    // If we got this far, the token verified successfully
    // Only log success at INFO level once per token ID
    arrow::util::ArrowLogLevel token_log_level = arrow::util::ArrowLogLevel::ARROW_DEBUG;
    {
      std::lock_guard<std::mutex> lk(token_log_mutex_);
      const std::string& token_id = decoded.get_id();
      if (!token_id.empty() &&
          logged_token_ids_.find(token_id) == logged_token_ids_.end()) {
        logged_token_ids_.insert(token_id);
        token_log_level = arrow::util::ArrowLogLevel::ARROW_INFO;

        // Optional: simple bound to avoid unbounded growth
        if (logged_token_ids_.size() > 50000) {
          logged_token_ids_.clear();
          // Re-insert current id so it remains considered logged
          logged_token_ids_.insert(token_id);
        }
      }
    }

    GIZMOSQL_LOGKV_DYNAMIC(
        token_log_level,
        "peer=" + context.peer() + " - Bearer Token was validated successfully" +
            " - token_claims=(id=" + decoded.get_id() + " sub=" + decoded.get_subject() +
            " iss=" + decoded.get_issuer() + ")",
        {"peer", context.peer()}, {"kind", "authentication"},
        {"authentication_type", "bearer"}, {"result", "success"},
        {"token_id", decoded.get_id()}, {"token_sub", decoded.get_subject()},
        {"token_iss", decoded.get_issuer()});

    return decoded;
  } catch (const std::exception& e) {
    auto error_message = e.what();
    GIZMOSQL_LOGKV(
        WARNING,
        "peer=" + context.peer() +
            " - Bearer Token verification failed with exception: " + error_message,
        {"peer", context.peer()}, {"kind", "authentication"},
        {"authentication_type", "bearer"}, {"result", "failure"},
        {"reason", error_message});

    return Status::Invalid("Token verification failed with error: " +
                           std::string(error_message));
  }
}

Status BearerAuthServerMiddlewareFactory::StartCall(
    const flight::CallInfo& info, const flight::ServerCallContext& context,
    std::shared_ptr<flight::ServerMiddleware>* middleware) {
  auto incoming_headers = context.incoming_headers();
  if (const std::pair<flight::CallHeaders::const_iterator,
                      flight::CallHeaders::const_iterator>& iter_pair =
          incoming_headers.equal_range(kAuthHeader);
      iter_pair.first != iter_pair.second) {
    std::string auth_header_type;
    ARROW_RETURN_NOT_OK(
        SecurityUtilities::GetAuthHeaderType(incoming_headers, &auth_header_type));
    if (auth_header_type == "Bearer") {
      std::string bearer_token = SecurityUtilities::FindKeyValPrefixInCallHeaders(
          incoming_headers, kAuthHeader, kBearerPrefix);
      ARROW_ASSIGN_OR_RAISE(auto decoded_jwt,
                            VerifyAndDecodeToken(bearer_token, context));

      *middleware = std::make_shared<BearerAuthServerMiddleware>(decoded_jwt);

      // Update our thread local context
      tl_request_ctx.username = decoded_jwt.get_subject();
      tl_request_ctx.role = decoded_jwt.get_payload_claim("role").as_string();
      tl_request_ctx.peer = context.peer();
      tl_request_ctx.session_id = decoded_jwt.get_payload_claim("session_id").as_string();
    }
  }
  return Status::OK();
}
}  // namespace gizmosql