// $Id: Conn.h,v 1.7 2005/03/17 09:22:17 vern Exp $
//
// Copyright (c) 1996, 1997, 1998, 1999, 2000, 2001, 2002
//      The Regents of the University of California.  All rights reserved.
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that: (1) source code distributions
// retain the above copyright notice and this paragraph in its entirety, (2)
// distributions including binary code include the above copyright notice and
// this paragraph in its entirety in the documentation or other materials
// provided with the distribution, and (3) all advertising materials mentioning
// features or use of this software display the following acknowledgement:
// ``This product includes software developed by the University of California,
// Lawrence Berkeley Laboratory and its contributors.'' Neither the name of
// the University nor the names of its contributors may be used to endorse
// or promote products derived from this software without specific prior
// written permission.
// THIS SOFTWARE IS PROVIDED ``AS IS'' AND WITHOUT ANY EXPRESS OR IMPLIED
// WARRANTIES, INCLUDING, WITHOUT LIMITATION, THE IMPLIED WARRANTIES OF
// MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE.

#ifndef conn_h
#define conn_h

#include <sys/types.h>

#include "Dict.h"
#include "Val.h"
#include "Timer.h"
#include "RuleMatcher.h"
#include "Serializer.h"
#include "PersistenceSerializer.h"

class Connection;
class ConnectionTimer;
class NetSessions;
class LoginConn;
class RuleHdrTest;
class Specific_RE_Matcher;

typedef enum {
	NUL_IN_LINE,
	SINGULAR_CR,
	SINGULAR_LF,
	NUM_EVENTS_TO_FLAG,
} ConnEventToFlag;

typedef void (Connection::*timer_func)(double t);

struct ConnID {
	const uint32* src_addr;
	const uint32* dst_addr;
	uint32 src_port;
	uint32 dst_port;
	bool is_one_way;	// if true, don't canonicalize

	// Returns a ListVal suitable for looking up a connection in
	// a hash table.  addr/ports are expected to be in network order.
	// Unless is_one_way is true, the lookup sorts src and dst,
	// so src_addr/src_port and dst_addr/dst_port just have to
	// reflect the two different sides of the connection,
	// neither has to be the particular source/destination
	// or originator/responder.
	HashKey* ConnID::BuildConnKey() const;

	// The structure used internally for hashing.
	struct Key {
		uint32 ip1[NUM_ADDR_WORDS];
		uint32 ip2[NUM_ADDR_WORDS];
		uint16 port1;
		uint16 port2;
	};
};

static inline int addr_port_canon_lt(const uint32* a1, uint32 p1,
					const uint32* a2, uint32 p2)
	{
#ifdef BROv6
	// Because it's a canonical ordering, not a strict ordering,
	// we can choose to give more weight to the least significant
	// word than to the most significant word.  This matters
	// because for the common case of IPv4 addresses embedded in
	// a IPv6 address, the top three words are identical, so we can
	// save a few cycles by first testing the bottom word.
	return a1[3] < a2[3] ||
		(a1[3] == a2[3] &&
		 (a1[2] < a2[2] ||
		  (a1[2] == a2[2] &&
		   (a1[1] < a2[1] ||
		    (a1[1] == a2[1] &&
		     a1[0] < a2[0] ||
		     (a1[0] == a2[0] &&
		      p1 < p2))))));
#else
	return *a1 < *a2 || (*a1 == *a2 && p1 < p2);
#endif
	}

class Connection : public BroObj {
public:
	Connection(NetSessions* s, HashKey* k, double t, const ConnID* id);
	virtual ~Connection();

	// Invoked when connection is about to be removed.  Use Ref(this)
	// inside Done to keep the connection object around (though it'll
	// no longer be accessible from the dictionary of active
	// connections).
	virtual void Done();

