// $Id: TCP_Endpoint.cc,v 1.5 2005/01/23 22:29:58 vern Exp $
//
// Copyright (c) 1996, 1997, 1998, 1999, 2000, 2001, 2002, 2003
//      The Regents of the University of California.  All rights reserved.
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that: (1) source code distributions
// retain the above copyright notice and this paragraph in its entirety, (2)
// distributions including binary code include the above copyright notice and
// this paragraph in its entirety in the documentation or other materials
// provided with the distribution, and (3) all advertising materials mentioning
// features or use of this software display the following acknowledgement:
// ``This product includes software developed by the University of California,
// Lawrence Berkeley Laboratory and its contributors.'' Neither the name of
// the University nor the names of its contributors may be used to endorse
// or promote products derived from this software without specific prior
// written permission.
// THIS SOFTWARE IS PROVIDED ``AS IS'' AND WITHOUT ANY EXPRESS OR IMPLIED
// WARRANTIES, INCLUDING, WITHOUT LIMITATION, THE IMPLIED WARRANTIES OF
// MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE.

#include "Net.h"
#include "NetVar.h"
#include "TCP.h"
#include "TCP_Rewriter.h"
#include "Sessions.h"
#include "Event.h"
#include "File.h"
#include "Val.h"


TCP_Endpoint::TCP_Endpoint(TCP_Connection* arg_conn, int is_orig)
	{
	contents_processor = 0;
	prev_state = state = TCP_INACTIVE;
	peer = 0;
	start_time = last_time = 0.0;
	start_seq = last_seq = ack_seq = 0;
	window = 0;
	window_scale = 0;
	window_seq = window_ack_seq = 0;
	contents_start_seq = 0;
	SYN_cnt = FIN_cnt = RST_cnt = 0;
	did_close = 0;
	analyzer = 0;
	contents_file = 0;
	conn = arg_conn;

	src_addr = is_orig ? conn->RespAddr() : conn->OrigAddr();
	dst_addr = is_orig ? conn->OrigAddr() : conn->RespAddr();

#ifdef BROv6
	checksum_base = ones_complement_checksum((void*) src_addr, 16, 0);
	checksum_base = ones_complement_checksum((void*) dst_addr, 16, checksum_base);
#else
	checksum_base = ones_complement_checksum((void*) src_addr, 4, 0);
	checksum_base = ones_complement_checksum((void*) dst_addr, 4, checksum_base);
#endif
	// Note, for IPv6, strictly speaking this field is 32 bits
	// rather than 16 bits.  But because the upper bits are all zero,
	// we get the same checksum either way.  The same applies to
	// later when we add in the data length in ValidChecksum().
	checksum_base += htons(IPPROTO_TCP);
	}

TCP_Endpoint::~TCP_Endpoint()
	{
	delete contents_processor;

	// No need to delete analyzers, that's done when our TCP_Connection
	// is deleted (really, Done()).
	}

void TCP_Endpoint::SetPeer(TCP_Endpoint* p)
	{
	peer = p;
	if ( IsOrig() )
		// Only one Endpoint adds the initial state to the counter.
		sessions->tcp_stats.StateEntered(state, peer->state);
	}

void TCP_Endpoint::AddContentsProcessor(TCP_Contents* arg_contents_processor)
	{
	if ( contents_processor != arg_contents_processor )
		delete contents_processor;
	contents_processor = arg_contents_processor;

	if ( contents_file )
		contents_processor->SetContentsFile(contents_file);
	}

void TCP_Endpoint::AddAnalyzer(TCP_EndpointAnalyzer* a)
	{
	a->AddAnalyzer(analyzer);
	analyzer = a;
	}

int TCP_Endpoint::DataPending() const
	{
	if ( contents_processor )
		return contents_processor->DataPending();
	else
		return 0;
	}

int TCP_Endpoint::HasUndeliveredData() const
	{
	if ( contents_processor )
		return contents_processor->HasUndeliveredData();
	else
		return 0;
	}

void TCP_Endpoint::MatchUndeliveredData()
	{
	if ( contents_processor )
		contents_processor->MatchUndeliveredData();
	}

void TCP_Endpoint::CheckEOF()
	{
	if ( contents_processor )
		contents_processor->CheckEOF();
	}

void TCP_Endpoint::SizeBufferedData(int& waiting_on_hole, int& waiting_on_ack)
	{
	if ( contents_processor )
		contents_processor->SizeBufferedData(waiting_on_hole, waiting_on_ack);
	else
		waiting_on_hole = waiting_on_ack = 0;
	}

