// $Id: ChunkedIO.cc,v 1.5 2005/08/23 21:40:04 vern Exp $

#include <unistd.h>
#include <fcntl.h>
#include <errno.h>
#include <signal.h>
#include <sys/time.h>
#include <netinet/in.h>
#include <assert.h>

#include "config.h"
#include "ChunkedIO.h"
#include "NetVar.h"

void ChunkedIO::Stats(char* buffer, int length)
	{
	snprintf(buffer, length,
		"bytes=%luK/%luK chunks=%lu/%lu io=%lu/%lu bytes/io=%.2fK/%.2fK",
		bytes_read / 1024, bytes_written / 1024,
		chunks_read, chunks_written, reads, writes,
		bytes_read / (1024.0 * reads),
		bytes_written / (1024.0 * writes));

	buffer[length - 1] = '\0';
	}

ChunkedIO::ChunkedIO()
	{
	pure = false;
	bytes_read = 0;
	bytes_written = 0;
	chunks_read = 0;
	chunks_written = 0;
	reads = 0;
	writes = 0;
	}

ChunkedIOFd::ChunkedIOFd(int arg_fd, const char* arg_tag)
	{
	int flags;

	tag = arg_tag;
	fd = arg_fd;
	eof = 0;
	last_flush = current_time();

	if ( (flags = fcntl(fd, F_GETFL, 0)) < 0)
		{
		error(fmt("can't obtain socket flags: %s", strerror(errno)));
		exit(1);
		}

	if ( fcntl(fd, F_SETFL, flags|O_NONBLOCK) < 0 )
		{
		error(fmt("can't set fd to non-blocking: %s (%d)",
			  strerror(errno), getpid()));
		exit(1);
		}

	read_buffer = new char[BUFFER_SIZE];
	read_len = 0;
	read_pos = 0;
	partial = 0;
	write_buffer = new char[BUFFER_SIZE];
	write_len = 0;
	write_pos = 0;

	pending_head = 0;
	pending_tail = 0;
	pending_len = 0;

	parent = getppid();
	}

ChunkedIOFd::~ChunkedIOFd()
	{
	delete [] read_buffer;
	delete [] write_buffer;
	close(fd);

	if ( partial )
		{
		delete [] partial->data;
		delete partial;
		}
	}

bool ChunkedIOFd::Write(Chunk* chunk)
	{
#ifdef DEBUG
	DBG_LOG(DBG_CHUNKEDIO, "write of size %d [%s]",
		chunk->len, fmt_bytes(chunk->data, 20));
#endif

	// Reject if our queue of pending chunks is way too large. Otherwise,
	// memory could fill up if the other side doesn't read.
	if ( pending_len > MAX_BUFFERED_CHUNKS )
		{
		DBG_LOG(DBG_CHUNKEDIO, "write queue full");
		errno = ENOSPC;
		return false;
		}

	if ( chunk->len <= BUFFER_SIZE - sizeof(uint32) )
		return WriteChunk(chunk, false);

	// We have to split it up.
	char* p = chunk->data;
	unsigned long left = chunk->len;

	while ( left )
		{
		Chunk* part = new Chunk;

		part->len = min(BUFFER_SIZE - sizeof(uint32), left);
		part->data = new char[part->len];
		memcpy(part->data, p, part->len);
		left -= part->len;
		p += part->len;

		if ( ! WriteChunk(part, left != 0) )
			return false;
		}

	delete [] chunk->data;
	delete chunk;

	return true;
	}

bool ChunkedIOFd::WriteChunk(Chunk* chunk, bool partial)
	{
	assert(chunk->len <= BUFFER_SIZE - sizeof(uint32) );

	if ( chunk->len == 0 )
		internal_error( "attempt to write 0 bytes chunk");

	if ( partial )
		chunk->len |= FLAG_PARTIAL;

	++chunks_written;

	// If it fits into the buffer, we're done (but keep care not
	// to reorder chunks).
	if ( ! pending_head && PutIntoWriteBuffer(chunk) )
		return true;

	// Otherwise queue it.
	++pending_len;
	ChunkQueue* q = new ChunkQueue;
	q->chunk = chunk;
	q->next = 0;

	if ( pending_tail )
		{
		pending_tail->next = q;
		pending_tail = q;
		}
	else
		pending_head = pending_tail = q;

	if ( pending_len && pending_len % 10000 == 0 )
		warn(fmt("ChunkedIO: %d blocks pending in process %d %s",
				pending_len, getpid(), tag));

	return Flush();
	}


