// Copyright 2018 PingCAP, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// See the License for the specific language governing permissions and
// limitations under the License.

package expression

import (
	"github.com/hanchuanchuan/goInception/ast"
	"github.com/hanchuanchuan/goInception/model"
	"github.com/hanchuanchuan/goInception/mysql"
	"github.com/hanchuanchuan/goInception/parser"
	"github.com/hanchuanchuan/goInception/parser/opcode"
	"github.com/hanchuanchuan/goInception/sessionctx"
	"github.com/hanchuanchuan/goInception/types"
	"github.com/pingcap/errors"
)

type simpleRewriter struct {
	exprStack

	schema *Schema
	err    error
	ctx    sessionctx.Context
}

// ParseSimpleExprWithTableInfo parses simple expression string to Expression.
// The expression string must only reference the column in table Info.
func ParseSimpleExprWithTableInfo(ctx sessionctx.Context, exprStr string, tableInfo *model.TableInfo) (Expression, error) {
	exprStr = "select " + exprStr
	stmts, _, err := parser.New().Parse(exprStr, "", "")
	if err != nil {
		return nil, errors.Trace(err)
	}
	expr := stmts[0].(*ast.SelectStmt).Fields.Fields[0].Expr
	return RewriteSimpleExprWithTableInfo(ctx, tableInfo, expr)
}

// ParseSimpleExprCastWithTableInfo parses simple expression string to Expression.
// And the expr returns will cast to the target type.
func ParseSimpleExprCastWithTableInfo(ctx sessionctx.Context, exprStr string, tableInfo *model.TableInfo, targetFt *types.FieldType) (Expression, error) {
	e, err := ParseSimpleExprWithTableInfo(ctx, exprStr, tableInfo)
	if err != nil {
		return nil, errors.Trace(err)
	}
	e = BuildCastFunction(ctx, e, targetFt)
	return e, nil
}

// RewriteSimpleExprWithTableInfo rewrites simple ast.ExprNode to expression.Expression.
func RewriteSimpleExprWithTableInfo(ctx sessionctx.Context, tbl *model.TableInfo, expr ast.ExprNode) (Expression, error) {
	dbName := model.NewCIStr(ctx.GetSessionVars().CurrentDB)
	columns := ColumnInfos2ColumnsWithDBName(ctx, dbName, tbl.Name, tbl.Columns)
	rewriter := &simpleRewriter{ctx: ctx, schema: NewSchema(columns...)}
	expr.Accept(rewriter)
	if rewriter.err != nil {
		return nil, errors.Trace(rewriter.err)
	}
	return rewriter.pop(), nil
}

// ParseSimpleExprsWithSchema parses simple expression string to Expression.
// The expression string must only reference the column in the given schema.
func ParseSimpleExprsWithSchema(ctx sessionctx.Context, exprStr string, schema *Schema) ([]Expression, error) {
	exprStr = "select " + exprStr
	stmts, _, err := parser.New().Parse(exprStr, "", "")
	if err != nil {
		return nil, errors.Trace(err)
	}
	fields := stmts[0].(*ast.SelectStmt).Fields.Fields
	exprs := make([]Expression, 0, len(fields))
	for _, field := range fields {
		expr, err := RewriteSimpleExprWithSchema(ctx, field.Expr, schema)
		if err != nil {
			return nil, errors.Trace(err)
		}
		exprs = append(exprs, expr)
	}
	return exprs, nil
}

// RewriteSimpleExprWithSchema rewrites simple ast.ExprNode to expression.Expression.
func RewriteSimpleExprWithSchema(ctx sessionctx.Context, expr ast.ExprNode, schema *Schema) (Expression, error) {
	rewriter := &simpleRewriter{ctx: ctx, schema: schema}
	expr.Accept(rewriter)
	if rewriter.err != nil {
		return nil, errors.Trace(rewriter.err)
	}
	return rewriter.pop(), nil
}

func (sr *simpleRewriter) rewriteColumn(nodeColName *ast.ColumnNameExpr) (*Column, error) {
	col := sr.schema.FindColumnByName(nodeColName.Name.Name.L)
	if col != nil {
		return col, nil
	}
	return nil, errBadField.GenWithStackByArgs(nodeColName.Name.Name.O, "expression")
}

func (sr *simpleRewriter) Enter(inNode ast.Node) (ast.Node, bool) {
	return inNode, false
}

