// $Id: RPC.cc,v 1.5 2005/04/30 00:03:13 vern Exp $
//
// Copyright (c) 1996, 1997, 1998, 1999, 2000, 2001, 2002
//      The Regents of the University of California.  All rights reserved.
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that: (1) source code distributions
// retain the above copyright notice and this paragraph in its entirety, (2)
// distributions including binary code include the above copyright notice and
// this paragraph in its entirety in the documentation or other materials
// provided with the distribution, and (3) all advertising materials mentioning
// features or use of this software display the following acknowledgement:
// ``This product includes software developed by the University of California,
// Lawrence Berkeley Laboratory and its contributors.'' Neither the name of
// the University nor the names of its contributors may be used to endorse
// or promote products derived from this software without specific prior
// written permission.
// THIS SOFTWARE IS PROVIDED ``AS IS'' AND WITHOUT ANY EXPRESS OR IMPLIED
// WARRANTIES, INCLUDING, WITHOUT LIMITATION, THE IMPLIED WARRANTIES OF
// MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE.

#include "config.h"

#include <stdlib.h>

#include "NetVar.h"
#include "XDR.h"
#include "RPC.h"
#include "Sessions.h"

#define MAX_RPC_LEN 8192

// The following correspond to the different RPC status values defined
// in bro.init.
#define BRO_RPC_TIMEOUT 6
#define BRO_RPC_AUTH_ERROR 7
#define BRO_RPC_UNKNOWN_ERROR 8

RPC_CallInfo::RPC_CallInfo(uint32 arg_xid, const u_char*& buf, int& n)
	{
	xid = arg_xid;

	call_n = n;
	call_buf = new u_char[call_n];
	memcpy((void*) call_buf, (const void*) buf, call_n);

	rpc_version = extract_XDR_uint32(buf, n);
	prog = extract_XDR_uint32(buf, n);
	vers = extract_XDR_uint32(buf, n);
	proc = extract_XDR_uint32(buf, n);
	cred_flavor = skip_XDR_opaque_auth(buf, n);
	verf_flavor = skip_XDR_opaque_auth(buf, n);

	header_len = call_n - n;

	valid_call = false;

	v = 0;
	}

RPC_CallInfo::~RPC_CallInfo()
	{
	delete [] call_buf;
	Unref(v);
	}

int RPC_CallInfo::CompareRexmit(const u_char* buf, int n) const
	{
	if ( n != call_n )
		return 0;

	return memcmp((const void*) call_buf, (const void*) buf, call_n) == 0;
	}


void rpc_callinfo_delete_func(void* v)
	{
	delete (RPC_CallInfo*) v;
	}

RPC_Interpreter::RPC_Interpreter(Connection* arg_conn)
	{
	conn = arg_conn;
	calls.SetDeleteFunc(rpc_callinfo_delete_func);
	}

RPC_Interpreter::~RPC_Interpreter()
	{
	}

