/* 
 * (c) 2009-2010 Sun Microsystems, Inc.
 *
 * This library is free software; you can redistribute it and/or
 * modify it under the terms of the GNU Lesser General Public
 * License as published by the Free Software Foundation; either
 * version 2 of the License, or (at your option) any later version.
 *
 * This library is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
 * Lesser General Public License for more details.
 *
 * You should have received a copy of the GNU Lesser General Public
 * License along with this library; if not, write to the
 * Free Software Foundation, Inc., 59 Temple Place - Suite 330,
 * Boston, MA 02111-1307, USA.
 */

#include "stdafx.h"

#ifndef _WIN32
#include <list>
#include <string>
#endif

#ifdef _WIN32
#include <windows.h>
#endif

#include "driver_manager.h"
#include <gmodule.h>
#include <cppconn/connection.h>
#include <cppconn/statement.h>
#include <cppconn/exception.h>
#include <boost/foreach.hpp>
#include <boost/scoped_array.hpp>
#include <sstream>

#include "boost/foreach.hpp"

namespace sql {

typedef std::map<std::string, std::string> Param_types;


static bool conv_to_dbc_value(const std::string &key, const grt::ValueRef value, ConnectOptionsMap &properties, Param_types &param_types)
{
  ConnectPropertyVal tmp;

  switch (value.type())
  {
  case grt::IntegerType:
    {
    grt::IntegerRef val= grt::IntegerRef::cast_from(value);
    std::string param_type;
    if (param_types.find(key) != param_types.end())
      param_type= param_types[key];
    if (param_type == "tristate")
      tmp = (int)(*val != 0);
    else if (param_type == "boolean")
      tmp = (bool)(*val != 0);
    else
      tmp = (int)(*val);
    properties[key] = tmp;
    }
    break;

  case grt::DoubleType:
    {
    grt::DoubleRef val= grt::DoubleRef::cast_from(value);
    tmp = *val;
    properties[key] = tmp;
    }
    break;

  case grt::StringType:
    {
    grt::StringRef val= grt::StringRef::cast_from(value);
    tmp = SQLString(val.c_str(), (*val).length());
    properties[key] = tmp;
    }
    break;

  case grt::UnknownType:
  //case grt::AnyType: // equal to grt::UnknownType
  case grt::ListType:
  case grt::DictType:
  case grt::ObjectType:
    break;
  }

  return true;
}


DriverManager *DriverManager::getDriverManager()
{
  static DriverManager *dm= new DriverManager;
  return dm;
}


DriverManager::DriverManager()
:
_driver_path(".")
{
}

  

void DriverManager::setTunnelFactoryFunction(TunnelFactoryFunction function)
{
  _createTunnel = function;
}

  
void DriverManager::setPasswordFindFunction(PasswordFindFunction function)
{
  _findPassword = function;
}


void DriverManager::setPasswordRequestFunction(PasswordRequestFunction function)
{
  _requestPassword = function;
}
  

  
void DriverManager::set_driver_dir(const std::string &path)
{
  _driver_path= path;
}

TunnelConnection *DriverManager::getTunnel(const db_mgmt_ConnectionRef &connectionProperties)
{
  if (_createTunnel)
    return _createTunnel(connectionProperties);
  return 0;
}
  
#define MYSQL_PASSWORD_CACHE_TIMEOUT 10

ConnectionWrapper DriverManager::getConnection(const db_mgmt_ConnectionRef &connectionProperties, ConnectionInitSlot connection_init_slot)
{
  TunnelConnection* tunnel = NULL;
  // 0. determine correct driver filename
  
  db_mgmt_DriverRef drv = connectionProperties->driver();
  
  std::string library= "";
  if (drv.is_valid())
    library = drv->driverLibraryName();
  else
    throw SQLException("Invalid connection settings: undefined connection driver");
#ifdef _WIN32
  library.append(".dll");
#elif defined(__APPLE__)
  library.append(".dylib");
#else
  library.append(".so");
#endif

  // 1. find driver

  GModule *gmodule= g_module_open((_driver_path + "/" + library).c_str(), G_MODULE_BIND_LOCAL);
  if (NULL == gmodule)
  {
    fprintf(stderr, "Error: %s", g_module_error());
    throw SQLException(std::string("Database driver: Failed to open library '").append(_driver_path + "/" + library).append("'. Check settings.").c_str());
  }
  
  Driver *(* get_driver_instance)()= NULL;
  g_module_symbol(gmodule, "sql_mysql_get_driver_instance", (gpointer*)&get_driver_instance);
  if (NULL == get_driver_instance)
    throw SQLException("Database driver: Failed to get library instance. Check settings.");

  // 2. call driver->connect()
  Param_types param_types;
  {
    grt::ListRef<db_mgmt_DriverParameter> params= connectionProperties->driver()->parameters();
    for (size_t n= 0, count= params.count(); n < count; ++n)
    {
      db_mgmt_DriverParameterRef param= params.get(n);
      param_types[param->name()]= param->paramType();
    }
  }
  //std::map<std::string, ConnectPropertyVal> properties;
  ConnectOptionsMap properties;
  grt::DictRef parameter_values= connectionProperties->parameterValues();
  parameter_values.foreach(sigc::bind(&conv_to_dbc_value, sigc::ref(properties), sigc::ref(param_types)));

  {
    ConnectPropertyVal tmp;
    const int conn_timeout = 30;
    if (properties.find("OPT_CONNECT_TIMEOUT") == properties.end())
      properties["OPT_CONNECT_TIMEOUT"]= conn_timeout;
    if (properties.find("OPT_READ_TIMEOUT") == properties.end())
      properties["OPT_READ_TIMEOUT"]= conn_timeout;
  }
  properties["CLIENT_MULTI_STATEMENTS"]= true;

#ifdef _WIN32
  // If we are on a pipe connection then set the host name explicitely.
  // However, pipe connections can only be established on the local box (Win only).
  if (drv->name() == "MysqlNativeSocket")
  {
    ConnectPropertyVal host = std::string(".");
    properties["hostName"]= host;
  }
#endif

  if (_createTunnel)
  {
    tunnel = _createTunnel(connectionProperties);

    if (tunnel)
    {
      // this can throw an exception if the tunnel can't be created
      //!tunnel->connect();
      
      // make the driver connect to the local tunnel port
      properties["port"]= tunnel->get_port();
      properties["hostName"]= sql::SQLString("127.0.0.1");
    }
  }
  

  // Check if there is a stored or cached password, if there isn't try without one (blank)
  // If we get an auth error, then we ask for the password
  bool force_ask_password = false;
retry:
  // password not in profile (and no keyfile provided)
  if (_requestPassword && (force_ask_password || parameter_values.get_string("password") == ""))
  {
    // check if we have cached the password for this connection
    if (time(NULL) - _cacheTime > MYSQL_PASSWORD_CACHE_TIMEOUT || force_ask_password)
    {
      _cacheKey.clear();
      _cachedPassword.clear();
    }
  
    std::string key = connectionProperties.repr();
    if (key == _cacheKey)
      properties["password"] = _cachedPassword;
    else
    {
      if (force_ask_password)
        _cachedPassword = _requestPassword(connectionProperties, force_ask_password);
      else
      {
        bool is_cached_password_found= _findPassword(connectionProperties, _cachedPassword);
        if (!is_cached_password_found)
          _cachedPassword= ""; // try no password
      } 
      properties["password"] = _cachedPassword;
      _cacheKey = key;
    }
    _cacheTime = time(NULL);
  }
  
  // passing some empty values confuse the connector
  {
    std::list<std::string> prop_names;
    prop_names.push_back("socket");
    prop_names.push_back("schema");
    BOOST_FOREACH (const std::string &prop_name, prop_names)
    {
      ConnectOptionsMap::iterator prop_iter= properties.find(prop_name);
      if (properties.end() != prop_iter)
      {
        sql::SQLString &val= boost::get<sql::SQLString>(prop_iter->second);
        if (val->empty())
          properties.erase(prop_iter);
      }
    }
  }
  
  Driver *driver= get_driver_instance();
  if (NULL == driver)
    throw SQLException("Database driver: Failed to get driver instance. Check  settings.");

  try
  {
    std::auto_ptr<Connection> conn(driver->connect(properties));

    connection_init_slot(conn.get(), connectionProperties);
    
    std::string def_schema= parameter_values.get_string("schema", "");
    if (!def_schema.empty())
      conn->setSchema(def_schema);
    
    return ConnectionWrapper(conn, tunnel);
  }
  catch (sql::SQLException &exc)
  {
    // authentication error
    if (exc.getErrorCode() == 1045)
    {
      if (!force_ask_password)
      {
        // ask for password again, this time disablig the password caching
        force_ask_password = true;
        goto retry;
      }
    }
    throw;
  }
  catch (...)
  {
    _cacheKey.clear();
    _cachedPassword.clear();
    
    throw;
  }
}


} // namespace sql
