/* 
 * Copyright (C) 2004 Red Hat Inc.
 * Copyright (C) 2005 Martin Koegler
 * Copyright (C) 2010 TigerVNC Team
 * Copyright (C) 2014-2021 m-privacy GmbH, Berlin
 *    
 * This is free software; you can redistribute it and/or modify
 * it under the terms of the GNU General Public License as published by
 * the Free Software Foundation; either version 2 of the License, or
 * (at your option) any later version.
 * 
 * This software is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU General Public License for more details.
 * 
 * You should have received a copy of the GNU General Public License
 * along with this software; if not, write to the Free Software
 * Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA  02111-1307,
 * USA.
 */

#ifdef HAVE_CONFIG_H
#include <config.h>
#endif

#ifndef HAVE_GNUTLS
#error "This source should not be compiled without HAVE_GNUTLS defined"
#endif

#ifdef WIN32
#include <winsock2.h>
#endif
#include <gcrypt.h>
#include <gnutls/x509.h>
GCRY_THREAD_OPTION_PTHREAD_IMPL;

#if !defined(WIN32) && !defined(WIN64)
#include <syslog.h>
#if !defined(__APPLE__)
#include <errno.h>
#include <sys/capability.h>
#endif
#endif

#include <stdlib.h>

#include <rfb/SSecurityTLS.h>
#include <rfb/SConnection.h>
#include <rfb/LogWriter.h>
#include <rfb/Exception.h>
#include <rdr/TLSInStream.h>
#include <rdr/TLSOutStream.h>
#include <gnutls/x509.h>

#define DH_BITS 1024 /* XXX This should be configurable! */

using namespace rfb;

StringParameter SSecurityTLS::X509_CertFile
("X509Cert", "Path to the X509 certificate in PEM format", "", ConfServer);

StringParameter SSecurityTLS::X509_KeyFile
("X509Key", "Path to the key of the X509 certificate in PEM format", "", ConfServer);

StringParameter SSecurityTLS::X509_CAFile
("x509ca", "specifies path to the CA of the x509 user certificates in PEM format", "", ConfServer);

StringParameter SSecurityTLS::X509_CRLFile
("x509crl", "specifies path to the CRL of the x509 user certificates in PEM format", "", ConfServer);

char * SSecurityTLS::ciphersuite;

static LogWriter vlog("TLS");

SSecurityTLS::SSecurityTLS(SConnection* sc, bool _anon, bool _requireclientcert)
  : SSecurity(sc), session(NULL), dh_params(NULL), anon_cred(NULL),
    cert_cred(NULL), requireclientcert(_requireclientcert), anon(_anon), tlsis(NULL), tlsos(NULL),
    rawis(NULL), rawos(NULL)
{
  gcry_control(GCRYCTL_SET_THREAD_CBS, &gcry_threads_pthread);

  certfile = X509_CertFile.getData();
  keyfile = X509_KeyFile.getData();
  cafile = X509_CAFile.getData();
  crlfile = X509_CRLFile.getData();

#if !defined(WIN32) && !defined(WIN64)
  openlog("Xtightgatevnc", LOG_PID, LOG_AUTH);
#endif

  ciphersuite = NULL;

  if (gnutls_global_init() != GNUTLS_E_SUCCESS)
    throw AuthFailureException("gnutls_global_init failed");
}

void SSecurityTLS::shutdown()
{
  if (session) {
    if (gnutls_bye(session, GNUTLS_SHUT_RDWR) != GNUTLS_E_SUCCESS) {
      /* FIXME: Treat as non-fatal error */
      vlog.error("shutdown(): TLS session has not been terminated gracefully");
    }
  }

  if (dh_params) {
    gnutls_dh_params_deinit(dh_params);
    dh_params = 0;
  }

  if (anon_cred) {
    gnutls_anon_free_server_credentials(anon_cred);
    anon_cred = 0;
  }

  if (cert_cred) {
    gnutls_certificate_free_credentials(cert_cred);
    cert_cred = 0;
  }

  if (sc && rawis && rawos) {
    sc->setStreams(rawis, rawos);
    rawis = NULL;
    rawos = NULL;
  }

  if (tlsis) {
//    delete tlsis;
    tlsis = NULL;
  }
  if (tlsos) {
//    delete tlsos;
    tlsos = NULL;
  }

  if (session) {
    gnutls_deinit(session);
    session = 0;
  }
}


