/*
 * Copyright (C) 2004 Red Hat Inc.
 * Copyright (C) 2005 Martin Koegler
 * Copyright (C) 2010 TigerVNC Team
 * Copyright (C) 2010-2021 m-privacy GmbH
 *
 * This is free software; you can redistribute it and/or modify
 * it under the terms of the GNU General Public License as published by
 * the Free Software Foundation; either version 2 of the License, or
 * (at your option) any later version.
 * 
 * This software is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU General Public License for more details.
 * 
 * You should have received a copy of the GNU General Public License
 * along with this software; if not, write to the Free Software
 * Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA  02111-1307,
 * USA.
 */

#ifdef HAVE_CONFIG_H
#include <config.h>
#endif

#ifndef HAVE_GNUTLS
#error "This header should not be compiled without HAVE_GNUTLS defined"
#endif

#ifdef WIN32
#include <winsock2.h>
#endif
#include <stdlib.h>
#ifndef WIN32
#include <unistd.h>
#endif
#include <gcrypt.h>
#include <gnutls/x509.h>

#include <rfb/CSecurityTLS.h>
#include <rfb/CConnection.h>
#include <rfb/LogWriter.h>
#include <rfb/Exception.h>
#include <rfb/UserMsgBox.h>
#include <rdr/TLSInStream.h>
#include <rdr/TLSOutStream.h>
#include <os/os.h>
#include <vncviewer/i18n.h>

#ifdef WIN32
static struct gcry_thread_cbs gcry_threads_other = {
  (GCRY_THREAD_OPTION_DEFAULT | (GCRY_THREAD_OPTION_VERSION << 8))
};
#else
GCRY_THREAD_OPTION_PTHREAD_IMPL;
#endif

/*
 * GNUTLS 2.6.5 and older didn't have some variables defined so don't use them.
 * GNUTLS 1.X.X defined LIBGNUTLS_VERSION_NUMBER so treat it as "old" gnutls as
 * well
 */
#if (defined(GNUTLS_VERSION_NUMBER) && GNUTLS_VERSION_NUMBER < 0x020606) || \
    defined(LIBGNUTLS_VERSION_NUMBER)
#define WITHOUT_X509_TIMES
#endif

/* Ancient GNUTLS... */
#if !defined(GNUTLS_VERSION_NUMBER) && !defined(LIBGNUTLS_VERSION_NUMBER)
#define WITHOUT_X509_TIMES
#endif

using namespace rfb;

StringParameter CSecurityTLS::X509CA("X509CA", "X509 CA certificate", "", ConfViewer);
StringParameter CSecurityTLS::X509CRL("X509CRL", "X509 CRL file", "", ConfViewer);
StringParameter CSecurityTLS::x509cert("x509cert", "X509 user certificate file", "", ConfViewer);
StringParameter CSecurityTLS::x509key("x509key", "X509 user key file", "", ConfViewer);
BoolParameter CSecurityTLS::tlsClear("tlsclear",
             "Use unencrypted TLS (obsolete and ignored)", false);
BoolParameter CSecurityTLS::tlsNormal("tlsnormal",
              "Use NORMAL GnuTLS settings instead of PERFORMANCE (obsolete and ignored, NORMAL is default now)", false);
IntParameter TLSInBufferSize("TLSInBufferSize", "TLS input buffer size", 1024 * 1024, 128 * 1024, 16 * 1024 * 1024);

char * CSecurityTLS::ciphersuite;

static LogWriter vlog("TLS");
static LogWriter vlog_raw("RawTLS");

static void debug_log(int level, const char* str)
{
  vlog_raw.debug("[%d]: %s", level, str);
}

