// $Id: RE.cc,v 1.3 2005/03/08 15:26:41 vern Exp $
//
// Copyright (c) 1998, 1999, 2001, 2002, 2003
//      The Regents of the University of California.  All rights reserved.
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that: (1) source code distributions
// retain the above copyright notice and this paragraph in its entirety, (2)
// distributions including binary code include the above copyright notice and
// this paragraph in its entirety in the documentation or other materials
// provided with the distribution, and (3) all advertising materials mentioning
// features or use of this software display the following acknowledgement:
// ``This product includes software developed by the University of California,
// Lawrence Berkeley Laboratory and its contributors.'' Neither the name of
// the University nor the names of its contributors may be used to endorse
// or promote products derived from this software without specific prior
// written permission.
// THIS SOFTWARE IS PROVIDED ``AS IS'' AND WITHOUT ANY EXPRESS OR IMPLIED
// WARRANTIES, INCLUDING, WITHOUT LIMITATION, THE IMPLIED WARRANTIES OF
// MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE.

#include "config.h"

#include <stdlib.h>

#include "RE.h"
#include "DFA.h"
#include "CCL.h"
#include "EquivClass.h"
#include "Serializer.h"

CCL* curr_ccl = 0;

Specific_RE_Matcher* rem;
NFA_Machine* nfa = 0;
int case_insensitive = 0;

extern int RE_parse(void);
extern void RE_set_input(const char* str);

// If true, the set-wise matching always returns false - for benchmarking.
extern int rule_bench;

Specific_RE_Matcher::Specific_RE_Matcher(match_type arg_mt, int arg_multiline)
: equiv_class(NUM_SYM)
	{
	mt = arg_mt;
	multiline = arg_multiline;
	any_ccl = 0;
	pattern_text = 0;
	dfa = 0;
	ecs = 0;
	accepted = new AcceptingSet();
	}

Specific_RE_Matcher::~Specific_RE_Matcher()
	{
	for ( int i = 0; i < ccl_list.length(); ++i )
		delete ccl_list[i];

	Unref(dfa);
	delete [] pattern_text;
	delete accepted;
	}

CCL* Specific_RE_Matcher::AnyCCL()
	{
	if ( ! any_ccl )
		{ // Create the '.' character class.
		any_ccl = new CCL();
		if ( ! multiline )
			any_ccl->Add('\n');
		any_ccl->Negate();
		EC()->CCL_Use(any_ccl);
		}

	return any_ccl;
	}

void Specific_RE_Matcher::ConvertCCLs()
	{
	for ( int i = 0; i < ccl_list.length(); ++i )
		equiv_class.ConvertCCL(ccl_list[i]);
	}

void Specific_RE_Matcher::AddPat(const char* new_pat)
	{
	if ( mt == MATCH_EXACTLY )
		AddExactPat(new_pat);
	else
		AddAnywherePat(new_pat);
	}

void Specific_RE_Matcher::AddAnywherePat(const char* new_pat)
	{
	int n = strlen(new_pat) + 7;

	if ( pattern_text )
		n += strlen(pattern_text) + 5;	// 5 for ()|()

	char* s = new char[n];
	if ( pattern_text )
		sprintf(s, "(%s)|(^?.*(%s))", pattern_text, new_pat);
	else
		sprintf(s, "^?.*(%s)", new_pat);

	delete [] pattern_text;
	pattern_text = s;
	}

void Specific_RE_Matcher::AddExactPat(const char* new_pat)
	{
	int n = strlen(new_pat) + 7;

	if ( pattern_text )
		n += strlen(pattern_text) + 5;	// 5 for ()|()

	char* s = new char[n];
	if ( pattern_text )
		sprintf(s, "(%s)|(^?(%s)$?)", pattern_text, new_pat);
	else
		sprintf(s, "^?(%s)$?", new_pat);

	delete [] pattern_text;
	pattern_text = s;
	}

int Specific_RE_Matcher::Compile(int lazy)
	{
	if ( ! pattern_text )
		return 0;

	rem = this;
	RE_set_input(pattern_text);
	if ( RE_parse() )
		{
		run_time("error compiling pattern /%s/", pattern_text);
		return 0;
		}

	EC()->BuildECs();
	ConvertCCLs();

	dfa = new DFA_Machine(nfa, EC());

	Unref(nfa); 
	nfa = 0;

	ecs = EC()->EquivClasses();

	return 1;
	}

int Specific_RE_Matcher::CompileSet(const string_list& set, const int_list& idx)
	{
	if ( set.length() != idx.length() )
		internal_error("compileset: lengths of sets differ");

	rem = this;

	NFA_Machine* set_nfa = 0;

	loop_over_list(set, i)
		{
		RE_set_input(set[i]);
		if ( RE_parse() )
			{
			run_time("error compiling pattern /%s/", set[i]);
			return 0;
			}

		nfa->FinalState()->SetAccept(idx[i]);
		set_nfa = set_nfa ? make_alternate(nfa, set_nfa) : nfa;
		}

	// Prefix the expression with a "^".
	nfa = new NFA_Machine(new NFA_State(SYM_BOL, rem->EC()));
	if ( set_nfa )
		nfa->AppendMachine( set_nfa );

	EC()->BuildECs();
	ConvertCCLs();

	dfa = new DFA_Machine(nfa, EC());
	ecs = EC()->EquivClasses();

	return 1;
	}

