#include "stdafx.h"
#include <map>
#include "grtpp_module_cpp.h"
#include "cppdbc.h"

#include "grts/structs.db.mgmt.h"

#include <cppdbc.h>
#include <memory>

class DbMySQLQueryImpl : public grt::ModuleImplBase
{
public:
  DbMySQLQueryImpl(grt::CPPModuleLoader *loader) 
  : grt::ModuleImplBase(loader),
  _last_error_code(0), _connection_id(0), _resultset_id(0), _tunnel_id(0)
  {
    _mutex = g_mutex_new();
  }

  virtual ~DbMySQLQueryImpl()
  {
    g_mutex_free(_mutex);
  }

  DEFINE_INIT_MODULE("1.0", "MySQL AB", grt::ModuleImplBase,
                  DECLARE_MODULE_FUNCTION(DbMySQLQueryImpl::openConnection), 
                  DECLARE_MODULE_FUNCTION(DbMySQLQueryImpl::closeConnection),
                  DECLARE_MODULE_FUNCTION(DbMySQLQueryImpl::lastError),
                  DECLARE_MODULE_FUNCTION(DbMySQLQueryImpl::lastErrorCode),
                  DECLARE_MODULE_FUNCTION(DbMySQLQueryImpl::execute),
                  DECLARE_MODULE_FUNCTION(DbMySQLQueryImpl::executeQuery),
                  DECLARE_MODULE_FUNCTION(DbMySQLQueryImpl::resultNumRows),
                  DECLARE_MODULE_FUNCTION(DbMySQLQueryImpl::resultNumFields),
                  DECLARE_MODULE_FUNCTION(DbMySQLQueryImpl::resultFieldType),
                  DECLARE_MODULE_FUNCTION(DbMySQLQueryImpl::resultFieldName),
                  DECLARE_MODULE_FUNCTION(DbMySQLQueryImpl::resultNextRow),
                  DECLARE_MODULE_FUNCTION(DbMySQLQueryImpl::resultFieldIntValue),
                  DECLARE_MODULE_FUNCTION(DbMySQLQueryImpl::resultFieldDoubleValue),
                  DECLARE_MODULE_FUNCTION(DbMySQLQueryImpl::resultFieldStringValue),
                  DECLARE_MODULE_FUNCTION(DbMySQLQueryImpl::resultFieldIntValueByName),
                  DECLARE_MODULE_FUNCTION(DbMySQLQueryImpl::resultFieldDoubleValueByName),
                  DECLARE_MODULE_FUNCTION(DbMySQLQueryImpl::resultFieldStringValueByName),
                  DECLARE_MODULE_FUNCTION(DbMySQLQueryImpl::closeResult),
                  DECLARE_MODULE_FUNCTION(DbMySQLQueryImpl::loadSchemata),
                  DECLARE_MODULE_FUNCTION(DbMySQLQueryImpl::loadSchemaObjects),
                  DECLARE_MODULE_FUNCTION(DbMySQLQueryImpl::loadSchemaList),
                  DECLARE_MODULE_FUNCTION(DbMySQLQueryImpl::loadSchemaObjectList),

                  DECLARE_MODULE_FUNCTION(DbMySQLQueryImpl::generateDdlScript),
                  DECLARE_MODULE_FUNCTION(DbMySQLQueryImpl::openTunnel),
                  DECLARE_MODULE_FUNCTION(DbMySQLQueryImpl::getTunnelPort),
                  DECLARE_MODULE_FUNCTION(DbMySQLQueryImpl::closeTunnel),
                  NULL);

  // returns connection-id or -1 for error
  int openConnection(const db_mgmt_ConnectionRef &info);
  int closeConnection(int conn);

  std::string lastError();
  int lastErrorCode();

  // returns 1/0 for ok, -1 for error
  int execute(int conn, const std::string &query);

  // returns result-id or -1 for error
  int executeQuery(int conn, const std::string &query);

  int resultNumRows(int result);
  int resultNumFields(int result);
  std::string resultFieldType(int result, int field);
  std::string resultFieldName(int result, int field);
  // returns 1 if ok, 0 if no more rows
  int resultNextRow(int result);
  int resultFieldIntValue(int result, int field);
  double resultFieldDoubleValue(int result, int field);
  std::string resultFieldStringValue(int result, int field);
  
  int resultFieldIntValueByName(int result, const std::string &field);
  double resultFieldDoubleValueByName(int result, const std::string &field);
  std::string resultFieldStringValueByName(int result, const std::string &field);
  
  int closeResult(int result);

