/* Copyright (C) 2002-2005 RealVNC Ltd.  All Rights Reserved.
 * Copyright (C) 2005 Martin Koegler
 * Copyright (C) 2010 TigerVNC Team
 * Copyright (C) 2015-2024 m-privacy GmbH
 * 
 * This is free software; you can redistribute it and/or modify
 * it under the terms of the GNU General Public License as published by
 * the Free Software Foundation; either version 2 of the License, or
 * (at your option) any later version.
 * 
 * This software is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU General Public License for more details.
 * 
 * You should have received a copy of the GNU General Public License
 * along with this software; if not, write to the Free Software
 * Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA  02111-1307,
 * USA.
 */

#ifdef HAVE_CONFIG_H
#include <config.h>
#endif

#include <rdr/Exception.h>
#include <rdr/TLSException.h>
#include <rdr/TLSOutStream.h>
#include <rdr/TLSErrno.h>
#include <rdr/mutex.h>
#include <rfb/LogWriter.h>
#include <rfb/Configuration.h>
#include <errno.h>
#ifndef WIN32
#include <sys/types.h>
#include <unistd.h>
#include <sys/time.h>
#include <sys/resource.h>
#endif
#include <stdlib.h>

#ifdef HAVE_GNUTLS
using namespace rdr;

static rfb::LogWriter vlog("TLSOutStream");

rfb::IntParameter TLSKeepOutBuffersHigh("TLSKeepOutBuffersHigh", "outgoing buffers to keep for high priority, TLS layer (min 4, max 65535)", 32, 4, MAXBUFFERNUM);
rfb::IntParameter TLSKeepOutBuffersMedium("TLSKeepOutBuffersMedium", "outgoing buffers to keep for medium priority, TLS layer (min 50, max 65535)", 4096, 50, MAXBUFFERNUM);
rfb::IntParameter TLSKeepOutBuffersLow("TLSKeepOutBuffersLow", "outgoing buffers to keep for low priority, TLS layer (min 2, max 65535)", 128, 2, MAXBUFFERNUM);
rfb::IntParameter TLSThreshOutBuffersHigh("TLSThreshOutBuffersHigh", "when to force flush outgoing buffers for high priority, TLS layer (min 8, max 65535)", 64, 8, MAXBUFFERNUM);
rfb::IntParameter TLSThreshOutBuffersMedium("TLSThreshOutBuffersMedium", "when to force flush outgoing buffers for medium priority, TLS layer (min 50, max 65535)", 8192, 50, MAXBUFFERNUM);
rfb::IntParameter TLSThreshOutBuffersLow("TLSThreshOutBuffersLow", "when to force flush outgoing buffers for low priority, TLS layer (min 4, max 65535)", 1024, 4, MAXBUFFERNUM);
rfb::IntParameter TLSMaxOutBuffersHigh("TLSMaxOutBuffersHigh", "available outgoing buffers for high priority, TLS layer (min 4, max 65535)", 256, 4, MAXBUFFERNUM);
rfb::IntParameter TLSMaxOutBuffersMedium("TLSMaxOutBuffersMedium", "available outgoing buffers for medium priority, TLS layer (min 50, max 65535)", 16384, 50, MAXBUFFERNUM);
rfb::IntParameter TLSMaxOutBuffersLow("TLSMaxOutBuffersLow", "available outgoing buffers for low priority, TLS layer (min 2, max 65535)", 4096, 2, MAXBUFFERNUM);
rfb::BoolParameter TLSCork("TLSCork", "Use GnuTLS record:cork to send bigger packets", true);
rfb::BoolParameter TLSPushFlush("TLSPushFlush", "Flush outstream for every TLS push", false);
rfb::BoolParameter sendTLSHeader("SendTLSHeader", "Send every TLS buffer with checksum header", false);

/* space to reserve as TLS size overhead */
#define TLSEXTRA 128

static bool runThreads = true;
TGVNC_CONDITION_TYPE TLSOutStream::buffersAvailableCondition;
MUTEX_TYPE TLSOutStream::buffersAvailableConditionLock;
TGVNC_CONDITION_TYPE TLSOutStream::flushedCondition[QPRIONUM];
MUTEX_TYPE TLSOutStream::flushedConditionLock[QPRIONUM];

