/* Copyright (C) 2002-2005 RealVNC Ltd.  All Rights Reserved.
 * Copyright (C) 2005 Martin Koegler
 * Copyright (C) 2010 TigerVNC Team
 * Copyright (C) 2014-2024 m-privacy GmbH
 * 
 * This is free software; you can redistribute it and/or modify
 * it under the terms of the GNU General Public License as published by
 * the Free Software Foundation; either version 2 of the License, or
 * (at your option) any later version.
 * 
 * This software is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU General Public License for more details.
 * 
 * You should have received a copy of the GNU General Public License
 * along with this software; if not, write to the Free Software
 * Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA  02111-1307,
 * USA.
 */

#ifdef HAVE_CONFIG_H
#include <config.h>
#endif

#include <rdr/Exception.h>
#include <rdr/TLSException.h>
#include <rdr/TLSInStream.h>
#include <rfb/LogWriter.h>

#include <errno.h>
#include <string.h>
#include <cstdlib>
#include <zlib.h>

#if !defined(__APPLE__) && !defined(WIN32) && !defined(_GNU_SOURCE)
#if __GLIBC_PREREQ(2,30)
#define _GNU_SOURCE
#include <unistd.h>

#else

#include <sys/syscall.h>

pid_t
gettid(void)
{

    return syscall(SYS_gettid);
}
#endif
#endif

#ifdef HAVE_GNUTLS 
using namespace rdr;

static rfb::LogWriter vlog("TLSInStream");

int TLSInStream::pullErrno = 0;

static TLSInStream * globalSelf;

ssize_t TLSInStream::pull(gnutls_transport_ptr str, void* data, size_t size)
{
  TLSInStream* self = (TLSInStream*) str;
  InStream *in = self->in;

  if (self->shutdown)
    return 0;
  try {
    if (!in->hasData(1)) {
      self->pullErrno = EAGAIN;
      return -1;
    }

    if (in->avail() < size)
      size = in->avail();
  
    in->readBytes(data, size);
  } catch (EndOfStream&) {
    return 0;
  } catch (Exception& e) {
    vlog.error("Failure reading TLS data: %s", e.str());
    self->pullErrno = EINVAL;
    return -1;
  }

  self->pullErrno = 0;
  return size;
}

int TLSInStream::pull_timeout(gnutls_transport_ptr_t str, unsigned int ms)
{
  TLSInStream* self = (TLSInStream*) str;
  InStream *in = self->in;
  unsigned int count = 0;

  vlog.debug("pull_timeout(): %u ms", ms);
  if (ms > 0) {
    while (count < ms) {
      if (in->hasData(1))
        return in->avail();
#ifdef WIN32
      Sleep(10);
#else
      usleep(10000);
#endif
      count += 10;
    }
  } else {
    in->hasData(1);
    return in->avail();
  }
  return 0;
}

int TLSInStream::errno_func(gnutls_transport_ptr_t trans)
{
  if (trans == (gnutls_transport_ptr_t) globalSelf) {
    return pullErrno;
  }
#ifdef WIN32
  vlog.error("Thread %lu: errno_func called for wrong transport %p", GetCurrentThreadId(), trans);
#else
#if defined(__APPLE__)
  vlog.error("errno_func called for wrong transport %p", trans);
#else
  vlog.error("Thread %lu: errno_func called for wrong transport %p", gettid(), trans);
#endif
#endif
  return EAGAIN;
}

TLSInStream::TLSInStream(InStream* _in, gnutls_session _session)
  : session(_session), in(_in), shutdown(false), checkHeader(false), firstRead(true), headerBytesRead(0), bodyBytesLeft(0), crcChecked(0), crcFailed(0), crcMissed(0)
{
  vlog.debug("Init");
  classLog = &vlog;
  globalSelf = this;
  gnutls_transport_set_pull_function(session, pull);
  gnutls_transport_set_pull_timeout_function(session, pull_timeout);
  gnutls_transport_get_ptr2(session, &recv, &send);
  gnutls_transport_set_ptr2(session, this, send);
  gnutls_transport_set_errno_function(session, errno_func);
}

