// Copyright 2013 The ql Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSES/QL-LICENSE file.

// Copyright 2015 PingCAP, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// See the License for the specific language governing permissions and
// limitations under the License.

package session

import (
	"bytes"
	"fmt"
	"runtime"
	"strings"
	"time"

	"github.com/hanchuanchuan/goInception/ast"
	"github.com/hanchuanchuan/goInception/config"
	"github.com/hanchuanchuan/goInception/parser"
	"github.com/hanchuanchuan/goInception/sessionctx/variable"
	"github.com/hanchuanchuan/goInception/util"
	"github.com/hanchuanchuan/goInception/util/sqlexec"
	"github.com/hanchuanchuan/goInception/util/timeutil"
	"github.com/jinzhu/gorm"
	"github.com/pingcap/errors"
	log "github.com/sirupsen/logrus"
	"golang.org/x/net/context"
	// "vitess.io/vitess/go/vt/sqlparser"
)

func (s *session) makeNewResult() ([]Record, error) {
	// if s.opt != nil && s.opt.Print && s.printSets != nil {
	// 	return s.printSets.Rows(), nil
	// } else if s.opt != nil && s.opt.split && s.splitSets != nil {
	// 	s.addNewSplitNode()
	// 	// log.Infof("%#v", s.splitSets)
	// 	return s.splitSets.Rows(), nil
	// } else {

	// }
	records := make([]Record, len(s.recordSets.records))
	for i, r := range s.recordSets.records {
		// log.Info(r.SeqNo, "		", i)
		r.SeqNo = i

		r.cut()
		records[i] = *r
	}
	return records, nil
}

func NewInception() *session {
	se := &session{
		parser:              parser.New(),
		sessionVars:         variable.NewSessionVars(),
		lowerCaseTableNames: 1,
		isAPI:               true,
	}

	// cluster := mocktikv.NewCluster()
	// mocktikv.BootstrapWithSingleStore(cluster)
	// mvccStore := mocktikv.MustNewMVCCStore()
	// store, err := mockstore.NewMockTikvStore(
	// 	mockstore.WithCluster(cluster),
	// 	mockstore.WithMVCCStore(mvccStore),
	// )
	// se.store = store

	// session.SetSchemaLease(0)
	// session.SetStatsLease(0)

	se.sessionVars.GlobalVarsAccessor = se

	tz := timeutil.InferSystemTZ()
	// log.Errorf("tz: %v", tz)
	timeutil.SetSystemTZ(tz)

	return se
}

// init 初始化map
func (s *session) init() {

	// printMemStats()

	// log.Errorf("init runtime.NumGoroutine: %v", runtime.NumGoroutine())

	s.dbName = ""
	s.haveBegin = false
	s.haveCommit = false
	s.threadID = 0
	s.isClusterNode = false

	s.tableCacheList = make(map[string]*TableInfo)
	s.dbCacheList = make(map[string]*DBInfo)

	s.backupDBCacheList = make(map[string]bool)
	s.backupTableCacheList = make(map[string]bool)

	s.inc = config.GetGlobalConfig().Inc
	s.osc = config.GetGlobalConfig().Osc
	s.ghost = config.GetGlobalConfig().Ghost

	// 在开启goinception鉴权时.非root用户禁止启用EnableAnyStatement功能
	// if !s.inc.SkipGrantTable {}
	if s.inc.EnableAnyStatement {
		if tmp := s.processInfo.Load(); tmp != nil {
			if pi, ok := tmp.(util.ProcessInfo); ok {
				if pi.User != "root" {
					log.Warnf("Insufficient permissions to enable any statement! user: %s", pi.User)
					s.inc.EnableAnyStatement = false
				}
			}
		}
	}
	s.inc.Lang = strings.Replace(strings.ToLower(s.inc.Lang), "-", "_", 1)

	s.sqlFingerprint = nil

	s.dbType = DBTypeMysql
	s.dbVersion = 0

	// 自定义审核级别,通过解析config.GetGlobalConfig().IncLevel生成
	s.parseIncLevel()
}

