
import sys
import grt
import copy

def dump_tree(f, ast, depth=0):
    sym, value, children = ast

    if children:
        f.write("%s<%s, %s>\n" % ("  "*depth, sym, value))
        for c in children:
            dump_tree(f, c, depth+2)
        f.write("%s</%s, %s>\n" % ("  "*depth, sym, value))
    else:
        f.write("%s<%s, %s/>\n" % ("  "*depth, sym, value))


def flatten_node(node, sep = ""):
    def flatten(expr_node, out):
        s, n, c = expr_node
        if n:
            out.append(n)
        for i in c:
            flatten(i, out)

    out = []
    flatten(node, out)
    return sep.join(out)


class SQLTreeTraverser:
    def __init__(self):
        self.token_path = []

        # stack of functions called before the normal handler. If the function returns True, the normal handler is skipped
        self.token_filters = []


    def fallback(self, symbol, value, children):
        pass
    

    def push_token_filter(self, begin, end):
        self.token_filters.append((begin, end))
    
    def pop_token_filter(self):
        self.token_filters.pop()


    def traverse(self, ast, token_path = []):
        self.token_path = token_path

        symbol, value, children = ast

        if self.token_filters:
            pre, post = ret = self.token_filters[-1]
        else:
            pre, post = None, None

        if pre:
            ret = pre(symbol, value, children)
        else:
            ret = False

        if not ret:
            meth = getattr(self, "sym_"+symbol, None)
            if meth:
                meth(value, children)
            else:
                self.fallback(symbol, value, children)
        
        for ch in children:
            self.traverse(ch, token_path + [symbol])

        if post:
            ret = post(symbol, value)
        else:
            ret = False

        if not ret:
            meth = getattr(self, "endsym_"+symbol, None)
            if meth:
                meth(value)


#######################################################################################


class SQLPrettifier(SQLTreeTraverser):
    def __init__(self, ast):
        SQLTreeTraverser.__init__(self)
        # self.output is a stack of tuples (output_items, skip_commas)
        # if skip_commas is True, comma symbols (44) will be skipped
        # self.out() will always output the token to the top of the stack
        self.output = [([], False)]

        self.ast = ast
        #self.indentation = "~"*3+"|"
        self.indentation = " "*4



    def indent(self, text, count=1):
        return self.indentation*count + ("\n"+(self.indentation*count)).join(text.split("\n"))

    def push(self, skip_commas = False):
        self.output.append(([], skip_commas))


    def pop(self):
        return self.output.pop()[0]

    def pop_joined(self, sep=""):
        return sep.join(self.pop())

    def out(self, value):
        outlist, skip_commas = self.output[-1]
        outlist.append(value)

    def begin_tokens(self, value, children):
        self.push()

    def end_spaced_tokens(self, value):
        self.out(self.pop_joined(" "))

    def end_tokens(self, value):
        self.out(self.pop_joined(""))

    def begin_comma_list(self, value, children):
        self.push()

    def end_comma_list(self, value):
        items = self.pop()
        text = ""
        for i in items:
            if i == ",":
                  i = ", "
            text += i
        self.out(text)


    def sym_ident(self, value, children):
        self.out(value)

    #def sym_TEXT_STRING(self, value, children):
    #    self.out("'%s'"%value.replace("'", r"\'"))

    #sym_TEXT_STRING_filesystem = sym_TEXT_STRING
    #sym_text_string = sym_TEXT_STRING 

    def fallback(self, symbol, value, children):
        if value:
            #print "%s == %s" % (symbol, value)
            self.out(value)


    def format_expression(self, expression):
        return SQLExpressionPrettifier(expression).run()


#######################################################################################