  int loadSchemata(int conn, grt::StringListRef schemata);
  int loadSchemaObjects(int conn, grt::StringRef schema, grt::StringRef object_type, grt::DictRef objects);

  grt::StringListRef loadSchemaList(int conn);
  grt::DictRef loadSchemaObjectList(int conn, grt::StringRef schema, grt::StringRef object_type);

  std::string generateDdlScript(grt::StringRef schema, grt::DictRef objects);

  
  // open SSH tunnel using the connection info
  // returns tunnel id or 0 if no tunnel needed
  int openTunnel(const db_mgmt_ConnectionRef &info);
  int getTunnelPort(int tunnel);
  int closeTunnel(int tunnel);

  std::string scramblePassword(const std::string& pass);

private:
  GMutex *_mutex;
  std::map<int, sql::ConnectionWrapper> _connections;
  std::map<int, sql::ResultSet*> _resultsets;
  std::map<int, sql::TunnelConnection*> _tunnels;
  std::string _last_error;
  int _last_error_code;

  int _connection_id;
  int _resultset_id;
  int _tunnel_id;
};



GRT_MODULE_ENTRY_POINT(DbMySQLQueryImpl);

#define CLEAR_ERROR() do { _last_error.clear(); _last_error_code = 0; } while (0)


struct Lock
{
  GMutex *mutex;

  Lock(GMutex *mtx)
  {
    mutex = mtx;
    g_mutex_lock(mutex);
  }

  ~Lock()
  {
    g_mutex_unlock(mutex);
  }
};


int DbMySQLQueryImpl::openConnection(const db_mgmt_ConnectionRef &info)
{
  sql::DriverManager *dm = sql::DriverManager::getDriverManager();
  
  if (!info.is_valid())
    throw std::invalid_argument("connection info is NULL");

  int new_connection_id = -1;

  CLEAR_ERROR();
  try
  {
    Lock lock(_mutex);
    sql::ConnectionWrapper conn(dm->getConnection(info));
    new_connection_id = ++_connection_id;
    _connections[new_connection_id] = conn;
  }
  catch (sql::SQLException &exc)
  {
    _last_error = exc.what();
    _last_error_code = exc.getErrorCode();
    Lock lock(_mutex);
    if (_connections.find(new_connection_id) != _connections.end())
      _connections.erase(new_connection_id);
    return -1;    
  }
  catch (std::exception &exc)
  {
    _last_error = exc.what();
    Lock lock(_mutex);
    if (_connections.find(new_connection_id) != _connections.end())
      _connections.erase(new_connection_id);
    return -1;
  }
  
  return new_connection_id;
}


int DbMySQLQueryImpl::closeConnection(int conn)
{
  CLEAR_ERROR();

  Lock lock(_mutex);
  if (_connections.find(conn) == _connections.end())
    throw std::invalid_argument("Invalid connection");

  _connections.erase(conn);

  return 0;
}


std::string DbMySQLQueryImpl::lastError()
{
  return _last_error;
}


int DbMySQLQueryImpl::lastErrorCode()
{
  return _last_error_code;
}


int DbMySQLQueryImpl::execute(int conn, const std::string &query)
{
  CLEAR_ERROR();
  
  sql::Connection *con= 0;
  {
    Lock lock(_mutex);
    if (_connections.find(conn) == _connections.end())
      throw std::invalid_argument("Invalid connection");
    con = _connections[conn].get();
  }

  try
  {
    std::auto_ptr<sql::Statement> pstmt(con->createStatement());
    return pstmt->execute(query) ? 1 : 0;
  }
  catch (sql::SQLException &exc)
  {
    _last_error = exc.what();
    _last_error_code = exc.getErrorCode();
    return -1;
  }
  catch (std::exception &e)
  {
    _last_error = e.what();
    return -1;
  }
  
  return -1;
}


int DbMySQLQueryImpl::executeQuery(int conn, const std::string &query)
{
  CLEAR_ERROR();
  
  sql::Connection *con= 0;
  {
    Lock lock(_mutex);
    if (_connections.find(conn) == _connections.end())
      throw std::invalid_argument("Invalid connection");
    con = _connections[conn].get();
  }

  try
  {
    std::auto_ptr<sql::Statement> pstmt(con->createStatement());
    sql::ResultSet *res = pstmt->executeQuery(query);

    ++_resultset_id;

    Lock lock(_mutex);
    _resultsets[_resultset_id] = res;
  }
  catch (sql::SQLException &exc)
  {
    _last_error = exc.what();
    _last_error_code = exc.getErrorCode();
    return -1;
  }
  catch (std::exception &e)
  {
    _last_error = e.what();
    return -1;
  }

  return _resultset_id;
}