// clear 清理变量或map等信息
func (s *session) clear() {
	if s.db != nil {
		defer s.db.Close()
	}
	if s.ddlDB != nil {
		defer s.ddlDB.Close()
	}
	if s.backupdb != nil {
		defer s.backupdb.Close()
	}

	s.dbName = ""
	s.haveBegin = false
	s.haveCommit = false
	s.threadID = 0
	s.isClusterNode = false

	for key, t := range s.tableCacheList {
		t.Indexes = nil
		t.Fields = nil
		t.Partitions = nil
		delete(s.tableCacheList, key)
	}

	s.tableCacheList = nil
	s.dbCacheList = nil
	s.backupDBCacheList = nil
	s.backupTableCacheList = nil
	s.sqlFingerprint = nil

	s.incLevel = nil

	s.recordSets = nil
	s.printSets = nil
	s.splitSets = nil

	s.statsCollector = nil
	s.opt = nil
	s.ch = nil
	s.chBackupRecord = nil
	s.insertBuffer = nil
	s.statistics = nil
	s.alterRollbackBuffer = nil

	// runtime.GC()
	// printMemStats()
}

func (s *session) Audit(ctx context.Context, sql string) ([]Record, error) {

	if s.opt == nil {
		return nil, errors.New("未配置数据源信息!")
	}

	s.init()
	defer s.clear()
	s.opt.Check = true
	err := s.audit(ctx, sql)
	if err != nil {
		log.Error(err)
	}
	return s.makeNewResult()
	// return s.recordSets.records, nil
	// return s.makeResult()
}

func (s *session) RunExecute(ctx context.Context, sql string) ([]Record, error) {

	if s.opt == nil {
		return nil, errors.New("未配置数据源信息!")
	}

	s.init()
	defer s.clear()

	s.opt.Check = false
	s.opt.Execute = true
	s.audit(ctx, sql)

	if s.hasErrorBefore() {
		return s.makeNewResult()
	}
	s.executeCommit(ctx)

	return s.makeNewResult()
	// return s.makeResult()
}

func (s *session) LoadOptions(opt SourceOptions) error {
	s.opt = &opt
	// return s.parseOptions()
	return nil
}

