#!/usr/bin/env python

# This file is part of Window-Switch.
# Copyright (c) 2009-2012 Antoine Martin <antoine@nagafix.co.uk>
# Window-Switch is released under the terms of the GNU GPL v3

import re

from winswitch.util.simple_logger import Logger
from winswitch.util.common import csv_list, visible_command, escape_newlines
from winswitch.consts import DELIMITER

from twisted.conch.ssh import common, channel
from collections import deque

LOG_LINE_DATA = False

BUFFER_PREVIOUS_LINES = 16
DEFAULT_MAX_BUFFER_SIZE = (256+1)*1024	#we're not expecting lines longer than this

class PortForward:

	def __init__(self, local_port, remote_host, remote_port, connection):
		Logger(self)
		self.local_port, self.remote_host, self.remote_port = local_port, remote_host, remote_port
		self.slog(None, local_port, remote_host, remote_port)
		connection.forward_port(local_port, remote_host, remote_port)

	def __str__(self):
		return	"PortForward(%s -> %s:%s)" % (self.local_port, self.remote_host, self.remote_port)

class ReverseForward:

	def __init__(self, local_port, remote_host, remote_port, connection):
		Logger(self)
		self.local_port, self.remote_host, self.remote_port = local_port, remote_host, remote_port
		connection.reverse_forward_port(local_port, remote_host, remote_port)

	def __str__(self):
		return	"ReverseForward(%s <- %s:%s)" % (self.local_port, self.remote_host, self.remote_port)


class ExecChannel(channel.SSHChannel):
	name = 'session'

	def __init__(self, command, connection):
		Logger(self)
		self.slog(None, command, connection)
		self.command = command
		self.connection = connection
		self._closed = False
		channel.SSHChannel.__init__(self, 2**16, 2**15, connection)
		connection.openChannel(self)

	def openFailed(self, reason):
		self.serror(None, reason)

	def channelOpen(self, ignoredData):
		self.slog("launching command=%s" % self.command, "[...]")
		self.data = ''
		d = self.conn.sendRequest(self, 'exec', common.NS(self.command), wantReply = 1)
		d.addCallback(self.execCallback)

	def execCallback(self, ignored):
		self.sdebug(None, ignored)

	def dataReceived(self, data):
		self.data += data

	def closed(self):
		self.slog("_closed=%s" % self._closed, "data=%s" % visible_command(self.data))
		if not self._closed:
			self._closed = True
			try:
				self.loseConnection()
			except Exception, e:
				self.serror("failed to loseConnection: %s" % e)

	def is_connected(self):
		return	not self._closed


class ExecLineChannel(ExecChannel):

	def __init__(self, command, ready_callback, line_callback, close_callback, connection):
		self.command = command
		self.ready_callback = ready_callback
		self.line_callback = line_callback
		self.close_callback = close_callback
		self.buffer = ''
		self.line_count = 0
		self.previous_lines = deque([])
		self.buffer_previous_lines = BUFFER_PREVIOUS_LINES
		self.max_buffer_size = DEFAULT_MAX_BUFFER_SIZE
		self.oversized_line = False

		self.started = False
		self.DEBUG = False
		self.LOG_ALL_DATA = False
		self.stopped = False
		ExecChannel.__init__(self, command, connection)
		self.sdebug(None, command, ready_callback, line_callback, close_callback, connection)

	def closed(self):
		self.slog()
		ExecChannel.closed(self)
		if self.close_callback:
			self.close_callback()

	def writeLine(self, data):
		if LOG_LINE_DATA:
			self.sdebug(None, visible_command(data))
		self.write("%s%s" % (data, DELIMITER))

	def stop(self, retry=False, message=None):
		self.slog("stopped=%s" % self.stopped, retry, message)
		if self.stopped:
			return
		self.stopped = True
		try:
			self.conn.sendEOF(self)
		except Exception, e:
			self.serror("failed to send EOF to %s: %s" % (self.conn, e), retry, message)
		self.closed()

	def is_connected(self):
		return	not self._closed and not self.stopped

	def execCallback(self, ignored):
		ExecChannel.execCallback(self, ignored)
		self.sdebug("will fire ready_callback=%s" % self.ready_callback, ignored)
		if self.ready_callback:
			self.ready_callback()

	def dataReceived(self, data):
		"""
		We buffer things until we get at least one full line of text to handle.
		"""
		if self.DEBUG:
			self.sdebug(None, visible_command(escape_newlines(data)))
		if self.buffer:
			data = "%s%s" % (self.buffer, data)
		lines = data.splitlines(True)
		last = lines[len(lines)-1]
		if self.LOG_ALL_DATA:
			self.sdebug("last=%s" % escape_newlines(last), escape_newlines(data))
			self.sdebug("lines=%s" % csv_list(lines, '"'), escape_newlines(data))
		if not (last.endswith('\n') or last.endswith('\r')):
			#save incomplete line in buffer
			self.buffer = last
			lines = lines[:len(lines)-1]
		else:
			self.buffer = ''

		for line in lines:
			self.line_count += 1
			while line.endswith("\n") or line.endswith("\r"):
				line = line[:len(line)-1]
			if self.oversized_line:
				self.oversized_line = False
				continue
			self.handle(line)
			#keep line in previous_lines fifo list
			if self.buffer_previous_lines>0:
				self.previous_lines.append(line)
				while len(self.previous_lines)>=self.buffer_previous_lines:
					self.previous_lines.popleft()

		if len(self.buffer)>self.max_buffer_size:
			self.oversized_line = True
			self.serror("dropping oversized line starting with: %s" % visible_command(self.buffer), visible_command(data))
			self.buffer = self.buffer[0:79]+"..."

	def handle(self, line):
		if LOG_LINE_DATA:
			self.sdebug(None, visible_command(line))
		#if we haven't received anything yet, check that this isn't the dreaded "command not found"...
		if not self.started:
			cmd = self.command.split(" ")[0]
			sre = r".*:\s*%s:\s*command not found" % cmd
			#not_found_re = re.compile(sre)
			#self.sdebug("sre=%s, not_found_re=%s, match=%s" % (sre, not_found_re, not_found_re.match(line)), line)
			if re.match(sre, line):
				self.stop(False, "Command not found!")
				return
			if line:
				self.started = True
		if self.line_callback:
			self.line_callback(line) 