class SQLExpressionPrettifier:
    def __init__(self, ast):
        self.ast = ast

        self.max_subexpression_length = 60
        

    def simplify(self, node):
        def is_terminal(node):
            symbol, value, children = node
            return value is not None and not children

        def strip_symbol(node_list, symbol):
            del_list = []
            for i, node in enumerate(node_list):
                if node[0] == symbol:
                    del_list.append(i)
                strip_symbol(node[2], symbol)

            for i in reversed(del_list):
                del node_list[i]

        def nameof(node):
            return node[0]


        s, v, c = node

        if s == "simple_ident" and c:
            text = flatten_node(c[0])
            node = s, text, []
            c = []
        elif len(c) == 1:
            cs, cv, cc = c[0]
            if is_terminal(c[0]):
                if cs == "TEXT_STRING_Q":
                    node = "*text", cv, []
                elif cs[0] == "*": # our own symbols (starting with *) assimilate the other ones
                    node = cs, cv, []
                else:
                    node = s, cv, []
                c = []
        elif len(c) == 3 and c[0][1] == '(' and c[-1][1] == ')' and nameof(c[1]) == '*simple_expression': # reduce anything like ( *simple_expression ) to *simple_expression
            node = c[1][0], "(%s)"%c[1][1], []
            c = []

        elif len(c) == 3 and c[0][1] == '(' and c[-1][1] == ')' and nameof(c[1]) == 'expr' and len(c[1][2])==3 and nameof(c[1][2][0]) == "*simple_expression" and nameof(c[1][2][-1]) == "*simple_expression": # reduce anything like ( *simple_expression op *simple_expression ) to *complex_expression
            text = flatten_node(c[1], " ")
            if 1 or len(text) < self.max_subexpression_length:
                node = "*complex_expression", "(%s)"%text, []
                c = []

        elif s in ["bool_pri", "bit_expr"] and len(c) == 3: # reduce anything like terminal <op> terminal into a *simple_expression
            ok = True
            for n in c:
                if not is_terminal(n):
                    ok = False
            if ok:
                v = flatten_node(node, " ")
                node = "*simple_expression", v, []
                c = []

        elif s == "predicate" and v is None and c[0][0] == "bit_expr":
            if len(c) > 1:
                # predicate can be something like bit_expr <pred> <stuff> or a bit_expr alone
                bit_expr = c[0]
                pred = c[1]
                rest = c[2:]
                if rest[0][1] == "(" and rest[-1][1] == ")":
                    items = copy.deepcopy(rest[1:-1])
                    strip_symbol(items, "44") # strip comas
                    text = "%s %s (%s)" % (flatten_node(bit_expr), pred[1], flatten_node((s, None, items), ", "))
                else:
                    text = "%s %s %s" % (flatten_node(bit_expr), pred[1], flatten_node((s, None, rest), " "))
                node = "*simple_expression", text, []
                c = []
            else:
                node = "*simple_expression", flatten_node(node, " "), []
                c = []

        elif s in ["expr_list"]:
            ok = True
            for n in c:
                if not is_terminal(n):
                    ok = False
            if ok:
                v = ", ".join([cc[1] for cc in c if cc[1] != ","])
                node = "*simple_expression", v, []
                c = []

        for i, n in enumerate(c):
            nn = self.simplify(n)
            if n != nn:
                c[i] = nn
        return node


    def run(self):
        import copy
        orig = self.ast

        while True:
            simplified = self.simplify(copy.deepcopy(orig))
            if simplified == orig:
                break
            orig = simplified

        
        def flattenificate(node, out):
            s, n, c = node
            if n is not None:
                if s in ["and", "or"]:
                    out.append("\n"+n)
                else:
                    out.append(n)
            for i in c:
                flattenificate(i, out)
        #dump_tree(sys.stdout, simplified)
        l = []
        flattenificate(simplified, l)
        #print l
        return " ".join(l)


#######################################################################################


