// $Id: SMB.cc,v 1.2 2004/08/31 21:22:47 vern Exp $
//
// Copyright (c) 1996-2004
//      The Regents of the University of California.  All rights reserved.
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that: (1) source code distributions
// retain the above copyright notice and this paragraph in its entirety, (2)
// distributions including binary code include the above copyright notice and
// this paragraph in its entirety in the documentation or other materials
// provided with the distribution, and (3) all advertising materials mentioning
// features or use of this software display the following acknowledgement:
// ``This product includes software developed by the University of California,
// Lawrence Berkeley Laboratory and its contributors.'' Neither the name of
// the University nor the names of its contributors may be used to endorse
// or promote products derived from this software without specific prior
// written permission.
// THIS SOFTWARE IS PROVIDED ``AS IS'' AND WITHOUT ANY EXPRESS OR IMPLIED
// WARRANTIES, INCLUDING, WITHOUT LIMITATION, THE IMPLIED WARRANTIES OF
// MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE.

#include "NetVar.h"
#include "SMB.h"
#include "Val.h"

enum SMB_Command {
#define SMB_COMMAND(name, value) name = value,
#include "SMB_COM.def"
#undef SMB_COMMAND
};

const char* SMB_command_name[256];
StringVal* SMB_command_str[256];

static void init_SMB_command_name()
	{
	static int initialized = 0;
	if ( initialized )
		return;

	initialized = 1;

	for ( int i = 0; i < 256; ++i )
		{
		SMB_command_name[i] = "<unknown>";
		SMB_command_str[i] = 0;
		}

#define SMB_COMMAND(name, value) SMB_command_name[value] = #name;
#include "SMB_COM.def"
	}

StringVal* get_SMB_command_str(int cmd)
	{
	if ( ! SMB_command_str[cmd] )
		SMB_command_str[cmd] = new StringVal(SMB_command_name[cmd]);

	return SMB_command_str[cmd];
	}

SMB_Session::SMB_Session(Connection* c)
	{
	conn = c;
	dce_rpc_session = 0;
	init_SMB_command_name();
	IPC_pipe = IPC_NONE;
	is_IPC = 0;
	transaction_name = 0;
	transaction_subcmd = 0;
	}

SMB_Session::~SMB_Session()
	{
	Unref(transaction_name);
	delete dce_rpc_session;
	}

void SMB_Session::Deliver(int is_orig, int len, const u_char* msg)
	{
	if ( len == 0 )
		return;

	if ( len < 35 )
		{
		Weird("truncated_SMB_header");
		return;
		}

	if ( strncmp((const char*) msg, "\xffSMB", 4) )
		{
		Weird("non-SMB message");
		return;
		}

	u_char cmd = msg[4];

	if ( smb_message )
		{
		val_list* vl = new val_list;
		StringVal* cmd_str = get_SMB_command_str(cmd);
		Ref(cmd_str);

		vl->append(conn->BuildConnVal());
		vl->append(new Val(is_orig, TYPE_BOOL));
		vl->append(cmd_str);
		vl->append(new StringVal(len, (const char*) msg));

		conn->ConnectionEvent(smb_message, vl);
		}

	switch ( cmd ) {
	case SMB_COM_TREE_CONNECT_ANDX:
		if ( is_orig )
			ParseTreeConnectAndx(len, msg);
		break;

	case SMB_COM_NT_CREATE_ANDX:
		if ( is_orig )
			ParseNtCreateAndx(len, msg);
		break;

	case SMB_COM_TRANSACTION:
	case SMB_COM_TRANSACTION2:
	case SMB_COM_TRANSACTION_SECONDARY:
	case SMB_COM_TRANSACTION2_SECONDARY:
		ParseTransaction(is_orig, cmd, len, msg);
		break;

	case SMB_COM_READ_ANDX:
		ParseReadAndx(is_orig, len, msg);
		break;

	case SMB_COM_WRITE_ANDX:
		ParseWriteAndx(is_orig, len, msg);
		break;
	}
	}