func (s *session) audit(ctx context.Context, sql string) (err error) {

	sqlList := strings.Split(sql, "\n")

	// tidb执行的SQL关闭general日志
	logging := s.inc.GeneralLog

	defer func() {
		if s.sessionVars.StmtCtx.AffectedRows() == 0 {
			if s.opt != nil && s.opt.Print {
				s.sessionVars.StmtCtx.AddAffectedRows(uint64(s.printSets.rc.count))
			} else if s.opt != nil && s.opt.split {
				s.sessionVars.StmtCtx.AddAffectedRows(uint64(s.splitSets.rc.count))
			} else {
				s.sessionVars.StmtCtx.AddAffectedRows(uint64(len(s.recordSets.records)))
			}
		}

		if logging {
			logQuery(sql, s.sessionVars)
		}
	}()

	// s.PrepareTxnCtx(ctx)
	connID := s.sessionVars.ConnectionID
	// connID := 1
	// err = s.loadCommonGlobalVariablesIfNeeded()
	// if err != nil {
	// 	return nil, errors.Trace(err)
	// }

	charsetInfo, collation := s.sessionVars.GetCharsetInfo()

	lineCount := len(sqlList) - 1
	// batchSize := 1

	tmp := s.processInfo.Load()
	if tmp != nil {
		pi := tmp.(util.ProcessInfo)
		pi.OperState = "CHECKING"
		pi.Percent = 0
		s.processInfo.Store(pi)
	}

	s.stage = StageCheck

	err = s.checkOptions()
	if err != nil {
		return err
	}

	if s.opt.Print {
		s.printSets = NewPrintSets()
	} else if s.opt.split {
		s.splitSets = NewSplitSets()
	} else {
		s.recordSets = NewRecordSets()
	}

	// sql指纹设置取并集
	if s.opt.fingerprint {
		s.inc.EnableFingerprint = true
	}

	if s.inc.EnableFingerprint {
		s.sqlFingerprint = make(map[string]*Record, 64)
	}

	var buf []string

	quotaIsDouble := true
	for i, sql_line := range sqlList {

		// 100行解析一次
		// 如果以分号结尾,或者是最后一行,就做解析
		// strings.HasSuffix(sql_line, ";")
		// && batchSize >= 100)

		if strings.Count(sql_line, "'")%2 == 1 {
			quotaIsDouble = !quotaIsDouble
		}

		if ((strings.HasSuffix(sql_line, ";") || strings.HasSuffix(sql_line, ";\r")) &&
			quotaIsDouble) || i == lineCount {
			// batchSize = 1
			buf = append(buf, sql_line)
			s1 := strings.Join(buf, "\n")

			s1 = strings.TrimRight(s1, ";")

			stmtNodes, err := s.ParseSQL(ctx, s1, charsetInfo, collation)

			if err == nil && len(stmtNodes) == 0 {
				tmpSQL := strings.TrimSpace(s1)
				// 未成功解析时,添加异常判断
				if tmpSQL != "" &&
					!strings.HasPrefix(tmpSQL, "#") &&
					!strings.HasPrefix(tmpSQL, "--") &&
					!strings.HasPrefix(tmpSQL, "/*") {
					err = errors.New("解析失败! 可能是解析器bug,请联系作者.")
				}
			}

			if err != nil {
				log.Errorf("con:%d 解析失败! %s", connID, err)
				log.Error(s1)
				if s.opt != nil && s.opt.Print {
					s.printSets.Append(2, strings.TrimSpace(s1), "", err.Error())
				} else if s.opt != nil && s.opt.split {
					s.addNewSplitNode()
					s.splitSets.Append(strings.TrimSpace(s1), err.Error())
				} else {
					s.recordSets.Append(&Record{
						Sql:          strings.TrimSpace(s1),
						ErrLevel:     2,
						ErrorMessage: err.Error(),
					})
				}
				return err
			}

			for i, stmtNode := range stmtNodes {
				//  是ASCII码160的特殊空格
				currentSql := strings.Trim(stmtNode.Text(), " ;\t\n\v\f\r ")

				s.myRecord = &Record{
					Sql:   currentSql,
					Buf:   new(bytes.Buffer),
					Type:  stmtNode,
					Stage: StageCheck,
				}

				s.SetMyProcessInfo(currentSql, time.Now(), float64(i)/float64(lineCount+1))

				var result []sqlexec.RecordSet
				var err error
				if s.opt != nil && s.opt.Print {
					result, err = s.printCommand(ctx, stmtNode, currentSql)
				} else if s.opt != nil && s.opt.split {
					result, err = s.splitCommand(ctx, stmtNode, currentSql)
				} else {
					result, err = s.processCommand(ctx, stmtNode, currentSql)
				}
				if err != nil {
					return err
				}
				if result != nil {
					return nil
				}

				// 进程Killed
				if err := checkClose(ctx); err != nil {
					log.Warn("Killed: ", err)
					s.appendErrorMessage("Operation has been killed!")
					if s.opt != nil && s.opt.Print {
						s.printSets.Append(2, "", "", strings.TrimSpace(s.myRecord.Buf.String()))
					} else if s.opt != nil && s.opt.split {
						s.addNewSplitNode()
						s.splitSets.Append("", strings.TrimSpace(s.myRecord.Buf.String()))
					} else {
						s.recordSets.Append(s.myRecord)
					}
					return err
				}

				if s.opt != nil && s.opt.Print {
					// s.printSets.Append(2, "", "", strings.TrimSpace(s.myRecord.Buf.String()))
				} else {
					// 远程操作时隐藏本地的set命令
					if _, ok := stmtNode.(*ast.InceptionSetStmt); ok && s.myRecord.ErrLevel == 0 {
						log.Info(currentSql)
					} else {
						s.recordSets.Append(s.myRecord)
					}
				}
			}

			buf = nil

		} else if i < lineCount {
			buf = append(buf, sql_line)
			// batchSize++
		}
	}

	return nil

}

