/* Copyright (C) 2005 Martin Koegler
 * Copyright (C) 2010 TigerVNC Team
 * Copyright (C) 2012-2021 m-privacy
 *
 * This is free software; you can redistribute it and/or modify
 * it under the terms of the GNU General Public License as published by
 * the Free Software Foundation; either version 2 of the License, or
 * (at your option) any later version.
 *
 * This software is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU General Public License for more details.
 *
 * You should have received a copy of the GNU General Public License
 * along with this software; if not, write to the Free Software
 * Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA  02111-1307,
 * USA.
 */

#ifdef WIN32
#include <winsock2.h>
#endif
#include <gsasl.h>
#include <stdlib.h>

#include <rfb/CConnection.h>
#include <rfb/CSecurityKrb.h>
#include <rfb/util.h>
#include <rfb/Exception.h>
#include <rfb/LogWriter.h>

using namespace rfb;

StringParameter CSecurityKrb::KrbAuthid
("krbauthid", "specifies Kerberos username", "", ConfViewer);
StringParameter CSecurityKrb::KrbService
("krbservice", "specifies name of Kerberos service", "host", ConfViewer);
StringParameter CSecurityKrb::KrbHostname
("krbhostname", "specifies Kerberos hostname (defaults to connection target name)", "", ConfViewer);

static LogWriter vlog("CSecurityKrb");

static char * connhost = NULL;

int callback (Gsasl * ctx, Gsasl_session * sctx, Gsasl_property prop)
{
	int rc = GSASL_NO_CALLBACK;

	switch (prop)
	{
	case GSASL_AUTHID:
        {
		const char * authid = CSecurityKrb::KrbAuthid.getData();
		if (authid && authid[0])
		{
			gsasl_property_set(sctx, GSASL_AUTHID, authid);
			vlog.debug("set AUTHID to %s", authid);
			rc = GSASL_OK;
		}
		else
		{
			gsasl_property_set(sctx, GSASL_AUTHID, "");
			vlog.debug("no AUTHID, set empty one");
			rc = GSASL_OK;
		}
		break;
        }
	case GSASL_AUTHZID:
        {
		const char * authzid = CSecurityKrb::KrbAuthid.getData();

		if (authzid && authzid[0])
		{
			gsasl_property_set(sctx, GSASL_AUTHZID, authzid);
			vlog.debug("set AUTHZID to %s", authzid);
			rc = GSASL_OK;
		}
		else
		{
			gsasl_property_set(sctx, GSASL_AUTHZID, "");
			vlog.debug("no AUTHZID, set empty one");
			rc = GSASL_OK;
		}
		break;
        }
	case GSASL_SERVICE:
        {
		char * krbservice = CSecurityKrb::KrbService.getData();
		if (krbservice && krbservice[0])
		{
			gsasl_property_set(sctx, GSASL_SERVICE, krbservice);
			vlog.debug("set SERVICE to \"%s\"", krbservice);
		}
		else
		{
			gsasl_property_set(sctx, GSASL_SERVICE, "host");
			vlog.debug("set SERVICE to default \"host\"");
		}
		rc = GSASL_OK;
		break;
        }
        break;
	case GSASL_HOSTNAME:
        {
		char * krbhostname = CSecurityKrb::KrbHostname.getData();
		if (krbhostname && krbhostname[0])
		{
			gsasl_property_set(sctx, GSASL_HOSTNAME, krbhostname);
			vlog.debug("set HOSTNAME to \"%s\"", krbhostname);
			rc = GSASL_OK;
		}
		else if (connhost)
		{
			gsasl_property_set(sctx, GSASL_HOSTNAME, connhost);
			vlog.debug("set HOSTNAME to connection target \"%s\"", connhost);
			rc = GSASL_OK;
		}
		else
		{
			vlog.debug("no HOSTNAME");
			rc = GSASL_NO_HOSTNAME;
		}
		break;
        }

	default:
		vlog.debug("callback: unsupported property %u", prop);
		break;
	}
	return rc;
}


CSecurityKrb::CSecurityKrb(CConnection* cc) : CSecurity(cc)
{
	/* init */
	if ((rc = gsasl_init(&gsasl_ctx)) != GSASL_OK)
		throw AuthFailureException("gsasl_init() failed!");
	else
		vlog.debug("gsasl_init() done");
	gsasl_callback_set (gsasl_ctx, callback);
	if ((rc = gsasl_client_start (gsasl_ctx, "GSSAPI", &gsasl_session)) != GSASL_OK)
	{
		vlog.error("gsasl_client_start() failed: %s", gsasl_strerror (rc));
		throw AuthFailureException("gsasl_client_start() failed!");
	}
	else
		vlog.debug("gsasl_client_start() done");

	state = 0;
	maxbuflen = 256;
	buflen = 0;
	buf = (char *)malloc(maxbuflen);
	if(!buf)
		throw AuthFailureException("failed to allocate buffer!");
	buf[0] = 0;
}

CSecurityKrb::~CSecurityKrb()
{
}

bool CSecurityKrb::processMsg()
{
	char *p;
	size_t plen;
	rdr::InStream* is = cc->getInStream();
	rdr::OutStream* os = cc->getOutStream();
	if (!connhost)
		connhost = strdup(cc->getServerName());

	vlog.debug("CSecurityKrb::processMsg() started");

	/* authenticate */
	do
	{
		if(state == 0)
		{
			vlog.debug("gsasl state 0: calling gsasl_step()");
			rc = gsasl_step(gsasl_session, buf, buflen, &p, &plen);
			if (rc == GSASL_NEEDS_MORE || (rc == GSASL_OK && plen > 0))
			{
#ifndef WIN32
				vlog.debug("gsasl sending to server: %zu bytes", plen);
#endif
				os->writeU32(plen);
				if (plen > 0)
				{
					os->writeBytes(p,plen);
					gsasl_free(p);
				}
				os->flush();
#ifndef WIN32
				vlog.debug("gsasl sent to server: %zu bytes", plen);
#endif
			}
			state = 1;
		}
		if(state == 1)
		{
			vlog.debug("gsasl state 1");
			if (rc == GSASL_NEEDS_MORE)
			{
				if (!is->hasData(4))
					return false;
				buflen = is->readU32();
				if(buflen > maxbuflen)
				{
					free(buf);
					maxbuflen = buflen;
					buf = (char *)malloc(maxbuflen);
					if(!buf)
						throw AuthFailureException("failed to allocate buffer!");
				}
#ifndef WIN32
				vlog.debug("gsasl expecting from server: %zu bytes", buflen);
#endif
			}
			state = 2;
		}
		if(state == 2)
		{
			vlog.debug("gsasl state 2");
			if ((rc == GSASL_NEEDS_MORE) && (buflen > 0))
			{
				if (!is->hasData(buflen))
					return false;
				is->readBytes(buf, buflen);
#ifndef WIN32
				vlog.debug("gsasl read from server: %zu bytes", buflen);
#endif
			}
			state = 0;
		}
	}
	while (rc == GSASL_NEEDS_MORE);

	if (rc != GSASL_OK)
	{
		vlog.error("Kerberos authentication error (%d): %s\n",
			   rc, gsasl_strerror (rc));
		throw AuthFailureException("Kerberos authentication failed!");
	}
	vlog.debug("gsasl authentication successful!");

	/* finish */
	gsasl_finish(gsasl_session);
	vlog.debug("gsasl_finish() done");
	gsasl_done(gsasl_ctx);
	vlog.debug("gsasl_done() done");
	if(buf)
		free(buf);

	// Return the response to the server
	return true;
}