const char* Specific_RE_Matcher::LookupDef(const char* def)
	{
	return defs.Lookup(def);
	}

int Specific_RE_Matcher::MatchAll(const char* s)
	{
	return MatchAll(byte_vec(s), strlen(s));
	}

int Specific_RE_Matcher::MatchAll(const BroString* s)
	{
	// s->Len() does not include '\0'.
	return MatchAll(s->Bytes(), s->Len());
	}

int Specific_RE_Matcher::Match(const char* s)
	{
	return Match(byte_vec(s), strlen(s));
	}

int Specific_RE_Matcher::Match(const BroString* s)
	{
	return Match(s->Bytes(), s->Len());
	}

int Specific_RE_Matcher::LongestMatch(const char* s)
	{
	return LongestMatch(byte_vec(s), strlen(s));
	}

int Specific_RE_Matcher::LongestMatch(const BroString* s)
	{
	return LongestMatch(s->Bytes(), s->Len());
	}

int Specific_RE_Matcher::MatchAll(byte_vec bv, int n)
	{
	if ( ! dfa )
		// An empty pattern matches "all" iff what's being
		// matched is empty.
		return n == 0;

	DFA_State_Handle* d = dfa->StartState();
	d = (*d)->Xtion(ecs[SYM_BOL], dfa);

	while ( d )
		{
		if ( --n < 0 )
			break;

		int ec = ecs[*(bv++)];
		d = (*d)->Xtion(ec, dfa);
		}

	if ( d )
		d = (*d)->Xtion(ecs[SYM_EOL], dfa);

	return d && (*d)->Accept() != 0;
	}


int Specific_RE_Matcher::Match(byte_vec bv, int n)
	{
	if ( ! dfa )
		// An empty pattern matches anything.
		return 1;

	DFA_State_Handle* d = dfa->StartState();

	d = (*d)->Xtion(ecs[SYM_BOL], dfa);
	if ( ! d ) return 0;

	for ( int i = 0; i < n; ++i )
		{
		int ec = ecs[bv[i]];
		d = (*d)->Xtion(ec, dfa);
		if ( ! d )
			break;

		if ( (*d)->Accept() )
			return i + 1;
		}

	if ( d )
		{
		d = (*d)->Xtion(ecs[SYM_EOL], dfa);
		if ( d && (*d)->Accept() )
			return n > 0 ? n : 1;	// we can't return 0 here for match...
		}

	return 0;
	}


void Specific_RE_Matcher::Dump(FILE* f)
	{
	dfa->Dump(f);
	}

RE_Match_State::~RE_Match_State()
	{
	if ( current_state )
		StateUnref(current_state);
	}

bool RE_Match_State::Match( const u_char* bv, int n, bool bol, bool eol )
	{
	if ( rule_bench > 0 )
		return false;

	if ( current_pos == -1  )
		{
		// First call to Match().
		if ( ! dfa )
			return false;

		// Initialize state and copy the accepting states of the start
		// state into the acceptance set.
		current_state = dfa->StartState();
		StateRef(current_state);

		const AcceptingSet* ac = (*current_state)->Accept();
		if ( ac )
			{
			loop_over_list(*ac, i)
				{
				accepted.append((*ac)[i]);
				match_pos.append(0);
				}
			}
		}

	else if ( ! current_state )
		return false;

	else
		(*current_state)->Unlock();

	current_pos = 0;

	int old_matches = accepted.length();

	int ec;
	int m = bol ? n + 1 : n;
	int e = eol ? -1 : 0;

	while ( --m >= e )
		{
		if ( m == n )
			ec = ecs[SYM_BOL];
		else if ( m == -1 )
			ec = ecs[SYM_EOL];
		else
			ec = ecs[*(bv++)];

		DFA_State_Handle* next_state = (*current_state)->Xtion(ec,dfa);

		if ( ! next_state )
			{
			current_state = 0;
			break;
			}

		if ( (*next_state)->Accept() )
			{
			const AcceptingSet* ac = (*next_state)->Accept();
			loop_over_list(*ac, i)
				{
				if ( ! accepted.is_member((*ac)[i]) )
					{
					accepted.append((*ac)[i]);
					match_pos.append(current_pos);
					}
				}
			}

		++current_pos;

		StateRef(next_state);
		StateUnref(current_state);
		current_state = next_state;
		}

	// Make sure our state doesn't expire until we return.
	if ( current_state )
		(*current_state)->Lock();

	return accepted.length() != old_matches;
	}