// checkOptions 校验配置信息
func (s *session) checkOptions() error {

	if s.opt == nil {
		return errors.New("未配置数据源信息!")
	}

	if s.opt.split || s.opt.Check || s.opt.Print {
		s.opt.Execute = false
		s.opt.Backup = false

		// 审核阶段自动忽略警告,以免审核过早中止
		s.opt.IgnoreWarnings = true
	}

	if s.opt.sleep <= 0 {
		s.opt.sleepRows = 0
	} else if s.opt.sleepRows < 1 {
		s.opt.sleepRows = 1
	}

	if s.opt.split || s.opt.Print {
		s.opt.Check = false
	}

	// 不再检查密码是否为空
	if s.opt.Host == "" || s.opt.Port == 0 || s.opt.User == "" {
		log.Warningf("%#v", s.opt)
		msg := ""
		if s.opt.Host == "" {
			msg += "主机名为空,"
		}
		if s.opt.Port == 0 {
			msg += "端口为0,"
		}
		if s.opt.User == "" {
			msg += "用户名为空,"
		}
		return fmt.Errorf(s.getErrorMessage(ER_SQL_INVALID_SOURCE), strings.TrimRight(msg, ","))
	}

	var addr string
	if s.opt.middlewareExtend == "" {
		tlsValue, err := s.getTLSConfig()
		if err != nil {
			return err
		}
		addr = fmt.Sprintf("%s:%s@tcp(%s:%d)/%s?charset=%s&parseTime=True&loc=Local&maxAllowedPacket=%d&tls=%s",
			s.opt.User, s.opt.Password, s.opt.Host, s.opt.Port, s.opt.db,
			s.inc.DefaultCharset, s.inc.MaxAllowedPacket, tlsValue)
	} else {
		s.opt.middlewareExtend = fmt.Sprintf("/*%s*/",
			strings.Replace(s.opt.middlewareExtend, ": ", "=", 1))

		addr = fmt.Sprintf("%s:%s@tcp(%s:%d)/%s?charset=%s&parseTime=True&loc=Local&maxAllowedPacket=%d&maxOpen=100&maxLifetime=60",
			s.opt.User, s.opt.Password, s.opt.Host, s.opt.Port,
			s.opt.middlewareDB, s.inc.DefaultCharset, s.inc.MaxAllowedPacket)
	}

	if s.inc.SqlMode != "" {
		addr = fmt.Sprintf("%s&sql_mode=%s",
			addr, s.inc.SqlMode)
	}

	db, err := gorm.Open("mysql", fmt.Sprintf("%s&autocommit=1", addr))

	if err != nil {
		return fmt.Errorf("con:%d %v", s.sessionVars.ConnectionID, err)
	}

	if s.opt.tranBatch > 1 {
		s.ddlDB, _ = gorm.Open("mysql", fmt.Sprintf("%s&autocommit=1", addr))
		s.ddlDB.LogMode(false)
	}

	// 禁用日志记录器,不显示任何日志
	db.LogMode(false)

	s.db = db

	s.dbName = s.opt.db

	if s.opt.Execute {
		if s.opt.Backup && !s.checkBinlogIsOn() {
			return errors.New("binlog日志未开启,无法备份!")
		}
	}

	if s.opt.Backup {
		// 不再检查密码是否为空
		if s.inc.BackupHost == "" || s.inc.BackupPort == 0 || s.inc.BackupUser == "" {
			return errors.New(s.getErrorMessage(ER_INVALID_BACKUP_HOST_INFO))
		} else {
			addr = fmt.Sprintf("%s:%s@tcp(%s:%d)/?charset=%s&parseTime=True&loc=Local&autocommit=1",
				s.inc.BackupUser, s.inc.BackupPassword, s.inc.BackupHost, s.inc.BackupPort,
				s.inc.DefaultCharset)
			backupdb, err := gorm.Open("mysql", addr)

			if err != nil {
				return fmt.Errorf("con:%d %v", s.sessionVars.ConnectionID, err)
			}

			backupdb.LogMode(false)
			s.backupdb = backupdb
		}
	}

	tmp := s.processInfo.Load()
	if tmp != nil {
		pi := tmp.(util.ProcessInfo)
		pi.DestHost = s.opt.Host
		pi.DestPort = s.opt.Port
		pi.DestUser = s.opt.User

		if s.opt.Check {
			pi.Command = "CHECK"
		} else if s.opt.Execute {
			pi.Command = "EXECUTE"
		}
		s.processInfo.Store(pi)
	}

	s.mysqlServerVersion()
	s.setSqlSafeUpdates()
	s.setLockWaitTimeout()

	if s.opt.Backup && s.dbType == DBTypeTiDB {
		s.appendErrorMessage("TiDB暂不支持备份功能.")
	}

	return nil
}

func printMemStats() {
	var m runtime.MemStats
	runtime.ReadMemStats(&m)
	log.Errorf("Alloc = %vMB TotalAlloc = %vMB Sys = %vMB NumGC = %v\n",
		m.Alloc/1024/1024, m.TotalAlloc/1024/1024, m.Sys/1024/1024, m.NumGC)
}