/* 
 * Copyright (c) 2007, 2010, Oracle and/or its affiliates. All rights reserved.
 *
 * This program is free software; you can redistribute it and/or
 * modify it under the terms of the GNU General Public License as
 * published by the Free Software Foundation; version 2 of the
 * License.
 * 
 * This program is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
 * GNU General Public License for more details.
 * 
 * You should have received a copy of the GNU General Public License
 * along with this program; if not, write to the Free Software
 * Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA
 * 02110-1301  USA
 */


#include "stdafx.h"

#include "mysql_sql_syntax_check.h"
#include "mysql_sql_parser_fe.h"
#include "grtdb/charset_utils.h"
#include "string_utilities.h"


using namespace grt;
using namespace base;


static size_t MAX_INSERT_SQL_LENGTH= 8 * 1024; // determines max length of passed sql to be checked


Mysql_sql_syntax_check::Null_state_keeper::~Null_state_keeper()
{
  _sql_parser->_check_sql_statement.disconnect();
}
#define NULL_STATE_KEEPER Null_state_keeper _nsk(this);


Mysql_sql_syntax_check::Mysql_sql_syntax_check(grt::GRT *grt)
:
Sql_parser_base(grt),
Mysql_sql_parser_base(grt),
Sql_syntax_check(grt)
{
  NULL_STATE_KEEPER
}


Mysql_sql_syntax_check::Statement_type Mysql_sql_syntax_check::determine_statement_type(const std::string &sql)
{
  NULL_STATE_KEEPER

  typedef std::map<std::string, Statement_type> KnownStatementTypes;
  static KnownStatementTypes known_statement_types;
  struct StaticInitializer
  {
    StaticInitializer()
    {
      known_statement_types[""]= sql_empty;
      known_statement_types["CREATE"]= sql_create;
      known_statement_types["ALTER"]= sql_alter;
      known_statement_types["DROP"]= sql_drop;
      known_statement_types["INSERT"]= sql_insert;
      known_statement_types["DELETE"]= sql_delete;
      known_statement_types["UPDATE"]= sql_update;
      known_statement_types["SELECT"]= sql_select;
      known_statement_types["DESC"]= sql_describe;
      known_statement_types["DESCRIBE"]= sql_describe;
      known_statement_types["SHOW"]= sql_show;
      known_statement_types["USE"]= sql_use;
      known_statement_types["LOAD"]= sql_load;
      known_statement_types["EDIT"]= sql_edit;
    }
  };
  static StaticInitializer static_initializer;

  Mysql_sql_parser_fe sql_parser_fe(_grtm->get_grt());
  std::string token= sql_parser_fe.get_first_sql_token(sql, "UNKNOWN");
  KnownStatementTypes::iterator statement_type= known_statement_types.find(token);
  return (known_statement_types.end() == statement_type) ? sql_unknown : statement_type->second;
}


bool Mysql_sql_syntax_check::parse_edit_statement(const std::string &sql,
  std::string &schema_name, std::string &table_name, std::string &statement_tail)
{
  NULL_STATE_KEEPER
  _messages_enabled= false;
  _use_delimiter= false;
  Check_sql_statement do_check_slot=
    sigc::bind(sigc::mem_fun(this, &Mysql_sql_syntax_check::do_parse_edit_statement),
      sigc::ref(schema_name), sigc::ref(table_name), sigc::ref(statement_tail));
  int err_count= check_sql_statement(sql, do_check_slot, ot_none);
  return (0 == err_count);
}


int Mysql_sql_syntax_check::check_sql(const std::string &sql)
{
  NULL_STATE_KEEPER
  _messages_enabled= false;
  _use_delimiter= false;

  Check_sql_statement do_check_slot;
  switch (_object_type)
  {
  case ot_trigger:
    do_check_slot= sigc::mem_fun(this, &Mysql_sql_syntax_check::do_check_trigger);
    break;
  case ot_view:
    do_check_slot= sigc::mem_fun(this, &Mysql_sql_syntax_check::do_check_view);
    break;
  case ot_routine:
    do_check_slot= sigc::mem_fun(this, &Mysql_sql_syntax_check::do_check_routine);
    break;
  default:
    do_check_slot= sigc::mem_fun(this, &Mysql_sql_syntax_check::do_check_sql);
    break;
  }

  int err_count= check_sql_statement(sql, do_check_slot, _object_type);
  return (err_count ? 0 : 1);
}


