#!/usr/bin/env python

# This file is part of Window-Switch.
# Copyright (c) 2009-2012 Antoine Martin <antoine@nagafix.co.uk>
# Window-Switch is released under the terms of the GNU GPL v3

import os

from winswitch.util.simple_logger import Logger
logger = Logger("conch_util")
debug_import = logger.get_debug_import()

debug_import("common")
from winswitch.util.common import is_valid_file, csv_list, hash_text, visible_command
debug_import("consts")
from winswitch.consts import LOCALHOST
debug_import("main_loop")
from winswitch.util.main_loop import listenTCP, get_reactor, addSystemEventTrigger

# import those here so we can find out if that is causing the crash:
debug_import("Crypto")
from Crypto import Util
from Crypto.Cipher import XOR
assert Util is not None and XOR is not None

debug_import("twisted.conch.ssh")
from twisted.conch.ssh import transport, userauth, connection, keys, forwarding, common
debug_import("twisted.conch.ssh.agent")
from twisted.conch.ssh.agent import SSHAgentClient
debug_import("twisted.conch.error")
from twisted.conch.error import ConchError
debug_import("twisted.internet")
from twisted.internet import defer, protocol
debug_import("twisted.version")
from twisted import version
TWISTED_MAJOR = version.major


#for testing, turns off all authentication except for the ssh agent:
USE_AGENT = True
ONLY_AGENT = False

MAX_ATTEMPTS=3

class ConchTransport(transport.SSHClientTransport):

	def __init__(self):
		Logger(self)
		addSystemEventTrigger('before', 'shutdown', lambda *args : self.loseConnection())

	def verifyHostKey(self, hostKey, fingerprint):
		self.slog("dialog_util=%s, factory=%s, hostkey_fingerprint=%s" % (self.factory.dialog_util, self.factory, self.factory.hostkey_fingerprint), "[...]", fingerprint)
		d = defer.Deferred()
		if (fingerprint and fingerprint==self.factory.hostkey_fingerprint):
			d.callback(1)
		elif self.factory.dialog_util:
			def nok(*args):
				d.errback(0)
			def ok(*args):
				self.factory.set_fingerprint(fingerprint)
				d.callback(0)
			if self.factory.hostkey_fingerprint:
				title = "Warning: SSH fingerprint mismatch!"
				text = "\nThe current fingerprint recorded for %s is:\n%s" % (self.factory.get_host_info(), self.factory.hostkey_fingerprint)
				text += "\n\nThe new fingerprint supplied by the server is:\n%s" % fingerprint
				text += "\n\nTo prevent 'man-in-the-middle' attacks, please ensure this is correct!"
			else:
				title = "Confirm SSH Host Key Fingerprint"
				text = "Please confirm the SSH host key fingerprint"
				text += "\nfor %s:\n\n%s" % (self.factory.get_host_info(), fingerprint)
			uuid = "verifyHostKey-%s" % fingerprint
			self.factory.dialog_util.ask(title, text, nok, ok, icon=None, UUID=uuid)
		else:
			d.errback(Exception("unkown host key:%s" % fingerprint))
		return	d

	def connectionSecure(self):
		auth = ConchUserAuth(self.factory, ConchConnection(self.factory))
		self.slog("auth=%s" % auth)
		self.requestService( auth )

	def loseConnection(self):
		self.slog(None)
		transport.SSHClientTransport.loseConnection(self)

	def connectionMade(self):
		self.slog(None)
		transport.SSHClientTransport.connectionMade(self)