func (sr *simpleRewriter) Leave(originInNode ast.Node) (retNode ast.Node, ok bool) {
	switch v := originInNode.(type) {
	case *ast.ColumnNameExpr:
		column, err := sr.rewriteColumn(v)
		if err != nil {
			sr.err = errors.Trace(err)
			return originInNode, false
		}
		sr.push(column)
	case *ast.ValueExpr:
		value := &Constant{Value: v.Datum, RetType: &v.Type}
		sr.push(value)
	case *ast.FuncCallExpr:
		sr.funcCallToExpression(v)
	case *ast.FuncCastExpr:
		arg := sr.pop()
		sr.err = errors.Trace(CheckArgsNotMultiColumnRow(arg))
		if sr.err != nil {
			return retNode, false
		}
		sr.push(BuildCastFunction(sr.ctx, arg, v.Tp))
	case *ast.BinaryOperationExpr:
		sr.binaryOpToExpression(v)
	case *ast.UnaryOperationExpr:
		sr.unaryOpToExpression(v)
	case *ast.BetweenExpr:
		sr.betweenToExpression(v)
	case *ast.IsNullExpr:
		sr.isNullToExpression(v)
	case *ast.IsTruthExpr:
		sr.isTrueToScalarFunc(v)
	case *ast.PatternLikeExpr:
		sr.likeToScalarFunc(v)
	case *ast.PatternRegexpExpr:
		sr.regexpToScalarFunc(v)
	case *ast.PatternInExpr:
		if v.Sel == nil {
			sr.inToExpression(len(v.List), v.Not, &v.Type)
		}
	case *ast.ParamMarkerExpr:
		tp := types.NewFieldType(mysql.TypeUnspecified)
		types.DefaultParamTypeForValue(v.GetValue(), tp)
		value := &Constant{Value: v.Datum, RetType: tp}
		sr.push(value)
	case *ast.RowExpr:
		sr.rowToScalarFunc(v)
	case *ast.ParenthesesExpr:
	case *ast.ColumnName:
	default:
		sr.err = errors.Errorf("UnknownType: %T", v)
		return retNode, false
	}
	if sr.err != nil {
		return retNode, false
	}
	return originInNode, true
}

func (sr *simpleRewriter) binaryOpToExpression(v *ast.BinaryOperationExpr) {
	right := sr.pop()
	left := sr.pop()
	var function Expression
	switch v.Op {
	case opcode.EQ, opcode.NE, opcode.NullEQ, opcode.GT, opcode.GE, opcode.LT, opcode.LE:
		function, sr.err = sr.constructBinaryOpFunction(left, right,
			v.Op.String())
	default:
		lLen := GetRowLen(left)
		rLen := GetRowLen(right)
		if lLen != 1 || rLen != 1 {
			sr.err = ErrOperandColumns.GenWithStackByArgs(1)
			return
		}
		function, sr.err = NewFunction(sr.ctx, v.Op.String(), types.NewFieldType(mysql.TypeUnspecified), left, right)
	}
	if sr.err != nil {
		sr.err = errors.Trace(sr.err)
		return
	}
	sr.push(function)
}

func (sr *simpleRewriter) funcCallToExpression(v *ast.FuncCallExpr) {
	args := sr.popN(len(v.Args))
	sr.err = errors.Trace(CheckArgsNotMultiColumnRow(args...))
	if sr.err != nil {
		return
	}
	if sr.rewriteFuncCall(v) {
		return
	}
	var function Expression
	function, sr.err = NewFunction(sr.ctx, v.FnName.L, &v.Type, args...)
	sr.push(function)
}

func (sr *simpleRewriter) rewriteFuncCall(v *ast.FuncCallExpr) bool {
	switch v.FnName.L {
	case ast.Nullif:
		if len(v.Args) != 2 {
			sr.err = ErrIncorrectParameterCount.GenWithStackByArgs(v.FnName.O)
			return true
		}
		param2 := sr.pop()
		param1 := sr.pop()
		// param1 = param2
		funcCompare, err := sr.constructBinaryOpFunction(param1, param2, ast.EQ)
		if err != nil {
			sr.err = err
			return true
		}
		// NULL
		nullTp := types.NewFieldType(mysql.TypeNull)
		nullTp.Flen, nullTp.Decimal = mysql.GetDefaultFieldLengthAndDecimal(mysql.TypeNull)
		paramNull := &Constant{
			Value:   types.NewDatum(nil),
			RetType: nullTp,
		}
		// if(param1 = param2, NULL, param1)
		funcIf, err := NewFunction(sr.ctx, ast.If, &v.Type, funcCompare, paramNull, param1)
		if err != nil {
			sr.err = err
			return true
		}
		sr.push(funcIf)
		return true
	default:
		return false
	}
}