SSecurityTLS::~SSecurityTLS()
{
  shutdown();

  delete[] keyfile;
  delete[] certfile;
  delete[] cafile;
  delete[] crlfile;
  if (ciphersuite)
    delete[] ciphersuite;

  gnutls_global_deinit();
}

bool SSecurityTLS::processMsg()
{
  const gnutls_datum_t *cert_list;
  unsigned int cert_list_size = 0;
  unsigned int status;
  vlog.debug("Process security message (session %p)", session);

  if (!session) {
    rdr::InStream* is = sc->getInStream();
    rdr::OutStream* os = sc->getOutStream();

    if (gnutls_init(&session, GNUTLS_SERVER) != GNUTLS_E_SUCCESS)
      throw AuthFailureException("gnutls_init failed");

#if !defined(WIN32) && !defined(WIN64) && !defined(__APPLE__)
    if (setuid(getuid())) {
      cap_t proccap;
      proccap = cap_get_proc();
      if (!proccap) {
        openlog("Xtightgatevnc", LOG_PID, LOG_AUTH);
        syslog(LOG_DEBUG, "SSecurityTLS::processMsg(): failed cap_get_proc() as user %u, error: %s", getuid(), strerror(errno));
      } else {
        char * capstext;
        capstext = cap_to_text(proccap, NULL);
        if (capstext) {
          openlog("Xtightgatevnc", LOG_PID, LOG_AUTH);
          syslog(LOG_DEBUG, "SSecurityTLS::processMsg(): failed setuid(%u) as user %u, caps %s, error: %s", getuid(), getuid(), capstext, strerror(errno));
          cap_free(capstext);
        }
        cap_free(proccap);
      }
      throw AuthFailureException("could not setuid");
    }
#endif

    try {
      setParams(session);
    }
    catch(...) {
      os->writeU8(0);
      os->flush();
      throw;
    }

    os->writeU8(1);
    os->flush();

    // Create these early as they set up the push/pull functions
    // for GnuTLS
    tlsis = new rdr::TLSInStream(is, session);
    tlsos = new rdr::TLSOutStream(os, session);

    rawis = is;
    rawos = os;
  }

  int err;
  err = gnutls_handshake(session);
  if (err != GNUTLS_E_SUCCESS) {
    if (!gnutls_error_is_fatal(err)) {
      vlog.debug("Deferring completion of TLS handshake: %s", gnutls_strerror(err));
      return false;
    }
    vlog.error("TLS Handshake failed: %s", gnutls_strerror (err));
#if !defined(WIN32) && !defined(WIN64)
    syslog(LOG_ERR, "TLS Handshake failed: %s", gnutls_strerror (err));
#endif
    shutdown();
    throw AuthFailureException("TLS Handshake failed");
  }

  /* if client provides a cert and we have a cafile, check it */
  if (*cafile && requireclientcert && gnutls_certificate_type_get(session) == GNUTLS_CRT_X509) {
    if ((err = gnutls_certificate_set_x509_trust_file(cert_cred,cafile,GNUTLS_X509_FMT_PEM)) < 0) {
      vlog.error("load of CA cert failed, error: %s", gnutls_strerror (err));
      throw AuthFailureException("load of CA cert failed");
    }
    if ((err = gnutls_certificate_verify_peers2(session, &status)) >= 0) {
      if (status & GNUTLS_CERT_REVOKED) {
        vlog.error("TLS certificate has been revoked, error: %s", gnutls_strerror (err));
#if !defined(WIN32) && !defined(WIN64)
        syslog(LOG_ERR, "TLS certificate has been revoked, error: %s", gnutls_strerror (err));
#endif
        throw AuthFailureException("certificate has been revoked");
      }
      if (status & GNUTLS_CERT_NOT_ACTIVATED) {
        vlog.error("TLS certificate has not been activated, error: %s", gnutls_strerror (err));
#if !defined(WIN32) && !defined(WIN64)
        syslog(LOG_ERR, "TLS certificate has not been activated, error: %s", gnutls_strerror (err));
#endif
        throw AuthFailureException("certificate has not been activated");
      }
      if (status & GNUTLS_CERT_EXPIRED) {
        vlog.error("TLS certificate has expired, error: %s", gnutls_strerror (err));
#if !defined(WIN32) && !defined(WIN64)
        syslog(LOG_ERR, "TLS certificate has expired, error: %s", gnutls_strerror (err));
#endif
        throw AuthFailureException("certificate has expired");
      }
      if (status & GNUTLS_CERT_SIGNER_NOT_FOUND) {
        vlog.error("TLS certificate signer not found, error: %s", gnutls_strerror (err));
#if !defined(WIN32) && !defined(WIN64)
        syslog(LOG_ERR, "TLS certificate signer not found, error: %s", gnutls_strerror (err));
#endif
        throw AuthFailureException("certificate signer not found");
      }
      if (status & GNUTLS_CERT_INVALID) {
        vlog.error("TLS certificate is invalid, error: %s", gnutls_strerror (err));
#if !defined(WIN32) && !defined(WIN64)
        syslog(LOG_ERR, "TLS certificate is invalid, error: %s", gnutls_strerror (err));
#endif
        throw AuthFailureException("certificate invalid");
      }
      cert_list = gnutls_certificate_get_peers(session, &cert_list_size);
      if (cert_list_size) {
        char cn[256];
        size_t cnsize = 256;
        gnutls_x509_crt_t crt;

        gnutls_x509_crt_init(&crt);
        /* cert_list[0] contains the cert with CN=user: save in sc->CertUserName */
        if (gnutls_x509_crt_import(crt, &cert_list[0],GNUTLS_X509_FMT_DER) < 0)
          throw AuthFailureException("decoding of first certificate failed");
        if (gnutls_x509_crt_get_dn_by_oid(crt,
                  GNUTLS_OID_X520_COMMON_NAME, 0, 0, cn, &cnsize))
          throw AuthFailureException("CN could not be decoded from user certificate");
        gnutls_x509_crt_deinit(crt);
        sc->CertUserName = strdup(cn);
        vlog.debug("received user name in client CN: %s", sc->CertUserName);
      } else {
        vlog.error("got empty certificate list from client!");
#if !defined(WIN32) && !defined(WIN64)
        syslog(LOG_ERR, "got empty certificate list from client!");
#endif
        throw AuthFailureException("got empty certificate list from client!");
      }
    } else {
      vlog.debug("gnutls_certificate_verify_peers2() returned error: %s", gnutls_strerror(err));
    }
  } else {
    vlog.debug("no client certificate or no CA given");
    if (requireclientcert) {
#if !defined(WIN32) && !defined(WIN64)
      syslog(LOG_ERR, "user certificate and local ca required!");
#endif
      throw AuthFailureException("user certificate and local ca required");
    }
  }

  getCipher();

  vlog.debug("TLS handshake completed with %s",
             gnutls_session_get_desc(session));

  sc->setStreams(tlsis, tlsos);

  return true;
}

