// Copyright 2015 PingCAP, Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // See the License for the specific language governing permissions and // limitations under the License. package session_test import ( "errors" "flag" "fmt" "path" "runtime" "sort" "strconv" "strings" "testing" "time" _ "github.com/go-sql-driver/mysql" "github.com/hanchuanchuan/goInception/ast" "github.com/hanchuanchuan/goInception/config" "github.com/hanchuanchuan/goInception/domain" "github.com/hanchuanchuan/goInception/kv" "github.com/hanchuanchuan/goInception/mysql" "github.com/hanchuanchuan/goInception/parser" "github.com/hanchuanchuan/goInception/server" "github.com/hanchuanchuan/goInception/session" "github.com/hanchuanchuan/goInception/store/mockstore" "github.com/hanchuanchuan/goInception/store/mockstore/mocktikv" "github.com/hanchuanchuan/goInception/util/auth" "github.com/hanchuanchuan/goInception/util/logutil" "github.com/hanchuanchuan/goInception/util/testkit" "github.com/hanchuanchuan/goInception/util/testleak" "github.com/jinzhu/gorm" . "github.com/pingcap/check" repllog "github.com/siddontang/go-log/log" log "github.com/sirupsen/logrus" "golang.org/x/net/context" ) // var _ = Suite(&testCommon{}) // 数据库类型 const ( DBTypeMysql = iota DBTypeMariaDB DBTypeTiDB ) var sql string // 是否测试api接口 var isAPI bool func init() { flag.BoolVar(&isAPI, "api", false, "test api interface") } // func TestCommonTest(t *testing.T) { // TestingT(t) // } type testCommon struct { cluster *mocktikv.Cluster mvccStore mocktikv.MVCCStore store kv.Storage dom *domain.Domain tk *testkit.TestKit db *gorm.DB dbAddr string // 执行结果集 // res *testkit.Result rows [][]interface{} DBVersion int DBType int sqlMode string innodbLargePrefix bool // 时间戳类型是否需要明确指定默认值 explicitDefaultsForTimestamp bool // 强制执行GTID一致性 enforeGtidConsistency bool // 是否忽略大小写(lower_case_table_names为1和2时忽略,否则不忽略) ignoreCase bool realRowCount bool remoteBackupTable string parser *parser.Parser session session.Session defaultInc config.Inc defaultIncLevel config.IncLevel // 测试数据库,默认为test_inc,该参数用以测试未指定数据库情况下的审核 useDB string // API调用 isAPI bool // session API调用 sessionService session.Session // API返回结果 records []session.Record } func (s *testCommon) initSetUp(c *C) { if testing.Short() { c.Skip("skipping test; in TRAVIS mode") } flag.Parse() log.SetLevel(log.ErrorLevel) repllog.SetLevel(repllog.LevelFatal) s.realRowCount = true s.useDB = "use test_inc;" testleak.BeforeTest() s.cluster = mocktikv.NewCluster() mocktikv.BootstrapWithSingleStore(s.cluster) s.mvccStore = mocktikv.MustNewMVCCStore() store, err := mockstore.NewMockTikvStore( mockstore.WithCluster(s.cluster), mockstore.WithMVCCStore(s.mvccStore), ) c.Assert(err, IsNil) s.store = store session.SetSchemaLease(0) session.SetStatsLease(0) s.dom, err = session.BootstrapSession(s.store) c.Assert(err, IsNil) if s.tk == nil { s.tk = testkit.NewTestKitWithInit(c, s.store) } server := &server.Server{} server.InitOscProcessList() s.tk.Se.SetSessionManager(server) s.session = s.tk.Se vars := s.session.GetSessionVars() if vars.User == nil { vars.User = &auth.UserIdentity{ Username: "root", Hostname: "127.0.0.1", } } // 重新设置进程信息,写入user@host s.session.SetProcessInfo("", time.Now(), mysql.ComQuery) cfg := config.GetGlobalConfig() _, localFile, _, _ := runtime.Caller(0) localFile = path.Dir(localFile) configFile := path.Join(localFile[0:len(localFile)-len("session")], "config/config.toml.default") c.Assert(cfg.Load(configFile), IsNil) inc := &config.GetGlobalConfig().Inc inc.BackupHost = "127.0.0.1" inc.BackupPort = 3306 inc.BackupUser = "test" inc.BackupPassword = "test" inc.Lang = "en-US" inc.EnableFingerprint = true inc.SqlSafeUpdates = 0 inc.EnableDropTable = true incLevel := &config.GetGlobalConfig().IncLevel incLevel.ErUseEnum = 1 incLevel.ErJsonTypeSupport = 1 incLevel.ErUseTextOrBlob = 1 // mysql5.6测试用例会出错(docker映射对外的端口不一致) config.GetGlobalConfig().Ghost.GhostAliyunRds = true s.defaultInc = *inc s.defaultIncLevel = *incLevel s.remoteBackupTable = "$_$Inception_backup_information$_$" s.parser = parser.New() c.Assert(s.mysqlServerVersion(), IsNil) c.Assert(s.sqlMode, Not(Equals), "") // log.Error("数据库版本: ", s.DBVersion, " type: ", s.DBType) // 测试API接口时自动忽略之前的测试方法 // log.Errorf("is api: %v", isAPI) s.isAPI = isAPI if isAPI { s.sessionService = session.NewInception() s.sessionService.SetSessionManager(server) s.sessionService.LoadOptions(session.SourceOptions{ Host: inc.BackupHost, Port: int(inc.BackupPort), User: inc.BackupUser, Password: inc.BackupPassword, RealRowCount: s.realRowCount, }) } } func (s *testCommon) tearDownSuite(c *C) { if testing.Short() { c.Skip("skipping test; in TRAVIS mode") } else { s.dom.Close() s.store.Close() testleak.AfterTest(c)() if s.db != nil { s.db.Close() } } } func (s *testCommon) tearDownTest(c *C) { if testing.Short() { c.Skip("skipping test; in TRAVIS mode") } if s.tk == nil { s.tk = testkit.NewTestKitWithInit(c, s.store) } saved := config.GetGlobalConfig().Inc defer func() { config.GetGlobalConfig().Inc = saved }() config.GetGlobalConfig().Inc.EnableDropTable = true s.runCheck("show tables") c.Assert(int(s.getAffectedRows()), GreaterEqual, 1) row := s.rows[s.getAffectedRows()-1] sql := row[5] // exec := `/*%s;--execute=1;--backup=0;--enable-ignore-warnings;*/ // inception_magic_start; // %s // %s; // inception_magic_commit;` for _, name := range strings.Split(sql.(string), "\n") { if strings.HasPrefix(name, "show tables") { continue } n := strings.Replace(name, "'", "", -1) // var res *testkit.Result var sql string if strings.HasPrefix(n, "v_") { // res = s.tk.MustQueryInc( // fmt.Sprintf(exec, s.getAddr(), s.useDB, "drop view `"+n+"`")) sql = "drop view test_inc.`" + n + "`" } else { // res = s.tk.MustQueryInc( // fmt.Sprintf(exec, s.getAddr(), s.useDB, "drop table "+s.useDB+".`"+n+"`")) sql = "drop table test_inc.`" + n + "`" } res := s.db.Exec(sql) c.Assert(res.Error, IsNil, Commentf("sql:%v", sql)) // c.Assert(s.getAffectedRows(), Equals, 2) // row := res.Rows()[s.getAffectedRows()-1] // c.Assert(row[2], Equals, "0", Commentf("%v", row)) // c.Assert(row[3], Equals, "Execute Successfully", Commentf("%v", row)) } } func (s *testCommon) runCheck(sql string) { if s.isAPI { s.sessionService.LoadOptions(session.SourceOptions{ Host: s.defaultInc.BackupHost, Port: int(s.defaultInc.BackupPort), User: s.defaultInc.BackupUser, Password: s.defaultInc.BackupPassword, RealRowCount: s.realRowCount, IgnoreWarnings: true, }) result, _ := s.sessionService.Audit(context.Background(), s.useDB+sql) s.records = result s.rows = make([][]interface{}, len(result)) for index, row := range result { // c.Assert(row.ErrLevel, Not(Equals), uint8(2), Commentf("%v", result)) s.rows[index] = row.List() } return } a := `/*%s;--check=1;--backup=0;--enable-ignore-warnings;real_row_count=%v;--db=test_inc;*/ inception_magic_start; %s; inception_magic_commit;` res := s.tk.MustQueryInc(fmt.Sprintf(a, s.getAddr(), s.realRowCount, sql)) s.rows = res.Rows() return } func (s *testCommon) mustCheck(c *C, sql string) *testkit.Result { if s.isAPI { s.sessionService.LoadOptions(session.SourceOptions{ Host: s.defaultInc.BackupHost, Port: int(s.defaultInc.BackupPort), User: s.defaultInc.BackupUser, Password: s.defaultInc.BackupPassword, RealRowCount: s.realRowCount, }) result, err := s.sessionService.Audit(context.Background(), s.useDB+sql) c.Assert(err, IsNil) s.records = result s.rows = make([][]interface{}, len(result)) for index, row := range result { c.Assert(row.ErrLevel, Not(Equals), uint8(2), Commentf("%v", result)) s.rows[index] = row.List() } return nil } a := `/*%s;--check=1;--backup=0;--enable-ignore-warnings;real_row_count=%v;*/ inception_magic_start; %s %s; inception_magic_commit;` res := s.tk.MustQueryInc(fmt.Sprintf(a, s.getAddr(), s.realRowCount, s.useDB, sql)) for _, row := range res.Rows() { c.Assert(row[2], Not(Equals), "2", Commentf("%v", row)) } s.rows = res.Rows() return res } func (s *testCommon) runExec(sql string) *testkit.Result { if s.isAPI { s.sessionService.LoadOptions(session.SourceOptions{ Host: s.defaultInc.BackupHost, Port: int(s.defaultInc.BackupPort), User: s.defaultInc.BackupUser, Password: s.defaultInc.BackupPassword, RealRowCount: s.realRowCount, IgnoreWarnings: true, }) result, _ := s.sessionService.RunExecute(context.Background(), s.useDB+sql) s.records = result s.rows = make([][]interface{}, len(result)) for index, row := range result { // c.Assert(row.ErrLevel, Not(Equals), uint8(2), Commentf("%v", result)) s.rows[index] = row.List() } return nil } a := `/*%s;--execute=1;--backup=0;--enable-ignore-warnings;real_row_count=%v;*/ inception_magic_start; %s %s; inception_magic_commit;` res := s.tk.MustQueryInc(fmt.Sprintf(a, s.getAddr(), s.realRowCount, s.useDB, sql)) s.rows = res.Rows() return res } func (s *testCommon) mustRunExec(c *C, sql string) *testkit.Result { config.GetGlobalConfig().Inc.EnableDropTable = true if s.isAPI { s.sessionService.LoadOptions(session.SourceOptions{ Host: s.defaultInc.BackupHost, Port: int(s.defaultInc.BackupPort), User: s.defaultInc.BackupUser, Password: s.defaultInc.BackupPassword, RealRowCount: s.realRowCount, IgnoreWarnings: true, }) result, err := s.sessionService.RunExecute(context.Background(), s.useDB+sql) c.Assert(err, IsNil) s.records = result s.rows = make([][]interface{}, len(result)) for index, row := range result { c.Assert(row.ErrLevel, Not(Equals), uint8(2), Commentf("%v", result)) s.rows[index] = row.List() } return nil } a := `/*%s;--execute=1;--backup=0;--enable-ignore-warnings;real_row_count=%v;*/ inception_magic_start; %s %s; inception_magic_commit;` res := s.tk.MustQueryInc(fmt.Sprintf(a, s.getAddr(), s.realRowCount, s.useDB, sql)) for _, row := range res.Rows() { c.Assert(strings.Contains(row[3].(string), "Execute Successfully"), Equals, true, Commentf("%v", res.Rows())) c.Assert(row[2], Not(Equals), "2", Commentf("%v", row)) } s.rows = res.Rows() return res } func (s *testCommon) runBackup(sql string) *testkit.Result { if s.isAPI { s.sessionService.LoadOptions(session.SourceOptions{ Host: s.defaultInc.BackupHost, Port: int(s.defaultInc.BackupPort), User: s.defaultInc.BackupUser, Password: s.defaultInc.BackupPassword, RealRowCount: s.realRowCount, Backup: true, IgnoreWarnings: true, }) result, _ := s.sessionService.RunExecute(context.Background(), s.useDB+sql) s.records = result s.rows = make([][]interface{}, len(result)) for index, row := range result { // c.Assert(row.ErrLevel, Not(Equals), uint8(2), Commentf("%v", result)) s.rows[index] = row.List() } return nil } a := `/*%s;--execute=1;--backup=1;--enable-ignore-warnings;real_row_count=%v;*/ inception_magic_start; %s %s; inception_magic_commit;` res := s.tk.MustQueryInc(fmt.Sprintf(a, s.getAddr(), s.realRowCount, s.useDB, sql)) s.rows = res.Rows() return res } func (s *testCommon) mustRunBackup(c *C, sql string) *testkit.Result { if s.isAPI { s.sessionService.LoadOptions(session.SourceOptions{ Host: s.defaultInc.BackupHost, Port: int(s.defaultInc.BackupPort), User: s.defaultInc.BackupUser, Password: s.defaultInc.BackupPassword, RealRowCount: s.realRowCount, Backup: true, IgnoreWarnings: true, }) result, err := s.sessionService.RunExecute(context.Background(), s.useDB+sql) c.Assert(err, IsNil) s.records = result s.rows = make([][]interface{}, len(result)) for index, row := range result { c.Assert(row.ErrLevel, Not(Equals), uint8(2), Commentf("%v", result)) s.rows[index] = row.List() } return nil } a := `/*%s;--execute=1;--backup=1;--enable-ignore-warnings;real_row_count=%v;*/ inception_magic_start; %s %s; inception_magic_commit;` res := s.tk.MustQueryInc(fmt.Sprintf(a, s.getAddr(), s.realRowCount, s.useDB, sql)) // 需要成功执行 for _, row := range res.Rows() { // c.Assert(strings.Contains(row[3].(string), "Backup Successfully"), // Equals, true, Commentf("%v", res.Rows())) c.Assert(row[2], Not(Equals), "2", Commentf("%v", row)) } s.rows = res.Rows() return res } func (s *testCommon) mustRunBackupTran(c *C, sql string) *testkit.Result { a := `/*%s;--execute=1;--backup=1;--enable-ignore-warnings;real_row_count=%v;--trans=3;*/ inception_magic_start; %s %s; inception_magic_commit;` res := s.tk.MustQueryInc(fmt.Sprintf(a, s.getAddr(), s.realRowCount, s.useDB, sql)) // 需要成功执行 for _, row := range res.Rows() { c.Assert(row[2], Not(Equals), "2", Commentf("%v", row)) } s.rows = res.Rows() return res } func (s *testCommon) runTranSQL(sql string, batch int) *testkit.Result { a := `/*%s;--execute=1;--backup=1;--execute=1;--enable-ignore-warnings;real_row_count=%v;--trans=%d;*/ inception_magic_start; %s %s; inception_magic_commit;` res := s.tk.MustQueryInc(fmt.Sprintf(a, s.getAddr(), s.realRowCount, batch, s.useDB, sql)) s.rows = res.Rows() return res } func (s *testCommon) mustrunTranSQL(c *C, sql string) *testkit.Result { a := `/*%s;--execute=1;--backup=1;--execute=1;--enable-ignore-warnings;real_row_count=%v;--trans=10;*/ inception_magic_start; %s %s; inception_magic_commit;` res := s.tk.MustQueryInc(fmt.Sprintf(a, s.getAddr(), s.realRowCount, s.useDB, sql)) // 需要成功执行 for _, row := range res.Rows() { c.Assert(row[2], Not(Equals), "2", Commentf("%v", row)) } s.rows = res.Rows() return res } func (s *testCommon) getAddr() string { if s.dbAddr != "" { return s.dbAddr } inc := config.GetGlobalConfig().Inc s.dbAddr = fmt.Sprintf(`--host=%s;--port=%d;--user=%s;--password=%s;`, inc.BackupHost, inc.BackupPort, inc.BackupUser, inc.BackupPassword) return s.dbAddr } func (s *testCommon) mysqlServerVersion() error { inc := config.GetGlobalConfig().Inc if s.db == nil || s.db.DB().Ping() != nil { addr := fmt.Sprintf("%s:%s@tcp(%s:%d)/mysql?charset=utf8mb4&parseTime=True&loc=Local&maxAllowedPacket=4194304&autocommit=1", inc.BackupUser, inc.BackupPassword, inc.BackupHost, inc.BackupPort) db, err := gorm.Open("mysql", addr) if err != nil { return err } // 禁用日志记录器,不显示任何日志 db.LogMode(false) s.db = db } var name, value string sql := `show variables where Variable_name in ('explicit_defaults_for_timestamp','innodb_large_prefix', 'version','sql_mode','lower_case_table_names','enforce_gtid_consistency');` rows, err := s.db.Raw(sql).Rows() if err != nil { return err } // emptyInnodbLargePrefix := true for rows.Next() { rows.Scan(&name, &value) switch name { case "version": if strings.Contains(strings.ToLower(value), "mariadb") { s.DBType = DBTypeMariaDB } else if strings.Contains(strings.ToLower(value), "tidb") { s.DBType = DBTypeTiDB } else { s.DBType = DBTypeMysql } versionStr := strings.Split(value, "-")[0] versionSeg := strings.Split(versionStr, ".") if len(versionSeg) == 3 { versionStr = fmt.Sprintf("%s%02s%02s", versionSeg[0], versionSeg[1], versionSeg[2]) version, err := strconv.Atoi(versionStr) if err != nil { return err } s.DBVersion = version } else { return errors.New(fmt.Sprintf("无法解析版本号:%s", value)) } log.Debug("db version: ", s.DBVersion) case "innodb_large_prefix": // emptyInnodbLargePrefix = false s.innodbLargePrefix = value == "ON" || value == "1" case "sql_mode": s.sqlMode = value case "lower_case_table_names": if v, err := strconv.Atoi(value); err != nil { return err } else { s.ignoreCase = v > 0 } case "explicit_defaults_for_timestamp": s.explicitDefaultsForTimestamp = value == "ON" || value == "1" case "enforce_gtid_consistency": s.enforeGtidConsistency = value == "ON" || value == "1" } } // 如果没有innodb_large_prefix系统变量 // if emptyInnodbLargePrefix { // if s.DBVersion > 50700 { // s.innodbLargePrefix = true // } else { // s.innodbLargePrefix = false // } // } return nil } func (s *testCommon) assertRows(c *C, rows [][]interface{}, rollbackSqls ...string) error { c.Assert(len(rows), Not(Equals), 0) inc := config.GetGlobalConfig().Inc if s.db == nil || s.db.DB().Ping() != nil { addr := fmt.Sprintf("%s:%s@tcp(%s:%d)/mysql?charset=utf8mb4&parseTime=True&loc=Local&maxAllowedPacket=4194304&autocommit=1", inc.BackupUser, inc.BackupPassword, inc.BackupHost, inc.BackupPort) db, err := gorm.Open("mysql", addr) if err != nil { log.Info(err) return err } // 禁用日志记录器,不显示任何日志 db.LogMode(false) s.db = db } // 有可能是 不同的表,不同的库 result := []string{} // affectedRows := 0 // opid := "" // backupDBName := "" // sqlIndex := 0 var lastTable, currentTable string var ids []string for _, row := range rows { opid := "" backupDBName := "" // runSql := "" // if row[5] != nil { // runSql = row[5].(string) // } // affectedRows := 0 // if row[6] != nil { // a := row[6].(string) // affectedRows, _ = strconv.Atoi(a) // } if row[7] != nil { opid = row[7].(string) } if row[8] != nil { backupDBName = row[8].(string) } currentSql := "" if row[5] != nil { currentSql = row[5].(string) } if !strings.Contains(row[3].(string), "Backup Successfully") || strings.HasSuffix(opid, "00000000") { continue } tableName := s.getObjectName(currentSql) // 表名没有时,查询一下 if tableName == "" { sql := "select tablename from `%s`.`%s` where opid_time = ?" sql = fmt.Sprintf(sql, backupDBName, s.remoteBackupTable) tableRows, err := s.db.Raw(sql, opid).Rows() c.Assert(err, IsNil) for tableRows.Next() { tableRows.Scan(&tableName) } tableRows.Close() } c.Assert(tableName, Not(Equals), "", Commentf("%v", row)) currentTable = fmt.Sprintf("%s.`%s`", backupDBName, tableName) if lastTable == "" { lastTable = currentTable } // 如果表改变了,或者超过500行了 if lastTable != currentTable || len(ids) >= 500 { if len(ids) > 0 { sql := "select rollback_statement from %s where opid_time in (?) order by opid_time,id;" sql = fmt.Sprintf(sql, lastTable) rows, err := s.db.Raw(sql, ids).Rows() c.Assert(err, IsNil) str := "" result1 := []string{} for rows.Next() { rows.Scan(&str) result1 = append(result1, s.trim(str)) } rows.Close() c.Assert(len(result1), Not(Equals), 0, Commentf("-----------: %v,%v", sql, ids)) result = append(result, result1...) ids = nil } lastTable = currentTable } ids = append(ids, opid) } // for循环可能提前退出,所以最后的查询放在外面 if len(ids) > 0 { sql := "select rollback_statement from %s where opid_time in (?) order by opid_time,id;" sql = fmt.Sprintf(sql, currentTable) rollbackRows, err := s.db.Raw(sql, ids).Rows() c.Assert(err, IsNil) str := "" result1 := []string{} for rollbackRows.Next() { rollbackRows.Scan(&str) result1 = append(result1, s.trim(str)) } rollbackRows.Close() c.Assert(len(result1), Not(Equals), 0, Commentf("------2-----: %v", rows)) result = append(result, result1...) } c.Assert(len(result), Equals, len(rollbackSqls), Commentf("%v", result)) // 如果是UPDATE多表操作,此时回滚的SQL可能是无序的 if len(result) > 1 && strings.HasPrefix(result[0], "UPDATE") { prefix := "" for i, sql := range result { if i == 0 { prefix = strings.Fields(sql)[1] continue } if prefix != strings.Fields(sql)[1] { sort.Strings(result) sort.Strings(rollbackSqls) break } } } for i := range result { c.Assert(result[i], Equals, rollbackSqls[i], Commentf("%v", result)) } return nil } func (s *testCommon) trim(str string) string { if strings.Contains(str, " ") { return s.trim(strings.Replace(str, " ", " ", -1)) } return str } func getLeftTable(node ast.ResultSetNode) *ast.TableSource { if node == nil { return nil } log.Infof("%T", node) switch x := node.(type) { case *ast.Join: return getLeftTable(x.Left) case *ast.TableSource: return x case *ast.SelectStmt: if x.From != nil { return getLeftTable(x.From.TableRefs) } case *ast.UnionStmt: for _, sel := range x.SelectList.Selects { return getLeftTable(sel) } } return nil } // getObjectName 解析操作表名 func (s *testCommon) getObjectName(sql string) (name string) { stmtNodes, _, _ := s.parser.Parse(sql, "utf8mb4", "utf8mb4_bin") for _, stmtNode := range stmtNodes { switch node := stmtNode.(type) { case *ast.InsertStmt: tableRefs := node.Table if tableRefs == nil || tableRefs.TableRefs == nil || tableRefs.TableRefs.Right != nil { return "" } tblSrc, ok := tableRefs.TableRefs.Left.(*ast.TableSource) if !ok { return "" } if tblSrc.AsName.L != "" { return "" } tblName, ok := tblSrc.Source.(*ast.TableName) if !ok { return "" } name = tblName.Name.String() case *ast.UpdateStmt: return "" // tblSrc := getLeftTable(node.TableRefs.TableRefs) // if tblSrc == nil { // log.Errorf("未找到表名!!! sql: %s", sql) // return "" // } // tblName, ok := tblSrc.Source.(*ast.TableName) // if !ok { // log.Infof("%#v", tblSrc.Source) // return "" // } // name = tblName.Name.String() case *ast.DeleteStmt: tableRefs := node.TableRefs if tableRefs == nil || tableRefs.TableRefs == nil || tableRefs.TableRefs.Right != nil { return "" } tblSrc, ok := tableRefs.TableRefs.Left.(*ast.TableSource) if !ok { return "" } if tblSrc.AsName.L != "" { return "" } tblName, ok := tblSrc.Source.(*ast.TableName) if !ok { return "" } name = tblName.Name.String() case *ast.CreateDatabaseStmt, *ast.DropDatabaseStmt: case *ast.CreateTableStmt: name = node.Table.Name.String() case *ast.AlterTableStmt: name = node.Table.Name.String() case *ast.DropTableStmt: for _, t := range node.Tables { name = t.Name.String() break } case *ast.RenameTableStmt: name = node.OldTable.Name.String() case *ast.TruncateTableStmt: name = node.Table.Name.String() case *ast.CreateIndexStmt: name = node.Table.Name.String() case *ast.DropIndexStmt: name = node.Table.Name.String() default: } return name } return "" } func (s *testCommon) queryStatistics() []int { inc := config.GetGlobalConfig().Inc if s.db == nil || s.db.DB().Ping() != nil { addr := fmt.Sprintf("%s:%s@tcp(%s:%d)/mysql?charset=utf8mb4&parseTime=True&loc=Local&maxAllowedPacket=4194304&autocommit=1", inc.BackupUser, inc.BackupPassword, inc.BackupHost, inc.BackupPort) db, err := gorm.Open("mysql", addr) if err != nil { fmt.Println(err) } // 禁用日志记录器,不显示任何日志 db.LogMode(false) s.db = db } sql := `select usedb, deleting, inserting, updating, selecting, altertable, renaming, createindex, dropindex, addcolumn, dropcolumn, changecolumn, alteroption, alterconvert, createtable, droptable, CREATEDB, truncating from inception.statistic order by id desc limit 1;` values := make([]int, 18) rows, err := s.db.Raw(sql).Rows() if err != nil { fmt.Println(err) panic(err) } else { defer rows.Close() for rows.Next() { rows.Scan(&values[0], &values[1], &values[2], &values[3], &values[4], &values[5], &values[6], &values[7], &values[8], &values[9], &values[10], &values[11], &values[12], &values[13], &values[14], &values[15], &values[16], &values[17]) } } return values } func trim(s string) string { if strings.Contains(s, " ") { return trim(strings.Replace(s, " ", " ", -1)) } return s } func (s *testCommon) query(table, opid string) string { inc := config.GetGlobalConfig().Inc if s.db == nil || s.db.DB().Ping() != nil { addr := fmt.Sprintf("%s:%s@tcp(%s:%d)/mysql?charset=utf8mb4&parseTime=True&loc=Local&maxAllowedPacket=4194304&autocommit=1", inc.BackupUser, inc.BackupPassword, inc.BackupHost, inc.BackupPort) db, err := gorm.Open("mysql", addr) if err != nil { fmt.Println(err) return err.Error() } // 禁用日志记录器,不显示任何日志 db.LogMode(false) s.db = db } result := []string{} sql := "select rollback_statement from 127_0_0_1_%d_test_inc.`%s` where opid_time = ?;" sql = fmt.Sprintf(sql, inc.BackupPort, table) rows, err := s.db.Raw(sql, opid).Rows() if err != nil { fmt.Println(err) return err.Error() } else { defer rows.Close() for rows.Next() { str := "" rows.Scan(&str) result = append(result, trim(str)) } } return strings.Join(result, "\n") } // parserStmt 解析sql变成ast语法 func (s *testCommon) parserStmt(sql string) ast.StmtNode { stmtNodes, _, _ := s.parser.Parse(sql, "utf8mb4", "utf8mb4_bin") for _, stmtNode := range stmtNodes { return stmtNode } return nil } func (s *testCommon) reset() { config.GetGlobalConfig().Inc = s.defaultInc config.GetGlobalConfig().IncLevel = s.defaultIncLevel log.SetLevel(log.ErrorLevel) // log.SetReportCaller(true) _ = logutil.InitLogger(&logutil.LogConfig{ Level: "error", Format: "text", DisableTimestamp: false, }) } // getAffectedRows 获取受影响行数, 区分api返回和mysql会话返回 func (s *testCommon) getAffectedRows() int { if s.isAPI { return len(s.rows) } return int(s.session.AffectedRows()) } func (s *testCommon) testAffectedRows(c *C, affectedRows ...int) { if len(s.rows) == 0 && len(s.records) == 0 { return } count := len(affectedRows) for i, affectedRow := range affectedRows { if s.isAPI { row := s.records[len(s.records)-(count-i)] c.Assert(row.AffectedRows, Equals, affectedRow, Commentf("%#v", row)) } else { row := s.rows[len(s.rows)-(count-i)] c.Assert(row[6], Equals, strconv.Itoa(affectedRow), Commentf("%v", row)) } } } func (s *testCommon) getResultRows() [][]interface{} { return s.rows }