int Mysql_sql_syntax_check::check_trigger(const std::string &sql)
{
  NULL_STATE_KEEPER
  _messages_enabled= false;
  _use_delimiter= true;
  int err_count= check_sql_statement(sql, sigc::mem_fun(this,
    &Mysql_sql_syntax_check::do_check_trigger), ot_trigger);
  return (err_count ? 0 : 1);
}


int Mysql_sql_syntax_check::check_view(const std::string &sql)
{
  NULL_STATE_KEEPER
  _messages_enabled= false;
  _use_delimiter= true;
  int err_count= check_sql_statement(sql, sigc::mem_fun(this,
    &Mysql_sql_syntax_check::do_check_view), ot_view);
  return (err_count ? 0 : 1);
}


int Mysql_sql_syntax_check::check_routine(const std::string &sql)
{
  NULL_STATE_KEEPER
  _messages_enabled= false;
  _use_delimiter= true;
  int err_count= check_sql_statement(sql, sigc::mem_fun(this,
    &Mysql_sql_syntax_check::do_check_routine), ot_routine);
  return (err_count ? 0 : 1);
}


int Mysql_sql_syntax_check::check_sql_statement(const std::string &sql, Check_sql_statement check_sql_statement, ObjectType object_type)
{
  _check_sql_statement= check_sql_statement;
  _process_sql_statement= sigc::bind(sigc::mem_fun(this, &Mysql_sql_syntax_check::process_sql_statement), object_type);

  Mysql_sql_parser_fe sql_parser_fe(_grtm->get_grt());
  sql_parser_fe.is_ast_generation_enabled= _is_ast_generation_enabled;
  sql_parser_fe.ignore_dml= false;
  sql_parser_fe.max_insert_statement_size= MAX_INSERT_SQL_LENGTH;
  {
    DictRef options= DictRef::cast_from(_grt->get("/wb/options/options"));
    sql_parser_fe.max_err_count= options.get_int("SqlEditor::SyntaxCheck::MaxErrCount", 100);    
  }

  const std::string *sql_ptr= &sql;
  
  // set delimiter for sql script if needed
  std::string sql_;
  if (_use_delimiter)
  {
    sql_= "DELIMITER " + _non_std_sql_delimiter + EOL + sql + EOL + _non_std_sql_delimiter;
    sql_ptr= &sql_;
  }

  return parse_sql_script(sql_parser_fe, *sql_ptr);
}


int Mysql_sql_syntax_check::process_sql_statement(const SqlAstNode *tree, ObjectType object_type)
{
  do_report_sql_statement_border(_stmt_begin_lineno, _stmt_begin_line_pos, _stmt_end_lineno, _stmt_end_line_pos);

  if (!_is_ast_generation_enabled && !_err_tok_len)
    return 0;
  
  if (!tree)
  {
    report_sql_error(_err_tok_lineno, true, _err_tok_line_pos, _err_tok_len, _err_msg, 2);
    return 1;
  }

  if (tree && (ot_none != object_type))
    tree= tree->subitem(sql::_statement, sql::_create);

  if (!tree)
    return 1;

  if (pr_processed == _check_sql_statement(tree))
    return 0;
  else
    return 1;
}


Mysql_sql_parser_base::Parse_result Mysql_sql_syntax_check::do_check_sql(const SqlAstNode *tree)
{
  if (tree)
    return check_sql(tree);
  else
    return pr_invalid;
}


