/* Copyright (C) 2002-2005 RealVNC Ltd.  All Rights Reserved.
 * Copyright (C) 2014-2024 m-privacy GmbH
 * 
 * This is free software; you can redistribute it and/or modify
 * it under the terms of the GNU General Public License as published by
 * the Free Software Foundation; either version 2 of the License, or
 * (at your option) any later version.
 * 
 * This software is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU General Public License for more details.
 * 
 * You should have received a copy of the GNU General Public License
 * along with this software; if not, write to the Free Software
 * Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA  02111-1307,
 * USA.
 */

#ifdef HAVE_CONFIG_H
#include <config.h>
#endif

#include <stdio.h>
#include <string.h>
#include <errno.h>
#include <stdlib.h>
#include <sys/time.h>
#include <zlib.h>
#ifdef _WIN32
#include <winsock2.h>
#define close closesocket
#undef errno
#define errno WSAGetLastError()
#include <os/winerrno.h>
#else
#include <sys/types.h>
#include <sys/socket.h>
#include <unistd.h>
#endif

/* Old systems have select() in sys/time.h */
#ifdef HAVE_SYS_SELECT_H
#include <sys/select.h>
#endif

#include <rdr/FdInStream.h>
#include <rdr/Exception.h>
#include <rfb/LogWriter.h>

using namespace rdr;

static rfb::LogWriter vlog("FdInStream");

FdInStream::FdInStream(int fd_, bool closeWhenDone_)
  : fd(fd_), closeWhenDone(closeWhenDone_), checkHeader(false), firstRead(true), headerBytesRead(0), bodyBytesLeft(0), crcChecked(0), crcFailed(0), crcMissed(0)
{
  vlog.debug("Init");
  classLog = &vlog;
}

FdInStream::~FdInStream()
{
  if (closeWhenDone) close(fd);
  vlog.debug("Exiting with nextSerial %u, CRC checked %u, failed %u, missed %u", nextSerial, crcChecked, crcFailed, crcMissed);
  resetClassLog();
}