class SQLPrettifier_SELECT(SQLPrettifier):
    def __init__(self, ast):
        SQLPrettifier.__init__(self, ast)

        self.inside_union = 0
        self.inside_subselect = 0
        self.inside_join = 0
        self.inside_expr = 0
        self.inside_where = 0
        self.where_expr = None

        self.indent_select_body = True

        self.max_select_item_width = 80
        self.max_group_item_width = 80


    def sym_expr(self, value, children):
        if "join_table" in self.token_path and self.inside_expr == 0:
            self.sym_join_expr(value, children)
        elif self.inside_where:
            self.sym_where_expr(value, children)
        self.inside_expr += 1


    def endsym_expr(self, value):
        self.inside_expr -= 1
        if "join_table" in self.token_path and self.inside_expr == 0:
            self.endsym_join_expr(value)
        elif self.inside_where:
            self.endsym_where_expr(value)


    ###

    def sym_44(self, value, children): # comma ,
        outlist, skip_commas = self.output[-1] # skip_commas
        if not skip_commas:
            self.out(value)

    def sym_42(self, value, children): # star *
        self.out(value)

    sym_40 = sym_42 # (
    sym_41 = sym_42 # )
    sym_46 = sym_42 # .
    def sym_comp_op(self, value, children): # =, >, < etc
        self.out(" %s " % value)

    def sym_and(self, value, children):
        self.out(" %s " % value)

    sym_or = sym_and
    sym_not = sym_and
    sym_IN_SYM = sym_and
    sym_BETWEEN_SYM = sym_and
    sym_AND_SYM = sym_and

    def clause_keyword(self, keyword, children):
        self.out("\n")
        self.out(keyword)
        self.out("\n")

    def indented_clause_keyword(self, keyword, children):
        self.out("\n")
        self.out(keyword)
        self.out("\n")
        self.out(self.indentation)

    ###

    def sym_IDENT_sys(self, value, children):
        self.out(value)

    def sym_simple_ident_q(self, value, children):
        self.push()
    
    def endsym_simple_ident_q(self, value):
        self.out(self.pop_joined())


    def sym_AS(self, value, children):
        self.out(value+" ")

    def sym_BY(self, value, children):
        self.out(value)

    ### SELECT
    def sym_SELECT_SYM(self, value, children):
        self.out(value)
        
    def sym_select_part2(self, value, children):
        self.push()
    sym_select_part2_derived = sym_select_part2 # subselects

    def endsym_select_part2(self, value):
        if self.indent_select_body:
            self.out(" "+self.indent(self.pop_joined()).lstrip(" "))
        else:
            self.out(self.pop_joined())
        self.out("\n")
    endsym_select_part2_derived = endsym_select_part2 # subselects

    def sym_select_paren(self, value, children):
        self.out(" ")

    def sym_select_options(self, value, children):
        self.push()

    def endsym_select_options(self, value):
        self.out(" "+self.pop_joined(" "))


    def sym_select_item_list(self, value, children):
        self.push(True)


    def endsym_select_item_list(self, value):
        items = self.pop()
        text = "\n"+", ".join(items)
        if len(text) > self.max_select_item_width:
            text = "\n"+",\n".join(items)
        
        if self.inside_union or self.inside_subselect:
            self.out(self.indent(text))
        else:
            self.out(self.indent(text))

    def sym_select_item(self, value, children):
        self.push()

    def endsym_select_item(self, value):
        self.out(self.pop_joined())


    def sym_into(self, value, children):
        self.out("\n")
        self.push()

    def endsym_into(self, value):
        self.out(self.pop_joined(" "))


    def sym_subselect(self, value, children):
        self.push()
        self.inside_subselect += 1

    
    def endsym_subselect(self, value):
        self.inside_subselect -= 1
        if "join_table" in self.token_path:
            self.out(self.pop_joined().strip())
        else:
            self.out(self.pop_joined().strip())
        self.out("\n")

    def sym_select_alias(self, value, children):
        self.out(" ") # space before subselect alias

    ### FROM / JOINs
    #sym_FROM = clause_keyword
    def sym_FROM(self, value, children):
        self.out("\n")
        self.out(value)
        self.out("\n")

    def sym_join_table_list(self, value, children):
        self.join_count = 0
        self.push(True)

    def endsym_join_table_list(self, value):
        self.out(self.indent(self.pop_joined(",\n")))

    def sym_join_table(self, value, children):
        self.push()

    def endsym_join_table(self, value):
        self.out(self.pop_joined())

    def sym_JOIN_SYM(self, value, children):
        if self.join_count > 0:
            self.out("\n%s " % value)
        else:
            self.out(" %s " % value)
        self.join_count += 1

    def sym_ON(self, value, children):
        self.out("\n%s " % self.indent(value))

    def sym_join_expr(self, value, children):
        self.push()

    def endsym_join_expr(self, value):
        items = self.pop()
        if "(" in items:
            self.out("\n"+self.indent("".join(items)))
        else:
            self.out("".join(items))

    def sym_table_ref(self, value, children):
        self.push()

    def endsym_table_ref(self, value):
        table_ref = self.pop()
        self.out("".join(table_ref))

    def sym_opt_table_alias(self, value, children):
        self.out(" ")

    ### WHERE clause
    def sym_where_clause(self, value, children):
        self.inside_where += 1
        self.nested_expr_count = 0
        self.push()

    def endsym_where_clause(self, value):
        self.out(self.pop_joined())
        self.inside_where -= 1

    def sym_where_expr(self, value, children):
        if self.nested_expr_count == 0:
            self.where_expr = ("expr", value, children)
            self.push()
        self.nested_expr_count += 1
    
    def endsym_where_expr(self, value):
        self.nested_expr_count -= 1
        if self.nested_expr_count == 0:
            self.pop() # discard everything
            self.out(self.indent(self.format_expression(self.where_expr)))

    sym_WHERE = clause_keyword

    ### GROUP clause

    def sym_group_clause(self, value, children):
        self.push()

    def endsym_group_clause(self, value):
        self.out(self.pop_joined())

    sym_GROUP_SYM = clause_keyword

    def sym_group_list(self, value, children):
        self.push(True)

    def endsym_group_list(self, value):
        items = self.pop()
        text = ", ".join(items)
        if len(text) > self.max_group_item_width:
            self.out(self.indent((",\n").join(items)))
        else:
            self.out(self.indent(text))


    ### HAVING clause
    def sym_having_clause(self, value, children):
        self.push()

    def endsym_having_clause(self, value):
        self.out(self.pop_joined())

    sym_HAVING = indented_clause_keyword

    ### ORDER

    sym_ORDER_SYM = indented_clause_keyword
    
    def sym_opt_order_clause(self, value, children):
        self.push()

    def endsym_opt_order_clause(self, value):
        self.out("%s"%self.pop_joined())

    def sym_order_dir(self, value, children):
        self.out(" "+value)

    ### LIMIT

    def sym_opt_limit_clause(self, value, children):
        self.out("\n")
        self.push()

    def endsym_opt_limit_clause(self, value):
        self.out("%s"%self.pop_joined(" "))

    ### UNION clause

    def sym_union_clause(self, value, children):
        self.out("\nUNION ")
        self.push()
        self.inside_union += 1

    def endsym_union_clause(self, value):
        self.inside_union -= 1
        self.out(self.pop_joined())

    def sym_UNION_SYM(self, value, children):
        if not self.inside_union:
            self.out("UNION ")

    #####
    def preprocess(self, ast):
        # simplify the tree by joining some tokens (like GROUP BY, ORDER BY, RIGHT JOIN etc)
        sym, nam, children = ast

        while True:
            to_delete = []
            changed = False
            for i, ch in enumerate(children):
                s, n, c = ch
                if i < len(children)-1:
                    ns, nn, nc = children[i+1]
                else:
                    ns, nn, nc = None, None, None

                if s in ["TEXT_STRING", "text_string", "TEXT_STRING_filesystem", "TEXT_STRING_literal", "TEXT_STRING_sys"]:
                    children[i] = "TEXT_STRING_Q", "'%s'"%n.replace("'", r"\'"), c
                    changed = True

                if s in ["GROUP_SYM", "ORDER_SYM"] and ns == "BY":
                    to_delete.append(i+1)
                    children[i] = s, n+" "+nn, c
                    changed = True

                if s == "opt_outer" and ns == "JOIN_SYM":
                    to_delete.append(i)
                    children[i+1] = ns, n+" "+nn, nc
                    changed = True

                if s == "NATURAL" and ns == "JOIN_SYM":
                    to_delete.append(i)
                    children[i+1] = ns, n+" "+nn, nc
                    changed = True

                if s in ["LEFT", "RIGHT", "INNER_SYM", "CROSS"] and ns == "JOIN_SYM":
                    to_delete.append(i)
                    children[i+1] = ns, n+" "+nn, nc
                    changed = True

                if s == "not" and ns == "IN_SYM":
                    to_delete.append(i+1)
                    children[i] = s, n+" "+nn, c
                    changed = True

            for i in reversed(to_delete):
                del children[i]

            if not changed:
                break


        for ch in children:
            self.preprocess(ch)


    def run(self):
        self.preprocess(self.ast)
        self.traverse(self.ast)
        return "".join(self.pop())