void SMB_Session::Weird(const char* msg)
	{
	conn->Weird(msg);
	}

int SMB_Session::GetWord(int len, const u_char* msg, int off, uint16* w)
	{
	if ( off + 1 >= len )
		{
		Weird("truncated_SMB_message__GetWord");
		return -1;
		}

	*w = msg[off] + (msg[off+1] << 8);

	return 0;
	}

int SMB_Session::Get_SMB_DataPortion(int len, const u_char* msg,
					uint16* byte_count, const u_char** data)
	{
	if ( len <= 32 )
		{
		Weird("truncated_SMB_message__Get_SMB_DataPortion_1");
		return -1;
		}

	uint8 word_count = msg[32];
	uint16 off = 33 + word_count * 2;

	if ( GetWord(len, msg, off, byte_count) == -1 )
		return -1;

	if ( off + 2 + *byte_count > len )
		{
		Weird("truncated_SMB_message__Get_SMB_DataPortion_2");
		return -1;
		}

	*data = msg + off + 2;

	return 0;
	}

// Extract a NUL-terminated string from [data, data+len-1]. The
// input can be in Unicode (little endian), and the returned string
// will be in ASCII.

BroString* SMB_Session::ExtractString(const u_char* msg, int& len,
					const u_char*& data)
	{
	if ( len == 0 )
		return new BroString("");

	int unicode = (msg[11] & 0x80) >> 7;
	if ( unicode && (data - msg) % 2 == 1 )
		// A unicode string must be word-aligned.
		--len; ++data;

	int slen = 1 + (unicode ? (len + 1) / 2 : len);
	u_char* s = new u_char[slen];

	int k;
	for ( k = 0; k < slen; ++k )
		{
		s[k] = *data;
		--len; ++data;

		if ( unicode )
			{
			// TODO: should verify whether *data == '\0'
			--len; ++data;
			}

		if ( s[k] == '\0' )
			break;

		if ( len <= 0 )
			{
			++k;
			s[k] = '\0';
			break;
			}
		}

	return new BroString(1, s, k);
	}

int SMB_Session::LooksLikeRPC(int len, const u_char* msg,
				uint16 data_count, uint16 data_off)
	{
	uint16 frag_length;

	if ( is_IPC && IPC_pipe != IPC_NONE )
		{ // check whether data part looks like an RPC message
		if ( data_count >= 16 &&
		     msg[data_off] == 5 && msg[data_off+1] == 0 &&
		     GetWord(len, msg, data_off+8, &frag_length) != -1 &&
		     frag_length == data_count )
			return 1;
		}

	return 0;
	}

static int lookup_IPC_name(BroString* name)
	{
	static const char* IPC_pipe_names[] = {
		"\\locator", "\\epmapper", "\\samr", 0
	};

	for ( int i = 0; IPC_pipe_names[i]; ++i )
		{
		if ( size_t(name->Len()) == strlen(IPC_pipe_names[i]) &&
		     strncmp((const char*) name->Bytes(),
			     IPC_pipe_names[i], name->Len()) == 0 )
			return i + 1;
		}

	return IPC_NONE;
	}