int RPC_Interpreter::DeliverRPC(const u_char* buf, int n, int is_orig)
	{
	uint32 xid = extract_XDR_uint32(buf, n);
	uint32 msg_type = extract_XDR_uint32(buf, n);

	if ( ! buf )
		return 0;

	HashKey h(xid);
	RPC_CallInfo* call = calls.Lookup(&h);

	if ( msg_type == RPC_CALL )
		{
		if ( ! is_orig )
			Weird("responder_RPC_call");

		if ( call )
			{
			if ( ! call->CompareRexmit(buf, n) )
				Weird("RPC_rexmit_inconsistency");

			if ( call->HeaderLen() > n )
				{
				Weird("RPC_underflow");
				return 0;
				}

			n -= call->HeaderLen();
			buf += call->HeaderLen();
			}

		else
			{
			call = new RPC_CallInfo(xid, buf, n);
			if ( ! buf )
				{
				delete call;
				return 0;
				}

			calls.Insert(&h, call);
			}

		if ( RPC_BuildCall(call, buf, n) )
			call->SetValidCall();
		else
			{
			Weird("bad_RPC");
			return 0;
			}
		}

	else if ( msg_type == RPC_REPLY )
		{
		if ( is_orig )
			Weird("originator_RPC_reply");

		uint32 reply_stat = extract_XDR_uint32(buf, n);
		if ( ! buf )
			return 0;

		uint32 status = BRO_RPC_UNKNOWN_ERROR;

		if ( reply_stat == RPC_MSG_ACCEPTED )
			{
			(void) skip_XDR_opaque_auth(buf, n);
			status = extract_XDR_uint32(buf, n);
			if ( ! buf )
				return 0;

			if ( status == RPC_PROG_MISMATCH )
				{
				(void) extract_XDR_uint32(buf, n);
				(void) extract_XDR_uint32(buf, n);

				if ( ! buf )
					return 0;
				}
			}

		else if ( reply_stat == RPC_MSG_DENIED )
			{
			uint32 reject_stat = extract_XDR_uint32(buf, n);
			if ( ! buf )
				return 0;

			if ( reject_stat == RPC_MISMATCH )
				{
				status = RPC_MISMATCH;

				(void) extract_XDR_uint32(buf, n);
				(void) extract_XDR_uint32(buf, n);

				if ( ! buf )
					return 0;
				}

			else if ( reject_stat == RPC_AUTH_ERROR )
				{
				status = BRO_RPC_AUTH_ERROR;

				(void) extract_XDR_uint32(buf, n);
				if ( ! buf )
					return 0;
				}

			else
				{
				status = BRO_RPC_UNKNOWN_ERROR;
				Weird("bad_RPC");
				}
			}

		else
			Weird("bad_RPC");

		if ( call )
			{
			int success = status == RPC_SUCCESS;

			if ( ! call->IsValidCall() )
				{
				if ( success )
					Weird("successful_RPC_reply_to_invalid_request");
				// We can't process this further, even if
				// it was successful, because the call
				// info won't be fully set up.
				}

			else
				{
				EventHandlerPtr event;
				Val* reply;
				if ( ! RPC_BuildReply(call, success, buf,
							n, event, reply) )
					Weird("bad_RPC");
				else
					Event(event, call->TakeRequestVal(),
						status, reply);
				}

			delete calls.RemoveEntry(&h);
			}
		else
			{
			Weird("unpaired_RPC_response");
			n = 0;
			}
		}

	else
		Weird("bad_RPC");

	if ( n > 0 )
		{
		// If it's just padded with zeroes, don't complain.
		for ( ; n > 0; --n, ++buf )
			if ( *buf != 0 )
				break;

		if ( n > 0 )
			Weird("excess_RPC");
		}

	else if ( n < 0 )
		internal_error("RPC underflow");

	return 1;
	}

void RPC_Interpreter::Timeout()
	{
	IterCookie* cookie = calls.InitForIteration();
	RPC_CallInfo* c;

	while ( (c = calls.NextEntry(cookie)) )
		if ( c->IsValidCall() )
			{
			const u_char* buf;
			int n = 0;
			EventHandlerPtr event;
			Val* reply;
			if ( ! RPC_BuildReply(c, 0, buf, n, event, reply) )
				Weird("bad_RPC");
			else
				Event(event, c->TakeRequestVal(),
					BRO_RPC_TIMEOUT, reply);
			}
	}

void RPC_Interpreter::Weird(const char* msg)
	{
	conn->Weird(msg);
	}


TCP_Contents_RPC::TCP_Contents_RPC(RPC_Interpreter* arg_interp,
					TCP_Endpoint* arg_endp)
: TCP_Contents(arg_endp)
	{
	interp = arg_interp;
	Init();
	}

void TCP_Contents_RPC::Init()
	{
	buf_len = 4;

	// For record marker:
	msg_buf = new u_char[buf_len];

	buf_n = 0;
	last_frag = 0;
	state = RPC_RECORD_MARKER;
	}