TLSOutStream::TLSOutStream(OutStream* _out, gnutls_session _session)
  : out(_out), session(_session), flushActive(false)
{
  vlog.debug("Initializing");
  MUTEX_INIT(&buffersAvailableConditionLock);
  TGVNC_CONDITION_INIT(&buffersAvailableCondition);
  for (int q = QPRIOMIN; q <= QPRIOMAX; q++) {
    MUTEX_INIT(&flushedConditionLock[q]);
    TGVNC_CONDITION_INIT(&flushedCondition[q]);
  }
  gnutls_transport_set_vec_push_function(session, vec_func);
  gnutls_transport_get_ptr2(session, &recv, &send);
  gnutls_transport_set_ptr2(session, recv, this);

  classLog = &vlog;
  memset(bufferStats, 0, sizeof(U32) * MAXBUFFERSIZE);
  maxBuffers[QPRIOHIGH] = TLSMaxOutBuffersHigh;
  maxBuffers[QPRIOMEDIUM] = TLSMaxOutBuffersMedium;
  maxBuffers[QPRIOLOW] = TLSMaxOutBuffersLow;
  keepBuffers[QPRIOHIGH] = TLSKeepOutBuffersHigh;
  keepBuffers[QPRIOMEDIUM] = TLSKeepOutBuffersMedium;
  keepBuffers[QPRIOLOW] = TLSKeepOutBuffersLow;
  flushThreshBuffers[QPRIOHIGH] = TLSThreshOutBuffersHigh;
  flushThreshBuffers[QPRIOMEDIUM] = TLSThreshOutBuffersMedium;
  flushThreshBuffers[QPRIOLOW] = TLSThreshOutBuffersLow;
  maxBufferSize = out->getMaxBufferSize() - TLSEXTRA;
  THREAD_CREATE(flushThread, flushThreadId, this);
  THREAD_SET_NAME(flushThreadId, "tg-tls-flush");
}

TLSOutStream::~TLSOutStream()
{
  runThreads = false;
  THREAD_JOIN(flushThreadId);
  gnutls_transport_set_vec_push_function(session, NULL);
//  gnutls_transport_set_ptr2(session, recv, send);
  gnutls_transport_set_push_function(session, NULL);
  printBufferUsage();
  if(vlog.getLevel() >= vlog.LEVEL_DEBUG) {
    U64 allCount = 0;
    for (U32 i=0; i<MAXBUFFERSIZE; i++) {
      allCount += bufferStats[i];
      if(bufferStats[i] >= 100)
        vlog.debug("buffer size %u used %u times", i, bufferStats[i]);
    }
    vlog.debug("buffer use total %llu, %llu flush calls", allCount, flushCalls);
  }

#if 0
  TGVNC_CONDITION_DESTROY(&buffersAvailableCondition);
  MUTEX_DESTROY(&buffersAvailableConditionLock);
  for (int q = QPRIOMIN; q <= QPRIOMAX; q++) {
    TGVNC_CONDITION_DESTROY(&flushedCondition[q]);
    MUTEX_DESTROY(&flushedConditionLock[q]);
  }
#endif
//  vlog.debug("~TLSOutStream: deleting out");
  delete out;
//  resetClassLog();
  vlog.debug("~TLSOutStream: exiting with serial %u.", serial);
}