void SSecurityTLS::getCipher()
{
  gnutls_kx_algorithm_t kx;
  gnutls_cipher_algorithm_t cipher;
  gnutls_mac_algorithm_t mac;

  kx = gnutls_kx_get(session);
  cipher = gnutls_cipher_get(session);
  mac = gnutls_mac_get(session);
  ciphersuite = new char[256];

  snprintf(ciphersuite, 255, "%s %s(%zuB) %s(%zuB)",
     gnutls_kx_get_name(kx),
     gnutls_cipher_get_name(cipher),
     8 * gnutls_cipher_get_key_size(cipher),
     gnutls_mac_get_name(mac),
     8 * gnutls_mac_get_key_size(mac));
  vlog.info("GnuTLS cipher suite: %s", ciphersuite);
}


void SSecurityTLS::setParams(gnutls_session_t session)
{
  static const char kx_anon_priority[] = ":+ANON-ECDH:+ANON-DH";

  int ret;
  char *prio;
  const char *err;

  prio = (char*)malloc(strlen(Security::GnuTLSPriority) +
                       strlen(kx_anon_priority) + 1);
  if (prio == NULL)
    throw AuthFailureException("Not enough memory for GnuTLS priority string");

  strcpy(prio, Security::GnuTLSPriority);
  if (anon)
    strcat(prio, kx_anon_priority);

  ret = gnutls_priority_set_direct(session, prio, &err);

  free(prio);

  if (ret != GNUTLS_E_SUCCESS) {
    if (ret == GNUTLS_E_INVALID_REQUEST)
      vlog.error("GnuTLS priority syntax error at: %s", err);
    throw AuthFailureException("gnutls_set_priority_direct failed");
  }

  if (gnutls_dh_params_init(&dh_params) != GNUTLS_E_SUCCESS)
    throw AuthFailureException("gnutls_dh_params_init failed");

  if (gnutls_dh_params_generate2(dh_params, DH_BITS) != GNUTLS_E_SUCCESS)
    throw AuthFailureException("gnutls_dh_params_generate2 failed");

  if (anon) {
    if (gnutls_anon_allocate_server_credentials(&anon_cred) != GNUTLS_E_SUCCESS)
      throw AuthFailureException("gnutls_anon_allocate_server_credentials failed");

    gnutls_anon_set_server_dh_params(anon_cred, dh_params);

    if (gnutls_credentials_set(session, GNUTLS_CRD_ANON, anon_cred)
        != GNUTLS_E_SUCCESS)
      throw AuthFailureException("gnutls_credentials_set failed");

    vlog.debug("Anonymous session has been set");

  } else {
    if (gnutls_certificate_allocate_credentials(&cert_cred) != GNUTLS_E_SUCCESS)
      throw AuthFailureException("gnutls_certificate_allocate_credentials failed");

    gnutls_certificate_set_dh_params(cert_cred, dh_params);

    if (*cafile) {
      if ((ret = gnutls_certificate_set_x509_trust_file(cert_cred,cafile,GNUTLS_X509_FMT_PEM)) < 0) {
        vlog.error("load of CA cert failed, error: %s", gnutls_strerror (ret));
        throw AuthFailureException("load of CA cert failed");
      }
    } else
      vlog.debug("no CA file given");
    if (*crlfile) {
      if ((ret = gnutls_certificate_set_x509_crl_file(cert_cred,crlfile,GNUTLS_X509_FMT_PEM)) < 0) {
        vlog.error("load of CRL cert failed, error: %s", gnutls_strerror (ret));
        throw AuthFailureException("load of CRL failed");
      }
    } else
      vlog.debug("no CRL file given");

    switch (gnutls_certificate_set_x509_key_file(cert_cred, certfile, keyfile, GNUTLS_X509_FMT_PEM)) {
    case GNUTLS_E_SUCCESS:
      break;
    case GNUTLS_E_CERTIFICATE_KEY_MISMATCH:
      throw AuthFailureException("Private key does not match certificate");
    case GNUTLS_E_UNSUPPORTED_CERTIFICATE_TYPE:
      throw AuthFailureException("Unsupported certificate type");
    default:
      throw AuthFailureException("Error loading X509 certificate or key");
    }

    if (gnutls_credentials_set(session, GNUTLS_CRD_CERTIFICATE, cert_cred)
        != GNUTLS_E_SUCCESS)
      throw AuthFailureException("gnutls_credentials_set failed");

    if (*cafile) {
      if (requireclientcert) {
        vlog.debug("server_set_request: GNUTLS_CERT_REQUIRE");
        gnutls_certificate_server_set_request(session, GNUTLS_CERT_REQUIRE);
      }
    }

    vlog.debug("X509 session has been set");

  }

}