CSecurityTLS::CSecurityTLS(CConnection* cc, bool _anon, bool _requireclientcert) : CSecurity(cc), session(0), anon_cred(0),
                  anon(_anon), requireclientcert(_requireclientcert), tlsis(0), tlsos(0)
{
#ifdef WIN32
  gcry_control(GCRYCTL_SET_THREAD_CBS, &gcry_threads_other);
#else
  gcry_control(GCRYCTL_SET_THREAD_CBS, &gcry_threads_pthread);
#endif
  if (gnutls_global_init() != GNUTLS_E_SUCCESS)
    throw AuthFailureException("gnutls_global_init failed");
  /* 100 means debug log */
  if (vlog_raw.getLevel() >= 110) {
    gnutls_global_set_log_level(10);
    gnutls_global_set_log_function(debug_log);
  }
        cafile = X509CA.getData();
        crlfile = X509CRL.getData();
  certfile = x509cert.getData();
  keyfile = x509key.getData();
  ciphersuite = NULL;
  if (requireclientcert && (!(*certfile) || !(*keyfile))) {
    vlog.error("Client TLS certificate and key required, but one or both are missing");
    throw AuthFailureException(_("Client TLS certificate and key required, but one or both are missing"));
  }
}

void CSecurityTLS::setDefaults()
{
  vlog.info("CSecurityTLS::setDefaults()");
  char* homeDir = NULL;

  if (getvnchomedir(&homeDir) == -1) {
    vlog.error("Could not obtain VNC home directory path");
    return;
  }

  vlog.info("vnc homeDir: %s", homeDir);

  int len = strlen(homeDir) + 1;
  CharArray caDefault(len + 11);
  CharArray crlDefault(len + 12);
  CharArray certDefault(len + 17);
  CharArray keyDefault(len + 12);

  sprintf(caDefault.buf, "%sx509_ca.pem", homeDir);
  sprintf(crlDefault.buf, "%sx509_crl.pem", homeDir);
  sprintf(certDefault.buf, "%sx509_usercert.pem", homeDir);
  sprintf(keyDefault.buf, "%sx509_key.pem", homeDir);
  delete [] homeDir;

  if (!fileexists(caDefault.buf))
   X509CA.setDefaultStr(strDup(caDefault.buf));
  if (!fileexists(crlDefault.buf))
   X509CRL.setDefaultStr(strDup(crlDefault.buf));
  if (!fileexists(certDefault.buf))
    x509cert.setDefaultStr(strDup(certDefault.buf));
  if (!fileexists(keyDefault.buf))
    x509key.setDefaultStr(strDup(keyDefault.buf));
}

void CSecurityTLS::shutdown(bool needbye)
{
  if (session && needbye)
    if (gnutls_bye(session, GNUTLS_SHUT_RDWR) != GNUTLS_E_SUCCESS)
      vlog.error("gnutls_bye failed");

  if (anon_cred) {
    gnutls_anon_free_client_credentials(anon_cred);
    anon_cred = 0;
  }

  if (cert_cred) {
    gnutls_certificate_free_credentials(cert_cred);
    cert_cred = 0;
  }

  if (rawis && rawos) {
    cc->setStreams(rawis, rawos);
    rawis = NULL;
    rawos = NULL;
  }

  if (tlsis) {
    delete tlsis;
    tlsis = NULL;
  }
  if (tlsos) {
    delete tlsos;
    tlsos = NULL;
  }

  if (session) {
    gnutls_deinit(session);
    session = 0;
  }
}


CSecurityTLS::~CSecurityTLS()
{
  shutdown(true);

  delete[] ciphersuite;
  delete[] cafile;
  delete[] crlfile;
  delete[] certfile;
  delete[] keyfile;

  gnutls_global_deinit();
}