THREAD_FUNC TLSOutStream::flushThread(void* param) {
  TLSOutStream * myself = (TLSOutStream *) param;

  struct queueBuffer * buffer;
  unsigned sent;
  bool restart;
  int n;
  struct timeval nowTime;
  int currentPrio;
  U64 sleepCount = 0;
  ssize_t writtenBytes[QPRIONUM];
  U8 headerBuf[CHECKHEADERSIZE];

#ifdef WIN32
  vlog.debug("flushThread (tid %lu) created", GetCurrentThreadId());
#else
#if defined(__APPLE__)
  vlog.debug("flushThread (tid %u) created", gettid());
#else
  vlog.debug("flushThread (tid %lu) created", gettid());
#endif
#endif

  for (int q = QPRIOMIN; q <= QPRIOMAX; q++)
    writtenBytes[q] = 0;

#ifdef WIN32
  vlog.debug("flushThread: trying to raise thread priority");
  SetThreadPriority(GetCurrentThread(), THREAD_PRIORITY_ABOVE_NORMAL);

#else

#if defined(__APPLE__)
  vlog.debug("flushThread %u: current priority is %i, now set nice(-10)", gettid(), getpriority(PRIO_PROCESS, 0));
#else
  vlog.debug("flushThread %lu: current priority is %i, now set nice(-10)", gettid(), getpriority(PRIO_PROCESS, 0));
#endif
  errno = 0;
  n = nice(-10);
  if (n == -1 && errno != 0) {
#if defined(__APPLE__)
    vlog.error("flushThread %u: failed to set nice(-10), error: %s", gettid(), strerror(errno));
#else
    vlog.error("flushThread %lu: failed to set nice(-10), error: %s", gettid(), strerror(errno));
#endif
  } else {
#if defined(__APPLE__)
    vlog.debug("flushThread %u: new priority is %i", gettid(), getpriority(PRIO_PROCESS, 0));
#else
    vlog.debug("flushThread %lu: new priority is %i", gettid(), getpriority(PRIO_PROCESS, 0));
#endif
  }
#endif

  while (runThreads) {
    myself->flushActive = false;
    /* sleep */
//    vlog.verbose("flushThread(): sleep");
    MUTEX_LOCK(&buffersAvailableConditionLock);
#ifdef WIN32
    TGVNC_CONDITION_TIMED_WAIT(&buffersAvailableCondition, &buffersAvailableConditionLock, 500);
#else
    int err = TGVNC_CONDITION_TIMED_WAIT(&buffersAvailableCondition, &buffersAvailableConditionLock, 500);
    if (err == EINTR)
      vlog.verbose("flushThread(): TGVNC_CONDITION_TIMED_WAIT was interrupted");
    else if (err != 0 && err != ETIMEDOUT)
      vlog.verbose("flushThread(): TGVNC_CONDITION_TIMED_WAIT returned error %i", err);
#endif
    MUTEX_UNLOCK(&buffersAvailableConditionLock);
//    vlog.verbose("flushThread(): woken up");
    myself->flushActive = true;
    if (!sleepCount) {
      vlog.debug("flushThread(): disable lower outstream master key");
      myself->out->setAllowMasterKey(false);
    }
    sleepCount++;

    gettimeofday(&nowTime, NULL);
    currentPrio = QPRIOMAX;

    while(currentPrio >= QPRIOMIN) {
      if (myself->queueEmpty((QPrio)currentPrio)) {
        currentPrio--;
        continue;
      }
      restart = false;
      if(TLSCork)
        gnutls_record_cork(myself->session);
      while (( buffer = myself->popBuffer((QPrio)currentPrio) )) {
//      vlog.debug("flushThread(): send buffer %u, prio %u, used %u", buffer->number, currentPrio, buffer->used);
        myself->bufferStats[buffer->used]++;
        if(sendTLSHeader) {
          myself->fillCheckHeader(buffer, headerBuf);
          sent = 0;
          while (sent < CHECKHEADERSIZE) {
            n = myself->writeTLS(headerBuf + sent, CHECKHEADERSIZE - sent);
            sent += n;
          }
          writtenBytes[currentPrio] += CHECKHEADERSIZE;
        }
        sent = 0;
        /* buffer splitting can happen in case of congestion */
        while (sent < buffer->used) {
          n = myself->writeTLS(buffer->data + sent, buffer->used - sent);
          if (n < (int) (buffer->used - sent))
            vlog.verbose("flushThread(): send buffer %u, prio %u, used %u, only sent %u of requested %u bytes", buffer->number, currentPrio, buffer->used, n, buffer->used - sent);
          sent += n;
        }
        writtenBytes[currentPrio] += buffer->used;
        myself->returnQueueBuffer(buffer, (QPrio)currentPrio, nowTime.tv_sec);

        /* new high prio buffers must be sent first */
        if (currentPrio < QPRIOMAX && !myself->queueEmpty(QPRIOMAX)) {
          vlog.verbose("flushThread(): new high prio %u buffer at prio %u, restart at high prio", QPRIOMAX, currentPrio);
          restart = true;
          break;
        }
        /* new medium prio buffers must be sent before low */
        if (currentPrio < QPRIOMEDIUM && !myself->queueEmpty(QPRIOMEDIUM)) {
          vlog.verbose("flushThread(): new medium prio buffer at prio %u, restart at high prio", currentPrio);
          restart = true;
          break;
        }
      }
      if(TLSCork) {
        n = gnutls_record_uncork(myself->session, GNUTLS_RECORD_WAIT);
        if (n < 0) {
          myself->printBufferUsage();
          if(vlog.getLevel() >= vlog.LEVEL_DEBUG)
            for (U32 i=0; i<MAXBUFFERSIZE; i++)
              if(myself->bufferStats[i] >= 10)
                vlog.debug("buffer size %u used %u times", i, myself->bufferStats[i]);
          throw TLSException("gnutls_record_uncork(GNUTLS_RECORD_WAIT)", n);
        }
#if 0
        if (n > 0) {
          vlog.verbose("flushThread(): uncorked %u bytes in prio %u, flushing now", n, currentPrio);
        }
#endif
      }
      myself->out->flush(TLSKEY);
      if (myself->queueEmpty((QPrio)currentPrio)) {
        /* wake up all flushers at prio currentPrio */
        MUTEX_LOCK(&flushedConditionLock[currentPrio]);
        TGVNC_CONDITION_BROADCAST(&flushedCondition[currentPrio]);
        MUTEX_UNLOCK(&flushedConditionLock[currentPrio]);
      }
      if (restart)
        currentPrio = QPRIOMAX;
      else
        currentPrio--;
    }
  }
#ifndef WIN32
#if defined(__APPLE__)
  vlog.debug("flushThread (tid %u) exiting, slept %llu times, %llu flush calls", gettid(), sleepCount, myself->flushCalls);
#else
  vlog.debug("flushThread (tid %lu) exiting, slept %llu times, %llu flush calls", gettid(), sleepCount, myself->flushCalls);
#endif
#else
  vlog.debug("flushThread (tid %lu) exiting, slept %llu times, %llu flush calls", GetCurrentThreadId(), sleepCount, myself->flushCalls);
#endif
  for (int q = QPRIOMIN; q <= QPRIOMAX; q++) {
    if (writtenBytes[q] > 0)
      vlog.debug("flushThread() wrote %zu bytes for prio %u", writtenBytes[q], q);
  }

  THREAD_EXIT(THREAD_NULL);
}

