// Copyright 2016 PingCAP, Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // See the License for the specific language governing permissions and // limitations under the License. package expression import ( "strconv" "strings" "time" "unicode" "github.com/hanchuanchuan/goInception/ast" "github.com/hanchuanchuan/goInception/mysql" "github.com/hanchuanchuan/goInception/parser/opcode" "github.com/hanchuanchuan/goInception/sessionctx" "github.com/hanchuanchuan/goInception/terror" "github.com/hanchuanchuan/goInception/types" "github.com/hanchuanchuan/goInception/util/chunk" "github.com/hanchuanchuan/goInception/util/hack" "github.com/pingcap/errors" ) // Filter the input expressions, append the results to result. func Filter(result []Expression, input []Expression, filter func(Expression) bool) []Expression { for _, e := range input { if filter(e) { result = append(result, e) } } return result } // ExtractColumns extracts all columns from an expression. func ExtractColumns(expr Expression) (cols []*Column) { // Pre-allocate a slice to reduce allocation, 8 doesn't have special meaning. result := make([]*Column, 0, 8) return extractColumns(result, expr, nil) } // ExtractCorColumns extracts correlated column from given expression. func ExtractCorColumns(expr Expression) (cols []*CorrelatedColumn) { switch v := expr.(type) { case *CorrelatedColumn: return []*CorrelatedColumn{v} case *ScalarFunction: for _, arg := range v.GetArgs() { cols = append(cols, ExtractCorColumns(arg)...) } } return } // ExtractColumnsFromExpressions is a more efficient version of ExtractColumns for batch operation. // filter can be nil, or a function to filter the result column. // It's often observed that the pattern of the caller like this: // // cols := ExtractColumns(...) // for _, col := range cols { // if xxx(col) {...} // } // // Provide an additional filter argument, this can be done in one step. // To avoid allocation for cols that not need. func ExtractColumnsFromExpressions(result []*Column, exprs []Expression, filter func(*Column) bool) []*Column { for _, expr := range exprs { result = extractColumns(result, expr, filter) } return result } func extractColumns(result []*Column, expr Expression, filter func(*Column) bool) []*Column { switch v := expr.(type) { case *Column: if filter == nil || filter(v) { result = append(result, v) } case *ScalarFunction: for _, arg := range v.GetArgs() { result = extractColumns(result, arg, filter) } } return result } // ColumnSubstitute substitutes the columns in filter to expressions in select fields. // e.g. select * from (select b as a from t) k where a < 10 => select * from (select b as a from t where b < 10) k. func ColumnSubstitute(expr Expression, schema *Schema, newExprs []Expression) Expression { switch v := expr.(type) { case *Column: id := schema.ColumnIndex(v) if id == -1 { return v } return newExprs[id] case *ScalarFunction: if v.FuncName.L == ast.Cast { newFunc := v.Clone().(*ScalarFunction) newFunc.GetArgs()[0] = ColumnSubstitute(newFunc.GetArgs()[0], schema, newExprs) return newFunc } newArgs := make([]Expression, 0, len(v.GetArgs())) for _, arg := range v.GetArgs() { newArgs = append(newArgs, ColumnSubstitute(arg, schema, newExprs)) } return NewFunctionInternal(v.GetCtx(), v.FuncName.L, v.RetType, newArgs...) } return expr } // getValidPrefix gets a prefix of string which can parsed to a number with base. the minimum base is 2 and the maximum is 36. func getValidPrefix(s string, base int64) string { var ( validLen int upper rune ) switch { case base >= 2 && base <= 9: upper = rune('0' + base) case base <= 36: upper = rune('A' + base - 10) default: return "" } Loop: for i := 0; i < len(s); i++ { c := rune(s[i]) switch { case unicode.IsDigit(c) || unicode.IsLower(c) || unicode.IsUpper(c): c = unicode.ToUpper(c) if c < upper { validLen = i + 1 } else { break Loop } case c == '+' || c == '-': if i != 0 { break Loop } default: break Loop } } if validLen > 1 && s[0] == '+' { return s[1:validLen] } return s[:validLen] } // SubstituteCorCol2Constant will substitute correlated column to constant value which it contains. // If the args of one scalar function are all constant, we will substitute it to constant. func SubstituteCorCol2Constant(expr Expression) (Expression, error) { switch x := expr.(type) { case *ScalarFunction: allConstant := true newArgs := make([]Expression, 0, len(x.GetArgs())) for _, arg := range x.GetArgs() { newArg, err := SubstituteCorCol2Constant(arg) if err != nil { return nil, errors.Trace(err) } _, ok := newArg.(*Constant) newArgs = append(newArgs, newArg) allConstant = allConstant && ok } if allConstant { val, err := x.Eval(chunk.Row{}) if err != nil { return nil, errors.Trace(err) } return &Constant{Value: val, RetType: x.GetType()}, nil } var newSf Expression if x.FuncName.L == ast.Cast { newSf = BuildCastFunction(x.GetCtx(), newArgs[0], x.RetType) } else { newSf = NewFunctionInternal(x.GetCtx(), x.FuncName.L, x.GetType(), newArgs...) } return newSf, nil case *CorrelatedColumn: return &Constant{Value: *x.Data, RetType: x.GetType()}, nil case *Constant: if x.DeferredExpr != nil { newExpr := FoldConstant(x) return &Constant{Value: newExpr.(*Constant).Value, RetType: x.GetType()}, nil } } return expr, nil } // timeZone2Duration converts timezone whose format should satisfy the regular condition // `(^(+|-)(0?[0-9]|1[0-2]):[0-5]?\d$)|(^+13:00$)` to time.Duration. func timeZone2Duration(tz string) time.Duration { sign := 1 if strings.HasPrefix(tz, "-") { sign = -1 } i := strings.Index(tz, ":") h, err := strconv.Atoi(tz[1:i]) terror.Log(errors.Trace(err)) m, err := strconv.Atoi(tz[i+1:]) terror.Log(errors.Trace(err)) return time.Duration(sign) * (time.Duration(h)*time.Hour + time.Duration(m)*time.Minute) } var oppositeOp = map[string]string{ ast.LT: ast.GE, ast.GE: ast.LT, ast.GT: ast.LE, ast.LE: ast.GT, ast.EQ: ast.NE, ast.NE: ast.EQ, } // a op b is equal to b symmetricOp a var symmetricOp = map[opcode.Op]opcode.Op{ opcode.LT: opcode.GT, opcode.GE: opcode.LE, opcode.GT: opcode.LT, opcode.LE: opcode.GE, opcode.EQ: opcode.EQ, opcode.NE: opcode.NE, opcode.NullEQ: opcode.NullEQ, } // PushDownNot pushes the `not` function down to the expression's arguments. func PushDownNot(ctx sessionctx.Context, expr Expression, not bool) Expression { if f, ok := expr.(*ScalarFunction); ok { switch f.FuncName.L { case ast.UnaryNot: return PushDownNot(f.GetCtx(), f.GetArgs()[0], !not) case ast.LT, ast.GE, ast.GT, ast.LE, ast.EQ, ast.NE: if not { return NewFunctionInternal(f.GetCtx(), oppositeOp[f.FuncName.L], f.GetType(), f.GetArgs()...) } for i, arg := range f.GetArgs() { f.GetArgs()[i] = PushDownNot(f.GetCtx(), arg, false) } return f case ast.LogicAnd: if not { args := f.GetArgs() for i, a := range args { args[i] = PushDownNot(f.GetCtx(), a, true) } return NewFunctionInternal(f.GetCtx(), ast.LogicOr, f.GetType(), args...) } for i, arg := range f.GetArgs() { f.GetArgs()[i] = PushDownNot(f.GetCtx(), arg, false) } return f case ast.LogicOr: if not { args := f.GetArgs() for i, a := range args { args[i] = PushDownNot(f.GetCtx(), a, true) } return NewFunctionInternal(f.GetCtx(), ast.LogicAnd, f.GetType(), args...) } for i, arg := range f.GetArgs() { f.GetArgs()[i] = PushDownNot(f.GetCtx(), arg, false) } return f } } if not { expr = NewFunctionInternal(ctx, ast.UnaryNot, types.NewFieldType(mysql.TypeTiny), expr) } return expr } // Contains tests if `exprs` contains `e`. func Contains(exprs []Expression, e Expression) bool { for _, expr := range exprs { if e == expr { return true } } return false } // ExtractFiltersFromDNFs checks whether the cond is DNF. If so, it will get the extracted part and the remained part. // The original DNF will be replaced by the remained part or just be deleted if remained part is nil. // And the extracted part will be appended to the end of the orignal slice. func ExtractFiltersFromDNFs(ctx sessionctx.Context, conditions []Expression) []Expression { var allExtracted []Expression for i := len(conditions) - 1; i >= 0; i-- { if sf, ok := conditions[i].(*ScalarFunction); ok && sf.FuncName.L == ast.LogicOr { extracted, remained := extractFiltersFromDNF(ctx, sf) allExtracted = append(allExtracted, extracted...) if remained == nil { conditions = append(conditions[:i], conditions[i+1:]...) } else { conditions[i] = remained } } } return append(conditions, allExtracted...) } // extractFiltersFromDNF extracts the same condition that occurs in every DNF item and remove them from dnf leaves. func extractFiltersFromDNF(ctx sessionctx.Context, dnfFunc *ScalarFunction) ([]Expression, Expression) { dnfItems := FlattenDNFConditions(dnfFunc) sc := ctx.GetSessionVars().StmtCtx codeMap := make(map[string]int) hashcode2Expr := make(map[string]Expression) for i, dnfItem := range dnfItems { innerMap := make(map[string]struct{}) cnfItems := SplitCNFItems(dnfItem) for _, cnfItem := range cnfItems { code := cnfItem.HashCode(sc) if i == 0 { codeMap[hack.String(code)] = 1 hashcode2Expr[hack.String(code)] = cnfItem } else if _, ok := codeMap[hack.String(code)]; ok { // We need this check because there may be the case like `select * from t, t1 where (t.a=t1.a and t.a=t1.a) or (something). // We should make sure that the two `t.a=t1.a` contributes only once. // TODO: do this out of this function. if _, ok = innerMap[hack.String(code)]; !ok { codeMap[hack.String(code)]++ innerMap[hack.String(code)] = struct{}{} } } } } // We should make sure that this item occurs in every DNF item. for hashcode, cnt := range codeMap { if cnt < len(dnfItems) { delete(hashcode2Expr, hashcode) } } if len(hashcode2Expr) == 0 { return nil, dnfFunc } newDNFItems := make([]Expression, 0, len(dnfItems)) onlyNeedExtracted := false for _, dnfItem := range dnfItems { cnfItems := SplitCNFItems(dnfItem) newCNFItems := make([]Expression, 0, len(cnfItems)) for _, cnfItem := range cnfItems { code := cnfItem.HashCode(sc) _, ok := hashcode2Expr[hack.String(code)] if !ok { newCNFItems = append(newCNFItems, cnfItem) } } // If the extracted part is just one leaf of the DNF expression. Then the value of the total DNF expression is // always the same with the value of the extracted part. if len(newCNFItems) == 0 { onlyNeedExtracted = true break } newDNFItems = append(newDNFItems, ComposeCNFCondition(ctx, newCNFItems...)) } extractedExpr := make([]Expression, 0, len(hashcode2Expr)) for _, expr := range hashcode2Expr { extractedExpr = append(extractedExpr, expr) } if onlyNeedExtracted { return extractedExpr, nil } return extractedExpr, ComposeDNFCondition(ctx, newDNFItems...) } // DeriveRelaxedFiltersFromDNF given a DNF expression, derive a relaxed DNF expression which only contains columns // in specified schema; the derived expression is a superset of original expression, i.e, any tuple satisfying // the original expression must satisfy the derived expression. Return nil when the derived expression is universal set. // A running example is: for schema of t1, `(t1.a=1 and t2.a=1) or (t1.a=2 and t2.a=2)` would be derived as // `t1.a=1 or t1.a=2`, while `t1.a=1 or t2.a=1` would get nil. func DeriveRelaxedFiltersFromDNF(expr Expression, schema *Schema) Expression { sf, ok := expr.(*ScalarFunction) if !ok || sf.FuncName.L != ast.LogicOr { return nil } ctx := sf.GetCtx() dnfItems := FlattenDNFConditions(sf) newDNFItems := make([]Expression, 0, len(dnfItems)) for _, dnfItem := range dnfItems { cnfItems := SplitCNFItems(dnfItem) newCNFItems := make([]Expression, 0, len(cnfItems)) for _, cnfItem := range cnfItems { if itemSF, ok := cnfItem.(*ScalarFunction); ok && itemSF.FuncName.L == ast.LogicOr { relaxedCNFItem := DeriveRelaxedFiltersFromDNF(cnfItem, schema) if relaxedCNFItem != nil { newCNFItems = append(newCNFItems, relaxedCNFItem) } // If relaxed expression for embedded DNF is universal set, just drop this CNF item continue } // This cnfItem must be simple expression now // If it cannot be fully covered by schema, just drop this CNF item if ExprFromSchema(cnfItem, schema) { newCNFItems = append(newCNFItems, cnfItem) } } // If this DNF item involves no column of specified schema, the relaxed expression must be universal set if len(newCNFItems) == 0 { return nil } newDNFItems = append(newDNFItems, ComposeCNFCondition(ctx, newCNFItems...)) } return ComposeDNFCondition(ctx, newDNFItems...) } // GetRowLen gets the length if the func is row, returns 1 if not row. func GetRowLen(e Expression) int { if f, ok := e.(*ScalarFunction); ok && f.FuncName.L == ast.RowFunc { return len(f.GetArgs()) } return 1 } // CheckArgsNotMultiColumnRow checks the args are not multi-column row. func CheckArgsNotMultiColumnRow(args ...Expression) error { for _, arg := range args { if GetRowLen(arg) != 1 { return ErrOperandColumns.GenWithStackByArgs(1) } } return nil } // GetFuncArg gets the argument of the function at idx. func GetFuncArg(e Expression, idx int) Expression { if f, ok := e.(*ScalarFunction); ok { return f.GetArgs()[idx] } return nil } // PopRowFirstArg pops the first element and returns the rest of row. // e.g. After this function (1, 2, 3) becomes (2, 3). func PopRowFirstArg(ctx sessionctx.Context, e Expression) (ret Expression, err error) { if f, ok := e.(*ScalarFunction); ok && f.FuncName.L == ast.RowFunc { args := f.GetArgs() if len(args) == 2 { return args[1], nil } ret, err = NewFunction(ctx, ast.RowFunc, f.GetType(), args[1:]...) return ret, errors.Trace(err) } return } // exprStack is a stack of expressions. type exprStack struct { stack []Expression } // pop pops an expression from the stack. func (s *exprStack) pop() Expression { if s.len() == 0 { return nil } lastIdx := s.len() - 1 expr := s.stack[lastIdx] s.stack = s.stack[:lastIdx] return expr } // popN pops n expressions from the stack. // If n greater than stack length or n is negative, it pops all the expressions. func (s *exprStack) popN(n int) []Expression { if n > s.len() || n < 0 { n = s.len() } idx := s.len() - n exprs := s.stack[idx:] s.stack = s.stack[:idx] return exprs } // push pushes one expression to the stack. func (s *exprStack) push(expr Expression) { s.stack = append(s.stack, expr) } // len returns the length of th stack. func (s *exprStack) len() int { return len(s.stack) } // ColumnSliceIsIntersect checks whether two column slice is intersected. func ColumnSliceIsIntersect(s1, s2 []*Column) bool { intSet := map[int64]struct{}{} for _, col := range s1 { intSet[col.UniqueID] = struct{}{} } for _, col := range s2 { if _, ok := intSet[col.UniqueID]; ok { return true } } return false }