#######################################################################################

class SQLPrettifier_CREATE_TABLE(SQLPrettifier):
    def __init__(self, ast):
        SQLPrettifier.__init__(self, ast)

    def sym_field_list(self, value, children):
        self.push()
        
    def endsym_field_list(self, value):
        items = [x for x in self.pop() if x != ","]
        self.out("\n%s\n"%self.indent(",\n".join(items)))
        

    sym_column_def = SQLPrettifier.begin_tokens
    endsym_column_def = SQLPrettifier.end_spaced_tokens

    sym_key_def = SQLPrettifier.begin_tokens
    endsym_key_def = SQLPrettifier.end_spaced_tokens

    sym_opt_constraint = SQLPrettifier.begin_tokens
    endsym_opt_constraint = SQLPrettifier.end_spaced_tokens

    def sym_FOREIGN(self, value, children):
        self.out("\n%s " % self.indent(value))

    def sym_REFERENCES(self, value, children):
        self.out("\n%s " % self.indent(value))

    def sym_opt_on_update_delete(self, value, children):
        self.out("\n")
        self.push()
    def endsym_opt_on_update_delete(self, value):
        self.out(self.indent(self.pop_joined(" "))[1:])


    sym_field_spec = SQLPrettifier.begin_tokens
    endsym_field_spec = SQLPrettifier.end_spaced_tokens

    sym_type = SQLPrettifier.begin_tokens
    endsym_type = SQLPrettifier.end_tokens

    sym_field_length = SQLPrettifier.begin_tokens
    endsym_field_length = SQLPrettifier.end_tokens

    sym_float_options = SQLPrettifier.begin_comma_list
    endsym_float_options = SQLPrettifier.end_comma_list

    sym_field_opt_list = SQLPrettifier.begin_tokens
    def endsym_field_opt_list(self, value):
        self.out(" "+self.pop_joined(" "))

    sym_string_list = SQLPrettifier.begin_comma_list
    endsym_string_list = SQLPrettifier.end_comma_list

    sym_opt_attribute_list = SQLPrettifier.begin_tokens
    endsym_opt_attribute_list = SQLPrettifier.end_spaced_tokens


    def sym_opt_create_table_options(self, value, children):
        self.push()


    def endsym_opt_create_table_options(self, value):
        self.out(self.pop_joined(" "))


    def run(self):
        self.preprocess(self.ast)
        self.traverse(self.ast)

        return self.pop_joined(" ")


#########################################################################

class SQLPrettifier_CREATE_VIEW(SQLPrettifier):
    pass


#########################################################################

class SQLPrettifier_CREATE_FUNCTION(SQLPrettifier):
    pass

#########################################################################


class SQLPrettifier_CREATE_PROCEDURE(SQLPrettifier):
    pass


#########################################################################

def formatter_for_statement_ast(ast):
    statement = ast[2][0][2][0]

    if statement[0] == "select":
        return SQLPrettifier_SELECT
    elif statement[0] == "create":
        create = statement[2][0]
        object = statement[2][1]
        if object[0] == "TABLE_SYM":
            return SQLPrettifier_CREATE_TABLE
        elif object[0] == "VIEW_SYM":
            pass

    return None