// constructBinaryOpFunction works as following:
// 1. If op are EQ or NE or NullEQ, constructBinaryOpFunctions converts (a0,a1,a2) op (b0,b1,b2) to (a0 op b0) and (a1 op b1) and (a2 op b2)
// 2. If op are LE or GE, constructBinaryOpFunctions converts (a0,a1,a2) op (b0,b1,b2) to
// `IF( (a0 op b0) EQ 0, 0,
//      IF ( (a1 op b1) EQ 0, 0, a2 op b2))`
// 3. If op are LT or GT, constructBinaryOpFunctions converts (a0,a1,a2) op (b0,b1,b2) to
// `IF( a0 NE b0, a0 op b0,
//      IF( a1 NE b1,
//          a1 op b1,
//          a2 op b2)
// )`
func (sr *simpleRewriter) constructBinaryOpFunction(l Expression, r Expression, op string) (Expression, error) {
	lLen, rLen := GetRowLen(l), GetRowLen(r)
	if lLen == 1 && rLen == 1 {
		return NewFunction(sr.ctx, op, types.NewFieldType(mysql.TypeTiny), l, r)
	} else if rLen != lLen {
		return nil, ErrOperandColumns.GenWithStackByArgs(lLen)
	}
	switch op {
	case ast.EQ, ast.NE, ast.NullEQ:
		funcs := make([]Expression, lLen)
		for i := 0; i < lLen; i++ {
			var err error
			funcs[i], err = sr.constructBinaryOpFunction(GetFuncArg(l, i), GetFuncArg(r, i), op)
			if err != nil {
				return nil, errors.Trace(err)
			}
		}
		return ComposeCNFCondition(sr.ctx, funcs...), nil
	default:
		larg0, rarg0 := GetFuncArg(l, 0), GetFuncArg(r, 0)
		var expr1, expr2, expr3 Expression
		if op == ast.LE || op == ast.GE {
			expr1 = NewFunctionInternal(sr.ctx, op, types.NewFieldType(mysql.TypeTiny), larg0, rarg0)
			expr1 = NewFunctionInternal(sr.ctx, ast.EQ, types.NewFieldType(mysql.TypeTiny), expr1, Zero)
			expr2 = Zero
		} else if op == ast.LT || op == ast.GT {
			expr1 = NewFunctionInternal(sr.ctx, ast.NE, types.NewFieldType(mysql.TypeTiny), larg0, rarg0)
			expr2 = NewFunctionInternal(sr.ctx, op, types.NewFieldType(mysql.TypeTiny), larg0, rarg0)
		}
		var err error
		l, err = PopRowFirstArg(sr.ctx, l)
		if err != nil {
			return nil, errors.Trace(err)
		}
		r, err = PopRowFirstArg(sr.ctx, r)
		if err != nil {
			return nil, errors.Trace(err)
		}
		expr3, err = sr.constructBinaryOpFunction(l, r, op)
		if err != nil {
			return nil, errors.Trace(err)
		}
		return NewFunction(sr.ctx, ast.If, types.NewFieldType(mysql.TypeTiny), expr1, expr2, expr3)
	}
}

func (sr *simpleRewriter) unaryOpToExpression(v *ast.UnaryOperationExpr) {
	var op string
	switch v.Op {
	case opcode.Plus:
		// expression (+ a) is equal to a
		return
	case opcode.Minus:
		op = ast.UnaryMinus
	case opcode.BitNeg:
		op = ast.BitNeg
	case opcode.Not:
		op = ast.UnaryNot
	default:
		sr.err = errors.Errorf("Unknown Unary Op %T", v.Op)
		return
	}
	expr := sr.pop()
	if GetRowLen(expr) != 1 {
		sr.err = ErrOperandColumns.GenWithStackByArgs(1)
		return
	}
	newExpr, err := NewFunction(sr.ctx, op, &v.Type, expr)
	sr.err = err
	sr.push(newExpr)
}

func (sr *simpleRewriter) likeToScalarFunc(v *ast.PatternLikeExpr) {
	pattern := sr.pop()
	expr := sr.pop()
	sr.err = errors.Trace(CheckArgsNotMultiColumnRow(expr, pattern))
	if sr.err != nil {
		return
	}
	escapeTp := &types.FieldType{}
	types.DefaultTypeForValue(int(v.Escape), escapeTp)
	function := sr.notToExpression(v.Not, ast.Like, &v.Type,
		expr, pattern, &Constant{Value: types.NewIntDatum(int64(v.Escape)), RetType: escapeTp})
	sr.push(function)
}

func (sr *simpleRewriter) regexpToScalarFunc(v *ast.PatternRegexpExpr) {
	parttern := sr.pop()
	expr := sr.pop()
	sr.err = errors.Trace(CheckArgsNotMultiColumnRow(expr, parttern))
	if sr.err != nil {
		return
	}
	function := sr.notToExpression(v.Not, ast.Regexp, &v.Type, expr, parttern)
	sr.push(function)
}

func (sr *simpleRewriter) rowToScalarFunc(v *ast.RowExpr) {
	elems := sr.popN(len(v.Values))
	function, err := NewFunction(sr.ctx, ast.RowFunc, elems[0].GetType(), elems...)
	if err != nil {
		sr.err = errors.Trace(err)
		return
	}
	sr.push(function)
}

