package session

import (
	"bytes"
	"fmt"
	"strings"

	"github.com/hanchuanchuan/goInception/ast"
	"github.com/hanchuanchuan/goInception/util/sqlexec"
	log "github.com/sirupsen/logrus"
	"golang.org/x/net/context"
)

// splitCommand 分隔功能实现
func (s *session) splitCommand(ctx context.Context, stmtNode ast.StmtNode,
	sql string) ([]sqlexec.RecordSet, error) {
	log.Debug("splitCommand")

	if !s.opt.split {
		return nil, nil
	}

	switch node := stmtNode.(type) {

	case *ast.UseStmt:
		s.dbName = node.DBName
		s.addSplitNode(s.dbName, "", true, node, sql)

	case *ast.InsertStmt:
		t := getSingleTableName(node.Table)
		s.addSplitNode(t.Schema.O, t.Name.O, true, node, sql)

	case *ast.DeleteStmt:
		if node.Tables != nil {
			for _, t := range node.Tables.Tables {
				s.addSplitNode(t.Schema.O, t.Name.O, true, node, sql)
				return nil, nil
			}
		} else {
			var tableList []*ast.TableSource
			tableList = extractTableList(node.TableRefs.TableRefs, tableList)

			for _, tblSource := range tableList {
				if t, ok := tblSource.Source.(*ast.TableName); ok {
					s.addSplitNode(t.Schema.O, t.Name.O, true, node, sql)
					return nil, nil
				}
			}
		}
		s.addSplitNode("", "", true, node, sql)
		return nil, nil
	case *ast.UpdateStmt:
		var originTable string
		if node.List != nil {
			for _, l := range node.List {
				originTable = l.Column.Table.L
				break
			}
		}

		var tableList []*ast.TableSource
		tableList = extractTableList(node.TableRefs.TableRefs, tableList)

		for _, tblSource := range tableList {
			tblName, ok := tblSource.Source.(*ast.TableName)
			if ok {
				if originTable == "" {
					s.addSplitNode(tblName.Schema.L, tblName.Name.L, true, node, sql)
					return nil, nil
				} else if originTable == tblName.Name.L || originTable == tblSource.AsName.L {
					s.addSplitNode(tblName.Schema.L, tblName.Name.L, true, node, sql)
					return nil, nil
				}
			}
		}

		s.addSplitNode("", "", true, node, sql)
		return nil, nil

	case *ast.CreateDatabaseStmt:
		s.addSplitNode(node.Name, "", false, node, sql)

	case *ast.DropDatabaseStmt:
		s.addSplitNode(node.Name, "", false, node, sql)

	case *ast.CreateTableStmt:
		s.addSplitNode(node.Table.Schema.O, node.Table.Name.O, false, node, sql)

	case *ast.AlterTableStmt:
		s.addSplitNode(node.Table.Schema.O, node.Table.Name.O, false, node, sql)

	case *ast.DropTableStmt:
		for _, t := range node.Tables {
			s.addSplitNode(t.Schema.O, t.Name.O, false, node, sql)
			return nil, nil
		}
	case *ast.RenameTableStmt:
		s.addSplitNode(node.OldTable.Schema.O, node.OldTable.Name.O, false, node, sql)

	case *ast.TruncateTableStmt:
		s.addSplitNode(node.Table.Schema.O, node.Table.Name.O, true, node, sql)

	case *ast.CreateIndexStmt:

		s.addSplitNode(node.Table.Schema.O, node.Table.Name.O, false, node, sql)

	case *ast.DropIndexStmt:
		s.addSplitNode(node.Table.Schema.O, node.Table.Name.O, false, node, sql)

	case *ast.UnionStmt, *ast.SelectStmt:
		return nil, nil

	case *ast.CreateViewStmt:
		return nil, nil

		// s.appendErrorMessage(fmt.Sprintf("命令禁止! 无法创建视图'%s'.", node.ViewName.Name))

	case *ast.ShowStmt:
		return nil, nil

	case *ast.InceptionSetStmt:
		return nil, nil

	case *ast.ExplainStmt:
		return nil, nil

	case *ast.ShowOscStmt:
		return nil, nil

	case *ast.KillStmt:
		return nil, nil

	default:
		log.Infof("无匹配类型:%T\n", stmtNode)
		return nil, nil
		// s.appendErrorNo(ER_NOT_SUPPORTED_YET)
	}

	return nil, nil
}

// addNewSplitRow 添加新的split分隔节点
func (s *session) addSplitNode(db, tableName string, isDML bool, stmtNode ast.StmtNode, currentSql string) {

	if db == "" {
		db = s.dbName
	}
	key := fmt.Sprintf("%s.%s", db, tableName)
	key = strings.ToLower(key)

	if s.splitSets.id == 0 {
		s.addNewSplitNode()
		if _, ok := stmtNode.(*ast.UseStmt); !ok && s.dbName != "" {
			s.splitSets.sqlBuf.WriteString(fmt.Sprintf("use `%s`;\n", s.dbName))
		}
	} else {
		if isDmlType, ok := s.splitSets.tableList[key]; ok {
			if isDmlType != isDML {
				s.addNewSplitNode()
				if _, ok := stmtNode.(*ast.UseStmt); !ok && s.dbName != "" {
					s.splitSets.sqlBuf.WriteString(fmt.Sprintf("use `%s`;\n", s.dbName))
				}
			}
		}
	}

	s.splitSets.tableList[key] = isDML

	switch stmtNode.(type) {
	case *ast.AlterTableStmt, *ast.DropTableStmt:
		s.splitSets.ddlflag = 1
	}

	s.splitSets.sqlBuf.WriteString(currentSql)
	s.splitSets.sqlBuf.WriteString(";\n")
}

// addNewSplitRow 添加新的split分隔节点
func (s *session) addNewSplitNode() {
	if s.splitSets == nil {
		s.splitSets = NewSplitSets()
	}

	sql := s.splitSets.sqlBuf.String()

	// if len(sql) == 0{
	// 	return
	// }

	if s.splitSets.id > 0 && len(sql) > 0 {
		s.splitSets.Append(sql, "")
	}

	s.splitSets.id += 1
	s.splitSets.tableList = make(map[string]bool)
	s.splitSets.ddlflag = 0
	s.splitSets.sqlBuf = new(bytes.Buffer)
}