bool FdInStream::fillBuffer(size_t maxSize)
{
  size_t n;

  if (checkHeader || firstRead) {
    if (!bodyBytesLeft) {
      U32 serialNumber;
      U32 crc;

      if (firstRead) {
        while (headerBytesRead < sizeof(U32)) {
          n = readFd(header + headerBytesRead, CHECKHEADERSIZE - headerBytesRead);
          if (n == 0) {
            if (headerBytesRead == 0)
              return false;
            vlog.debug("fillBuffer(): firstRead: could only read %zu header bytes, assuming no header", headerBytesRead);
            break;
          }
          headerBytesRead += n;
        }
        firstRead = false;
        /* is this a header with serial 0 and bodyBytesLeft > 0?
         * If yes, use headers, if no, move to buffer
         */
        if (headerBytesRead >= CHECKHEADERSIZE && checkCheckHeader(header, &serialNumber, &crc, &bodyBytesLeft) && serialNumber == 0 && bodyBytesLeft > 0) {
          checkHeader = true;
          headerBytesRead = 0;
          vlog.debug("fillBuffer(): detected header in first read, enabling header support, %u body bytes", bodyBytesLeft);
        } else {
          vlog.debug("fillBuffer(): no header in first read, no header support, move %zu bytes to buffer", headerBytesRead);
          memcpy((U8*) end, header, headerBytesRead);
          end += headerBytesRead;
          return true;
        }
      } else {
        while (headerBytesRead < CHECKHEADERSIZE) {
          n = readFd(header + headerBytesRead, CHECKHEADERSIZE - headerBytesRead);
          if (n == 0) {
            return false;
          }
          headerBytesRead += n;
        }
        headerBytesRead = 0;
        if (!checkCheckHeader(header, &serialNumber, &crc, &bodyBytesLeft)) {
          U8 * tmpbuf;
          const size_t tmpbufsize = 128 * 1024;

          vlog.error("fillBuffer(): checkCheckHeader() failed, try to find correct header");
          tmpbuf = (U8 *) malloc(tmpbufsize);
          if (tmpbuf) {
            size_t offset = 0;
            size_t tmpread = 0;
//            const size_t maxread = tmpbufsize < maxSize ? tmpbufsize : maxSize;
            const size_t maxread = tmpbufsize < maxSize + CHECKHEADERSIZE ? tmpbufsize : maxSize + CHECKHEADERSIZE;

            sleep(1);
            while (tmpread < maxread) {
              n = readFd(tmpbuf + tmpread, maxread - tmpread);
              if (n == 0)
                break;
              tmpread += n;
            }
            vlog.debug("fillBuffer(): filled tmpbuf with %zu bytes", tmpread);
            if (tmpread > CHECKHEADERSIZE) {
              while (offset < tmpread - CHECKHEADERSIZE && *((U32 *) (tmpbuf + offset)) != serialNumber)
                offset++;
              if (offset == tmpread - CHECKHEADERSIZE ) {
                /* not found, we are going to crash */
                vlog.error("fillBuffer(): failed to find missing header, giving up");
                memcpy((U8*)end, tmpbuf, tmpread);
                end += tmpread;
                free(tmpbuf);
                return true;
              } else {
                /* found, try to restart with offset */
                vlog.debug("fillBuffer(): found missing header at offset %zu, trying to shift buffer", offset);
                checkCheckHeader(tmpbuf + offset, &serialNumber, &crc, &bodyBytesLeft);
                while (true) {
                  if (bodyBytesLeft >= tmpread - offset - CHECKHEADERSIZE) {
                    if (tmpread - offset - CHECKHEADERSIZE > 0) {
                      memcpy((U8*)end, tmpbuf + offset + CHECKHEADERSIZE, tmpread - offset - CHECKHEADERSIZE);
                      end += (tmpread - offset - CHECKHEADERSIZE);
                      bodyBytesLeft -= tmpread - offset - CHECKHEADERSIZE;
                    }
                    free(tmpbuf);
                    return true;
                  } else {
                    /* we read over the start of the next header */
                    vlog.debug("fillBuffer(): dead over the next header");
                    memcpy((U8*)end, tmpbuf + offset + CHECKHEADERSIZE, bodyBytesLeft);
                    end += bodyBytesLeft;
                    offset += CHECKHEADERSIZE + bodyBytesLeft;
                    if (tmpread - offset >= CHECKHEADERSIZE) {
                      if (!checkCheckHeader(tmpbuf + offset, &serialNumber, &crc, &bodyBytesLeft)) {
                        vlog.error("fillBuffer(): checkCheckHeader() failed again, giving up");
                        free(tmpbuf);
                        return true;
                      }
                    } else {
                      headerBytesRead = tmpread - offset;
                      memcpy(header, tmpbuf + offset, headerBytesRead);
                      free(tmpbuf);
                      return true;
                    }
                  }
                }
              }
            }
            free(tmpbuf);
          } else {
            vlog.error("fillBuffer(): failed to allocate tmpbuf");
          }
        }
      }
      if (bodyBytesLeft > 0) {
        /* try a crc check, but we need the space */
        if (maxSize >= bodyBytesLeft) {
          U32 newCrc;
          const U16 savedSize = bodyBytesLeft;

          while (bodyBytesLeft > 0) {
            n = readFd((U8*)end, bodyBytesLeft);
            if (n == 0)
              return false;
            end += n;
            bodyBytesLeft -= n;
          }
          newCrc = crc32(0L, Z_NULL, 0);
          newCrc = crc32(newCrc, (U8*)end - savedSize, savedSize);
          crcChecked++;
          if (newCrc != crc) {
            vlog.error("CRC mismatch: expected %u, got %u with size %u", crc, newCrc, savedSize);
            crcFailed++;
          }
          return true;
        } else {
          vlog.verbose("fillBuffer(): not enough space (%lu) for serial %u CRC check with size %u, crc %u", maxSize, serialNumber, bodyBytesLeft, crc);
          crcMissed++;
        }
      }
    }
    if (maxSize > bodyBytesLeft && bodyBytesLeft > 0)
      maxSize = bodyBytesLeft;
  }

  n = readFd((U8*)end, maxSize);
  if (n == 0)
    return false;
  end += n;
  if (bodyBytesLeft > 0)
    bodyBytesLeft -= n;

  return true;
}

//
// readFd() reads up to the given length in bytes from the
// file descriptor into a buffer. Zero is
// returned if no bytes can be read. Otherwise it returns the number of bytes read.  It
// never attempts to recv() unless select() indicates that the fd is readable -
// this means it can be used on an fd which has been set non-blocking.  It also
// has to cope with the annoying possibility of both select() and recv()
// returning EINTR.
//

size_t FdInStream::readFd(void* buf, size_t len)
{
  int n;
  do {
    fd_set fds;
    struct timeval tv;

    tv.tv_sec = tv.tv_usec = 0;

    FD_ZERO(&fds);
    FD_SET(fd, &fds);
    n = select(fd+1, &fds, 0, 0, &tv);
  } while (n < 0 && errno == EINTR);

  if (n < 0)
    throw SystemException("select",errno);

  if (n == 0)
    return 0;

  do {
    n = ::recv(fd, (char*)buf, len, 0);
  } while (n < 0 && errno == EINTR);

  if (n < 0)
    throw SystemException("read",errno);
  if (n == 0)
    throw EndOfStream();

  return n;
}
