/* 
 * Copyright (c) 2008, 2010, Oracle and/or its affiliates. All rights reserved.
 *
 * This program is free software; you can redistribute it and/or
 * modify it under the terms of the GNU General Public License as
 * published by the Free Software Foundation; version 2 of the
 * License.
 * 
 * This program is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
 * GNU General Public License for more details.
 * 
 * You should have received a copy of the GNU General Public License
 * along with this program; if not, write to the Free Software
 * Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA
 * 02110-1301  USA
 */

#include <stdafx.h>

#include "wb_tunnel.h"
#include "wb_context.h"
#include "mforms/mforms.h"

#include "base/string_utilities.h"

#include <errno.h>
#ifdef _WIN32
#include <windows.h>
#else
#include <unistd.h>
#include <signal.h>
#include <sys/types.h>
#include <sys/wait.h>
#endif

#include "python_context.h"

#include <sstream>
#include "boost/scoped_array.hpp"

#include "base/log.h"
ENABLE_LOG("wbprivate")

using namespace wb;
using namespace base;

class tunnel_auth_error : public std::runtime_error
{
public:
  tunnel_auth_error(const std::string &err) : std::runtime_error(err) {}
};

class SSHTunnel : public sql::TunnelConnection
{
  TunnelManager *_tm;
  int _port;
  
public:
  SSHTunnel(TunnelManager *tm, int port)
  : _tm(tm), _port(port)
  {    
  }
  
  virtual ~SSHTunnel()
  {
    disconnect();
  }
  
  virtual int get_port()
  {
    return _port;
  }
  
  virtual void connect(db_mgmt_ConnectionRef connectionProperties)
  {    
    if (_port == 0)
      throw std::runtime_error("Could not connect SSH tunnel");
    
    _tm->wait_tunnel(_port);
    /*
    if (!g_str_has_prefix(result.c_str(), "OK"))
    {
      if (g_str_has_prefix(result.c_str(), "Private key file is encrypted"))
      {
        std::string password;
        _tm->wb()->request_input("Enter passphrase for key", 1, password);
        if (!password.empty())
        {
          grt::DictRef parameter_values= connectionProperties->parameterValues();
          parameter_values["sshPassword"]= grt::StringRef(password);
          connect(connectionProperties);
          return;
        }
      }

      throw std::runtime_error("Could not connect SSH tunnel: "+result);
    }*/
  }
  
  virtual void disconnect()
  {
    _tm->close_tunnel(_port);
  }
};



TunnelManager::TunnelManager(wb::WBContext *wb)
: _wb(wb)
{
}


void TunnelManager::start()
{
  std::string progpath = _wb->get_grt_manager()->get_basedir()+"/sshtunnel.py";
  
  grt::WillEnterPython lock;
  grt::PythonContext *py = grt::PythonContext::get();
  if (py->run_file(progpath.c_str(), false) < 0)
  {
    g_warning("Tunnel manager could not be executed");
    throw std::runtime_error("Cannot start SSH tunnel manager");
  }
  _tunnel = py->eval_string("TunnelManager()");
}


int TunnelManager::lookup_tunnel(const char *server, const char *username, const char *target)
{
  grt::WillEnterPython lock;
  
  // Note: without the (char*) cast gcc will complain about passing a const char* to a char*.
  //       Ideally the function signature should be changed to take a const char*.
  PyObject *ret = PyObject_CallMethod(_tunnel, (char*) "lookup_tunnel", (char*) "sss", server, username, target);
  if (!ret)
  {
    PyErr_Print();
    return -1;
  }
  if (ret == Py_None)
  {
    Py_XDECREF(ret);
    return -1;
  }
  int port = PyInt_AsLong(ret);
  Py_XDECREF(ret);
  return port;
}



void TunnelManager::shutdown()
{
  grt::WillEnterPython lock;
  PyObject *ret = PyObject_CallMethod(_tunnel, (char*) "shutdown", (char*) "", NULL);
  if (!ret)
  {
    PyErr_Print();
    return;
  }
  Py_XDECREF(ret);
}


int TunnelManager::open_tunnel(const char *server, const char *username, const char *password, 
                               const char *keyfile, const char *target)
{
  grt::WillEnterPython lock;
  PyObject *ret = PyObject_CallMethod(_tunnel, (char*) "open_tunnel", (char*) "sssss",
                                      server, username, password, keyfile, target);
  if (!ret)
  {
    PyErr_Print();
    throw std::runtime_error("Error calling TunnelManager.open_tunnel");
  }
  if (PyTuple_Size(ret) != 2)
  {
    Py_XDECREF(ret);
    throw std::runtime_error("TunnelManager.open_tunnel returned invalid value");
  }
  
  PyObject *status = PyTuple_GetItem(ret, 0);
  PyObject *value = PyTuple_GetItem(ret, 1);

  if (status == Py_False)
  {
    char *error = PyString_AsString(value);
    Py_XDECREF(ret);
    
    if (g_str_has_prefix(error, "Authentication error"))
      throw tunnel_auth_error(error);

    throw std::runtime_error(error);
  }
  else
  {
    int port = PyInt_AsLong(value);
    Py_XDECREF(ret);
    return port;    
  }
}