TLSInStream::~TLSInStream()
{
  gnutls_transport_set_pull_function(session, NULL);
  gnutls_transport_set_pull_timeout_function(session, NULL);
  gnutls_transport_set_ptr2(session, recv, send);
  vlog.debug("Exiting with nextSerial %u, CRC checked %u, failed %u, missed %u", nextSerial, crcChecked, crcFailed, crcMissed);
  resetClassLog();
}

bool TLSInStream::fillBuffer(size_t maxSize)
{
  size_t n;
  bool retval = false;

  if (checkHeader || firstRead) {
    if (!bodyBytesLeft) {
      U32 serialNumber;
      U32 crc;

      if (firstRead) {
        while (headerBytesRead < sizeof(U32)) {
          vlog.debug("fillBuffer(): firstRead: try to read first header, %zu done", headerBytesRead);
          n = readTLS(header + headerBytesRead, CHECKHEADERSIZE - headerBytesRead);
          if (n == 0) {
            if (headerBytesRead == 0)
              return retval;
            vlog.debug("fillBuffer(): firstRead: could only read %zu header bytes, assuming no header", headerBytesRead);
            break;
          }
          headerBytesRead += n;
        }
        firstRead = false;
        /* is this a header with serial 0 and bodyBytesLeft > 0?
         * If yes, use headers, if no, move to buffer
         */
        if (headerBytesRead >= CHECKHEADERSIZE && checkCheckHeader(header, &serialNumber, &crc, &bodyBytesLeft) && serialNumber == 0 && bodyBytesLeft > 0) {
          checkHeader = true;
          headerBytesRead = 0;
          vlog.debug("fillBuffer(): detected header in first read, enabling header support, %u body bytes", bodyBytesLeft);
        } else {
          vlog.debug("fillBuffer(): no header in first read, no header support, move %zu bytes to buffer", headerBytesRead);
          memcpy((U8*) end, header, headerBytesRead);
          end += headerBytesRead;
          if (maxSize <= headerBytesRead || !in->hasData(1))
            return true;
          maxSize -= headerBytesRead;
          retval = true;
        }
      } else {
        while (headerBytesRead < CHECKHEADERSIZE) {
          n = readTLS(header + headerBytesRead, CHECKHEADERSIZE - headerBytesRead);
          if (n == 0) {
            return retval;
          }
          headerBytesRead += n;
        }
        headerBytesRead = 0;
        if (!checkCheckHeader(header, &serialNumber, &crc, &bodyBytesLeft)) {
          U8 * tmpbuf;
          const size_t tmpbufsize = 128 * 1024;

          vlog.error("fillBuffer(): checkCheckHeader() failed, try to find correct header");
          tmpbuf = (U8 *) malloc(tmpbufsize);
          if (tmpbuf) {
            size_t offset = 0;
            size_t tmpread = 0;
//            const size_t maxread = tmpbufsize < maxSize ? tmpbufsize : maxSize;
            const size_t maxread = tmpbufsize < maxSize + CHECKHEADERSIZE ? tmpbufsize : maxSize + CHECKHEADERSIZE;

            sleep(1);
            while (tmpread < maxread) {
              n = readTLS(tmpbuf + tmpread, maxread - tmpread);
              if (n == 0)
                break;
              tmpread += n;
            }
            vlog.debug("fillBuffer(): filled tmpbuf with %zu bytes", tmpread);
            if (tmpread > CHECKHEADERSIZE) {
              while (offset < tmpread - CHECKHEADERSIZE && *((U32 *) (tmpbuf + offset)) != serialNumber)
                offset++;
              if (offset == tmpread - CHECKHEADERSIZE ) {
                /* not found, we are going to crash */
                vlog.error("fillBuffer(): failed to find missing header, giving up");
                memcpy((U8*)end, tmpbuf, tmpread);
                end += tmpread;
                free(tmpbuf);
                return true;
              } else {
                /* found, try to restart with offset */
                vlog.debug("fillBuffer(): found missing header at offset %zu, trying to shift buffer", offset);
                checkCheckHeader(tmpbuf + offset, &serialNumber, &crc, &bodyBytesLeft);
                while (true) {
                  if (bodyBytesLeft >= tmpread - offset - CHECKHEADERSIZE) {
                    if (tmpread - offset - CHECKHEADERSIZE > 0) {
                      memcpy((U8*)end, tmpbuf + offset + CHECKHEADERSIZE, tmpread - offset - CHECKHEADERSIZE);
                      end += (tmpread - offset - CHECKHEADERSIZE);
                      bodyBytesLeft -= tmpread - offset - CHECKHEADERSIZE;
                    }
                    free(tmpbuf);
                    return true;
                  } else {
                    /* we read over the start of the next header */
                    vlog.debug("fillBuffer(): dead over the next header");
                    memcpy((U8*)end, tmpbuf + offset + CHECKHEADERSIZE, bodyBytesLeft);
                    end += bodyBytesLeft;
                    offset += CHECKHEADERSIZE + bodyBytesLeft;
                    if (tmpread - offset >= CHECKHEADERSIZE) {
                      if (!checkCheckHeader(tmpbuf + offset, &serialNumber, &crc, &bodyBytesLeft)) {
                        vlog.error("fillBuffer(): checkCheckHeader() failed again, giving up");
                        free(tmpbuf);
                        return true;
                      }
                    } else {
                      headerBytesRead = tmpread - offset;
                      memcpy(header, tmpbuf + offset, headerBytesRead);
                      free(tmpbuf);
                      return true;
                    }
                  }
                }
              }
            }
            free(tmpbuf);
          } else {
            vlog.error("fillBuffer(): failed to allocate tmpbuf");
          }
        }
      }
      if (bodyBytesLeft > 0) {
        /* try a crc check, but we need the space */
        if (maxSize >= bodyBytesLeft) {
          U32 newCrc;
          const U16 savedSize = bodyBytesLeft;

          while (bodyBytesLeft > 0) {
            n = readTLS((U8*)end, bodyBytesLeft);
            if (n == 0)
              return retval;
            end += n;
            bodyBytesLeft -= n;
          }
          newCrc = crc32(0L, Z_NULL, 0);
          newCrc = crc32(newCrc, (U8*)end - savedSize, savedSize);
          crcChecked++;
          if (newCrc != crc) {
            vlog.error("CRC mismatch: expected %u, got %u with size %u", crc, newCrc, savedSize);
            crcFailed++;
          }
          return true;
        } else {
          vlog.verbose("fillBuffer(): not enough space (%lu) for serial %u CRC check with size %u, crc %u", maxSize, serialNumber, bodyBytesLeft, crc);
          crcMissed++;
        }
      }
    }
    if (maxSize > bodyBytesLeft && bodyBytesLeft > 0)
      maxSize = bodyBytesLeft;
  }

  n = readTLS((U8*) end, maxSize);
  if (n == 0)
    return retval;
  end += n;
  if (bodyBytesLeft > 0)
    bodyBytesLeft -= n;

  return true;
}

size_t TLSInStream::readTLS(U8* buf, size_t len)
{
  int n;

  if (gnutls_record_check_pending(session) == 0) {
    if (!in->hasData(1))
      return 0;
  }

  n = GNUTLS_E_INTERRUPTED;
  while (n == GNUTLS_E_INTERRUPTED || n == GNUTLS_E_AGAIN)
    n = gnutls_record_recv(session, (void *) buf, len);

  if (n < 0) throw TLSException("readTLS", n);

  return n;
}

#endif