void TLSOutStream::flush(int key, QPrio prio, bool wait)
{
//  if(!check_key(key, __func__))
//    return;
  if(!check_prio(prio, __func__))
    return;

//  if (prio == QPRIOMIN)
//  vlog.verbose("flush(): called with key %u, prio %u, wait=%u", key, prio, wait);

  if (queueEmpty(prio) || (flushActive && !wait)) {
//    vlog.verbose("flush(): called for empty queue with key %u, prio %u, wait=%u", key, prio, wait);
    return;
  }

  flushCalls++;

  /* wake up flushThread */
  MUTEX_LOCK(&buffersAvailableConditionLock);
  TGVNC_CONDITION_SEND_SIG(&buffersAvailableCondition);
  MUTEX_UNLOCK(&buffersAvailableConditionLock);

  if (wait) {
    /* sleep for prio prio */
    MUTEX_LOCK(&flushedConditionLock[prio]);
    TGVNC_CONDITION_TIMED_WAIT(&flushedCondition[prio], &flushedConditionLock[prio], 500);
    MUTEX_UNLOCK(&flushedConditionLock[prio]);
  }
}

ssize_t TLSOutStream::vec_func(gnutls_transport_ptr str, const giovec_t * iov, int iovcnt)
{
  TLSOutStream* self = (TLSOutStream*) str;
  OutStream *out = self->out;
  ssize_t count = 0;

  for (int i=0; i < iovcnt; i++) {
    out->writeBytes(iov[i].iov_base, iov[i].iov_len, TLSKEY);
    count += iov[i].iov_len;
  }
  if(TLSPushFlush)
    out->flush(TLSKEY);
  return count;
}

size_t TLSOutStream::writeTLS(const U8* data, size_t length)
{
  int n;

  n = gnutls_record_send(session, data, length);
  while (n == GNUTLS_E_INTERRUPTED || n == GNUTLS_E_AGAIN) {
    vlog.verbose("writeTLS(): need to retry gnutls_record_send()");
    n = gnutls_record_send(session, NULL, 0);
  }
  if (n < 0) {
    printBufferUsage();
    if(vlog.getLevel() >= vlog.LEVEL_DEBUG)
      for (U32 i=0; i<MAXBUFFERSIZE; i++)
        if(bufferStats[i] >= 10)
          vlog.debug("buffer size %u used %u times", i, bufferStats[i]);
    throw TLSException("writeTLS(): gnutls_record_send()", n);
  }
  return n;
}

void TLSOutStream::printBufferUsage() {
	OutStream::printBufferUsage();
	out->printBufferUsage();
}

#endif