	// Process the connection's next packet.  "data" points just
	// beyond the IP header.  It's updated to point just beyond
	// the transport header (or whatever should be saved, if we
	// decide not to save the full packet contents).
	//
	// If record_packet is true, the packet should be recorded.
	// If record_content is true, then its entire contents should
	// be recorded, otherwise just up through the transport header.
	// Both are assumed set to true when called.
	virtual void NextPacket(double t, int is_orig,
			const IP_Hdr* ip, int len, int caplen,
			const u_char*& data,
			int& record_packet, int& record_content,
			// arguments for reproducing packets
			const struct pcap_pkthdr* hdr,
			const u_char* const pkt,
			int hdr_size) = 0;

	HashKey* Key() const			{ return key; }
	void ClearKey()				{ key = 0; }

	double StartTime() const		{ return start_time; }
	double LastTime() const			{ return last_time; }

	const uint32* OrigAddr() const		{ return orig_addr; }
	const uint32* RespAddr() const		{ return resp_addr; }

	uint32 OrigPort() const			{ return orig_port; }
	uint32 RespPort() const			{ return resp_port; }

	virtual TransportProto ConnTransport() const = 0;

	// If we are rewriting the trace of the connection, then we do
	// not record original packets.
	virtual int RewritingTrace()	{ return 0; }

	// True if we should record subsequent packets (either headers or
	// in their entirety, depending on record_contents).  We still
	// record subsequent SYN/FIN/RST, regardless of how this is set.
	int RecordPackets() const		{ return record_packets; }
	void SetRecordPackets(int do_record)	{ record_packets = do_record; }

	// True if we should record full packets for this connection,
	// false if we should just record headers.
	int RecordContents() const		{ return record_contents; }
	void SetRecordContents(int do_record)	{ record_contents = do_record; }

	virtual void SetContentsFile(unsigned int direction, BroFile* f);
	virtual BroFile* GetContentsFile(unsigned int direction) const;

	// If true, skip processing of remainder of connection.  Note
	// that this does not in itself imply that record_packets is false;
	// we might want instead to process the connection off-line.
	void SetSkip(int do_skip)		{ skip = do_skip; }
	int Skipping() const			{ return skip; }

	// Arrange for the connection to expire after the given amount of time.
	void SetLifetime(double lifetime);

	// Returns true if the packet reflects a reuse of this
	// connection (i.e., not a continuation but the beginning of
	// a new connection).
	virtual int IsReuse(double t, const u_char* pkt) = 0;

	// Get/set the inactivity timeout for this connection.
	void SetInactivityTimeout(double timeout);
	double InactivityTimeout() const	{ return inactivity_timeout; }

	// Activate connection_status_update timer.
	void EnableStatusUpdateTimer();

	RecordVal* BuildConnVal();

	LoginConn* AsLoginConn()		{ return login_conn; }

	// Tries really hard to extract a program name and a version.
	Val* BuildVersionVal(const char* s, int len);

	// Raises a software_version_found event based on the
	// given string (returns false if it's not parseable).
	int VersionFoundEvent(const uint32* addr, const char* s, int len);

	// Raises a software_unparsed_version_found event.
	int UnparsedVersionFoundEvent(const uint32* addr,
					const char* full_descr, int len);
	
	void Event(EventHandlerPtr f, const char* name = 0);
	void Event(EventHandlerPtr f, Val* v1, Val* v2 = 0);
	void ConnectionEvent(EventHandlerPtr f, val_list* vl);

	virtual void Weird(const char* name);
	virtual void Weird(const char* name, const char* addl);
	virtual void Weird(const char* name, int addl_len, const char* addl);

	// Cancel all associated timers.
	void CancelTimers();

	// Interface to the rule matching subsystem.
	void InitEndpointMatcher(const IP_Hdr* ip, int caplen, int from_orig);

	// bol/eol should be set to false for type Rule::PAYLOAD; they're
	// deduced automatically.
	void Match(Rule::PatternType type, const u_char* data,
			int data_len, bool bol, bool eol, int from_orig)
		{
		rule_matcher->Match(from_orig ? orig_match_state : resp_match_state,
			type, data, data_len, bol, eol);
		}