class ConchUserAuth(userauth.SSHUserAuthClient):
	""" The retry code is a bit ugly (duplicates upstream code)
	But this is the only way that I found to make it try all authentication methods
	before asking the user about passphrases or passwords """

	def __init__(self, factory, connection):
		Logger(self, log_colour=Logger.MAGENTA)
		self.slog(None, factory, connection)
		self.factory = factory
		self.keyAgent = None
		self.agentKeys = []
		self.agentKeyIndex = 0
		self.user = factory.username
		self.attempts = 0
		self.skip_password = False
		self.skip_passphrase = False
		self.skip_agent = False
		userauth.SSHUserAuthClient.__init__(self, self.user, connection)

	def serviceStopped(self):
		if self.keyAgent:
			self.keyAgent.transport.loseConnection()
			self.keyAgent = None

	def serviceStarted(self):
		agent_socket_filename = os.environ.get('SSH_AUTH_SOCK')
		self.slog("agent_socket_filename=%s" % agent_socket_filename)
		if USE_AGENT and agent_socket_filename:
			self.slog("starting agent")
			cc = protocol.ClientCreator(get_reactor(), SSHAgentClient)
			d = cc.connectUNIX(agent_socket_filename)
			def ebSetAgent(a):
				self.serror(None, a)
				userauth.SSHUserAuthClient.serviceStarted(self)
			def setAgent(a):
				self.slog(None, a)
				self.keyAgent = a
				d = self.keyAgent.requestIdentities()
				def got_identities(key_list):
					self.sdebug(None, key_list)
					try:
						i = 0
						for data,filename in key_list:
							self.slog("key[%s]=%s (%s bytes)" % (i, filename, len(data)), "%s identities" % len(key_list))
							i += 1
					except:
						pass
					self.agentKeys = key_list
					userauth.SSHUserAuthClient.serviceStarted(self)
				def no_identities(msg):
					self.serror(None, msg)
					userauth.SSHUserAuthClient.serviceStarted(self)
				d.addCallback(got_identities)
				d.addErrback(no_identities)
				return d
			d.addCallback(setAgent)
			d.addErrback(ebSetAgent)
		else:
			userauth.SSHUserAuthClient.serviceStarted(self)

	def signData(self, publicKey, signData):
		"""
		Extend the base signing behavior by using an SSH agent to sign the
		data, if one is available.
		@type publicKey: L{Key}
		@type signData: C{str}
		"""
		self.sdebug(None, type(publicKey), "%s chars" % len(signData))
		if self.keyAgent and not self.skip_agent: # agent key
			return self.keyAgent.signData(publicKey.blob(), signData)
		else:
			return userauth.SSHUserAuthClient.signData(self, publicKey, signData)
		
	def _cbUserauthFailure(self, result, iterator):
		"""
		Duplicated code from SSHUserAuthClient (twisted 10.1) so we can hook up re-tries for this version..
		"""
		self.slog("attempts=%s" % self.attempts, result, iterator)
		if TWISTED_MAJOR<10:
			userauth.SSHUserAuthClient._cbUserauthFailure(self, result, iterator)
			return
		if result:
			return
		try:
			method = iterator.next()
		except StopIteration:
			if self.attempts<MAX_ATTEMPTS:
				self.attempts += 1
				userauth.SSHUserAuthClient.serviceStarted(self)
			else:
				self.transport.sendDisconnect(transport.DISCONNECT_NO_MORE_AUTH_METHODS_AVAILABLE, 'no more authentication methods available')
		else:
			d = defer.maybeDeferred(self.tryAuth, method)
			d.addCallback(self._cbUserauthFailure, iterator)
			return d

	def ssh_USERAUTH_FAILURE(self, packet):
		self.slog("TWISTED_MAJOR=%s" % TWISTED_MAJOR, visible_command(str(packet)))
		if TWISTED_MAJOR<10:
			self.ssh_USERAUTH_FAILURE_8_2(packet)
		else:
			userauth.SSHUserAuthClient.ssh_USERAUTH_FAILURE(self, packet);

	def ssh_USERAUTH_FAILURE_8_2(self, packet):
		"""
		Duplicated code from SSHUserAuthClient (twisted 8.2) so we can hook up re-tries for this version..
		"""
		self.slog("attempts=%s" % self.attempts, visible_command(str(packet)))
		canContinue, partial = common.getNS(packet)
		canContinue = canContinue.split(',')
		partial = ord(partial)
		if partial:
			self.authenticatedWith.append(self.lastAuth)
		def _(x, y):
			try:
				i1 = self.preferredOrder.index(x)
			except ValueError:
				return 1
			try:
				i2 = self.preferredOrder.index(y)
			except ValueError:
				return -1
			return cmp(i1, i2)
		canContinue.sort(_)
		for method in canContinue:
			if method not in self.authenticatedWith and self.tryAuth(method):
				return
		if self.attempts<MAX_ATTEMPTS:
			self.attempts += 1
			userauth.SSHUserAuthClient.serviceStarted(self)
		else:
			self.transport.sendDisconnect(transport.DISCONNECT_NO_MORE_AUTH_METHODS_AVAILABLE,
									'no more authentication methods available')

	def getPassword(self):
		self.slog("attempts=%s, skip_password=%s" % (self.attempts, self.skip_password))
		if self.skip_password or ONLY_AGENT:
			return
		d = defer.Deferred()
		if self.attempts>0 and self.factory.dialog_util:
			def nok(*args):
				self.skip_password = True
				d.errback(0)
			def ok(new_password, save_it):
				self.slog(None, hash_text(new_password), save_it)
				if save_it:
					self.factory.set_password(new_password)
				d.callback(new_password)
			title = "Password for %s" % self.factory.get_host_info()
			if self.factory.password:
				text = "The current SSH password saved for this host is incorrect"
				text += "\nPlease try again"
			else:
				self.factory.get_host_info()
				text = "Please enter the SSH password for this host"
			uuid = "getPassword-%s" % self.factory.host
			self.factory.dialog_util.ask(title, text, nok, ok, password=True, ask_save_password="Save password", icon=None, UUID=uuid)
			return	d
		elif self.attempts==0 and self.factory.password:
			d.callback(self.factory.password)
			return d
		else:
			return

	def getPublicKey(self):
		self.slog("agentKeyIndex=%s, %s agentKeys" % (self.agentKeyIndex, len(self.agentKeys)))
		""" first we try all the keys from the agent: """
		while self.agentKeyIndex<len(self.agentKeys):
			key_data,filename = self.agentKeys[self.agentKeyIndex]
			self.agentKeyIndex += 1
			try:
				key = keys.Key.fromString(key_data)
				self.slog("parsed agent key[%s]=%s %s" % (self.agentKeyIndex, filename, visible_command(key)))
				return key
			except Exception, e:
				self.serr("parsing key_data(%s)=%s" % (filename, key_data), e)
		""" next we try the key filename specified (if any) """
		self.skip_agent = True
		if ONLY_AGENT:
			return
		self.slog("testing public_key=%s" % self.factory.public_key)
		if not is_valid_file(self.factory.public_key):
			# the file doesn't exist, or we've tried a public key
			return	None
		try:
			pubk =	keys.Key.fromFile(filename=self.factory.public_key).blob()
		except ValueError, e:
			self.serr("failed to load public key %s" % self.factory.public_key, e)
			return	None
		self.slog("key(%s)=%s" % (self.factory.public_key, len(pubk or [])))
		if pubk in self.triedPublicKeys:
			self.slog("key %s has already been tried..." % self.factory.public_key)
			return	None
		return pubk

	def getPrivateKey(self):
		if ONLY_AGENT:
			return
		self.slog("skip_passphrase=%s" % self.skip_passphrase)
		if self.skip_passphrase:
			return
		def privateKeyFromFile(filename, passphrase):
			if not is_valid_file(filename):
				return
			lp = hash_text(passphrase)
			try:
				key =	keys.Key.fromFile(filename, passphrase=passphrase).keyObject
				self.slog("=%s" % key, filename, lp)
				return key
			except keys.EncryptedKeyError, e:
				#ask user for passphrase
				self.serror("Wrong passphrase? : %s" % e, filename, lp)
			except keys.BadKeyError, e:
				#can be bad passphrase
				self.serr(None, e, filename, lp)
			except ValueError, e:
				self.serr("unsupported key file format?", e, filename, lp)
			return
		key = privateKeyFromFile(self.factory.private_key, self.factory.key_passphrase)
		if key:
			#dump_key(key)
			return defer.succeed(key)
		elif self.factory.dialog_util:
			d = defer.Deferred()
			def nok(*args):
				self.sdebug(None, *args)
				self.skip_passphrase = True
				d.errback(Exception("user declined passphrase request"))
			def ok(new_passphrase, save_it):
				self.slog(None, hash_text(new_passphrase), save_it)
				try:
					key = privateKeyFromFile(self.factory.private_key, new_passphrase)
				except Exception, e:
					self.serr("cannot load private key", e, "(...)", save_it)
					key = None
				if (key):
					if save_it:
						self.factory.set_passphrase(new_passphrase)
					d.callback(key)
				else:
					ask_for_passphrase()

			def ask_for_passphrase():
				title = "Your SSH key requires a passphrase"
				text = "Please provide the passphrase for your key located at:\n%s" % self.factory.private_key
				uuid = "getPrivateKey-%s" % self.factory.host
				self.factory.dialog_util.ask(title, text, nok, ok, password=True, ask_save_password="Save this passphrase", icon=None, UUID=uuid)

			ask_for_passphrase()
			return	d
		else:
			return