int SMB_Session::ParseTreeConnectAndx(int len, const u_char* msg)
	{
	uint16 password_length;
	if ( GetWord(len, msg, 32 + 7, &password_length) == -1 )
		return -1;

	uint16 byte_count;
	const u_char* data;
	if ( Get_SMB_DataPortion(len, msg, &byte_count, &data) == -1 )
		return -1;

	if ( data + byte_count > msg + len )
		{
		Weird("truncated_SMB_message__Get_SMB_DataPortion_3");
		return -1;
		}

	int data_length = byte_count;
	if ( password_length > data_length )
		{
		Weird("truncated_SMB_message__Get_SMB_DataPortion_4");
		return -1;
		}

	// Skip password.
	data_length -= password_length;
	data += password_length;

	BroString* path = ExtractString(msg, data_length, data);
	BroString* service = ExtractString(msg, data_length, data);

	if ( strstr_n(path->Len(), path->Bytes(), 5,
		      (const u_char*) "\\IPC$") != -1 )
		is_IPC = 1;	// TODO: change is_IPC to 0 on tree_disconnect

	if ( smb_com_tree_connect_andx )
		{
		val_list* vl = new val_list;
		vl->append(conn->BuildConnVal());
		vl->append(new StringVal(path));
		vl->append(new StringVal(service));
		conn->ConnectionEvent(smb_com_tree_connect_andx, vl);
		}
	else
		{
		delete path;
		delete service;
		}

	return 0;
	}

int SMB_Session::ParseNtCreateAndx(int len, const u_char* msg)
	{
	uint16 name_length;
	if ( GetWord(len, msg, 32 + 6, &name_length) == -1 )
		return -1;

	uint16 byte_count;
	const u_char* data;
	if ( Get_SMB_DataPortion(len, msg, &byte_count, &data) == -1 )
		return -1;

	if ( data + name_length > msg + len )
		{
		Weird("truncated_SMB_message__ParseNtCreateAndx");
		return -1;
		}

	int data_length = name_length;
	BroString* name = ExtractString(msg, data_length, data);

	IPC_pipe = (enum IPC_named_pipe) lookup_IPC_name(name);

	if ( smb_com_nt_create_andx )
		{
		val_list* vl = new val_list;
		vl->append(conn->BuildConnVal());
		vl->append(new StringVal(name));
		conn->ConnectionEvent(smb_com_nt_create_andx, vl);
		}
	else
		delete name;

	return 0;
	}

int SMB_Session::ParseTransaction(int is_orig, int cmd, int len,
					const u_char* msg)
	{
	int is_primary =
		(cmd == SMB_COM_TRANSACTION || cmd == SMB_COM_TRANSACTION2);

	uint8 word_count = msg[32];
	if ( word_count == 0 && ! is_orig && is_primary )
		// Interim server response.
		return 0;

	// Get the data offset and length.
	int data_count_off;
	int setup_count_off, setup_off;
	uint16 total_data_count;
	uint16 data_displacement;
	uint16 data_count;
	uint16 data_off;

	if ( is_orig )
		{
		data_count_off = is_primary ? 23 : 11;
		if ( is_primary )
			{
			setup_off = 29;
			setup_count_off = data_count_off + 4;
			}
		else
			setup_off = setup_count_off = 0;
		}
	else
		{
		data_count_off = 13;
		setup_off = 21;
		setup_count_off = data_count_off + 6;
		}

	if ( GetWord(len, msg, 32 + 3, &total_data_count) == -1 ||
	     GetWord(len, msg, 32 + data_count_off, &data_count) == -1 ||
	     GetWord(len, msg, 32 + data_count_off + 2, &data_off) == -1 )
		return -1;

	if ( is_orig && is_primary )
		data_displacement = 0;
	else
		{
		if ( GetWord(len, msg, 32 + data_count_off + 4,
				&data_displacement) == -1 )
			return -1;
		}

	if ( setup_count_off && setup_count_off < len &&
	     msg[32 + setup_count_off] > 0 && setup_off )
		{
		if ( GetWord(len, msg, 32 + setup_off,
				&transaction_subcmd) == -1 )
			return -1;
		}
	else
		transaction_subcmd = 0;

	// Extract the name if it's a request.
	if ( is_orig && is_primary )
		{
		if ( transaction_name )
			{
			Unref(transaction_name);
			transaction_name = 0;
			}

		if ( cmd == SMB_COM_TRANSACTION )
			{
			uint16 byte_count;
			const u_char* data;

			if ( Get_SMB_DataPortion(len, msg,
						&byte_count, &data) == -1 )
				return -1;

			int data_length = byte_count;
			transaction_name = new StringVal(ExtractString(msg, data_length, data));
			}
		}

	if ( data_off + data_count > len )
		{
		Weird("truncated_SMB_message__ParseTransaction");
		return -1;
		}

	if ( ! is_primary || data_displacement != 0 || data_count < total_data_count )
		// ### FIXME: data reassembly
		return 0;

	EventHandlerPtr f;
	if ( cmd == SMB_COM_TRANSACTION ||
	     cmd == SMB_COM_TRANSACTION_SECONDARY )
		f = smb_com_transaction;

	else if ( cmd == SMB_COM_TRANSACTION2 ||
		  cmd == SMB_COM_TRANSACTION2_SECONDARY )
		f = smb_com_transaction2;

	if ( f )
		{
		val_list* vl = new val_list;

		vl->append(conn->BuildConnVal());
		vl->append(new Val(is_orig, TYPE_BOOL));
		vl->append(new Val(transaction_subcmd, TYPE_COUNT));
		if ( transaction_name )
			{
			Ref(transaction_name);
			vl->append(transaction_name);
			}
		else
			vl->append(new StringVal(""));
		vl->append(new StringVal(data_count,
						(const char*) msg + data_off));
		conn->ConnectionEvent(f, vl);
		}

	if ( LooksLikeRPC(len, msg, data_count, data_off) )
		{
		if ( ! dce_rpc_session )
			dce_rpc_session = new DCE_RPC_Session(conn);

		dce_rpc_session->DeliverPDU(is_orig, data_count, msg + data_off);
		}

	return 0;
	}