bool ChunkedIOFd::PutIntoWriteBuffer(Chunk* chunk)
	{
	uint32 len = chunk->len & ~FLAG_PARTIAL;

	if ( write_len + len + (IsPure() ? 0 : sizeof(len)) > BUFFER_SIZE )
		return false;

	if ( ! IsPure() )
		{
		uint32 nlen = htonl(chunk->len);
		memcpy(write_buffer + write_len, &nlen, sizeof(nlen));
		write_len += sizeof(nlen);
		}

	memcpy(write_buffer + write_len, chunk->data, len);
	write_len += len;

	delete [] chunk->data;
	delete chunk;

	return true;
	}

bool ChunkedIOFd::FlushWriteBuffer()
	{
	last_flush = current_time();

	while ( write_pos != write_len )
		{
		uint32 len = write_len - write_pos;

		int written = write(fd, write_buffer + write_pos, len);

		if ( written < 0 )
			{
			if ( errno == EPIPE )
				eof = true;

			// These errnos are equal on POSIX.
			return errno == EWOULDBLOCK || errno == EAGAIN;
			}

		bytes_written += written;
		if ( written > 0 )
			++writes;

		if ( unsigned(written) == len )
			{
			write_pos = write_len = 0;
			return true;
			}

		if ( written == 0 )
			internal_error("written==0");

		// Short write.
		write_pos += written;
		}

	return true;
	}

bool ChunkedIOFd::OptionalFlush()
	{
	// This threshhold is quite arbitrary.
	if ( current_time() - last_flush > 1 )
		return Flush();

	return true;
	}

bool ChunkedIOFd::Flush()
	{
	// Try to write data out.
	while ( pending_head )
		{
		if ( ! FlushWriteBuffer() )
			return false;

		// If we couldn't write the whole buffer, we stop here
		// and try again next time.
		if ( write_len > 0 )
			return true;

		// Put as many pending chunks into the buffer as possible.
		while ( pending_head )
			{
			if ( ! PutIntoWriteBuffer(pending_head->chunk) )
				break;

			ChunkQueue* q = pending_head;
			pending_head = pending_head->next;
			if ( ! pending_head )
				pending_tail = 0;

			--pending_len;
			delete q;
			}
		}

	return FlushWriteBuffer();
	}

uint32 ChunkedIOFd::ChunkAvailable()
	{
	int bytes_left = read_len - read_pos;

	if ( bytes_left < int(sizeof(uint32)) )
		return 0;

	bytes_left -= sizeof(uint32);

	// We have to copy the value here as it may not be
	// aligned correctly in the data.
	uint32 len;
	memcpy(&len, read_buffer + read_pos, sizeof(len));
	len = ntohl(len);

	if ( uint32(bytes_left) < (len & ~FLAG_PARTIAL) )
		return 0;

	assert(len & ~FLAG_PARTIAL);

	return len;
	}

ChunkedIO::Chunk* ChunkedIOFd::ExtractChunk()
	{
	uint32 len = ChunkAvailable();
	uint32 real_len = len & ~FLAG_PARTIAL;
	if ( ! real_len )
		return 0;

	read_pos += sizeof(uint32);

	Chunk* chunk = new Chunk;
	chunk->len = len;
	chunk->data = new char[real_len];
	memcpy(chunk->data, read_buffer + read_pos, real_len);
	read_pos += real_len;

	++chunks_read;

	return chunk;
	}

ChunkedIO::Chunk* ChunkedIOFd::ConcatChunks(Chunk* c1, Chunk* c2)
	{
	Chunk* c = new Chunk;

	c->len = c1->len + c2->len;
	c->data = new char[c->len];

	memcpy(c->data, c1->data, c1->len);
	memcpy(c->data + c1->len, c2->data, c2->len);

	delete [] c1->data;
	delete c1;
	delete [] c2->data;
	delete c2;

	return c;
	}