int DbMySQLQueryImpl::resultNumRows(int result)
{
  Lock lock(_mutex);
  CLEAR_ERROR();
  if (_resultsets.find(result) == _resultsets.end())
    throw std::invalid_argument("Invalid resultset");
  sql::ResultSet *res = _resultsets[result];

  return res->rowsCount();
}


int DbMySQLQueryImpl::resultNumFields(int result)
{
  Lock lock(_mutex);
  CLEAR_ERROR();
  if (_resultsets.find(result) == _resultsets.end())
    throw std::invalid_argument("Invalid resultset");
  sql::ResultSet *res = _resultsets[result];
  
  return res->getMetaData()->getColumnCount();
}


std::string DbMySQLQueryImpl::resultFieldType(int result, int field)
{
  Lock lock(_mutex);
  CLEAR_ERROR();
  if (_resultsets.find(result) == _resultsets.end())
    throw std::invalid_argument("Invalid resultset");
  sql::ResultSet *res = _resultsets[result];
  
  return res->getMetaData()->getColumnTypeName(field);  
}


std::string DbMySQLQueryImpl::resultFieldName(int result, int field)
{
  Lock lock(_mutex);
  CLEAR_ERROR();
  if (_resultsets.find(result) == _resultsets.end())
    throw std::invalid_argument("Invalid resultset");
  sql::ResultSet *res = _resultsets[result];

  return res->getMetaData()->getColumnName(field);
}


int DbMySQLQueryImpl::resultNextRow(int result)
{
  Lock lock(_mutex);
  CLEAR_ERROR();
  if (_resultsets.find(result) == _resultsets.end())
    throw std::invalid_argument("Invalid resultset");
  sql::ResultSet *res = _resultsets[result];  
  return res->next() ? 1 : 0;
}


int DbMySQLQueryImpl::resultFieldIntValue(int result, int field)
{
  Lock lock(_mutex);
  CLEAR_ERROR();
  if (_resultsets.find(result) == _resultsets.end())
    throw std::invalid_argument("Invalid resultset");
  sql::ResultSet *res = _resultsets[result];

  return res->getInt(field);
}


double DbMySQLQueryImpl::resultFieldDoubleValue(int result, int field)
{
  Lock lock(_mutex);
  CLEAR_ERROR();
  if (_resultsets.find(result) == _resultsets.end())
    throw std::invalid_argument("Invalid resultset");
  sql::ResultSet *res = _resultsets[result];
  
  return res->getDouble(field);
}


std::string DbMySQLQueryImpl::resultFieldStringValue(int result, int field)
{
  Lock lock(_mutex);
  CLEAR_ERROR();
  if (_resultsets.find(result) == _resultsets.end())
    throw std::invalid_argument("Invalid resultset");

  sql::ResultSet *res = _resultsets[result];
  return res->getString(field);
}


int DbMySQLQueryImpl::resultFieldIntValueByName(int result, const std::string &field)
{
  Lock lock(_mutex);
  CLEAR_ERROR();
  if (_resultsets.find(result) == _resultsets.end())
    throw std::invalid_argument("Invalid resultset");
  sql::ResultSet *res = _resultsets[result];

  return res->getInt(field);
}


double DbMySQLQueryImpl::resultFieldDoubleValueByName(int result, const std::string &field)
{
  Lock lock(_mutex);
  CLEAR_ERROR();
  if (_resultsets.find(result) == _resultsets.end())
    throw std::invalid_argument("Invalid resultset");
  sql::ResultSet *res = _resultsets[result];
  
  return res->getDouble(field);
}


std::string DbMySQLQueryImpl::resultFieldStringValueByName(int result, const std::string &field)
{
  Lock lock(_mutex);
  CLEAR_ERROR();
  if (_resultsets.find(result) == _resultsets.end())
    throw std::invalid_argument("Invalid resultset");
  sql::ResultSet *res = _resultsets[result];
  
  return res->getString(field);
}


int DbMySQLQueryImpl::closeResult(int result)
{
  Lock lock(_mutex);
  CLEAR_ERROR();
  if (_resultsets.find(result) == _resultsets.end())
    return -1;
  sql::ResultSet *res = _resultsets[result];
  delete res;
  _resultsets.erase(result);
  return 0;
}