int TCP_Endpoint::ValidChecksum(const struct tcphdr* tp, int len) const
	{
	uint32 sum = checksum_base;
	int tcp_len = tp->th_off * 4 + len;

	if ( len % 2 == 1 )
		// Add in pad byte.
		sum += htons(((const u_char*) tp)[tcp_len - 1] << 8);

	sum += htons((unsigned short) tcp_len);	// fill out pseudo header
	sum = ones_complement_checksum((void*) tp, tcp_len, sum);

	return sum == 0xffff;
	}

static inline bool is_handshake(EndpointState state)
	{
	return state == TCP_INACTIVE ||
		state == TCP_SYN_SENT || state == TCP_SYN_ACK_SENT;
	}

void TCP_Endpoint::SetState(EndpointState new_state)
	{
	if ( new_state != state )
		{
		// Activate inactivity timer if this transition finishes the
		// handshake.
		if ( ! is_handshake(new_state) )
			if ( is_handshake(state) && is_handshake(peer->state) )
				conn->SetInactivityTimeout(tcp_inactivity_timeout);

		prev_state = state;
		state = new_state;
		if ( IsOrig() )
			sessions->tcp_stats.ChangeState(prev_state, state,
						peer->state, peer->state);
		else
			sessions->tcp_stats.ChangeState(peer->state, peer->state,
						prev_state, state);
		}
	}

int TCP_Endpoint::Size() const
	{
	int size;

	if ( seq_delta(last_seq, ack_seq) > 0 || ack_seq == start_seq + 1 )
		// Either last_seq corresponds to more data sent than we've
		// seen ack'd, or we haven't seen any data ack'd (in which
		// case we should trust last_seq anyway).  This last test
		// matters for the case in which the connection has
		// transferred > 2 GB of data, in which case we will find
		// seq_delta(last_seq, ack_seq) < 0 even if ack_seq
		// corresponds to no data transferred.
		size = last_seq - start_seq;

	else
		// It could be that ack_seq > last_seq, if we've seen an
		// ack for the connection (say in a FIN) without seeing
		// the corresponding data.
		size = ack_seq - start_seq;

	// Don't include SYN octet in sequence space.  For partial connections
	// (no SYN seen), we're still careful to adjust start_seq as though
	// there was an initial SYN octet, because if we don't then the
	// packet reassembly code gets confused.
	if ( size != 0 )
		--size;

	if ( FIN_cnt > 0 && size != 0 )
		--size;	// don't include FIN octet.

	return size;
	}

int TCP_Endpoint::DataSent(double t, int seq, int len, const u_char* data,
				const IP_Hdr* ip, const struct tcphdr* tp)
	{
	int status = 0;

	if ( contents_processor )
		status = contents_processor->DataSent(t, seq, len, data);

	for ( TCP_EndpointAnalyzer* a = analyzer; a; a = a->NextAnalyzer() )
		status = a->DataSent(t, seq, len, data, ip, tp) || status;

	if ( contents_file && ! contents_processor && 
	     seq + len > contents_start_seq )
		{
		int under_seq = contents_start_seq - seq;
		if ( under_seq > 0 )
			{
			seq += under_seq;
			data += under_seq;
			len -= under_seq;
			}

		// DEBUG_MSG("%d: seek %d, data=%02x len=%d\n", IsOrig(), seq - contents_start_seq, *data, len);
		FILE* f = contents_file->Seek(seq - contents_start_seq);

		if ( fwrite(data, 1, len, f) < unsigned(len) )
			// ### this should really generate an event
			internal_error("contents write failed");
		}

	if ( rule_matcher && ! contents_processor )
		conn->Match(Rule::PAYLOAD, data, len, false, false, IsOrig());

	return status;
	}

void TCP_Endpoint::AckReceived(int seq)
	{
	if ( contents_processor )
		contents_processor->AckReceived(seq);
	}

void TCP_Endpoint::SetContentsFile(BroFile* f)
	{
	contents_file = f;
	contents_start_seq = last_seq - start_seq;

	if ( contents_start_seq == 0 )
		contents_start_seq = 1;	// skip SYN

	if ( contents_processor )
		contents_processor->SetContentsFile(contents_file);
	}

IMPLEMENT_SERIAL(TCP_Endpoint, SER_TCP_ENDPOINT);

bool TCP_Endpoint::Serialize(SerialInfo* info) const
	{
	return SerialObj::Serialize(info);
	}