TCP_Contents_RPC::~TCP_Contents_RPC()
	{
	delete [] msg_buf;
	}

void TCP_Contents_RPC::Deliver(int seq, int len, u_char* data)
	{
	if ( state == RPC_COMPLETE )
		Init();

	int n;
	for ( n = 0; buf_n < buf_len && n < len; ++n )
		msg_buf[buf_n++] = data[n];

	if ( buf_n < buf_len )
		// Haven't filled up the message buffer yet, no more to do.
		return;

	switch ( state ) {
	case RPC_RECORD_MARKER:
		{ // Have the whole record marker.
		int prev_frag_len = buf_len - 4;
		const u_char* buf = &msg_buf[prev_frag_len];
		int n = 4;

		uint32 marker = extract_XDR_uint32(buf, n);
		if ( ! buf )
			internal_error("inconsistent RPC record marker extraction");

		if ( prev_frag_len > 0 && last_frag )
			internal_error("last_frag set but more fragments");

		last_frag = (marker & 0x80000000) != 0;

		marker &= 0x7fffffff;

		if ( prev_frag_len > 0 )
			// We're adding another fragment.
			marker += prev_frag_len;

		// Fragment length is now given by marker.  Sanity-check.
		if ( marker > MAX_RPC_LEN )
			{
			Conn()->Weird("excessive_RPC_len");
			marker = MAX_RPC_LEN;
			}

		// The new size is either the full record size (if this
		// is the last fragment), or that plus 4 more bytes for
		// the next fragment header.
		int new_size = last_frag ? marker : marker + 4;

		u_char* tmp = new u_char[new_size];
		int msg_len = buf_len < marker ? buf_len : marker;

		for ( int i = 0; i < msg_len; ++i )
			tmp[i] = msg_buf[i];

		delete [] msg_buf;
		msg_buf = tmp;

		buf_len = marker;	// we only want to fill to here
		buf_n = prev_frag_len;	// overwrite this fragment's header

		state = RPC_MESSAGE_BUFFER;
		}
		break;

	case RPC_MESSAGE_BUFFER:
		{ // Have the whole fragment.
		if ( ! last_frag )
			{
			// We earlier made sure to leave an extra 4 bytes
			// at the end of the buffer - use them now for
			// the new fragment header.
			buf_len += 4;
			state = RPC_RECORD_MARKER;
			break;
			}

		if ( ! interp->DeliverRPC(msg_buf, buf_n, IsOrig()) )
			Conn()->Weird("partial_RPC");

		state = RPC_COMPLETE;
		delete [] msg_buf;
		msg_buf = 0;
		}
		break;

	case RPC_COMPLETE:
		internal_error("RPC state inconsistency");
	}

	if ( n < len )
		// More data to munch on.
		Deliver(seq, len - n, data + n);
	}

UDP_RPC::UDP_RPC(RPC_Interpreter* arg_interp, NetSessions* s, HashKey* k,
		double t, const ConnID* id, const struct udphdr* up)
: UDP_Connection(s, k, t, id, up)
	{
	interp = arg_interp;
	ADD_TIMER(&UDP_RPC::ExpireTimer, t + rpc_timeout, 1, TIMER_RPC_EXPIRE);
	}

UDP_RPC::~UDP_RPC()
	{
	delete interp;
	}

void UDP_RPC::Done()
	{
	interp->Timeout();
	UDP_Connection::Done();
	}

int UDP_RPC::Request(double t, const u_char* data, int len)
	{
	if ( ! interp->DeliverRPC(data, len, 1) )
		Weird("bad_RPC");
	return 1;
	}

int UDP_RPC::Reply(double t, const u_char* data, int len)
	{
	if ( ! interp->DeliverRPC(data, len, 0) )
		Weird("bad_RPC");
	return 1;
	}

void UDP_RPC::ExpireTimer(double /* t */)
	{
	Event(connection_timeout);
	sessions->Remove(this);
	}