Mysql_sql_parser_base::Parse_result Mysql_sql_syntax_check::do_parse_edit_statement(const SqlAstNode *tree,
  std::string &schema_name, std::string &table_name, std::string &statement_tail)
{
  if (!tree)
    return pr_invalid;
  const SqlAstNode *edit_node= tree->subitem(sql::_statement, sql::_edit);
  if (!edit_node)
    return pr_invalid;

  process_obj_full_name_item(edit_node->subitem(sql::_table_ident), schema_name, table_name);

  const SqlAstNode *first_tail_node= edit_node->subitem(sql::_where_clause);
  if (!first_tail_node)
    first_tail_node= edit_node->subitem(sql::_opt_order_clause);
  if (first_tail_node)
    statement_tail= edit_node->restore_sql_text(_sql_statement, first_tail_node);
  else
    statement_tail.clear();

  return pr_processed;
}


Mysql_sql_parser_base::Parse_result Mysql_sql_syntax_check::do_check_trigger(const SqlAstNode *tree)
{
  const SqlAstNode *trigger_tail= NULL;
  {
    static sql::symbol path1[]= { sql::_view_or_trigger_or_sp_or_event, sql::_definer_tail, sql::_ };
    static sql::symbol path2[]= { sql::_view_or_trigger_or_sp_or_event, sql::_no_definer_tail, sql::_ };
    static sql::symbol * paths[]= { path1, path2 };

    trigger_tail= tree->search_by_paths(paths, ARR_CAPACITY(paths));
    if (trigger_tail)
      trigger_tail= trigger_tail->subitem(sql::_trigger_tail);
  }

  if (trigger_tail && trigger_tail->subseq(sql::_TRIGGER_SYM))
    return check_trigger(tree, trigger_tail);
  else
    return pr_irrelevant;
}


Mysql_sql_parser_base::Parse_result Mysql_sql_syntax_check::do_check_view(const SqlAstNode *tree)
{
  const SqlAstNode *view_tail= NULL;
  {
    static sql::symbol path1[]= { sql::_view_or_trigger_or_sp_or_event, sql::_definer_tail, sql::_ };
    static sql::symbol path2[]= { sql::_view_or_trigger_or_sp_or_event, sql::_no_definer_tail, sql::_ };
    static sql::symbol path3[]= { sql::_view_or_trigger_or_sp_or_event, sql::_ };
    static sql::symbol * paths[]= { path1, path2, path3 };

    view_tail= tree->search_by_paths(paths, ARR_CAPACITY(paths));
    if (view_tail)
      view_tail= view_tail->subitem(sql::_view_tail);
  }

  if (view_tail)
    return check_view(tree, view_tail);
  else
    return pr_irrelevant;
}


Mysql_sql_parser_base::Parse_result Mysql_sql_syntax_check::do_check_routine(const SqlAstNode *tree)
{
  const SqlAstNode *routine_tail= NULL;
  {
    static sql::symbol path1[]= { sql::_view_or_trigger_or_sp_or_event, sql::_definer_tail, sql::_ };
    static sql::symbol path2[]= { sql::_view_or_trigger_or_sp_or_event, sql::_no_definer_tail, sql::_ };
    static sql::symbol * paths[]= { path1, path2 };

    routine_tail= tree->search_by_paths(paths, ARR_CAPACITY(paths));
    if (routine_tail)
    {
      static sql::symbol path1[]= { sql::_sp_tail, sql::_ };
      static sql::symbol path2[]= { sql::_sf_tail, sql::_ };
      static sql::symbol * paths[]= { path1, path2 };

      routine_tail= routine_tail->search_by_paths(paths, ARR_CAPACITY(paths));
    }
  }

  if (routine_tail)
    return check_routine(tree, routine_tail);
  else
    return pr_irrelevant;
}


Mysql_sql_parser_base::Parse_result Mysql_sql_syntax_check::check_sql(const SqlAstNode *tree)
{
  return pr_processed;
}


Mysql_sql_parser_base::Parse_result Mysql_sql_syntax_check::check_trigger(const SqlAstNode *tree, const SqlAstNode *trigger_tail)
{
  return pr_processed;
}


Mysql_sql_parser_base::Parse_result Mysql_sql_syntax_check::check_view(const SqlAstNode *tree, const SqlAstNode *view_tail)
{
  return pr_processed;
}


Mysql_sql_parser_base::Parse_result Mysql_sql_syntax_check::check_routine(const SqlAstNode *tree, const SqlAstNode *routine_tail)
{
  return pr_processed;
}
