// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.

package session

import (
	"fmt"

	"golang.org/x/net/context"

	json "github.com/CorgiMan/json2"
	"github.com/hanchuanchuan/goInception/ast"
	"github.com/hanchuanchuan/goInception/model"
	"github.com/hanchuanchuan/goInception/util/sqlexec"
	log "github.com/sirupsen/logrus"
)

func (s *session) printCommand(ctx context.Context, stmtNode ast.StmtNode,
	currentSql string) ([]sqlexec.RecordSet, error) {
	log.Debug("printCommand")

	// b, err := json.MarshalIndent(stmtNode, "", "  ")
	tree, err := json.Marshal(stmtNode)
	if err != nil {
		log.Error(err)
		s.printSets.Append(2, currentSql, "", err.Error())
	} else {
		s.printSets.Append(0, currentSql, string(tree), "")
	}

	return nil, nil

	// switch node := stmtNode.(type) {
	// case *ast.UseStmt:
	// 	s.checkChangeDB(node, currentSql)

	// case *ast.InsertStmt:
	// 	s.printInsert(node, currentSql)
	// case *ast.DeleteStmt:
	// 	s.checkDelete(node, currentSql)
	// case *ast.UpdateStmt:
	// 	s.checkUpdate(node, currentSql)

	// default:
	// 	log.Infof("无匹配类型:%T\n", stmtNode)
	// 	s.AppendErrorNo(ER_NOT_SUPPORTED_YET)
	// }

	// return nil, nil
}

func (s *session) printInsert(node *ast.InsertStmt, sql string) {

	log.Debug("printInsert")

	t := getSingleTableName(node.Table)

	res := make(map[string]interface{}, 3)
	res["command"] = "insert"

	table_object := make(map[string]interface{}, 2)
	if t.Schema.O == "" {
		table_object["db"] = s.dbName
	} else {
		table_object["db"] = t.Schema.String()
	}
	table_object["table"] = t.Name.String()
	res["table_object"] = table_object

	if len(node.Columns) > 0 {
		fields := make([]interface{}, 0, len(node.Columns))
		for _, c := range node.Columns {
			if c.Schema.O == "" {
				c.Schema = model.NewCIStr(s.dbName)
			}
			if c.Table.O == "" {
				c.Table = model.NewCIStr(t.Name.O)
			}

			field := make(map[string]string, 2)
			field["type"] = "FIELD_ITEM"
			field["field"] = c.Name.String()
			// if c.WildCard == nil {
			field["db"] = c.Schema.String()
			field["table"] = c.Table.String()
			// }

			fields = append(fields, field)
		}
		res["fields"] = fields
	}

	if len(node.Lists) > 0 {
		many_values := make([]interface{}, 0, len(node.Lists))
		for _, list := range node.Lists {
			values := make([]map[string]string, 0, len(list))
			for _, vv := range list {
				f := printItem(vv)
				values = append(values, f.(map[string]string))
			}
			many_values = append(many_values, values)
		}
		res["many_values"] = many_values
	}

	// insert select 语句
	if node.Select != nil {
		select_insert_values := make(map[string]interface{}, 1)

		res["select_insert_values"] = select_insert_values

		s.printSelectItem(node.Select)
		// log.Infof("%#v", node.Select)
		// log.Infof("%#v", node.Select.Fields)
		// log.Infof("%#v", node.Select.From)
	}

	log.Info(res)
}