int SMB_Session::ParseReadAndx(int is_orig, int len, const u_char* msg)
	{
	if ( is_orig )
		// Here we only process responses to READ_ANDX.
		return 0;

	// Get the data offset and length.
	int data_count_off = 11;
	uint16 data_count;
	if ( GetWord(len, msg, 32 + data_count_off, &data_count) == -1 )
		return -1;

	uint16 data_off;
	if ( GetWord(len, msg, 32 + data_count_off + 2, &data_off) == -1 )
		return -1;

	if ( data_off + data_count > len )
		{
		Weird("truncated_SMB_message__ParseReadAndx");
		return -1;
		}

	if ( smb_com_read_andx )
		{
		val_list* vl = new val_list;
		vl->append(conn->BuildConnVal());
		vl->append(new Val(is_orig, TYPE_BOOL));
		vl->append(new StringVal(data_count,
						(const char*) msg + data_off));
		conn->ConnectionEvent(smb_com_read_andx, vl);
		}

	if ( LooksLikeRPC(len, msg, data_count, data_off) )
		{
		if ( ! dce_rpc_session )
			dce_rpc_session = new DCE_RPC_Session(conn);

		dce_rpc_session->DeliverPDU(is_orig, data_count, msg + data_off);
		}

	return 0;
	}

int SMB_Session::ParseWriteAndx(int is_orig, int len, const u_char* msg)
	{
	if ( ! is_orig )
		// Here we only process WRITE_ANDX requests.
		return 0;

	// Get the data offset and length.
	int data_count_off = 21;
	uint16 data_count;
	if ( GetWord(len, msg, 32 + data_count_off, &data_count) == -1 )
		return -1;

	uint16 data_off;
	if ( GetWord(len, msg, 32 + data_count_off + 2, &data_off) == -1 )
		return -1;

	if ( data_off + data_count > len )
		{
		Weird("truncated_SMB_message__ParseWriteAndx");
		// DEBUG_MSG("data_off=0x%x data_count=0x%x len=0x%x\n", data_off, data_count, len);
		return -1;
		}

	if ( smb_com_write_andx )
		{
		val_list* vl = new val_list;
		vl->append(conn->BuildConnVal());
		vl->append(new Val(is_orig, TYPE_BOOL));
		vl->append(new StringVal(data_count,
						(const char*) msg + data_off));
		conn->ConnectionEvent(smb_com_write_andx, vl);
		}

	if ( LooksLikeRPC(len, msg, data_count, data_off) )
		{
		if ( ! dce_rpc_session )
			dce_rpc_session = new DCE_RPC_Session(conn);

		dce_rpc_session->DeliverPDU(is_orig, data_count, msg + data_off);
		}

	return 0;
	}