bool ChunkedIOFd::Read(Chunk** chunk, bool may_block)
	{
	// We will be called regularly. So take the opportunity
	// to flush the write buffer once in a while.
	OptionalFlush();

	if ( ! ReadChunk(chunk, may_block) )
		return false;

	if ( ! *chunk )
		return true;

#ifdef DEBUG
	if ( *chunk )
		DBG_LOG(DBG_CHUNKEDIO, "read of size %d %s[%s]",
				(*chunk)->len & ~FLAG_PARTIAL,
				(*chunk)->len & FLAG_PARTIAL ? "(P) " : "",
				fmt_bytes((*chunk)->data, 20));
#endif

	if ( ! ((*chunk)->len & FLAG_PARTIAL) )
		{
		if ( ! partial )
			return true;
		else
			{
			// This is the last chunk of an oversized one.
			*chunk = ConcatChunks(partial, *chunk);
			partial = 0;

#ifdef DEBUG
			if ( *chunk )
				DBG_LOG(DBG_CHUNKEDIO,
					"built virtual chunk of size %d [%s]",
					(*chunk)->len,
					fmt_bytes((*chunk)->data, 20));
#endif


			return true;
			}
		}

	// This chunk is the non-last part of an oversized.
	(*chunk)->len &= ~FLAG_PARTIAL;

	if ( ! partial )
		// First part of oversized chunk.
		partial = *chunk;
	else
		partial = ConcatChunks(partial, *chunk);

	*chunk = 0;
	return true; // Read following part next time.
	}

bool ChunkedIOFd::ReadChunk(Chunk** chunk, bool may_block)
	{
	// We will be called regularly. So take the opportunity
	// to flush the write buffer once in a while.
	OptionalFlush();

	*chunk = ExtractChunk();
	if ( *chunk )
		return true;

	int bytes_left = read_len - read_pos;

	// If we have a partial chunk left, move this to the head of
	// the buffer.
	if ( bytes_left )
		memmove(read_buffer, read_buffer + read_pos, bytes_left);

	read_pos = 0;
	read_len = bytes_left;

	// If allowed, wait a bit for something to read.
	if ( may_block )
		{
		fd_set fd_read, fd_write, fd_except;

		FD_ZERO(&fd_read);
		FD_ZERO(&fd_write);
		FD_ZERO(&fd_except);
		FD_SET(fd, &fd_read);

		struct timeval small_timeout;
		small_timeout.tv_sec = 0;
		small_timeout.tv_usec = 50;

		select(fd + 1, &fd_read, &fd_write, &fd_except, &small_timeout);
		}

	// Make sure the parent is still runnning
	// (only checking for EPIPE after a read doesn't
	// seem to be sufficient).
	if ( kill(parent, 0) < 0 && errno != EPERM )
		{
		eof = true;
		errno = EPIPE;
		return false;
		}

	// Try to fill the buffer.
	while ( true )
		{
		int len = BUFFER_SIZE - read_len;
		int read = ::read(fd, read_buffer + read_len, len);

		if ( read < 0 )
			{
			// These errno's are the same on POSIX.
			if ( errno == EWOULDBLOCK || errno == EAGAIN )
				{
				// Let's see if we have a chunk now --
				// even if we time out, we may have read
				// just enough in previous iterations!
				*chunk = ExtractChunk();
				return true;
				}

			if ( errno == EPIPE )
				eof = true;

			return false;
			}

		if ( read == 0 )
			{
			*chunk = ExtractChunk();
			if ( *chunk )
				return true;

			eof = true;
			return false;
			}

		read_len += read;

		++reads;
		bytes_read += read;

		if ( read == len )
			break;
		}

	// Let's see if we have a chunk now.
	*chunk = ExtractChunk();

	return true;
	}

bool ChunkedIOFd::CanRead()
	{
	// We will be called regularly. So take the opportunity
	// to flush the write buffer once in a while.
	OptionalFlush();

	if ( ChunkAvailable() )
		return true;

	fd_set fd_read;
	FD_ZERO(&fd_read);
	FD_SET(fd, &fd_read);

	struct timeval no_timeout;
	no_timeout.tv_sec = 0;
	no_timeout.tv_usec = 0;

	return select(fd + 1, &fd_read, 0, 0, &no_timeout) > 0;
	}

bool ChunkedIOFd::CanWrite()
	{
	return pending_head != 0;
	}

const char* ChunkedIOFd::Error()
	{
	static char buffer[1024];

	snprintf(buffer, sizeof(buffer), "%s [%d]", strerror(errno), errno);
	buffer[sizeof(buffer) - 1] = '\0';

	return buffer;
	}

void ChunkedIOFd::Stats(char* buffer, int length)
	{
	int i = snprintf(buffer, length, "pending=%d ", pending_len);
	ChunkedIO::Stats(buffer + i, length - i);
	buffer[length-1] = '\0';
	}


#ifdef USE_OPENSSL

#include <openssl/ssl.h>

SSL_CTX* ChunkedIOSSL::ctx;

ChunkedIOSSL::ChunkedIOSSL(int arg_socket, bool arg_server)
	{
	socket = arg_socket;
	eof = false;
	setup = false;
	server = arg_server;
	ssl = 0;

	write_state = LEN;
	write_head = 0;
	write_tail = 0;
	write_pending = 0;

	read_state = LEN;
	read_chunk = 0;
	read_ptr = 0;
	}