void TunnelManager::wait_tunnel(int port)
{
  grt::WillEnterPython lock;
  PyObject *ret = PyObject_CallMethod(_tunnel, (char*) "wait_connection", (char*) "i", port);
  if (!ret)
  {
    PyErr_Print();
    throw std::runtime_error("Error calling TunnelManager.wait_connection");
  }
  if (ret == Py_None)
  {
    Py_XDECREF(ret);
    return;
  }
  std::string str = PyString_AsString(ret);
  Py_XDECREF(ret);
  
  if (g_str_has_prefix(str.c_str(), "Authentication error"))
    throw tunnel_auth_error(str);
  
  throw std::runtime_error("Error connecting SSH tunnel: "+str);
}


void TunnelManager::close_tunnel(int port)
{
  grt::WillEnterPython lock;
  PyObject *ret = PyObject_CallMethod(_tunnel, (char*) "close", (char*) "i", port);
  if (!ret)
  {
    PyErr_Print();
    return;
  }
  Py_XDECREF(ret);
}


sql::TunnelConnection* TunnelManager::create_tunnel(db_mgmt_ConnectionRef connectionProperties)
{
  if (!_tunnel)
  {
    log_info("Starting tunnel");
    start();
  }

  sql::TunnelConnection* tunnel = 0;
  grt::DictRef parameter_values= connectionProperties->parameterValues();

  if (connectionProperties->driver()->name() == "MysqlNativeSSH")
  {
    std::string server = parameter_values.get_string("sshHost");
    std::string username = parameter_values.get_string("sshUserName");
    std::string password = parameter_values.get_string("sshPassword");
    std::string keyfile = bec::expand_tilde(parameter_values.get_string("sshKeyFile"));
    std::string target = parameter_values.get_string("hostName");
    int target_port = parameter_values.get_int("port", 3306);

    target = strfmt("%s:%i", target.c_str(), target_port);

    // before anything, check if a tunnel already exists for this server/user/target tuple
    _wb->get_grt_manager()->replace_status_text("Looking for existing SSH tunnel to "+server+"...");
    int tunnel_port;
    tunnel_port = lookup_tunnel(server.c_str(), username.c_str(), target.c_str());
    if (tunnel_port > 0)
    {
      _wb->get_grt_manager()->replace_status_text("Existing SSH tunnel found, connecting...");
      tunnel = new ::SSHTunnel(this, tunnel_port);
    }
    else
    { 
      bool reset_password = false;
    retry:
      
      _wb->get_grt_manager()->replace_status_text("Existing SSH tunnel not found, opening new one...");
      if (keyfile.empty() && password.empty())
      {
        // interactively ask user for password
        if (!mforms::Utilities::find_or_ask_for_password(_("Open SSH Tunnel"),
                                                        strfmt("ssh@%s", server.c_str()),
                                                        username,
                                                        reset_password,
                                                        password))
          // we need to throw an exception to signal that tunnel could not be opened (and not that it was not needed)
          throw std::runtime_error("SSH password input cancelled by user");
      }
      if (!keyfile.empty())
      {
        bool encrypted = true;
        char *contents = NULL;
        gsize length;
        // check if the keyfile is encrypted
        if (g_file_get_contents(keyfile.c_str(), &contents, &length, NULL) && contents)
        {
          if (!g_strstr_len(contents, length, "ENCRYPTED"))
            encrypted = false;
        }

        // interactively ask user for SSH key passphrase
        if (encrypted && !mforms::Utilities::find_or_ask_for_password(_("Open SSH Tunnel"),
                                                                     strfmt("ssh_keyfile@%s", keyfile.c_str()),
                                                                     username,
                                                                     reset_password,
                                                                     password))
          // we need to throw an exception to signal that tunnel could not be opened (and not that it was not needed)
          throw std::runtime_error("SSH key passphrase input cancelled by user");
      }
      
      _wb->get_grt_manager()->replace_status_text("Opening SSH tunnel to "+server+"...");
      try
      {
        tunnel_port = open_tunnel(server.c_str(), username.c_str(), password.c_str(), keyfile.c_str(), target.c_str());
        
        _wb->get_grt_manager()->replace_status_text("SSH tunnel opened, connecting...");
        
        tunnel = new ::SSHTunnel(this, tunnel_port);
      }
      catch (tunnel_auth_error &exc)
      {
        _wb->get_grt_manager()->replace_status_text("Authentication error opening SSH tunnel");
        if (mforms::Utilities::show_error("Could not connect the SSH Tunnel", exc.what(), _("Retry"), _("Cancel")) == mforms::ResultOk)
        {
          reset_password= true;
          goto retry;
        }
        else
          throw std::runtime_error("Cancelled");
      }
      catch (std::exception &exc)
      {
        _wb->get_grt_manager()->replace_status_text("Could not open SSH tunnel");
        throw std::runtime_error(std::string("Cannot open SSH Tunnel: ").append(exc.what()));
      }
    }
  }

  if (tunnel)
  {
    tunnel->connect(connectionProperties);
  }

  return tunnel;
}


TunnelManager::~TunnelManager()
{
  shutdown();
}