func printItem(expr ast.ExprNode) interface{} {
	if expr == nil {
		return nil
	}

	switch e := expr.(type) {
	// case *ast.ColumnNameExpr:
	// 	field := make(map[string]string, 2)
	// 	field["type"] = "FIELD_ITEM"
	// 	field["field"] = e.Name.String()
	// 	if e.WildCard == nil {
	// 		field["db"] = e.Schema.String()
	// 		field["table"] = e.Table.String()

	// 	}
	// 	return field

	case *ast.ValueExpr:
		v := e.GetDatum().GetValue()
		value := make(map[string]string, 2)
		value["type"] = fmt.Sprintf("%T", v)
		value["value"] = fmt.Sprint(v)
		return value
		// switch v := e.GetDatum().GetValue().(type) {
		// case nil:
		// 	d.SetNull()
		// case bool:
		// 	if x {
		// 		d.SetInt64(1)
		// 	} else {
		// 		d.SetInt64(0)
		// 	}
		// case int:
		// 	d.SetInt64(int64(x))
		// case int64:
		// 	d.SetInt64(x)
		// case uint64:
		// 	d.SetUint64(x)
		// case float32:
		// 	d.SetFloat32(x)
		// case float64:
		// 	d.SetFloat64(x)
		// case string:
		// 	d.SetString(x)
		// case []byte:
		// 	d.SetBytes(x)
		// case *types.MyDecimal:
		// 	d.SetMysqlDecimal(x)
		// case types.Duration:
		// 	d.SetMysqlDuration(x)
		// case types.Enum:
		// 	d.SetMysqlEnum(x)
		// case types.BinaryLiteral:
		// 	d.SetBinaryLiteral(x)
		// case types.BitLiteral: // Store as BinaryLiteral for Bit and Hex literals
		// 	d.SetBinaryLiteral(BinaryLiteral(x))
		// case types.HexLiteral:
		// 	d.SetBinaryLiteral(BinaryLiteral(x))
		// case types.Set:
		// 	d.SetMysqlSet(x)
		// case json.BinaryJSON:
		// 	d.SetMysqlJSON(x)
		// case Time:
		// 	d.SetMysqlTime(x)
		// default:
		// 	d.SetInterface(x)
		// }
		// case *ast.SelectField:
		//  field := make(map[string]string, 2)
		//  field["type"] = "FIELD_ITEM"
		//  field["field"] = e.AsName.String()
		//  if e.WildCard == nil {
		//      field["db"] = e.Schema.String()
		//      field["table"] = e.Table.String()

		//  }
		//  return field
	}

	return nil
}

func (s *session) printSelectItem(node ast.ResultSetNode) bool {
	log.Debug("printSelectItem")

	switch x := node.(type) {
	case *ast.UnionStmt:
		stmt := x.SelectList
		for _, sel := range stmt.Selects[:len(stmt.Selects)-1] {
			if sel.Limit != nil {
				s.appendErrorNo(ErrWrongUsage, "UNION", "LIMIT")
			}
			if sel.OrderBy != nil {
				s.appendErrorNo(ErrWrongUsage, "UNION", "ORDER BY")
			}
		}

		for _, sel := range stmt.Selects {
			s.printSubSelectItem(sel)
		}

	case *ast.SelectStmt:
		s.printSubSelectItem(x)
	default:
		log.Info(x)
		log.Infof("%#v", x)
	}
	return !s.hasError()
}

func (s *session) printSubSelectItem(node *ast.SelectStmt) bool {
	log.Debug("printSubSelectItem")

	log.Infof("%#v", node)
	log.Infof("%#v", node.Fields)
	if node.Fields != nil {
		for _, f := range node.Fields.Fields {
			log.Infof("%#v", f)
		}
	}
	log.Infof("%#v", node.From)

	var tableList []*ast.TableSource
	if node.From != nil {
		tableList = extractTableList(node.From.TableRefs, tableList)
	}

	var tableInfoList []*TableInfo
	for _, tblSource := range tableList {

		switch x := tblSource.Source.(type) {
		case *ast.TableName:
			tblName := x
			t := s.getTableFromCache(tblName.Schema.O, tblName.Name.O, false)
			if t != nil {
				if tblSource.AsName.L != "" {
					t.AsName = tblSource.AsName.O
				}
				tableInfoList = append(tableInfoList, t)
			}
		case *ast.SelectStmt:
			s.printSubSelectItem(x)
		}
	}

	if node.Fields != nil {
		for _, field := range node.Fields.Fields {
			if field.WildCard == nil {
				s.checkItem(field.Expr, tableInfoList)
			}
		}
	}

	if node.GroupBy != nil {
		for _, item := range node.GroupBy.Items {
			s.checkItem(item.Expr, tableInfoList)
		}
	}

	if node.Having != nil {
		s.checkItem(node.Having.Expr, tableInfoList)
	}

	if node.OrderBy != nil {
		for _, item := range node.OrderBy.Items {
			s.checkItem(item.Expr, tableInfoList)
		}
	}

	return !s.hasError()
}