ChunkedIOSSL::~ChunkedIOSSL()
	{
	if ( setup )
		{
		SSL_shutdown(ssl);

		// We don't care if the other side closes properly.
		setup = false;
		}

	if ( ssl )
		{
		SSL_free(ssl);
		ssl = 0;
		}

	close(socket);
	}


static int pem_passwd_cb(char *buf, int size, int rwflag, void *passphrase)
	{
	strncpy(buf, (char*) passphrase, size);
	buf[size - 1] = '\0';
	return strlen(buf);
	}

bool ChunkedIOSSL::Init()
	{
	// If the handshake doesn't succeed immediately we will
	// be called multiple times.
	if ( ! ctx )
		{
		SSL_load_error_strings();

		ctx = SSL_CTX_new(SSLv3_method());
		if ( ! ctx )
			{
			error("can't create SSL context");
			return false;
			}

		// We access global variables here. But as they are
		// declared const and we don't modify them this should
		// be fine.
		const char* key = ssl_private_key->AsString()->CheckString();

		if ( ! (key && *key &&
			SSL_CTX_use_certificate_chain_file(ctx, key)) )
			{
			error(fmt("can't read certificate from file %s", key));
			return false;
			}

		const char *passphrase = ssl_passphrase->AsString()->CheckString();

		if ( passphrase && ! streq(passphrase, "<undefined>") )
			{
			SSL_CTX_set_default_passwd_cb(ctx, pem_passwd_cb);
			SSL_CTX_set_default_passwd_cb_userdata(ctx, (void*) passphrase);
			}

		if ( ! (key && *key &&
			SSL_CTX_use_PrivateKey_file(ctx, key, SSL_FILETYPE_PEM)) )
			{
			error(fmt("can't read private key from file %s", key));
			return false;
			}

		const char* ca = ssl_ca_certificate->AsString()->CheckString();
		if ( ! (ca && *ca && SSL_CTX_load_verify_locations(ctx, ca, 0)) )
			{
			error(fmt("can't read CA certificate from file %s", ca));
			return false;
			}

		// Only use real ciphers.
		if ( ! SSL_CTX_set_cipher_list(ctx, "HIGH") )
			{
			error("can't set cipher list");
			return false;
			}

		// Require client certificate.
		SSL_CTX_set_verify(ctx,
			SSL_VERIFY_PEER | SSL_VERIFY_FAIL_IF_NO_PEER_CERT, 0);
		}

	int flags;

	if ( (flags = fcntl(socket, F_GETFL, 0)) < 0)
		{
		error(fmt("can't obtain socket flags: %s", strerror(errno)));
		return false;
		}

	if ( fcntl(socket, F_SETFL, flags|O_NONBLOCK) < 0 )
		{
		error(fmt("can't set socket to non-blocking: %s",
			  strerror(errno)));
		return false;
		}

	if ( ! ssl )
		{
		ssl = SSL_new(ctx);
		if ( ! ssl )
			{
			error("can't create SSL object");
			return false;
			}

		BIO* bio = BIO_new_socket(socket, BIO_NOCLOSE);
		SSL_set_bio(ssl, bio, bio);
		}

	int success;
	if ( server )
		success = last_ret = SSL_accept(ssl);
	else
		success = last_ret = SSL_connect(ssl);

	if ( success > 0 )
		{ // handshake done
		setup = true;
		return true;
		}

	int error = SSL_get_error(ssl, success);

	if ( success < 0 &&
	     (error == SSL_ERROR_WANT_WRITE || error == SSL_ERROR_WANT_READ) )
		// Handshake not finished yet, but that's ok for now.
		return true;

	// Some error.
	eof = true;
	return false;
	}

bool ChunkedIOSSL::Write(Chunk* chunk)
	{
#ifdef DEBUG
	DBG_LOG(DBG_CHUNKEDIO, "ssl write of size %d [%s]",
		chunk->len, fmt_bytes(chunk->data, 20));
#endif

	// Queue it.
	++write_pending;
	Queue* q = new Queue;
	q->chunk = chunk;
	q->next = 0;

	// Temporarily convert len into network byte order.
	chunk->len = htonl(chunk->len);

	if ( write_tail )
		{
		write_tail->next = q;
		write_tail = q;
		}
	else
		write_head = write_tail = q;

	Flush();
	return true;
	}