TCP_Endpoint* TCP_Endpoint::Unserialize(UnserialInfo* info)
	{
	return (TCP_Endpoint*) SerialObj::Unserialize(info, SER_TCP_ENDPOINT);
	}

bool TCP_Endpoint::DoSerialize(SerialInfo* info) const
	{
	DO_SERIALIZE(SER_TCP_ENDPOINT, SerialObj);

	for ( int i = 0; i < NUM_ADDR_WORDS; ++i )
		if ( ! (SERIALIZE(src_addr[i]) && SERIALIZE(dst_addr[i])) )
			return false;

	SERIALIZE_OPTIONAL(contents_processor);
	SERIALIZE_OPTIONAL(analyzer);
	SERIALIZE_OPTIONAL(contents_file);

	return peer->Serialize(info) &&
		conn->Serialize(info) &&
		SERIALIZE(int(state)) &&
		SERIALIZE(int(prev_state)) &&
		SERIALIZE(checksum_base) &&
		SERIALIZE(start_time) &&
		SERIALIZE(last_time) &&
		SERIALIZE(start_seq) &&
		SERIALIZE(last_seq) &&
		SERIALIZE(ack_seq) &&
		SERIALIZE(window) &&
		SERIALIZE(window_scale) &&
		SERIALIZE(window_ack_seq) &&
		SERIALIZE(window_seq) &&
		SERIALIZE(contents_start_seq) &&
		SERIALIZE(FIN_seq) &&
		SERIALIZE(SYN_cnt) &&
		SERIALIZE(FIN_cnt) &&
		SERIALIZE(RST_cnt) &&
		SERIALIZE(did_close);
	}

bool TCP_Endpoint::DoUnserialize(UnserialInfo* info)
	{
	DO_UNSERIALIZE(SerialObj);

	uint32* src_addr_tmp = new uint32[NUM_ADDR_WORDS];
	uint32* dst_addr_tmp = new uint32[NUM_ADDR_WORDS];

	for ( int i = 0; i < NUM_ADDR_WORDS; ++i )
		if ( ! UNSERIALIZE(&src_addr_tmp[i]) ||
		     ! UNSERIALIZE(&dst_addr_tmp[i]) )
			return false;

	src_addr = src_addr_tmp;
	dst_addr = dst_addr_tmp;

	UNSERIALIZE_OPTIONAL(contents_processor, TCP_Contents::Unserialize(info));
	UNSERIALIZE_OPTIONAL(analyzer, TCP_EndpointAnalyzer::Unserialize(info));
	UNSERIALIZE_OPTIONAL(contents_file, BroFile::Unserialize(info));

	peer = TCP_Endpoint::Unserialize(info);
	if ( ! peer )
		return false;

	conn = (TCP_Connection*) TCP_Connection::Unserialize(info);
	if ( ! conn )
		return false;

	int tmp;
	if ( ! UNSERIALIZE(&tmp) )
		return false;
	state = (EndpointState) tmp;

	if ( ! UNSERIALIZE(&tmp) )
		return false;

	prev_state = (EndpointState) tmp;

	return	UNSERIALIZE(&checksum_base) &&
		UNSERIALIZE(&start_time) &&
		UNSERIALIZE(&last_time) &&
		UNSERIALIZE(&start_seq) &&
		UNSERIALIZE(&last_seq) &&
		UNSERIALIZE(&ack_seq) &&
		UNSERIALIZE(&window) &&
		UNSERIALIZE(&window_scale) &&
		UNSERIALIZE(&window_ack_seq) &&
		UNSERIALIZE(&window_seq) &&
		UNSERIALIZE(&contents_start_seq) &&
		UNSERIALIZE(&FIN_seq) &&
		UNSERIALIZE(&SYN_cnt) &&
		UNSERIALIZE(&FIN_cnt) &&
		UNSERIALIZE(&RST_cnt) &&
		UNSERIALIZE(&did_close);
	}

TCP_EndpointAnalyzer::TCP_EndpointAnalyzer(TCP_Endpoint* e)
	{
	endp = e;
	next = 0;
	e->AddAnalyzer(this);
	Ref(Conn());
	}

TCP_EndpointAnalyzer::~TCP_EndpointAnalyzer()
	{
	Unref(Conn());
	}

void TCP_EndpointAnalyzer::Describe(ODesc* d) const
	{
	d->Add(Endpoint()->IsOrig() ? "analyzer(orig)" : "analyzer(resp)");
	Conn()->Describe(d);
	}