int Specific_RE_Matcher::LongestMatch(byte_vec bv, int n)
	{
	if ( ! dfa )
		// An empty pattern matches anything.
		return 1;

	int last_accept = 0;
	DFA_State_Handle* d = dfa->StartState();

	d = (*d)->Xtion(ecs[SYM_BOL], dfa);
	if ( ! d )
		return 0;

	if ( (*d)->Accept() )
		last_accept = 1;

	for ( int i = 0; i < n; ++i )
		{
		int ec = ecs[bv[i]];
		d = (*d)->Xtion(ec, dfa);

		if ( ! d )
			break;

		if ( (*d)->Accept() )
			last_accept = i + 1;
		}

	if ( d )
		{
		d = (*d)->Xtion(ecs[SYM_EOL], dfa);
		if ( d && (*d)->Accept() )
			return n > 0 ? n : 1;
		}

	return last_accept;
	}

unsigned int Specific_RE_Matcher::MemoryAllocation() const
	{
	unsigned int size = 0;

	for ( int i = 0; i < ccl_list.length(); ++i )
		size += ccl_list[i]->MemoryAllocation();

	return size + padded_sizeof(*this)
		+ (pattern_text ? pad_size(strlen(pattern_text) + 1) : 0)
		+ defs.MemoryAllocation() - padded_sizeof(defs) // FIXME: count content
		+ ccl_dict.MemoryAllocation() - padded_sizeof(ccl_dict) // FIXME: count content
		+ ccl_list.MemoryAllocation() - padded_sizeof(ccl_list)
		+ equiv_class.Size() - padded_sizeof(EquivClass)
		+ (dfa ? dfa->MemoryAllocation() : 0) // this is ref counted; consider the bytes here?
		+ padded_sizeof(*any_ccl)
		+ accepted->MemoryAllocation();
	}

RE_Matcher::RE_Matcher()
	{
	re_anywhere = new Specific_RE_Matcher(MATCH_ANYWHERE);
	re_exact = new Specific_RE_Matcher(MATCH_EXACTLY);
	}

RE_Matcher::RE_Matcher(const char* pat)
	{
	re_anywhere = new Specific_RE_Matcher(MATCH_ANYWHERE);
	re_exact = new Specific_RE_Matcher(MATCH_EXACTLY);

	AddPat(pat);
	}

RE_Matcher::~RE_Matcher()
	{
	delete re_anywhere;
	delete re_exact;
	}

void RE_Matcher::AddPat(const char* new_pat)
	{
	re_anywhere->AddPat(new_pat);
	re_exact->AddPat(new_pat);
	}

int RE_Matcher::Compile(int lazy)
	{
	return re_anywhere->Compile(lazy) && re_exact->Compile(lazy);
	}

bool RE_Matcher::Serialize(SerialInfo* info) const
	{
	return SerialObj::Serialize(info);
	}

RE_Matcher* RE_Matcher::Unserialize(UnserialInfo* info)
	{
	return (RE_Matcher*) SerialObj::Unserialize(info, SER_RE_MATCHER);
	}

IMPLEMENT_SERIAL(RE_Matcher, SER_RE_MATCHER);

bool RE_Matcher::DoSerialize(SerialInfo* info) const
	{
	DO_SERIALIZE(SER_RE_MATCHER, SerialObj);
	return SERIALIZE(re_anywhere->PatternText())
			&& SERIALIZE(re_anywhere->PatternText());
	}

bool RE_Matcher::DoUnserialize(UnserialInfo* info)
	{
	DO_UNSERIALIZE(SerialObj);

	re_anywhere = new Specific_RE_Matcher(MATCH_ANYWHERE);
	re_exact = new Specific_RE_Matcher(MATCH_EXACTLY);

	const char* pat;
	if ( ! UNSERIALIZE_STR(&pat, 0) )
		return false;

	re_anywhere->SetPat(pat);
	if ( ! re_anywhere->Compile() )
		{
		info->s->Error(fmt("Can't compile regexp '%s'", pat));
		return false;
		}

	if ( ! UNSERIALIZE_STR(&pat, 0) )
		return false;

	re_exact->SetPat(pat);
	if ( ! re_exact->Compile() )
		{
		info->s->Error(fmt("Can't compile regexp '%s'", pat));
		return false;
		}

	return true;
	}


static RE_Matcher* matcher_merge(const RE_Matcher* re1, const RE_Matcher* re2,
				const char* merge_op)
	{
	const char* text1 = re1->PatternText();
	const char* text2 = re2->PatternText();

	int n = strlen(text1) + strlen(text2) + strlen(merge_op) + 32 /* slop */ ;

	char* merge_text = new char[n];
	snprintf(merge_text, n - 1, "(%s)%s(%s)", text1, merge_op, text2);

	RE_Matcher* merge = new RE_Matcher(merge_text);
	delete merge_text;

	merge->Compile();

	return merge;
	}

RE_Matcher* RE_Matcher_conjunction(const RE_Matcher* re1, const RE_Matcher* re2)
	{
	return matcher_merge(re1, re2, "");
	}

RE_Matcher* RE_Matcher_disjunction(const RE_Matcher* re1, const RE_Matcher* re2)
	{
	return matcher_merge(re1, re2, "|");
	}
