1 В избранное 0 Ответвления 0

OSCHINA-MIRROR/hanchuanchuan-goInception

Присоединиться к Gitlife
Откройте для себя и примите участие в публичных проектах с открытым исходным кодом с участием более 10 миллионов разработчиков. Приватные репозитории также полностью бесплатны :)
Присоединиться бесплатно
Клонировать/Скачать
builtin_arithmetic.go 31 КБ
Копировать Редактировать Исходные данные Просмотреть построчно История
hanchuanchuan Отправлено 6 лет назад 27f3c5a
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936
// Copyright 2017 PingCAP, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// See the License for the specific language governing permissions and
// limitations under the License.
package expression
import (
"fmt"
"math"
"github.com/cznic/mathutil"
"github.com/hanchuanchuan/goInception/mysql"
"github.com/hanchuanchuan/goInception/sessionctx"
"github.com/hanchuanchuan/goInception/terror"
"github.com/hanchuanchuan/goInception/types"
"github.com/hanchuanchuan/goInception/util/chunk"
"github.com/pingcap/errors"
"github.com/pingcap/tipb/go-tipb"
)
var (
_ functionClass = &arithmeticPlusFunctionClass{}
_ functionClass = &arithmeticMinusFunctionClass{}
_ functionClass = &arithmeticDivideFunctionClass{}
_ functionClass = &arithmeticMultiplyFunctionClass{}
_ functionClass = &arithmeticIntDivideFunctionClass{}
_ functionClass = &arithmeticModFunctionClass{}
)
var (
_ builtinFunc = &builtinArithmeticPlusRealSig{}
_ builtinFunc = &builtinArithmeticPlusDecimalSig{}
_ builtinFunc = &builtinArithmeticPlusIntSig{}
_ builtinFunc = &builtinArithmeticMinusRealSig{}
_ builtinFunc = &builtinArithmeticMinusDecimalSig{}
_ builtinFunc = &builtinArithmeticMinusIntSig{}
_ builtinFunc = &builtinArithmeticDivideRealSig{}
_ builtinFunc = &builtinArithmeticDivideDecimalSig{}
_ builtinFunc = &builtinArithmeticMultiplyRealSig{}
_ builtinFunc = &builtinArithmeticMultiplyDecimalSig{}
_ builtinFunc = &builtinArithmeticMultiplyIntUnsignedSig{}
_ builtinFunc = &builtinArithmeticMultiplyIntSig{}
_ builtinFunc = &builtinArithmeticIntDivideIntSig{}
_ builtinFunc = &builtinArithmeticIntDivideDecimalSig{}
_ builtinFunc = &builtinArithmeticModIntSig{}
_ builtinFunc = &builtinArithmeticModRealSig{}
_ builtinFunc = &builtinArithmeticModDecimalSig{}
)
// precIncrement indicates the number of digits by which to increase the scale of the result of division operations
// performed with the / operator.
const precIncrement = 4
// numericContextResultType returns types.EvalType for numeric function's parameters.
// the returned types.EvalType should be one of: types.ETInt, types.ETDecimal, types.ETReal
func numericContextResultType(ft *types.FieldType) types.EvalType {
if types.IsTypeTemporal(ft.Tp) {
if ft.Decimal > 0 {
return types.ETDecimal
}
return types.ETInt
}
if types.IsBinaryStr(ft) {
return types.ETInt
}
evalTp4Ft := types.ETReal
if !ft.Hybrid() {
evalTp4Ft = ft.EvalType()
if evalTp4Ft != types.ETDecimal && evalTp4Ft != types.ETInt {
evalTp4Ft = types.ETReal
}
}
return evalTp4Ft
}
// setFlenDecimal4Int is called to set proper `Flen` and `Decimal` of return
// type according to the two input parameter's types.
func setFlenDecimal4Int(retTp, a, b *types.FieldType) {
retTp.Decimal = 0
retTp.Flen = mysql.MaxIntWidth
}
// setFlenDecimal4RealOrDecimal is called to set proper `Flen` and `Decimal` of return
// type according to the two input parameter's types.
func setFlenDecimal4RealOrDecimal(retTp, a, b *types.FieldType, isReal bool) {
if a.Decimal != types.UnspecifiedLength && b.Decimal != types.UnspecifiedLength {
retTp.Decimal = a.Decimal + b.Decimal
if !isReal && retTp.Decimal > mysql.MaxDecimalScale {
retTp.Decimal = mysql.MaxDecimalScale
}
if a.Flen == types.UnspecifiedLength || b.Flen == types.UnspecifiedLength {
retTp.Flen = types.UnspecifiedLength
return
}
digitsInt := mathutil.Max(a.Flen-a.Decimal, b.Flen-b.Decimal)
retTp.Flen = digitsInt + retTp.Decimal + 3
if isReal {
retTp.Flen = mathutil.Min(retTp.Flen, mysql.MaxRealWidth)
return
}
retTp.Flen = mathutil.Min(retTp.Flen, mysql.MaxDecimalWidth)
return
}
if isReal {
retTp.Flen, retTp.Decimal = types.UnspecifiedLength, types.UnspecifiedLength
} else {
retTp.Flen, retTp.Decimal = mysql.MaxDecimalWidth, mysql.MaxDecimalScale
}
}
func (c *arithmeticDivideFunctionClass) setType4DivDecimal(retTp, a, b *types.FieldType) {
var deca, decb = a.Decimal, b.Decimal
if deca == types.UnspecifiedFsp {
deca = 0
}
if decb == types.UnspecifiedFsp {
decb = 0
}
retTp.Decimal = deca + precIncrement
if retTp.Decimal > mysql.MaxDecimalScale {
retTp.Decimal = mysql.MaxDecimalScale
}
if a.Flen == types.UnspecifiedLength {
retTp.Flen = mysql.MaxDecimalWidth
return
}
retTp.Flen = a.Flen + decb + precIncrement
if retTp.Flen > mysql.MaxDecimalWidth {
retTp.Flen = mysql.MaxDecimalWidth
}
}
func (c *arithmeticDivideFunctionClass) setType4DivReal(retTp *types.FieldType) {
retTp.Decimal = mysql.NotFixedDec
retTp.Flen = mysql.MaxRealWidth
}
type arithmeticPlusFunctionClass struct {
baseFunctionClass
}
func (c *arithmeticPlusFunctionClass) getFunction(ctx sessionctx.Context, args []Expression) (builtinFunc, error) {
if err := c.verifyArgs(args); err != nil {
return nil, errors.Trace(err)
}
lhsTp, rhsTp := args[0].GetType(), args[1].GetType()
lhsEvalTp, rhsEvalTp := numericContextResultType(lhsTp), numericContextResultType(rhsTp)
if lhsEvalTp == types.ETReal || rhsEvalTp == types.ETReal {
bf := newBaseBuiltinFuncWithTp(ctx, args, types.ETReal, types.ETReal, types.ETReal)
setFlenDecimal4RealOrDecimal(bf.tp, args[0].GetType(), args[1].GetType(), true)
sig := &builtinArithmeticPlusRealSig{bf}
sig.setPbCode(tipb.ScalarFuncSig_PlusReal)
return sig, nil
} else if lhsEvalTp == types.ETDecimal || rhsEvalTp == types.ETDecimal {
bf := newBaseBuiltinFuncWithTp(ctx, args, types.ETDecimal, types.ETDecimal, types.ETDecimal)
setFlenDecimal4RealOrDecimal(bf.tp, args[0].GetType(), args[1].GetType(), false)
sig := &builtinArithmeticPlusDecimalSig{bf}
sig.setPbCode(tipb.ScalarFuncSig_PlusDecimal)
return sig, nil
} else {
bf := newBaseBuiltinFuncWithTp(ctx, args, types.ETInt, types.ETInt, types.ETInt)
if mysql.HasUnsignedFlag(args[0].GetType().Flag) || mysql.HasUnsignedFlag(args[1].GetType().Flag) {
bf.tp.Flag |= mysql.UnsignedFlag
}
setFlenDecimal4Int(bf.tp, args[0].GetType(), args[1].GetType())
sig := &builtinArithmeticPlusIntSig{bf}
sig.setPbCode(tipb.ScalarFuncSig_PlusInt)
return sig, nil
}
}
type builtinArithmeticPlusIntSig struct {
baseBuiltinFunc
}
func (s *builtinArithmeticPlusIntSig) Clone() builtinFunc {
newSig := &builtinArithmeticPlusIntSig{}
newSig.cloneFrom(&s.baseBuiltinFunc)
return newSig
}
func (s *builtinArithmeticPlusIntSig) evalInt(row chunk.Row) (val int64, isNull bool, err error) {
a, isNull, err := s.args[0].EvalInt(s.ctx, row)
if isNull || err != nil {
return 0, isNull, errors.Trace(err)
}
b, isNull, err := s.args[1].EvalInt(s.ctx, row)
if isNull || err != nil {
return 0, isNull, errors.Trace(err)
}
isLHSUnsigned := mysql.HasUnsignedFlag(s.args[0].GetType().Flag)
isRHSUnsigned := mysql.HasUnsignedFlag(s.args[1].GetType().Flag)
switch {
case isLHSUnsigned && isRHSUnsigned:
if uint64(a) > math.MaxUint64-uint64(b) {
return 0, true, types.ErrOverflow.GenWithStackByArgs("BIGINT UNSIGNED", fmt.Sprintf("(%s + %s)", s.args[0].String(), s.args[1].String()))
}
case isLHSUnsigned && !isRHSUnsigned:
if b < 0 && uint64(-b) > uint64(a) {
return 0, true, types.ErrOverflow.GenWithStackByArgs("BIGINT UNSIGNED", fmt.Sprintf("(%s + %s)", s.args[0].String(), s.args[1].String()))
}
if b > 0 && uint64(a) > math.MaxUint64-uint64(b) {
return 0, true, types.ErrOverflow.GenWithStackByArgs("BIGINT UNSIGNED", fmt.Sprintf("(%s + %s)", s.args[0].String(), s.args[1].String()))
}
case !isLHSUnsigned && isRHSUnsigned:
if a < 0 && uint64(-a) > uint64(b) {
return 0, true, types.ErrOverflow.GenWithStackByArgs("BIGINT UNSIGNED", fmt.Sprintf("(%s + %s)", s.args[0].String(), s.args[1].String()))
}
if a > 0 && uint64(b) > math.MaxInt64-uint64(a) {
return 0, true, types.ErrOverflow.GenWithStackByArgs("BIGINT UNSIGNED", fmt.Sprintf("(%s + %s)", s.args[0].String(), s.args[1].String()))
}
case !isLHSUnsigned && !isRHSUnsigned:
if (a > 0 && b > math.MaxInt64-a) || (a < 0 && b < math.MinInt64-a) {
return 0, true, types.ErrOverflow.GenWithStackByArgs("BIGINT", fmt.Sprintf("(%s + %s)", s.args[0].String(), s.args[1].String()))
}
}
return a + b, false, nil
}
type builtinArithmeticPlusDecimalSig struct {
baseBuiltinFunc
}
func (s *builtinArithmeticPlusDecimalSig) Clone() builtinFunc {
newSig := &builtinArithmeticPlusDecimalSig{}
newSig.cloneFrom(&s.baseBuiltinFunc)
return newSig
}
func (s *builtinArithmeticPlusDecimalSig) evalDecimal(row chunk.Row) (*types.MyDecimal, bool, error) {
a, isNull, err := s.args[0].EvalDecimal(s.ctx, row)
if isNull || err != nil {
return nil, isNull, errors.Trace(err)
}
b, isNull, err := s.args[1].EvalDecimal(s.ctx, row)
if isNull || err != nil {
return nil, isNull, errors.Trace(err)
}
c := &types.MyDecimal{}
err = types.DecimalAdd(a, b, c)
if err != nil {
return nil, true, errors.Trace(err)
}
return c, false, nil
}
type builtinArithmeticPlusRealSig struct {
baseBuiltinFunc
}
func (s *builtinArithmeticPlusRealSig) Clone() builtinFunc {
newSig := &builtinArithmeticPlusRealSig{}
newSig.cloneFrom(&s.baseBuiltinFunc)
return newSig
}
func (s *builtinArithmeticPlusRealSig) evalReal(row chunk.Row) (float64, bool, error) {
a, isNull, err := s.args[0].EvalReal(s.ctx, row)
if isNull || err != nil {
return 0, isNull, errors.Trace(err)
}
b, isNull, err := s.args[1].EvalReal(s.ctx, row)
if isNull || err != nil {
return 0, isNull, errors.Trace(err)
}
if (a > 0 && b > math.MaxFloat64-a) || (a < 0 && b < -math.MaxFloat64-a) {
return 0, true, types.ErrOverflow.GenWithStackByArgs("DOUBLE", fmt.Sprintf("(%s + %s)", s.args[0].String(), s.args[1].String()))
}
return a + b, false, nil
}
type arithmeticMinusFunctionClass struct {
baseFunctionClass
}
func (c *arithmeticMinusFunctionClass) getFunction(ctx sessionctx.Context, args []Expression) (builtinFunc, error) {
if err := c.verifyArgs(args); err != nil {
return nil, errors.Trace(err)
}
lhsTp, rhsTp := args[0].GetType(), args[1].GetType()
lhsEvalTp, rhsEvalTp := numericContextResultType(lhsTp), numericContextResultType(rhsTp)
if lhsEvalTp == types.ETReal || rhsEvalTp == types.ETReal {
bf := newBaseBuiltinFuncWithTp(ctx, args, types.ETReal, types.ETReal, types.ETReal)
setFlenDecimal4RealOrDecimal(bf.tp, args[0].GetType(), args[1].GetType(), true)
sig := &builtinArithmeticMinusRealSig{bf}
sig.setPbCode(tipb.ScalarFuncSig_MinusReal)
return sig, nil
} else if lhsEvalTp == types.ETDecimal || rhsEvalTp == types.ETDecimal {
bf := newBaseBuiltinFuncWithTp(ctx, args, types.ETDecimal, types.ETDecimal, types.ETDecimal)
setFlenDecimal4RealOrDecimal(bf.tp, args[0].GetType(), args[1].GetType(), false)
sig := &builtinArithmeticMinusDecimalSig{bf}
sig.setPbCode(tipb.ScalarFuncSig_MinusDecimal)
return sig, nil
} else {
bf := newBaseBuiltinFuncWithTp(ctx, args, types.ETInt, types.ETInt, types.ETInt)
setFlenDecimal4Int(bf.tp, args[0].GetType(), args[1].GetType())
if (mysql.HasUnsignedFlag(args[0].GetType().Flag) || mysql.HasUnsignedFlag(args[1].GetType().Flag)) && !ctx.GetSessionVars().SQLMode.HasNoUnsignedSubtractionMode() {
bf.tp.Flag |= mysql.UnsignedFlag
}
sig := &builtinArithmeticMinusIntSig{baseBuiltinFunc: bf}
sig.setPbCode(tipb.ScalarFuncSig_MinusInt)
return sig, nil
}
}
type builtinArithmeticMinusRealSig struct {
baseBuiltinFunc
}
func (s *builtinArithmeticMinusRealSig) Clone() builtinFunc {
newSig := &builtinArithmeticMinusRealSig{}
newSig.cloneFrom(&s.baseBuiltinFunc)
return newSig
}
func (s *builtinArithmeticMinusRealSig) evalReal(row chunk.Row) (float64, bool, error) {
a, isNull, err := s.args[0].EvalReal(s.ctx, row)
if isNull || err != nil {
return 0, isNull, errors.Trace(err)
}
b, isNull, err := s.args[1].EvalReal(s.ctx, row)
if isNull || err != nil {
return 0, isNull, errors.Trace(err)
}
if (a > 0 && -b > math.MaxFloat64-a) || (a < 0 && -b < -math.MaxFloat64-a) {
return 0, true, types.ErrOverflow.GenWithStackByArgs("DOUBLE", fmt.Sprintf("(%s - %s)", s.args[0].String(), s.args[1].String()))
}
return a - b, false, nil
}
type builtinArithmeticMinusDecimalSig struct {
baseBuiltinFunc
}
func (s *builtinArithmeticMinusDecimalSig) Clone() builtinFunc {
newSig := &builtinArithmeticMinusDecimalSig{}
newSig.cloneFrom(&s.baseBuiltinFunc)
return newSig
}
func (s *builtinArithmeticMinusDecimalSig) evalDecimal(row chunk.Row) (*types.MyDecimal, bool, error) {
a, isNull, err := s.args[0].EvalDecimal(s.ctx, row)
if isNull || err != nil {
return nil, isNull, errors.Trace(err)
}
b, isNull, err := s.args[1].EvalDecimal(s.ctx, row)
if isNull || err != nil {
return nil, isNull, errors.Trace(err)
}
c := &types.MyDecimal{}
err = types.DecimalSub(a, b, c)
if err != nil {
return nil, true, errors.Trace(err)
}
return c, false, nil
}
type builtinArithmeticMinusIntSig struct {
baseBuiltinFunc
}
func (s *builtinArithmeticMinusIntSig) Clone() builtinFunc {
newSig := &builtinArithmeticMinusIntSig{}
newSig.cloneFrom(&s.baseBuiltinFunc)
return newSig
}
func (s *builtinArithmeticMinusIntSig) evalInt(row chunk.Row) (val int64, isNull bool, err error) {
a, isNull, err := s.args[0].EvalInt(s.ctx, row)
if isNull || err != nil {
return 0, isNull, errors.Trace(err)
}
b, isNull, err := s.args[1].EvalInt(s.ctx, row)
if isNull || err != nil {
return 0, isNull, errors.Trace(err)
}
forceToSigned := s.ctx.GetSessionVars().SQLMode.HasNoUnsignedSubtractionMode()
isLHSUnsigned := !forceToSigned && mysql.HasUnsignedFlag(s.args[0].GetType().Flag)
isRHSUnsigned := !forceToSigned && mysql.HasUnsignedFlag(s.args[1].GetType().Flag)
if forceToSigned && mysql.HasUnsignedFlag(s.args[0].GetType().Flag) {
if a < 0 || (a > math.MaxInt64) {
return 0, true, types.ErrOverflow.GenWithStackByArgs("BIGINT UNSIGNED", fmt.Sprintf("(%s - %s)", s.args[0].String(), s.args[1].String()))
}
}
if forceToSigned && mysql.HasUnsignedFlag(s.args[1].GetType().Flag) {
if b < 0 || (b > math.MaxInt64) {
return 0, true, types.ErrOverflow.GenWithStackByArgs("BIGINT UNSIGNED", fmt.Sprintf("(%s - %s)", s.args[0].String(), s.args[1].String()))
}
}
switch {
case isLHSUnsigned && isRHSUnsigned:
if uint64(a) < uint64(b) {
return 0, true, types.ErrOverflow.GenWithStackByArgs("BIGINT UNSIGNED", fmt.Sprintf("(%s - %s)", s.args[0].String(), s.args[1].String()))
}
case isLHSUnsigned && !isRHSUnsigned:
if b >= 0 && uint64(a) < uint64(b) {
return 0, true, types.ErrOverflow.GenWithStackByArgs("BIGINT UNSIGNED", fmt.Sprintf("(%s - %s)", s.args[0].String(), s.args[1].String()))
}
if b < 0 && uint64(a) > math.MaxUint64-uint64(-b) {
return 0, true, types.ErrOverflow.GenWithStackByArgs("BIGINT UNSIGNED", fmt.Sprintf("(%s - %s)", s.args[0].String(), s.args[1].String()))
}
case !isLHSUnsigned && isRHSUnsigned:
if uint64(a-math.MinInt64) < uint64(b) {
return 0, true, types.ErrOverflow.GenWithStackByArgs("BIGINT UNSIGNED", fmt.Sprintf("(%s - %s)", s.args[0].String(), s.args[1].String()))
}
case !isLHSUnsigned && !isRHSUnsigned:
if (a > 0 && -b > math.MaxInt64-a) || (a < 0 && -b < math.MinInt64-a) {
return 0, true, types.ErrOverflow.GenWithStackByArgs("BIGINT", fmt.Sprintf("(%s - %s)", s.args[0].String(), s.args[1].String()))
}
}
return a - b, false, nil
}
type arithmeticMultiplyFunctionClass struct {
baseFunctionClass
}
func (c *arithmeticMultiplyFunctionClass) getFunction(ctx sessionctx.Context, args []Expression) (builtinFunc, error) {
if err := c.verifyArgs(args); err != nil {
return nil, errors.Trace(err)
}
lhsTp, rhsTp := args[0].GetType(), args[1].GetType()
lhsEvalTp, rhsEvalTp := numericContextResultType(lhsTp), numericContextResultType(rhsTp)
if lhsEvalTp == types.ETReal || rhsEvalTp == types.ETReal {
bf := newBaseBuiltinFuncWithTp(ctx, args, types.ETReal, types.ETReal, types.ETReal)
setFlenDecimal4RealOrDecimal(bf.tp, args[0].GetType(), args[1].GetType(), true)
sig := &builtinArithmeticMultiplyRealSig{bf}
sig.setPbCode(tipb.ScalarFuncSig_MultiplyReal)
return sig, nil
} else if lhsEvalTp == types.ETDecimal || rhsEvalTp == types.ETDecimal {
bf := newBaseBuiltinFuncWithTp(ctx, args, types.ETDecimal, types.ETDecimal, types.ETDecimal)
setFlenDecimal4RealOrDecimal(bf.tp, args[0].GetType(), args[1].GetType(), false)
sig := &builtinArithmeticMultiplyDecimalSig{bf}
sig.setPbCode(tipb.ScalarFuncSig_MultiplyDecimal)
return sig, nil
} else {
bf := newBaseBuiltinFuncWithTp(ctx, args, types.ETInt, types.ETInt, types.ETInt)
if mysql.HasUnsignedFlag(lhsTp.Flag) || mysql.HasUnsignedFlag(rhsTp.Flag) {
bf.tp.Flag |= mysql.UnsignedFlag
setFlenDecimal4Int(bf.tp, args[0].GetType(), args[1].GetType())
sig := &builtinArithmeticMultiplyIntUnsignedSig{bf}
sig.setPbCode(tipb.ScalarFuncSig_MultiplyInt)
return sig, nil
}
setFlenDecimal4Int(bf.tp, args[0].GetType(), args[1].GetType())
sig := &builtinArithmeticMultiplyIntSig{bf}
sig.setPbCode(tipb.ScalarFuncSig_MultiplyInt)
return sig, nil
}
}
type builtinArithmeticMultiplyRealSig struct{ baseBuiltinFunc }
func (s *builtinArithmeticMultiplyRealSig) Clone() builtinFunc {
newSig := &builtinArithmeticMultiplyRealSig{}
newSig.cloneFrom(&s.baseBuiltinFunc)
return newSig
}
type builtinArithmeticMultiplyDecimalSig struct{ baseBuiltinFunc }
func (s *builtinArithmeticMultiplyDecimalSig) Clone() builtinFunc {
newSig := &builtinArithmeticMultiplyDecimalSig{}
newSig.cloneFrom(&s.baseBuiltinFunc)
return newSig
}
type builtinArithmeticMultiplyIntUnsignedSig struct{ baseBuiltinFunc }
func (s *builtinArithmeticMultiplyIntUnsignedSig) Clone() builtinFunc {
newSig := &builtinArithmeticMultiplyIntUnsignedSig{}
newSig.cloneFrom(&s.baseBuiltinFunc)
return newSig
}
type builtinArithmeticMultiplyIntSig struct{ baseBuiltinFunc }
func (s *builtinArithmeticMultiplyIntSig) Clone() builtinFunc {
newSig := &builtinArithmeticMultiplyIntSig{}
newSig.cloneFrom(&s.baseBuiltinFunc)
return newSig
}
func (s *builtinArithmeticMultiplyRealSig) evalReal(row chunk.Row) (float64, bool, error) {
a, isNull, err := s.args[0].EvalReal(s.ctx, row)
if isNull || err != nil {
return 0, isNull, errors.Trace(err)
}
b, isNull, err := s.args[1].EvalReal(s.ctx, row)
if isNull || err != nil {
return 0, isNull, errors.Trace(err)
}
result := a * b
if math.IsInf(result, 0) {
return 0, true, types.ErrOverflow.GenWithStackByArgs("DOUBLE", fmt.Sprintf("(%s * %s)", s.args[0].String(), s.args[1].String()))
}
return result, false, nil
}
func (s *builtinArithmeticMultiplyDecimalSig) evalDecimal(row chunk.Row) (*types.MyDecimal, bool, error) {
a, isNull, err := s.args[0].EvalDecimal(s.ctx, row)
if isNull || err != nil {
return nil, isNull, errors.Trace(err)
}
b, isNull, err := s.args[1].EvalDecimal(s.ctx, row)
if isNull || err != nil {
return nil, isNull, errors.Trace(err)
}
c := &types.MyDecimal{}
err = types.DecimalMul(a, b, c)
if err != nil && !terror.ErrorEqual(err, types.ErrTruncated) {
return nil, true, errors.Trace(err)
}
return c, false, nil
}
func (s *builtinArithmeticMultiplyIntUnsignedSig) evalInt(row chunk.Row) (val int64, isNull bool, err error) {
a, isNull, err := s.args[0].EvalInt(s.ctx, row)
if isNull || err != nil {
return 0, isNull, errors.Trace(err)
}
unsignedA := uint64(a)
b, isNull, err := s.args[1].EvalInt(s.ctx, row)
if isNull || err != nil {
return 0, isNull, errors.Trace(err)
}
unsignedB := uint64(b)
result := unsignedA * unsignedB
if unsignedA != 0 && result/unsignedA != unsignedB {
return 0, true, types.ErrOverflow.GenWithStackByArgs("BIGINT UNSIGNED", fmt.Sprintf("(%s * %s)", s.args[0].String(), s.args[1].String()))
}
return int64(result), false, nil
}
func (s *builtinArithmeticMultiplyIntSig) evalInt(row chunk.Row) (val int64, isNull bool, err error) {
a, isNull, err := s.args[0].EvalInt(s.ctx, row)
if isNull || err != nil {
return 0, isNull, errors.Trace(err)
}
b, isNull, err := s.args[1].EvalInt(s.ctx, row)
if isNull || err != nil {
return 0, isNull, errors.Trace(err)
}
result := a * b
if a != 0 && result/a != b {
return 0, true, types.ErrOverflow.GenWithStackByArgs("BIGINT", fmt.Sprintf("(%s * %s)", s.args[0].String(), s.args[1].String()))
}
return result, false, nil
}
type arithmeticDivideFunctionClass struct {
baseFunctionClass
}
func (c *arithmeticDivideFunctionClass) getFunction(ctx sessionctx.Context, args []Expression) (builtinFunc, error) {
if err := c.verifyArgs(args); err != nil {
return nil, errors.Trace(err)
}
lhsTp, rhsTp := args[0].GetType(), args[1].GetType()
lhsEvalTp, rhsEvalTp := numericContextResultType(lhsTp), numericContextResultType(rhsTp)
if lhsEvalTp == types.ETReal || rhsEvalTp == types.ETReal {
bf := newBaseBuiltinFuncWithTp(ctx, args, types.ETReal, types.ETReal, types.ETReal)
c.setType4DivReal(bf.tp)
sig := &builtinArithmeticDivideRealSig{bf}
sig.setPbCode(tipb.ScalarFuncSig_DivideReal)
return sig, nil
}
bf := newBaseBuiltinFuncWithTp(ctx, args, types.ETDecimal, types.ETDecimal, types.ETDecimal)
c.setType4DivDecimal(bf.tp, lhsTp, rhsTp)
sig := &builtinArithmeticDivideDecimalSig{bf}
sig.setPbCode(tipb.ScalarFuncSig_DivideDecimal)
return sig, nil
}
type builtinArithmeticDivideRealSig struct{ baseBuiltinFunc }
func (s *builtinArithmeticDivideRealSig) Clone() builtinFunc {
newSig := &builtinArithmeticDivideRealSig{}
newSig.cloneFrom(&s.baseBuiltinFunc)
return newSig
}
type builtinArithmeticDivideDecimalSig struct{ baseBuiltinFunc }
func (s *builtinArithmeticDivideDecimalSig) Clone() builtinFunc {
newSig := &builtinArithmeticDivideDecimalSig{}
newSig.cloneFrom(&s.baseBuiltinFunc)
return newSig
}
func (s *builtinArithmeticDivideRealSig) evalReal(row chunk.Row) (float64, bool, error) {
a, isNull, err := s.args[0].EvalReal(s.ctx, row)
if isNull || err != nil {
return 0, isNull, errors.Trace(err)
}
b, isNull, err := s.args[1].EvalReal(s.ctx, row)
if isNull || err != nil {
return 0, isNull, errors.Trace(err)
}
if b == 0 {
return 0, true, errors.Trace(handleDivisionByZeroError(s.ctx))
}
result := a / b
if math.IsInf(result, 0) {
return 0, true, types.ErrOverflow.GenWithStackByArgs("DOUBLE", fmt.Sprintf("(%s / %s)", s.args[0].String(), s.args[1].String()))
}
return result, false, nil
}
func (s *builtinArithmeticDivideDecimalSig) evalDecimal(row chunk.Row) (*types.MyDecimal, bool, error) {
a, isNull, err := s.args[0].EvalDecimal(s.ctx, row)
if isNull || err != nil {
return nil, isNull, errors.Trace(err)
}
b, isNull, err := s.args[1].EvalDecimal(s.ctx, row)
if isNull || err != nil {
return nil, isNull, errors.Trace(err)
}
c := &types.MyDecimal{}
err = types.DecimalDiv(a, b, c, types.DivFracIncr)
if err == types.ErrDivByZero {
return c, true, errors.Trace(handleDivisionByZeroError(s.ctx))
} else if err == nil {
_, frac := c.PrecisionAndFrac()
if frac < s.baseBuiltinFunc.tp.Decimal {
err = c.Round(c, s.baseBuiltinFunc.tp.Decimal, types.ModeHalfEven)
}
}
return c, false, errors.Trace(err)
}
type arithmeticIntDivideFunctionClass struct {
baseFunctionClass
}
func (c *arithmeticIntDivideFunctionClass) getFunction(ctx sessionctx.Context, args []Expression) (builtinFunc, error) {
if err := c.verifyArgs(args); err != nil {
return nil, errors.Trace(err)
}
lhsTp, rhsTp := args[0].GetType(), args[1].GetType()
lhsEvalTp, rhsEvalTp := numericContextResultType(lhsTp), numericContextResultType(rhsTp)
if lhsEvalTp == types.ETInt && rhsEvalTp == types.ETInt {
bf := newBaseBuiltinFuncWithTp(ctx, args, types.ETInt, types.ETInt, types.ETInt)
if mysql.HasUnsignedFlag(lhsTp.Flag) || mysql.HasUnsignedFlag(rhsTp.Flag) {
bf.tp.Flag |= mysql.UnsignedFlag
}
sig := &builtinArithmeticIntDivideIntSig{bf}
return sig, nil
}
bf := newBaseBuiltinFuncWithTp(ctx, args, types.ETInt, types.ETDecimal, types.ETDecimal)
if mysql.HasUnsignedFlag(lhsTp.Flag) || mysql.HasUnsignedFlag(rhsTp.Flag) {
bf.tp.Flag |= mysql.UnsignedFlag
}
sig := &builtinArithmeticIntDivideDecimalSig{bf}
return sig, nil
}
type builtinArithmeticIntDivideIntSig struct{ baseBuiltinFunc }
func (s *builtinArithmeticIntDivideIntSig) Clone() builtinFunc {
newSig := &builtinArithmeticIntDivideIntSig{}
newSig.cloneFrom(&s.baseBuiltinFunc)
return newSig
}
type builtinArithmeticIntDivideDecimalSig struct{ baseBuiltinFunc }
func (s *builtinArithmeticIntDivideDecimalSig) Clone() builtinFunc {
newSig := &builtinArithmeticIntDivideDecimalSig{}
newSig.cloneFrom(&s.baseBuiltinFunc)
return newSig
}
func (s *builtinArithmeticIntDivideIntSig) evalInt(row chunk.Row) (int64, bool, error) {
b, isNull, err := s.args[1].EvalInt(s.ctx, row)
if isNull || err != nil {
return 0, isNull, errors.Trace(err)
}
if b == 0 {
return 0, true, errors.Trace(handleDivisionByZeroError(s.ctx))
}
a, isNull, err := s.args[0].EvalInt(s.ctx, row)
if isNull || err != nil {
return 0, isNull, errors.Trace(err)
}
var (
ret int64
val uint64
)
isLHSUnsigned := mysql.HasUnsignedFlag(s.args[0].GetType().Flag)
isRHSUnsigned := mysql.HasUnsignedFlag(s.args[1].GetType().Flag)
switch {
case isLHSUnsigned && isRHSUnsigned:
ret = int64(uint64(a) / uint64(b))
case isLHSUnsigned && !isRHSUnsigned:
val, err = types.DivUintWithInt(uint64(a), b)
ret = int64(val)
case !isLHSUnsigned && isRHSUnsigned:
val, err = types.DivIntWithUint(a, uint64(b))
ret = int64(val)
case !isLHSUnsigned && !isRHSUnsigned:
ret, err = types.DivInt64(a, b)
}
return ret, err != nil, errors.Trace(err)
}
func (s *builtinArithmeticIntDivideDecimalSig) evalInt(row chunk.Row) (ret int64, isNull bool, err error) {
sc := s.ctx.GetSessionVars().StmtCtx
var num [2]*types.MyDecimal
for i, arg := range s.args {
num[i], isNull, err = arg.EvalDecimal(s.ctx, row)
// Its behavior is consistent with MySQL.
if terror.ErrorEqual(err, types.ErrTruncated) {
err = nil
}
if terror.ErrorEqual(err, types.ErrOverflow) {
newErr := errTruncatedWrongValue.GenWithStackByArgs("DECIMAL", arg)
err = sc.HandleOverflow(newErr, newErr)
}
if isNull || err != nil {
return 0, isNull, errors.Trace(err)
}
}
c := &types.MyDecimal{}
err = types.DecimalDiv(num[0], num[1], c, types.DivFracIncr)
if err == types.ErrDivByZero {
return 0, true, errors.Trace(handleDivisionByZeroError(s.ctx))
}
if err == types.ErrTruncated {
err = sc.HandleTruncate(errTruncatedWrongValue.GenWithStackByArgs("DECIMAL", c))
}
if err == types.ErrOverflow {
newErr := errTruncatedWrongValue.GenWithStackByArgs("DECIMAL", c)
err = sc.HandleOverflow(newErr, newErr)
}
if err != nil {
return 0, true, errors.Trace(err)
}
ret, err = c.ToInt()
// err returned by ToInt may be ErrTruncated or ErrOverflow, only handle ErrOverflow, ignore ErrTruncated.
if err == types.ErrOverflow {
return 0, true, types.ErrOverflow.GenWithStackByArgs("BIGINT", fmt.Sprintf("(%s DIV %s)", s.args[0].String(), s.args[1].String()))
}
return ret, false, nil
}
type arithmeticModFunctionClass struct {
baseFunctionClass
}
func (c *arithmeticModFunctionClass) setType4ModRealOrDecimal(retTp, a, b *types.FieldType, isDecimal bool) {
if a.Decimal == types.UnspecifiedLength || b.Decimal == types.UnspecifiedLength {
retTp.Decimal = types.UnspecifiedLength
} else {
retTp.Decimal = mathutil.Max(a.Decimal, b.Decimal)
if isDecimal && retTp.Decimal > mysql.MaxDecimalScale {
retTp.Decimal = mysql.MaxDecimalScale
}
}
if a.Flen == types.UnspecifiedLength || b.Flen == types.UnspecifiedLength {
retTp.Flen = types.UnspecifiedLength
} else {
retTp.Flen = mathutil.Max(a.Flen, b.Flen)
if isDecimal {
retTp.Flen = mathutil.Min(retTp.Flen, mysql.MaxDecimalWidth)
return
}
retTp.Flen = mathutil.Min(retTp.Flen, mysql.MaxRealWidth)
}
}
func (c *arithmeticModFunctionClass) getFunction(ctx sessionctx.Context, args []Expression) (builtinFunc, error) {
if err := c.verifyArgs(args); err != nil {
return nil, errors.Trace(err)
}
lhsTp, rhsTp := args[0].GetType(), args[1].GetType()
lhsEvalTp, rhsEvalTp := numericContextResultType(lhsTp), numericContextResultType(rhsTp)
if lhsEvalTp == types.ETReal || rhsEvalTp == types.ETReal {
bf := newBaseBuiltinFuncWithTp(ctx, args, types.ETReal, types.ETReal, types.ETReal)
c.setType4ModRealOrDecimal(bf.tp, lhsTp, rhsTp, false)
if mysql.HasUnsignedFlag(lhsTp.Flag) {
bf.tp.Flag |= mysql.UnsignedFlag
}
sig := &builtinArithmeticModRealSig{bf}
return sig, nil
} else if lhsEvalTp == types.ETDecimal || rhsEvalTp == types.ETDecimal {
bf := newBaseBuiltinFuncWithTp(ctx, args, types.ETDecimal, types.ETDecimal, types.ETDecimal)
c.setType4ModRealOrDecimal(bf.tp, lhsTp, rhsTp, true)
if mysql.HasUnsignedFlag(lhsTp.Flag) {
bf.tp.Flag |= mysql.UnsignedFlag
}
sig := &builtinArithmeticModDecimalSig{bf}
return sig, nil
} else {
bf := newBaseBuiltinFuncWithTp(ctx, args, types.ETInt, types.ETInt, types.ETInt)
if mysql.HasUnsignedFlag(lhsTp.Flag) {
bf.tp.Flag |= mysql.UnsignedFlag
}
sig := &builtinArithmeticModIntSig{bf}
return sig, nil
}
}
type builtinArithmeticModRealSig struct {
baseBuiltinFunc
}
func (s *builtinArithmeticModRealSig) Clone() builtinFunc {
newSig := &builtinArithmeticModRealSig{}
newSig.cloneFrom(&s.baseBuiltinFunc)
return newSig
}
func (s *builtinArithmeticModRealSig) evalReal(row chunk.Row) (float64, bool, error) {
b, isNull, err := s.args[1].EvalReal(s.ctx, row)
if isNull || err != nil {
return 0, isNull, errors.Trace(err)
}
if b == 0 {
return 0, true, errors.Trace(handleDivisionByZeroError(s.ctx))
}
a, isNull, err := s.args[0].EvalReal(s.ctx, row)
if isNull || err != nil {
return 0, isNull, errors.Trace(err)
}
return math.Mod(a, b), false, nil
}
type builtinArithmeticModDecimalSig struct {
baseBuiltinFunc
}
func (s *builtinArithmeticModDecimalSig) Clone() builtinFunc {
newSig := &builtinArithmeticModDecimalSig{}
newSig.cloneFrom(&s.baseBuiltinFunc)
return newSig
}
func (s *builtinArithmeticModDecimalSig) evalDecimal(row chunk.Row) (*types.MyDecimal, bool, error) {
a, isNull, err := s.args[0].EvalDecimal(s.ctx, row)
if isNull || err != nil {
return nil, isNull, errors.Trace(err)
}
b, isNull, err := s.args[1].EvalDecimal(s.ctx, row)
if isNull || err != nil {
return nil, isNull, errors.Trace(err)
}
c := &types.MyDecimal{}
err = types.DecimalMod(a, b, c)
if err == types.ErrDivByZero {
return c, true, errors.Trace(handleDivisionByZeroError(s.ctx))
}
return c, err != nil, errors.Trace(err)
}
type builtinArithmeticModIntSig struct {
baseBuiltinFunc
}
func (s *builtinArithmeticModIntSig) Clone() builtinFunc {
newSig := &builtinArithmeticModIntSig{}
newSig.cloneFrom(&s.baseBuiltinFunc)
return newSig
}
func (s *builtinArithmeticModIntSig) evalInt(row chunk.Row) (val int64, isNull bool, err error) {
b, isNull, err := s.args[1].EvalInt(s.ctx, row)
if isNull || err != nil {
return 0, isNull, errors.Trace(err)
}
if b == 0 {
return 0, true, errors.Trace(handleDivisionByZeroError(s.ctx))
}
a, isNull, err := s.args[0].EvalInt(s.ctx, row)
if isNull || err != nil {
return 0, isNull, errors.Trace(err)
}
var ret int64
isLHSUnsigned := mysql.HasUnsignedFlag(s.args[0].GetType().Flag)
isRHSUnsigned := mysql.HasUnsignedFlag(s.args[1].GetType().Flag)
switch {
case isLHSUnsigned && isRHSUnsigned:
ret = int64(uint64(a) % uint64(b))
case isLHSUnsigned && !isRHSUnsigned:
if b < 0 {
ret = int64(uint64(a) % uint64(-b))
} else {
ret = int64(uint64(a) % uint64(b))
}
case !isLHSUnsigned && isRHSUnsigned:
if a < 0 {
ret = -int64(uint64(-a) % uint64(b))
} else {
ret = int64(uint64(a) % uint64(b))
}
case !isLHSUnsigned && !isRHSUnsigned:
ret = a % b
}
return ret, false, nil
}

Опубликовать ( 0 )

Вы можете оставить комментарий после Вход в систему

1
https://gitlife.ru/oschina-mirror/hanchuanchuan-goInception.git
git@gitlife.ru:oschina-mirror/hanchuanchuan-goInception.git
oschina-mirror
hanchuanchuan-goInception
hanchuanchuan-goInception
v1.2.2