	void FinishEndpointMatcher();

	void ClearMatchState(bool orig)
		{
		if ( orig_match_state )
			rule_matcher->ClearEndpointState(orig_match_state);
		if ( resp_match_state )
			rule_matcher->ClearEndpointState(resp_match_state);
		}

	inline int FlagEvent(ConnEventToFlag e)
		{
		if ( e >= 0 && e < NUM_EVENTS_TO_FLAG )
			{
			if ( suppress_event & (1 << e) )
				return 0;
			suppress_event |= 1 << e;
			}
		return 1;
		}

	void MakePersistent()
		{
		persistent = 1;
		persistence_serializer->Register(this);
		}

	bool IsPersistent()	{ return persistent; }

	bool Serialize(SerialInfo* info) const;
	static Connection* Unserialize(UnserialInfo* info);

	DECLARE_ABSTRACT_SERIAL(Connection);

	// Statistics.

	// Just a lower bound.
	virtual unsigned int MemoryAllocation() const;
	unsigned int MemoryAllocationConnVal() const;

	static unsigned int TotalConnections()
		{ return total_connections; }
	static unsigned int CurrentConnections()
		{ return current_connections; }

protected:
	Connection()	{ }

	virtual void UpdateEndpointVal(RecordVal* endp, int is_orig) = 0;

	// Add the given timer to expire in delta_t seconds.  If do_expire
	// is true, then the timer is also evaluated when Bro terminates,
	// otherwise not.
	//
	// Note: the name is all uppercase because this used to be a macro.

	void AddTimer(timer_func timer, double delta_t, int do_expire,
			TimerType type);

	void RemoveTimer(Timer* t);

	// Allow other classes to access pointers to these:
	friend class ConnectionTimer;
	friend class TCP_Connection;	// required by gcc 3.4

	void DeleteTimer(double t);
	void InactivityTimer(double t);
	void StatusUpdateTimer(double t);

	NetSessions* sessions;
	HashKey* key;
	timer_list timers;
	uint32 orig_addr[NUM_ADDR_WORDS];	// in network order
	uint32 resp_addr[NUM_ADDR_WORDS];	// in network order
	uint32 orig_port, resp_port;	// in network order
	double start_time, last_time;
	double inactivity_timeout;
	RecordVal* conn_val;
	RecordVal* orig_endp;
	RecordVal* resp_endp;
	LoginConn* login_conn;	// either nil, or this
	int suppress_event;	// suppress certain events to once per conn.

	RuleEndpointState* orig_match_state;
	RuleEndpointState* resp_match_state;

	unsigned int installed_status_timer:1;
	unsigned int timers_canceled:1;
	unsigned int is_active:1;
	unsigned int skip:1;
	unsigned int weird:1;
	unsigned int finished:1;
	unsigned int record_packets:1, record_contents:1;
	unsigned int persistent:1;

	// Count number of connections.
	static unsigned int total_connections;
	static unsigned int current_connections;
};

class ConnectionTimer : public Timer {
public:
	ConnectionTimer(Connection* arg_conn, timer_func arg_timer,
			double arg_t, int arg_do_expire, TimerType arg_type)
		: Timer(arg_t, arg_type)
		{ Init(arg_conn, arg_timer, arg_do_expire); }
	virtual ~ConnectionTimer();

	void Dispatch(double t, int is_expire);

protected:
	ConnectionTimer()	{}

	void Init(Connection* conn, timer_func timer, int do_expire);

	DECLARE_SERIAL(ConnectionTimer);

	Connection* conn;
	timer_func timer;
	int do_expire;
};

#define ADD_TIMER(timer, t, do_expire, type) \
	AddTimer(timer_func(timer), (t), (do_expire), (type))

#endif