int TCP_EndpointAnalyzer::DataSent(double /* t */, int /* seq */,
				int /* len */, const u_char* /* data */,
				const IP_Hdr* /* ip */,
				const struct tcphdr* /* tp */)
	{
	return 0;
	}

IMPLEMENT_SERIAL(TCP_EndpointAnalyzer, SER_TCP_ENDPOINT_ANALYZER);

bool TCP_EndpointAnalyzer::Serialize(SerialInfo* info) const
	{
	return SerialObj::Serialize(info);
	}

TCP_EndpointAnalyzer* TCP_EndpointAnalyzer::Unserialize(UnserialInfo* info)
	{
	return (TCP_EndpointAnalyzer*) SerialObj::Unserialize(info, SER_TCP_ENDPOINT_ANALYZER);
	}

bool TCP_EndpointAnalyzer::DoSerialize(SerialInfo* info) const
	{
	DO_SERIALIZE(SER_TCP_ENDPOINT_ANALYZER, BroObj);

	SERIALIZE_OPTIONAL(next);
	return endp->Serialize(info);
	}

bool TCP_EndpointAnalyzer::DoUnserialize(UnserialInfo* info)
	{
	DO_UNSERIALIZE(BroObj);

	UNSERIALIZE_OPTIONAL(next, TCP_EndpointAnalyzer::Unserialize(info));
	endp = TCP_Endpoint::Unserialize(info);
	return endp != 0;
	}

TCP_EndpointStats::TCP_EndpointStats(TCP_Endpoint* e)
: TCP_EndpointAnalyzer(e)
	{
	num_pkts = 0;
	num_rxmit = 0;
	num_rxmit_bytes = 0;
	num_in_order = 0;
	num_OO = 0;
	num_repl = 0;
	max_top_seq = 0;
	last_id = 0;
	endian_type = ENDIAN_UNKNOWN;
	}

int endian_flip(int n)
	{
	return ((n & 0xff) << 8) | ((n & 0xff00) >> 8);
	}

int TCP_EndpointStats::DataSent(double /* t */, int seq, int len,
			const u_char* /* data */,
			const IP_Hdr* ip, const struct tcphdr* /* tp */)
	{
	if ( ++num_pkts == 1 )
		{ // First packet.
		last_id = ntohs(ip->ID4());
		return 0;
		}

	int id = ntohs(ip->ID4());

	if ( id == last_id )
		{
		++num_repl;
		return 0;
		}

	short id_delta = id - last_id;
	short id_endian_delta = endian_flip(id) - endian_flip(last_id);

	int abs_id_delta = id_delta > 0 ? id_delta : -id_delta;
	int abs_id_endian_delta = id_endian_delta > 0 ? id_endian_delta : -id_endian_delta;

	int final_id_delta;

	if ( abs_id_delta < abs_id_endian_delta )
		{ // Consistent with big-endian.
		if ( endian_type == ENDIAN_UNKNOWN )
			endian_type = ENDIAN_BIG;
		else if ( endian_type == ENDIAN_BIG )
			;
		else
			endian_type = ENDIAN_CONFUSED;

		final_id_delta = id_delta;
		}
	else
		{ // Consistent with little-endian.
		if ( endian_type == ENDIAN_UNKNOWN )
			endian_type = ENDIAN_LITTLE;
		else if ( endian_type == ENDIAN_LITTLE )
			;
		else
			endian_type = ENDIAN_CONFUSED;

		final_id_delta = id_endian_delta;
		}

	if ( final_id_delta < 0 && final_id_delta > -256 )
		{
		++num_OO;
		return 0;
		}

	last_id = id;

	++num_in_order;

	int top_seq = seq + len;
	if ( top_seq <= max_top_seq )
		{
		++num_rxmit;
		num_rxmit_bytes += len;
		}
	else
		max_top_seq = top_seq;

	return 0;
	}

RecordVal* TCP_EndpointStats::BuildStats()
	{
	RecordVal* stats = new RecordVal(endpoint_stats);

	stats->Assign(0, new Val(num_pkts,TYPE_COUNT));
	stats->Assign(1, new Val(num_rxmit,TYPE_COUNT));
	stats->Assign(2, new Val(num_rxmit_bytes,TYPE_COUNT));
	stats->Assign(3, new Val(num_in_order,TYPE_COUNT));
	stats->Assign(4, new Val(num_OO,TYPE_COUNT));
	stats->Assign(5, new Val(num_repl,TYPE_COUNT));
	stats->Assign(6, new Val(endian_type,TYPE_COUNT));

	return stats;
	}