func (sr *simpleRewriter) betweenToExpression(v *ast.BetweenExpr) {
	right := sr.pop()
	left := sr.pop()
	expr := sr.pop()
	sr.err = errors.Trace(CheckArgsNotMultiColumnRow(expr))
	if sr.err != nil {
		return
	}
	var l, r Expression
	l, sr.err = NewFunction(sr.ctx, ast.GE, &v.Type, expr, left)
	if sr.err == nil {
		r, sr.err = NewFunction(sr.ctx, ast.LE, &v.Type, expr, right)
	}
	if sr.err != nil {
		sr.err = errors.Trace(sr.err)
		return
	}
	function, err := NewFunction(sr.ctx, ast.LogicAnd, &v.Type, l, r)
	if err != nil {
		sr.err = errors.Trace(err)
		return
	}
	if v.Not {
		function, err = NewFunction(sr.ctx, ast.UnaryNot, &v.Type, function)
		if err != nil {
			sr.err = errors.Trace(err)
			return
		}
	}
	sr.push(function)
}

func (sr *simpleRewriter) isNullToExpression(v *ast.IsNullExpr) {
	arg := sr.pop()
	if GetRowLen(arg) != 1 {
		sr.err = ErrOperandColumns.GenWithStackByArgs(1)
		return
	}
	function := sr.notToExpression(v.Not, ast.IsNull, &v.Type, arg)
	sr.push(function)
}

func (sr *simpleRewriter) notToExpression(hasNot bool, op string, tp *types.FieldType,
	args ...Expression) Expression {
	opFunc, err := NewFunction(sr.ctx, op, tp, args...)
	if err != nil {
		sr.err = errors.Trace(err)
		return nil
	}
	if !hasNot {
		return opFunc
	}

	opFunc, err = NewFunction(sr.ctx, ast.UnaryNot, tp, opFunc)
	if err != nil {
		sr.err = errors.Trace(err)
		return nil
	}
	return opFunc
}

func (sr *simpleRewriter) isTrueToScalarFunc(v *ast.IsTruthExpr) {
	arg := sr.pop()
	op := ast.IsTruth
	if v.True == 0 {
		op = ast.IsFalsity
	}
	if GetRowLen(arg) != 1 {
		sr.err = ErrOperandColumns.GenWithStackByArgs(1)
		return
	}
	function := sr.notToExpression(v.Not, op, &v.Type, arg)
	sr.push(function)
}

// inToExpression converts in expression to a scalar function. The argument lLen means the length of in list.
// The argument not means if the expression is not in. The tp stands for the expression type, which is always bool.
// a in (b, c, d) will be rewritten as `(a = b) or (a = c) or (a = d)`.
func (sr *simpleRewriter) inToExpression(lLen int, not bool, tp *types.FieldType) {
	exprs := sr.popN(lLen + 1)
	leftExpr := exprs[0]
	elems := exprs[1:]
	l, leftFt := GetRowLen(leftExpr), leftExpr.GetType()
	for i := 0; i < lLen; i++ {
		if l != GetRowLen(elems[i]) {
			sr.err = ErrOperandColumns.GenWithStackByArgs(l)
			return
		}
	}
	leftIsNull := leftFt.Tp == mysql.TypeNull
	if leftIsNull {
		sr.push(Null.Clone())
		return
	}
	leftEt := leftFt.EvalType()
	if leftEt == types.ETInt {
		for i := 0; i < len(elems); i++ {
			if c, ok := elems[i].(*Constant); ok {
				elems[i], _ = RefineComparedConstant(sr.ctx, mysql.HasUnsignedFlag(leftFt.Flag), c, opcode.EQ)
			}
		}
	}
	allSameType := true
	for _, elem := range elems {
		if elem.GetType().Tp != mysql.TypeNull && GetAccurateCmpType(leftExpr, elem) != leftEt {
			allSameType = false
			break
		}
	}
	var function Expression
	if allSameType && l == 1 {
		function = sr.notToExpression(not, ast.In, tp, exprs...)
	} else {
		eqFunctions := make([]Expression, 0, lLen)
		for i := 0; i < len(elems); i++ {
			expr, err := sr.constructBinaryOpFunction(leftExpr, elems[i], ast.EQ)
			if err != nil {
				sr.err = err
				return
			}
			eqFunctions = append(eqFunctions, expr)
		}
		function = ComposeDNFCondition(sr.ctx, eqFunctions...)
		if not {
			var err error
			function, err = NewFunction(sr.ctx, ast.UnaryNot, tp, function)
			if err != nil {
				sr.err = err
				return
			}
		}
	}
	sr.push(function)
}