bool CSecurityTLS::processMsg()
{
  rdr::InStream* is = cc->getInStream();
  rdr::OutStream* os = cc->getOutStream();
  client = cc;
  int err;

  if (!session) {
    if (!is->hasData(1))
      return false;

    if (is->readU8() == 0)
      throw AuthFailureException(_("Server failed to initialize TLS session"));

    if (gnutls_init(&session, GNUTLS_CLIENT) != GNUTLS_E_SUCCESS)
      throw AuthFailureException("gnutls_init failed");

    if ((err = gnutls_priority_set_direct(session, "NORMAL:+ANON-ECDH:+ANON-DH", NULL)) != GNUTLS_E_SUCCESS) {
      vlog.error("gnutls_priority_set_direct(NORMAL:+ANON-ECDH:+ANON-DH) failed: %s\n", gnutls_strerror (err));
      throw AuthFailureException("gnutls_priority_set_direct(NORMAL:+ANON-ECDH:+ANON-DH) failed");
    }

    setParam();

    // Create these early as they set up the push/pull functions
    // for GnuTLS
    tlsis = new rdr::TLSInStream(is, session);
    tlsos = new rdr::TLSOutStream(os, session);

    rawis = is;
    rawos = os;
  }
  if ((err = gnutls_handshake(session)) != GNUTLS_E_SUCCESS) {
    if (!gnutls_error_is_fatal(err)) {
      vlog.debug("Deferring completion of TLS handshake: %s", gnutls_strerror(err));
      return false;
    }

    vlog.error("TLS Handshake failed: %s\n", gnutls_strerror (err));
    shutdown(false);
    throw AuthFailureException(_("TLS handshake failed"));
  }

  vlog.debug("TLS handshake completed with %s",
             gnutls_session_get_desc(session));

  checkSession();
  getCipher();

  cc->setStreams(tlsis, tlsos);
  tlsis->setTLSDone(true);

  vlog.debug("TLS session has been established");
  return true;
}

void CSecurityTLS::getCipher()
{
  gnutls_kx_algorithm_t kx;
  gnutls_cipher_algorithm_t cipher;
  gnutls_mac_algorithm_t mac;

  kx = gnutls_kx_get(session);
  cipher = gnutls_cipher_get(session);
  mac = gnutls_mac_get(session);
  ciphersuite = new char[256];

  snprintf(ciphersuite, 255, "%s %s(%zuB) %s(%zuB)",
     gnutls_kx_get_name(kx),
     gnutls_cipher_get_name(cipher),
     8 * gnutls_cipher_get_key_size(cipher),
     gnutls_mac_get_name(mac),
     8 * gnutls_mac_get_key_size(mac));
  vlog.info("GnuTLS cipher suite: %s", ciphersuite);
}

void CSecurityTLS::setParam()
{
  static const char kx_anon_priority[] = ":+ANON-ECDH:+ANON-DH";

  int ret;
  char *prio;
  const char *err;

  prio = (char*)malloc(strlen(Security::GnuTLSPriority) +
                       strlen(kx_anon_priority) + 1);
  if (prio == NULL)
    throw AuthFailureException(_("Not enough memory for GnuTLS priority string"));

  strcpy(prio, Security::GnuTLSPriority);
  if (anon)
    strcat(prio, kx_anon_priority);

  ret = gnutls_priority_set_direct(session, prio, &err);

  free(prio);

  if (ret != GNUTLS_E_SUCCESS) {
    if (ret == GNUTLS_E_INVALID_REQUEST)
      vlog.error("GnuTLS priority syntax error at: %s", err);
    throw AuthFailureException(_("gnutls_set_priority_direct failed"));
  }

  if (anon) {
    if (gnutls_anon_allocate_client_credentials(&anon_cred) != GNUTLS_E_SUCCESS)
      throw AuthFailureException(_("gnutls_anon_allocate_client_credentials failed"));

    if (gnutls_credentials_set(session, GNUTLS_CRD_ANON, anon_cred) != GNUTLS_E_SUCCESS)
      throw AuthFailureException(_("gnutls_credentials_set failed"));

    vlog.debug("Anonymous session has been set");
  } else {
    if (gnutls_certificate_allocate_credentials(&cert_cred) != GNUTLS_E_SUCCESS)
      throw AuthFailureException(_("gnutls_certificate_allocate_credentials failed"));

    if (gnutls_certificate_set_x509_system_trust(cert_cred) != GNUTLS_E_SUCCESS)
      vlog.error("Could not load system certificate trust store");

    if (*cafile && gnutls_certificate_set_x509_trust_file(cert_cred,cafile,GNUTLS_X509_FMT_PEM) < 0)
      throw AuthFailureException(_("Failed to load the CA certificate"));

    /* Load previously saved certs */
    char *homeDir = NULL;
    int err;
    if (getvnchomedir(&homeDir) == -1)
      vlog.error("Could not obtain VNC home directory path");
    else {
      CharArray caSave(strlen(homeDir) + 19 + 1);
      sprintf(caSave.buf, "%sx509_savedcerts.pem", homeDir);
      delete [] homeDir;

      err = gnutls_certificate_set_x509_trust_file(cert_cred, caSave.buf,
                                                   GNUTLS_X509_FMT_PEM);
      if (err < 0)
        vlog.debug("Failed to load saved server certificates from %s", caSave.buf);
    }

    if (*crlfile && gnutls_certificate_set_x509_crl_file(cert_cred,crlfile,GNUTLS_X509_FMT_PEM) < 0)
      throw AuthFailureException("load of CRL failed");

    if (requireclientcert && *certfile && *keyfile) {
      int err;
      if ((err = gnutls_certificate_set_x509_key_file(cert_cred, certfile, keyfile, GNUTLS_X509_FMT_PEM)) != GNUTLS_E_SUCCESS) {
        vlog.error("gnutls_certificate_set_x509_key_file() failed with error: %s", gnutls_strerror(err));
        throw AuthFailureException(_("The TLS certificate file or the TLS key file could not be loaded."));
      } else
        vlog.debug("Client certificate %s and key %s loaded", certfile, keyfile);
    } else
      vlog.debug("No client certificate needed or provided");

    if (gnutls_credentials_set(session, GNUTLS_CRD_CERTIFICATE, cert_cred) != GNUTLS_E_SUCCESS)
      throw AuthFailureException(_("gnutls_credentials_set failed"));

    if (gnutls_server_name_set(session, GNUTLS_NAME_DNS,
                               client->getServerName(),
                               strlen(client->getServerName())) != GNUTLS_E_SUCCESS)
      vlog.error("Failed to configure the server name for TLS handshake");

    vlog.debug("X509 session has been set");
  }
}