bool ChunkedIOSSL::WriteData(char* p, uint32 len, bool* error)
	{
	*error = true;

	int success = last_ret = SSL_write(ssl, p, len);

	if ( success == 0 )
		{
		// Connection closed.
		eof = true;
		return false;
		}

	if ( success < 0 )
		{
		int e = SSL_get_error(ssl, success);
		if ( e == SSL_ERROR_WANT_WRITE || e == SSL_ERROR_WANT_READ )
			*error = false; // We try again next time.
		else
			eof = true;
		return false;
		}

	// SSL guarantees us that all bytes have been written. That's nice. :-)
	*error = false;
	return true;
	}

bool ChunkedIOSSL::Flush()
	{
	if ( ! setup )
		{
		// We may need to finish the handshake.
		if ( ! Init() )
			return false;
		if ( ! setup )
			return true;
		}

	while ( write_head )
		{
		bool error;

		Chunk* c = write_head->chunk;

		if ( write_state == LEN )
			{
			if ( ! WriteData((char*)&c->len, sizeof(c->len), &error) )
				return ! error;
			write_state = DATA;

			// Convert back from network byte order.
			c->len = ntohl(c->len);
			}

		if ( ! WriteData(c->data, c->len, &error) )
			return ! error;

		// Chunk written, throw away.
		Queue* q = write_head;
		write_head = write_head->next;
		if ( ! write_head )
			write_tail = 0;
		--write_pending;
		delete q;

		delete [] c->data;
		delete c;

		write_state = LEN;
		}

	return true;
	}


bool ChunkedIOSSL::ReadData(char* p, uint32 len, bool* error)
	{
	*error = true;

	if ( ! read_ptr )
		read_ptr = p;

	while ( 1 )
		{
		int read = last_ret = SSL_read(ssl, read_ptr, len - (read_ptr - p) );

		if ( read == 0 )
			{
			// Connection closed.
			eof = true;
			return false;
			}

		if ( read < 0 )
			{
			int e = SSL_get_error(ssl, read);
			if ( e == SSL_ERROR_WANT_WRITE ||
			     e == SSL_ERROR_WANT_READ )
				*error = false; // We try again next time.
			else
				eof = true;
			return false;
			}

		read_ptr += read;

		if ( unsigned(read_ptr - p) == len )
			{
			// We read as much as requested..
			*error = false;
			read_ptr = 0;
			return true;
			}
		}

	return true;
	}

bool ChunkedIOSSL::Read(Chunk** chunk, bool mayblock)
	{
	*chunk = 0;

	if ( ! setup )
		{
		// We may need to finish the handshake.
		if ( ! Init() )
			return false;
		if ( ! setup )
			return true;
		}

	bool error;

	if ( read_state == LEN )
		{
		if ( ! read_chunk )
			{
			read_chunk = new Chunk;
			read_chunk->data = 0;
			}

		if ( ! ReadData((char*)&read_chunk->len, sizeof(read_chunk->len),
				&error) )
			return ! error;

		read_state = DATA;
		read_chunk->len = ntohl(read_chunk->len);
		}

	if ( ! read_chunk->data )
		read_chunk->data = new char[read_chunk->len];

	if ( ! ReadData(read_chunk->data, read_chunk->len, &error) )
		return ! error;

	// Chunk fully read. Pass it on.
	*chunk = read_chunk;
	read_chunk = 0;
	read_state = LEN;

#ifdef DEBUG
	if ( *chunk )
		DBG_LOG(DBG_CHUNKEDIO, "ssl read of size %d [%s]",
			(*chunk)->len, fmt_bytes((*chunk)->data, 20));
#endif

	return true;
	}

bool ChunkedIOSSL::CanRead()
	{
	// We will be called regularly. So take the opportunity
	// to flush the write buffer.
	Flush();

	if ( SSL_pending(ssl) )
		return true;

	fd_set fd_read;
	FD_ZERO(&fd_read);
	FD_SET(socket, &fd_read);

	struct timeval notimeout;
	notimeout.tv_sec = 0;
	notimeout.tv_usec = 0;

	return select(socket + 1, &fd_read, NULL, NULL, &notimeout) > 0;
	}

bool ChunkedIOSSL::CanWrite()
	{
	return write_head != 0;
	}

const char* ChunkedIOSSL::Error()
	{
	return ERR_error_string(SSL_get_error(ssl, last_ret), 0);
	}

void ChunkedIOSSL::Stats(char* buffer, int length)
	{
	int i = snprintf(buffer, length, "pending=%ld ", write_pending);
	ChunkedIO::Stats(buffer + i, length - i);
	buffer[length-1] = '\0';
	}

#endif	/* USE_OPENSSL */