TCP_Contents_SMB::TCP_Contents_SMB(TCP_Endpoint* arg_endp, SMB_Session* s)
: TCP_Contents(arg_endp)
	{
	smb_session = s;
	msg_buf = 0;
	msg_len = 0;
	buf_len = 0;
	buf_n = 0;
	}

void TCP_Contents_SMB::InitMsgBuf()
	{
	delete [] msg_buf;
	msg_buf = new u_char[msg_len];
	buf_len = msg_len;
	buf_n = 0;
	}

TCP_Contents_SMB::~TCP_Contents_SMB()
	{
	delete [] msg_buf;
	}

void TCP_Contents_SMB::DeliverSMB(int len, const u_char* data)
	{
	// Check the 4-byte header.
	if ( strncmp((const char*) data, "\xffSMB", 4) )
		{
		Conn()->Weird(fmt("SMB-over-TCP header error: %02x%02x%02x%02x, \\x%02x%c%c%c",
			dshdr[0], dshdr[1], dshdr[2], dshdr[3],
			data[0], data[1], data[2], data[3]));
		SetSkipDeliveries(1);
		}
	else
		smb_session->Deliver(IsOrig(), len, data);

	buf_n = 0;
	msg_len = 0;
	}

void TCP_Contents_SMB::Deliver(int /* seq */, int len, u_char* data)
	{
	while ( len > 0 )
		{
		if ( ! msg_len )
			{
			// Get the SMB-over-TCP header (4 bytes).
			while ( buf_n < 4 && len > 0 )
				{
				dshdr[buf_n] = *data;
				++buf_n; ++data; --len;
				}

			if ( buf_n < 4 )
				return;

			buf_n = 0;
			for ( int i = 1; i < 4; ++i )
				msg_len = ( msg_len << 8 ) + dshdr[i];

			if ( dshdr[0] != 0 || msg_len <= 4 )
				{
				Conn()->Weird("SMB message length error");
				SetSkipDeliveries(1);
				return;
				}
			}

		if ( buf_n == 0 && msg_len <= len )
			{
			// The fast lane:
			// Keep msg_len -- it will be changed in DeliverSMB
			int mlen = msg_len;
			DeliverSMB(msg_len, data);
			len -= mlen;
			data += mlen;
			}

		else
			{
			if ( buf_len < msg_len )
				InitMsgBuf();

			while ( buf_n < msg_len && len > 0 )
				{
				msg_buf[buf_n] = *data;
				++buf_n;
				++data;
				--len;
				}

			if ( buf_n < msg_len )
				return;

			DeliverSMB(msg_len, msg_buf);
			}
		}
	}

SMB_TCP_Conn::SMB_TCP_Conn(NetSessions* s, HashKey* k, double t,
				const ConnID* id, const struct tcphdr* tp)
:TCP_Connection(s, k, t, id, tp)
	{
	o_smb = r_smb = 0;
	smb_session = new SMB_Session(this);
	}

SMB_TCP_Conn::~SMB_TCP_Conn()
	{
	delete smb_session;
	}

void SMB_TCP_Conn::BuildEndpoints()
	{
	o_smb = new TCP_Contents_SMB(orig, smb_session);
	r_smb = new TCP_Contents_SMB(resp, smb_session);

	orig->AddContentsProcessor(o_smb);
	resp->AddContentsProcessor(r_smb);
	}
