/* Copyright (C) 2002-2005 RealVNC Ltd.  All Rights Reserved.
 * Copyright (C) 2005 Martin Koegler
 * Copyright (C) 2010 TigerVNC Team
 * Copyright (C) 2015-2021 m-privacy GmbH
 *
 * This is free software; you can redistribute it and/or modify
 * it under the terms of the GNU General Public License as published by
 * the Free Software Foundation; either version 2 of the License, or
 * (at your option) any later version.
 *
 * This software is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU General Public License for more details.
 *
 * You should have received a copy of the GNU General Public License
 * along with this software; if not, write to the Free Software
 * Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA  02111-1307,
 * USA.
 */

#ifdef HAVE_CONFIG_H
#include <config.h>
#endif

#include <rdr/Exception.h>
#include <rdr/MultiInStream.h>
#include <rfb/LogWriter.h>
#include <rdr/MultiStream.h>
#include <errno.h>
#include <unistd.h>
#ifndef WIN32
#include <sys/socket.h>
#endif
#include <stdlib.h>
#include <zstd.h>

#define MULTIHEADERSIZE 3

using namespace rdr;

static rfb::LogWriter vlog("MultiInStream");

enum {
	NON_VNC_BUF_SIZE =  65 * 1024,
};

static unsigned int maxUsed[STREAM_QUANTITY];

static U64 zstdCompressedBytes = 0;
static U64 zstdDecompressedBytes = 0;

#define PAINDEF (384 * 1024)
#define PAINMIN (64 * 1024)
#define PAINMAX (1024 * 1024)
rfb::IntParameter multiPAInBufferSize("MultiPAInBufferSize", "Multi PulseAudio input buffer size", PAINDEF, PAINMIN, PAINMAX);

MultiInStream::MultiInStream(InStream* _in, size_t _bufSize)
	: packType(VNC_STREAM_ID), packRestSize(0), in(_in), bufSize(_bufSize) {
	classLog = &vlog;
	vlog.debug("Init with VNC buffer size %zu", bufSize);
	for (int i=0; i<STREAM_QUANTITY; i++) {
		if (i == VNC_STREAM_ID) {
			dataStart[i] = dataEnd[i] = bufStart[i] = new U8[bufSize];
			bufEnd[i] = bufStart[i] + bufSize;
		} else if (i == PULSEAUDIO_STREAM_ID) {
			dataStart[i] = dataEnd[i] = bufStart[i] = new U8[multiPAInBufferSize];
			bufEnd[i] = bufStart[i] + multiPAInBufferSize;
		} else if (i == PULSEAUDIO_ZSTD_STREAM_ID) {
			dataStart[i] = dataEnd[i] = bufStart[i] = new U8[multiPAInBufferSize];
			bufEnd[i] = bufStart[i] + multiPAInBufferSize;
		} else {
			dataStart[i] = dataEnd[i] = bufStart[i] = new U8[NON_VNC_BUF_SIZE];
			bufEnd[i] = bufStart[i] + NON_VNC_BUF_SIZE;
		}
		streamCounter[i] = 0;
		callbacks[i].callback = NULL;
		maxUsed[i] = 0;
		isMultiPart[i] = false;
		streamUseCount[i] = 0;
	}
	ptr = end = bufStart[VNC_STREAM_ID];
}

MultiInStream::~MultiInStream() {
	for (int i=0; i<STREAM_QUANTITY; i++) {
		delete[] bufStart[i];
		if (maxUsed[i] > 0)
			vlog.debug("maxUsed[%u] = %u", i, maxUsed[i]);
		if (streamUseCount[i] > 0)
			vlog.debug("stream type %u had %u multi packets", i, streamUseCount[i]);
	}
	if (zstdDecompressedBytes > 0)
		vlog.debug("Pulse ZSTD decompressed total %llu bytes to %llu bytes, global ratio %llu%%", zstdCompressedBytes, zstdDecompressedBytes, 100 * zstdCompressedBytes / zstdDecompressedBytes);
	if (in) {
		delete in;
		in = NULL;
	}
	resetClassLog();
}

void MultiInStream::setCallback(unsigned streamId, bool (*cbFunc)(const U8*, int)) {
	if (streamId >= STREAM_QUANTITY) {
		vlog.error("setCallback: invalid streamId %u", streamId);
		return;
	}
	vlog.debug("Set callback for streamId %u", streamId);
	callbacks[streamId].callback = cbFunc;
}

size_t MultiInStream::pos() {
	return streamCounter[VNC_STREAM_ID];
}

void MultiInStream::checkBufferMove(unsigned packType) {
	const size_t oldDataLen = dataEnd[packType] - dataStart[packType];
	if (!oldDataLen || (packType == UNKNOWN_STREAM_ID))
		dataEnd[packType] = dataStart[packType] = bufStart[packType];
	else {
		if (1 > (bufEnd[packType] - dataEnd[packType] + oldDataLen)) {
			vlog.debug("checkBufferMove(): moving %zu bytes for packType %u", oldDataLen, packType);
			memmove(bufStart[packType], dataStart[packType], oldDataLen);
			dataStart[packType] = bufStart[packType];
			dataEnd[packType] = bufStart[packType] + oldDataLen;
		}
	}
}