void CSecurityTLS::checkSession()
{
  const unsigned allowed_errors = GNUTLS_CERT_INVALID |
				  GNUTLS_CERT_SIGNER_NOT_FOUND |
				  GNUTLS_CERT_SIGNER_NOT_CA;
  unsigned int status;
  const gnutls_datum *cert_list;
  unsigned int cert_list_size = 0;
  int err;
  gnutls_datum info;

  if (anon)
    return;

  if (gnutls_certificate_type_get(session) != GNUTLS_CRT_X509)
	  throw AuthFailureException(_("This certificate type is not supported."));

  err = gnutls_certificate_verify_peers2(session, &status);
  if (err != 0) {
    vlog.error("server certificate verification failed: %s", gnutls_strerror(err));
    throw AuthFailureException(_("The verification of the server certificate has failed."));
  }

  if (status & GNUTLS_CERT_REVOKED)
    throw AuthFailureException("The server certificate has been revoked.");

#ifndef WITHOUT_X509_TIMES
  if (status & GNUTLS_CERT_NOT_ACTIVATED)
    throw AuthFailureException("Das Server-Zertifikat wurde nicht aktiviert");

  if (status & GNUTLS_CERT_EXPIRED) {
    vlog.debug("server certificate has expired");
    if (!msg->showMsgBox(UserMsgBox::M_YESNO, _("Certificate expired"),
             _("The server certificate has expired. Do you wish to continue?")
      throw AuthFailureException(_("The server certificate is expired");
  }
#endif
  /* Process other errors later */

  cert_list = gnutls_certificate_get_peers(session, &cert_list_size);
  if (!cert_list_size)
    throw AuthFailureException("empty certificate chain");

  /* Process only server's certificate, not issuer's certificate */
  gnutls_x509_crt crt;
  gnutls_x509_crt_init(&crt);

  if (gnutls_x509_crt_import(crt, &cert_list[0], GNUTLS_X509_FMT_DER) < 0)
    throw AuthFailureException(_("Decoding of certificate failed"));

  if (gnutls_x509_crt_check_hostname(crt, client->getServerName()) == 0) {
    char buf[255];
    vlog.debug("hostname mismatch");
    snprintf(buf, sizeof(buf), _("The hostname (%s) does not match the server certificate. Do you still want to connect?"), client->getServerName());
    buf[sizeof(buf) - 1] = '\0';
    if (!msg->showMsgBox(UserMsgBox::M_YESNO, _("Incorrect server certificate"), buf))
      throw AuthFailureException("hostname cert mismatch");
  }

  if (status == 0) {
    /* Everything is fine (hostname + verification) */
    gnutls_x509_crt_deinit(crt);
    return;
  }

  if (status & GNUTLS_CERT_INVALID)
    vlog.debug("server certificate invalid");
  if (status & GNUTLS_CERT_SIGNER_NOT_FOUND)
    vlog.debug("server cert signer not found");
  if (status & GNUTLS_CERT_SIGNER_NOT_CA)
    vlog.debug("server cert signer not CA");

  if (status & GNUTLS_CERT_INSECURE_ALGORITHM)
    throw AuthFailureException("The server certificate uses an insecure algorithm");

  if ((status & (~allowed_errors)) != 0) {
    /* No other errors are allowed */
    vlog.debug("GNUTLS status of certificate verification: %u", status);
    throw AuthFailureException(_("Invalid status of server certificate verification"));
  }

  vlog.debug("Saved server certificates don't match");

  if (gnutls_x509_crt_print(crt, GNUTLS_CRT_PRINT_ONELINE, &info)) {
    /*
     * GNUTLS doesn't correctly export gnutls_free symbol which is
     * a function pointer. Linking with Visual Studio 2008 Express will
     * fail when you call gnutls_free().
     */
#if WIN32
    free(info.data);
#else
    gnutls_free(info.data);
#endif
    throw AuthFailureException(_("Failed to locate the TLS certificate for display."));
  }

  size_t out_size = 64 * 1024;
  char *out_buf = NULL;
  char *certinfo = NULL;
  int len = 0;

  vlog.debug("certificate issuer unknown");

  len = snprintf(NULL, 0, _("The TLS certificate was signed by an unknown authority:\n\n%s\n\nSave it and continue?"), info.data);
  if (len++ < 0)
    throw AuthFailureException("certificate decoding error");

  vlog.debug("%s", info.data);

  certinfo = new char[len];

  snprintf(certinfo, len, _("The TLS certificate was signed by an unknown authority:\n\n%s\n\nSave it and continue?"), info.data);

  for (int i = 0; i < len - 1; i++)
    if (certinfo[i] == ',' && certinfo[i + 1] == ' ')
      certinfo[i] = '\n';

  if (!msg->showMsgBox(UserMsgBox::M_YESNO, _("The issuer of the TLS certificate is unknown"), certinfo)) {
    delete [] certinfo;
    throw AuthFailureException(_("Terminating program because the issuer of the TLS certificate is unknown."));
  }

  delete [] certinfo;

#if 0
  if (gnutls_x509_crt_export(crt, GNUTLS_X509_FMT_PEM, NULL, &out_size)
      != GNUTLS_E_SHORT_MEMORY_BUFFER)
    throw AuthFailureException(_("The issuer of the TLS certificate is unknown, and the export has failed."));
#endif

  // Save cert
  out_buf =  new char[out_size];

  if (gnutls_x509_crt_export(crt, GNUTLS_X509_FMT_PEM, out_buf, &out_size) < 0)
    throw AuthFailureException(_("The issuer of the TLS certificate is unknown, and the export has failed."));

  char *homeDir = NULL;
  if (getvnchomedir(&homeDir) == -1)
    vlog.error("Could not obtain VNC home directory path");
  else {
    FILE *f;
    CharArray caSave(strlen(homeDir) + 1 + 19);
    sprintf(caSave.buf, "%sx509_savedcerts.pem", homeDir);
//    delete [] homeDir;
    f = fopen(caSave.buf, "a+");
    if (!f)
    msg->showMsgBox(UserMsgBox::M_OK, _("Save Certificate Error"),
          _("Failed to save the TLS certificate"));
    else {
      fprintf(f, "%s\n", out_buf);
      fclose(f);
    }
  }

  delete [] out_buf;

  gnutls_x509_crt_deinit(crt);
  /*
   * GNUTLS doesn't correctly export gnutls_free symbol which is
   * a function pointer. Linking with Visual Studio 2008 Express will
   * fail when you call gnutls_free().
   */
#if WIN32
  free(info.data);
#else
  gnutls_free(info.data);
#endif
}