class PortForwardChannel(forwarding.SSHListenClientForwardingChannel):

	def __init__(self, conn, callback, errback):
		Logger(self)
		self.slog(None, conn, callback, errback)
		self.channel_callback = callback
		self.channel_errback = errback
		forwarding.SSHListenClientForwardingChannel.__init__(self, conn=conn)

	def channelOpen(self, specificData):
		#self.sdebug("conn=%s, callback=%s" % (self.conn, self.factory.channel_callback), specificData)
		self.sdebug("conn=%s" % self.conn, specificData)
		forwarding.SSHListenClientForwardingChannel.channelOpen(self, specificData)
		if self.channel_callback:
			self.channel_callback(specificData)
		#if self.factory.channel_callback:
		#	self.factory.channel_callback()

	def openFailed(self, reason):
		self.sdebug("errback=%s" % self.channel_errback, reason)
		forwarding.SSHListenClientForwardingChannel.openFailed(self, reason)
		if self.channel_errback:
			self.channel_errback(reason)


class ConchConnection(connection.SSHConnection):

	def __init__(self, factory):
		Logger(self)
		self.slog(None, factory)
		self.factory = factory
		self.forwarding_sockets = {}
		self.remote_forwards = {}
		connection.SSHConnection.__init__(self)

	def serviceStarted(self):
		self.slog("opening channels: %s" % csv_list(self.factory.channel_constructors))
		connection.SSHConnection.serviceStarted(self)
		for c in self.factory.channel_constructors:
			c(self)

	def serviceStopped(self):
		self.slog(None)
		# Stop forwarding sockets
		for socket, _, _, _ in self.forwarding_sockets.values():
			socket.stopListening()
		connection.SSHConnection.serviceStopped(self)

	def openChannel(self, channel, extra=''):
		self.slog(None, repr(channel), repr(extra))
		connection.SSHConnection.openChannel(self, channel, extra)

	def channelClosed(self, channel):
		self.sdebug(None, channel)
		connection.SSHConnection.channelClosed(self, channel)

	def forward_port(self, local_port, remote_host, remote_port, callback, errback):
		self.slog(None, local_port, remote_host, remote_port)
		forward_factory = forwarding.SSHListenForwardingFactory(self, (remote_host, remote_port), lambda conn : PortForwardChannel(conn, callback, errback))
		socket = listenTCP(local_port, forward_factory)
		spec = socket, local_port, remote_host, remote_port
		self.forwarding_sockets[local_port] = spec
		return spec

	def reverse_forward_port(self, local_port, remote_host, remote_port):
		self.slog(None, local_port, remote_host, remote_port)
		data = forwarding.packGlobal_tcpip_forward((remote_host, remote_port))
		d = self.sendGlobalRequest('tcpip-forward', data, wantReply=1)
		d.addCallback(self._cbRemoteForwarding, local_port, remote_host, remote_port)
		d.addErrback(self._ebRemoteForwarding, local_port, remote_host, remote_port)
		return	d

	def _cbRemoteForwarding(self, result, local_port, remote_host, remote_port):
		self.slog(None, result, local_port, remote_host, remote_port)
		self.remote_forwards[(remote_host, remote_port)] = local_port

	def _ebRemoteForwarding(self, failure, local_port, remote_host, remote_port):
		self.serror(None, failure, local_port, remote_host, remote_port)

	def cancelRemoteForwarding(self, remote_host, remote_port):
		self.slog(None, remote_host, remote_port)
		data = forwarding.packGlobal_tcpip_forward((remote_host, remote_port))
		self.sendGlobalRequest('cancel-tcpip-forward', data)
		try:
			del self.remote_forwards[(remote_host, remote_port)]
		except:
			pass

	def channel_forwarded_tcpip(self, windowSize, maxPacket, data):
		self.slog(None, windowSize, maxPacket, repr(data))
		remote, orig = forwarding.unpackOpen_forwarded_tcpip(data)
		self.slog("remote=%s, orig=%s" % (csv_list(remote), orig), windowSize, maxPacket, repr(data))
		if remote not in self.remote_forwards:
			raise ConchError(connection.OPEN_CONNECT_FAILED, "don't know about that port")
		connect = self.remote_forwards.get(remote)
		self.slog("connect=%s" % (connect,))
		return forwarding.SSHConnectForwardingChannel((LOCALHOST, connect),
								remoteWindow = windowSize, remoteMaxPacket = maxPacket, conn = self)