int DbMySQLQueryImpl::loadSchemata(int conn, grt::StringListRef schemata)
{
  CLEAR_ERROR();
  
  sql::Connection *con = 0;
  {
    Lock lock(_mutex);
    if (_connections.find(conn) == _connections.end())
      throw std::invalid_argument("Invalid connection");
    con = _connections[conn].get();
  }

  try
  {
    sql::DatabaseMetaData *dbc_meta(con->getMetaData());
    std::auto_ptr<sql::ResultSet> rset(dbc_meta->getSchemaObjects("", "", "schema"));
    while (rset->next())
    {
      std::string name = rset->getString("name");
      schemata.insert(name);
      //schemata_ddl.insert(rset->getString("ddl"));
    }
  }
  catch (sql::SQLException &exc)
  {
    _last_error = exc.what();
    _last_error_code = exc.getErrorCode();
    return -1;
  }
  catch (std::exception &e)
  {
    _last_error = e.what();
    return -1;
  }
  
  return 0;
}


grt::StringListRef DbMySQLQueryImpl::loadSchemaList(int conn)
{
  grt::StringListRef list(get_grt());
  if (loadSchemata(conn, list) == 0)
    return list;
  return grt::StringListRef();
}


int DbMySQLQueryImpl::loadSchemaObjects(int conn, grt::StringRef schema, grt::StringRef object_type, grt::DictRef objects)
{
  CLEAR_ERROR();
  
  sql::Connection *con = 0;
  {
    Lock lock(_mutex);
    if (_connections.find(conn) == _connections.end())
      throw std::invalid_argument("Invalid connection");
    con = _connections[conn].get();
  }

  try
  {
    std::list<std::string> object_types;
    if (object_type.empty())
    {
      object_types.push_back("table");
      object_types.push_back("view");
      object_types.push_back("routine");
      object_types.push_back("trigger");
    }
    else
    {
      object_types.push_back(object_type);
    }
    sql::DatabaseMetaData *dbc_meta(con->getMetaData());
    for (std::list<std::string>::const_iterator i= object_types.begin(), end= object_types.end(); i != end; ++i)
    {
      std::auto_ptr<sql::ResultSet> rset(dbc_meta->getSchemaObjects("", *schema, *object_type));
      while (rset->next())
      {
        std::string name = rset->getString("name");
        std::string ddl = rset->getString("ddl");
        objects.gset(name, ddl);
      }
    }
  }
  catch (sql::SQLException &exc)
  {
    _last_error = exc.what();
    _last_error_code = exc.getErrorCode();
    return -1;
  }
  catch (std::exception &e)
  {
    _last_error = e.what();
    return -1;
  }
  
  return 0;
}


grt::DictRef DbMySQLQueryImpl::loadSchemaObjectList(int conn, grt::StringRef schema, grt::StringRef object_type)
{
  grt::DictRef objects(get_grt());
  if (loadSchemaObjects(conn, schema, object_type, objects) == 0)
    return objects;
  return grt::DictRef();
}


std::string DbMySQLQueryImpl::generateDdlScript(grt::StringRef schema, grt::DictRef objects)
{
  const std::string delimiter= "$$";
  std::string ddl_script = "DELIMITER " + delimiter + "\n\n";
  ddl_script += "USE `" + *schema + "`\n" + delimiter + "\n\n";
  for (grt::DictRef::const_iterator i = objects.begin(), end = objects.end(); i != end; ++i)
  {
    std::string name = i->first;
    std::string ddl = (grt::StringRef::can_wrap(i->second)) ? grt::StringRef::cast_from(i->second) : "";
    if (g_utf8_validate(ddl.c_str(), -1, NULL))
      ddl_script += ddl;
    else
      ddl_script += "CREATE ... `" + *schema + "`.`" + name + "`: DDL contains non-UTF symbol(s)";
    ddl_script += "\n" + delimiter + "\n\n";
  }
  return ddl_script;
}




int DbMySQLQueryImpl::openTunnel(const db_mgmt_ConnectionRef &info)
{
  sql::DriverManager *dm = sql::DriverManager::getDriverManager();
  sql::TunnelConnection *tun = dm->getTunnel(info);
  if (tun)
  {
    _tunnels[++_tunnel_id] = tun;
    return _tunnel_id;
  }
  return 0;
}


int DbMySQLQueryImpl::getTunnelPort(int tunnel)
{
  if (_tunnels.find(tunnel) == _tunnels.end())
    throw std::invalid_argument("Invalid tunnel-id");
  return _tunnels[tunnel]->get_port();
}

int DbMySQLQueryImpl::closeTunnel(int tunnel)
{
  if (_tunnels.find(tunnel) == _tunnels.end())
    throw std::invalid_argument("Invalid tunnel-id");
  delete _tunnels[tunnel];
  _tunnels.erase(tunnel);
  return 0;
}