void MultiInStream::handleOtherData(unsigned packType) {
	const size_t dataLen = dataEnd[packType] - dataStart[packType];

	streamUseCount[packType]++;
	if (maxUsed[packType] < dataEnd[packType] - bufStart[packType])
		maxUsed[packType] = dataEnd[packType] - bufStart[packType];
	if (packType == PULSEAUDIO_ZSTD_STREAM_ID && callbacks[PULSEAUDIO_STREAM_ID].callback) {
		int dataSize = ZSTD_getFrameContentSize(dataStart[packType], dataLen);

		if (dataSize > 0) {
			rdr::U8 * data = (rdr::U8 *) malloc(dataSize);
			if (data) {
				dataSize = ZSTD_decompress(data, dataSize, dataStart[packType], dataLen);
				if (dataSize > 0 && callbacks[PULSEAUDIO_STREAM_ID].callback(data, dataSize)) {
					dataEnd[packType] = dataStart[packType] = bufStart[packType];
					free(data);
					zstdCompressedBytes += dataLen;
					zstdDecompressedBytes += dataSize;
//					vlog.verbose("handleOtherData(): zstd decompressed %zu bytes to %u bytes, global ratio %llu%%", dataLen, dataSize, 100 * zstdCompressedBytes / zstdDecompressedBytes);
					return;
				}
				free(data);
			}
		}
	}
	if (callbacks[packType].callback) {
		if (callbacks[packType].callback(dataStart[packType], dataLen))
			dataEnd[packType] = dataStart[packType] = bufStart[packType];
	} else {
		vlog.debug("Ignore incoming stream of type %u without callback", packType);
		dataEnd[packType] = dataStart[packType] = bufStart[packType];
	}
}

bool MultiInStream::overrun(size_t needed) {
	if (needed > bufSize) {
		throw Exception("MultiInStream overrun: bufSize exceeded");
	}

	dataStart[VNC_STREAM_ID] = (rdr::U8*) ptr;
	size_t freeSpaceLen;
	size_t missing = needed - (dataEnd[VNC_STREAM_ID] - dataStart[VNC_STREAM_ID]);
	size_t wantedBytes;

	do {
		if (!packRestSize) {
			if (!in->hasData(MULTIHEADERSIZE))
				break;
			packType = in->readU8();
			if ((packType & ~MULTIPARTMARKER) >= STREAM_QUANTITY) {
				vlog.debug("Unknown incoming stream of type %u", packType);
				packType = UNKNOWN_STREAM_ID;
			} else {
				if (packType & MULTIPARTMARKER) {
					packType &= ~MULTIPARTMARKER;
					isMultiPart[packType] = true;
				} else {
					isMultiPart[packType] = false;
				}
			}
			packRestSize = in->readU16();
			/* enforce flush after this packet, if buffer has not enough space
			 * for another packet of same size
			 */
			if (isMultiPart[packType]) {
//				vlog.debug("Packet of packType %u, size %zu, is marked as MultiPart", packType, packRestSize);
				if ((packRestSize << 1) > (size_t)(bufEnd[packType] - dataEnd[packType])) {
					vlog.info("Marking packet of packType %u, size %zu, back as non-MultiPart, because buffer is running out of space. Please increase MultiPAInBufferSize.", packType, packRestSize);
					isMultiPart[packType] = false;
				}
			}
		}
		checkBufferMove(packType);
		freeSpaceLen = bufEnd[packType] - dataEnd[packType];
		wantedBytes = packRestSize;
		if (wantedBytes > freeSpaceLen) {
			vlog.debug("overrun(): packType %u, wantedBytes %zu > freeSpaceLen %zu", packType, wantedBytes, freeSpaceLen);
			wantedBytes = freeSpaceLen;
		}
		if (packType == VNC_STREAM_ID && wantedBytes > missing)
			wantedBytes = missing;
		if (wantedBytes <= 0)
			break;
		if (!in->hasData(wantedBytes)) {
			if (packType == VNC_STREAM_ID)
				break;
			if (wantedBytes > in->avail()) {
				vlog.verbose2("overrun(): wantedBytes %zu > avail %zu", wantedBytes, in->avail());
				wantedBytes = in->avail();
				if (wantedBytes <= 0)
					break;
			}
		}
		in->readBytes(dataEnd[packType], wantedBytes);
		dataEnd[packType] += wantedBytes;
		streamCounter[packType] += wantedBytes;
		packRestSize -= wantedBytes;
		if (packType == VNC_STREAM_ID)
			missing -= wantedBytes;
		else if (!packRestSize && !isMultiPart[packType])
			handleOtherData(packType);
	} while (packType != VNC_STREAM_ID || missing > 0);

	ptr = dataStart[VNC_STREAM_ID];
	end = dataEnd[VNC_STREAM_ID];

	if (maxUsed[packType] < dataEnd[packType] - bufStart[packType])
		maxUsed[packType] = dataEnd[packType] - bufStart[packType];
	if (missing > 0)
		return false;
	else
		return true;
}