class ConchFactory(protocol.ClientFactory):
	#path=os.path.expanduser('~/.ssh/id_dsa')
	#path+".pub", path
	def __init__(self, username, password, host, port, public_key, private_key, key_passphrase, hostkey_fingerprint):
		Logger(self)
		self.slog(None, username, hash_text(password), host, port, public_key, hash_text(private_key), hash_text(key_passphrase), hostkey_fingerprint)
		self.protocol = ConchTransport
		self.username = username
		self.password = password
		self.host = host
		self.port = port
		self.public_key = public_key
		self.private_key = private_key
		self.key_passphrase = key_passphrase
		self.hostkey_fingerprint = hostkey_fingerprint
		self.channel_constructors = []
		self.server_name = None
		self.dialog_util = None

	def get_host_info(self):
		if self.server_name:
			return	"%s at %s:%s" % (self.server_name, self.host, self.port)
		return	"%s:%s" % (self.host, self.port)

	def set_fingerprint(self, fingerprint):
		self.sdebug("ignored", fingerprint)

	def set_password(self, new_password):
		self.sdebug("ignored", hash_text(new_password))

	def set_passphrase(self, new_passphrase):
		self.sdebug("ignored", hash_text(new_passphrase))

	def clientConnectionLost(self, connector, reason):
		self.slog(None, connector, reason)
		protocol.ClientFactory.clientConnectionLost(self, connector, reason)

	def clientConnectionFailed(self, connector, reason):
		self.slog(None, connector, reason)
		protocol.ClientFactory.clientConnectionLost(self, connector, reason)
