1 В избранное 0 Ответвления 0

OSCHINA-MIRROR/hanchuanchuan-goInception

Присоединиться к Gitlife
Откройте для себя и примите участие в публичных проектах с открытым исходным кодом с участием более 10 миллионов разработчиков. Приватные репозитории также полностью бесплатны :)
Присоединиться бесплатно
Это зеркальный репозиторий, синхронизируется ежедневно с исходного репозитория.
Клонировать/Скачать
session_masking.go 18 КБ
Копировать Редактировать Исходные данные Просмотреть построчно История
hanchuanchuan Отправлено 3 лет назад e46a69b
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667
// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package session
import (
"bytes"
"fmt"
"strings"
"golang.org/x/net/context"
json "github.com/CorgiMan/json2"
"github.com/hanchuanchuan/goInception/ast"
"github.com/hanchuanchuan/goInception/model"
"github.com/hanchuanchuan/goInception/util/sqlexec"
log "github.com/sirupsen/logrus"
)
type masking struct {
maskingFields []MaskingFieldInfo
session *session
colIndex int
buf *bytes.Buffer
// newItemIndex int
}
// func (m *masking) copy() *masking {
// p := &masking{
// session: m.session,
// newItemIndex: 0,
// colIndex: 0,
// buf: nil,
// }
// copy(m.maskingFields, p.maskingFields)
// return p
// }
func (s *session) maskingCommand(ctx context.Context, stmtNode ast.StmtNode,
currentSql string) ([]sqlexec.RecordSet, error) {
log.Debug("maskingCommand")
s.maskingFields = make([]MaskingFieldInfo, 0)
switch node := stmtNode.(type) {
case *ast.UseStmt:
s.dbName = node.DBName
case *ast.UnionStmt, *ast.SelectStmt:
p := masking{
session: s,
maskingFields: make([]MaskingFieldInfo, 0),
buf: new(bytes.Buffer),
}
_, _ = p.checkSelectItem(node, 0)
fields := p.maskingFields
tree, err := json.Marshal(fields)
// tree1, _ := json.MarshalIndent(fields, " ", "")
// log.Errorf("%#v", string(tree1))
if err != nil {
log.Error(err)
s.printSets.Append(2, currentSql, "", err.Error())
} else {
if p.buf.Len() > 0 {
s.printSets.Append(0, currentSql, string(tree), p.buf.String())
} else {
s.printSets.Append(0, currentSql, string(tree), "")
}
}
default:
s.printSets.Append(2, currentSql, "", "not support")
}
return nil, nil
}
func (s *masking) checkSelectItem(node ast.ResultSetNode, level int) (
tables []*TableInfo, fields []MaskingFieldInfo) {
if node == nil {
return
}
switch x := node.(type) {
case *ast.UnionStmt:
for _, sel := range x.SelectList.Selects {
tmpTables, tmpFields := s.checkSubSelectItem(sel, level)
if tmpTables != nil {
tables = append(tables, tmpTables...)
}
fields = append(fields, tmpFields...)
}
return
case *ast.SelectStmt:
return s.checkSubSelectItem(x, level)
case *ast.Join:
tmpTables, tmpFields := s.checkSelectItem(x.Left, level+1)
tables = append(tables, tmpTables...)
fields = append(fields, tmpFields...)
tmpTables, tmpFields = s.checkSelectItem(x.Right, level+1)
tables = append(tables, tmpTables...)
fields = append(fields, tmpFields...)
return
case *ast.TableSource:
switch tblSource := x.Source.(type) {
case *ast.TableName:
dbName := tblSource.Schema.O
if dbName == "" {
dbName = s.session.dbName
}
t := s.session.getTableFromCache(dbName, tblSource.Name.O, false)
if t != nil {
if x.AsName.L != "" {
t.AsName = x.AsName.O
return []*TableInfo{t.copy()}, nil
}
return []*TableInfo{t}, nil
}
return
case *ast.SelectStmt:
return s.checkSubSelectItem(tblSource, level+1)
case *ast.UnionStmt:
return s.checkSelectItem(tblSource, level+1)
default:
return s.checkSelectItem(tblSource, level+1)
}
default:
log.Infof("%T", x)
}
return
}
func (s *masking) checkSubSelectItem(node *ast.SelectStmt, level int) (tableInfoList []*TableInfo, fields []MaskingFieldInfo) {
log.Debug("checkSubSelectItem")
var tableList []*ast.TableSource
if node.From != nil {
tableList = extractTableList(node.From.TableRefs, tableList)
}
for _, tblSource := range tableList {
switch x := tblSource.Source.(type) {
case *ast.TableName:
tblName := x
dbName := tblName.Schema.O
if dbName == "" {
dbName = s.session.dbName
}
t := s.session.getTableFromCache(dbName, tblName.Name.O, false)
if t != nil {
if tblSource.AsName.L != "" {
t.AsName = tblSource.AsName.O
tableInfoList = append(tableInfoList, t.copy())
} else {
tableInfoList = append(tableInfoList, t)
}
} else {
tableInfoList = append(tableInfoList, &TableInfo{
Schema: tblName.Schema.String(),
Name: tblName.Name.String(),
})
}
case *ast.SelectStmt:
// 递归审核子查询
tmpTables, tmpFields := s.checkSubSelectItem(x, level+1)
if tblSource.AsName.L != "" {
for _, f := range tmpTables {
f.AsName = tblSource.AsName.String()
}
}
for _, f := range tmpFields {
if f.Alias == "" {
continue
}
for _, t := range tmpTables {
if f.Table == t.Name {
if t.maskingFields == nil {
t.maskingFields = make([]MaskingFieldInfo, 0)
}
t.maskingFields = append(t.maskingFields, f)
}
}
}
// for _, t := range tmpTables {
// log.Errorf("t: %v, maskingFields: %v", t.Name, len(t.maskingFields))
// for _, f := range t.maskingFields {
// log.Errorf("f: %#v ", f)
// }
// }
tableInfoList = append(tableInfoList, tmpTables...)
fields = append(fields, tmpFields...)
default:
tmpTables, tmpFields := s.checkSelectItem(x, level+1)
tableInfoList = append(tableInfoList, tmpTables...)
fields = append(fields, tmpFields...)
}
}
if node.Fields != nil {
newFields := make([]*ast.SelectField, 0)
for _, field := range node.Fields.Fields {
var tmpFields []MaskingFieldInfo
// 非星号列
if field.WildCard == nil {
// 如果列有别名,则特殊处理
subFields := s.checkSelectField(field, tableInfoList, level, s.colIndex)
if len(subFields) > 0 && field.AsName.L != "" {
for index := range subFields {
(&subFields[index]).Alias = field.AsName.String()
}
fields = append(fields, subFields...)
}
newFields = append(newFields, field)
if level == 0 {
s.colIndex++
}
continue
}
// WildCard!=nil 为星号列
db := field.WildCard.Schema.L
wildTable := field.WildCard.Table.L
if wildTable == "" {
for _, tblSource := range tableList {
tblName, ok := tblSource.Source.(*ast.TableName)
if ok {
if tblName.Schema.L == "" {
tblName.Schema = model.NewCIStr(s.session.dbName)
}
t := s.session.getTableFromCache(tblName.Schema.O, tblName.Name.O, false)
if t != nil {
tmpFields = append(tmpFields,
Convert(tblName.Schema.O, tblName.Name.O, t.Fields)...)
for _index, f := range t.Fields {
tableName := tblSource.AsName.String()
if tableName == "" {
tableName = tblName.Name.O
}
newField := &ast.SelectField{
Expr: &ast.ColumnNameExpr{
Name: &ast.ColumnName{
Name: model.NewCIStr(f.Field),
Table: model.NewCIStr(tableName),
},
},
}
s.checkSelectField(newField, tableInfoList,
level, s.colIndex+_index)
newFields = append(newFields, newField)
}
if level == 0 {
s.colIndex += len(t.Fields)
}
} else {
s.appendErrorNo(ER_TABLE_NOT_EXISTED_ERROR,
fmt.Sprintf("`%s`.`%s`", tblName.Schema.O, tblName.Name.String()))
}
} else {
cols := s.session.getSubSelectColumns(tblSource.Source)
for _index, f := range cols {
newField := &ast.SelectField{
Expr: &ast.ColumnNameExpr{
Name: &ast.ColumnName{
Name: model.NewCIStr(f),
Table: model.NewCIStr(tblSource.AsName.String()),
},
},
}
s.checkSelectField(newField, tableInfoList,
level, s.colIndex+_index)
newFields = append(newFields, newField)
}
if level == 0 {
s.colIndex += len(cols)
}
}
}
} else {
for _, tblSource := range tableList {
var tableName string
tblName, ok := tblSource.Source.(*ast.TableName)
if tblSource.AsName.L != "" {
tableName = tblSource.AsName.L
} else if ok {
tableName = tblName.Name.L
}
if (ok && (db == "" || db == tblName.Schema.L) && wildTable == tableName) || (!ok && wildTable == tableName) {
if ok {
dbName := tblName.Schema.O
if dbName == "" {
dbName = s.session.dbName
}
t := s.session.getTableFromCache(dbName, tblName.Name.O, false)
if t != nil {
tmpFields = append(tmpFields,
Convert(tblName.Schema.O, tblName.Name.O, t.Fields)...)
for _index, f := range t.Fields {
tableName := tblSource.AsName.String()
if tableName == "" {
tableName = tblName.Name.O
}
newField := &ast.SelectField{
Expr: &ast.ColumnNameExpr{
Name: &ast.ColumnName{
Name: model.NewCIStr(f.Field),
Table: model.NewCIStr(tableName),
},
},
}
s.checkSelectField(newField, tableInfoList,
level, s.colIndex+_index)
newFields = append(newFields, newField)
}
if level == 0 {
s.colIndex += len(t.Fields)
}
} else {
s.appendErrorNo(ER_TABLE_NOT_EXISTED_ERROR,
fmt.Sprintf("`%s`.`%s`", tblName.Schema.O, tblName.Name.String()))
}
} else {
cols := s.session.getSubSelectColumns(tblSource.Source)
for _index, f := range cols {
newField := &ast.SelectField{
Expr: &ast.ColumnNameExpr{
Name: &ast.ColumnName{
Name: model.NewCIStr(f),
Table: model.NewCIStr(field.WildCard.Table.O),
},
},
}
s.checkSelectField(newField, tableInfoList,
level, s.colIndex+_index)
newFields = append(newFields, newField)
}
if level == 0 {
s.colIndex += len(cols)
}
}
}
}
}
if tmpFields != nil {
fields = append(fields, tmpFields...)
}
}
if len(newFields) > len(node.Fields.Fields) {
node.Fields.Fields = newFields
}
}
return tableInfoList, fields
}
func (s *masking) checkSelectField(field *ast.SelectField,
tableInfoList []*TableInfo, level, colIndex int) (fields []MaskingFieldInfo) {
// log.Debug("checkSelectField")
fields = s.checkItem(field.Expr, tableInfoList)
for index := range fields {
f := fields[index]
if level == 0 {
f.Index = uint16(colIndex)
if field.AsName.String() != "" {
(&f).Alias = field.AsName.String()
s.maskingFields = append(s.maskingFields, f)
} else {
(&f).Alias = s.getExprAliasName(field)
s.maskingFields = append(s.maskingFields, f)
}
}
}
return fields
}
func (s *masking) checkItem(expr ast.ExprNode, tables []*TableInfo) (fields []MaskingFieldInfo) {
if expr == nil {
return
}
// log.Errorf("expr: %#v", expr)
switch e := expr.(type) {
case *ast.ColumnNameExpr:
// log.Errorf("e.Name: %#v", e.Name)
fs := s.checkFieldItem(e.Name, tables)
if fs == nil {
s.appendErrorNo(ER_COLUMN_NOT_EXISTED, e.Name.OrigColName())
db := e.Name.Schema.O
if db == "" {
db = s.session.dbName
}
fields = append(fields, MaskingFieldInfo{
Schema: db,
Table: e.Name.Table.String(),
Field: e.Name.Name.String(),
})
} else {
for _, f := range fs {
fields = append(fields, *f)
}
}
if e.Refer != nil {
fields = append(fields, s.checkItem(e.Refer.Expr, tables)...)
}
case *ast.BinaryOperationExpr:
fields = append(fields, s.checkItem(e.L, tables)...)
fields = append(fields, s.checkItem(e.R, tables)...)
case *ast.UnaryOperationExpr:
fields = append(fields, s.checkItem(e.V, tables)...)
case *ast.FuncCallExpr:
fields = append(fields, s.checkFuncItem(e, tables)...)
// return s.checkFuncItem(e, tables)
case *ast.FuncCastExpr:
fields = append(fields, s.checkItem(e.Expr, tables)...)
case *ast.AggregateFuncExpr:
return s.checkAggregateFuncItem(e, tables)
case *ast.PatternInExpr:
fields = append(fields, s.checkItem(e.Expr, tables)...)
for _, expr := range e.List {
fields = append(fields, s.checkItem(expr, tables)...)
}
if e.Sel != nil {
fields = append(fields, s.checkItem(e.Sel, tables)...)
}
case *ast.PatternLikeExpr:
return s.checkItem(e.Expr, tables)
case *ast.PatternRegexpExpr:
return s.checkItem(e.Expr, tables)
case *ast.SubqueryExpr:
_, fields = s.checkSelectItem(e.Query, 1)
return fields
case *ast.CompareSubqueryExpr:
fields = append(fields, s.checkItem(e.L, tables)...)
fields = append(fields, s.checkItem(e.R, tables)...)
case *ast.ExistsSubqueryExpr:
_, fields = s.checkSelectItem(e.Sel, 1)
return fields
case *ast.IsNullExpr:
return s.checkItem(e.Expr, tables)
case *ast.IsTruthExpr:
return s.checkItem(e.Expr, tables)
case *ast.BetweenExpr:
fields = append(fields, s.checkItem(e.Expr, tables)...)
fields = append(fields, s.checkItem(e.Left, tables)...)
fields = append(fields, s.checkItem(e.Right, tables)...)
case *ast.CaseExpr:
fields = append(fields, s.checkItem(e.Value, tables)...)
for _, when := range e.WhenClauses {
fields = append(fields, s.checkItem(when.Expr, tables)...)
fields = append(fields, s.checkItem(when.Result, tables)...)
}
fields = append(fields, s.checkItem(e.ElseClause, tables)...)
case *ast.DefaultExpr:
// s.checkFieldItem(e.Name, tables)
// pass
case *ast.ParenthesesExpr:
fields = append(fields, s.checkItem(e.Expr, tables)...)
case *ast.RowExpr:
for _, expr := range e.Values {
fields = append(fields, s.checkItem(expr, tables)...)
}
case *ast.ValuesExpr:
fs := s.checkFieldItem(e.Column.Name, tables)
for _, f := range fs {
fields = append(fields, *f)
}
case *ast.VariableExpr:
return s.checkItem(e.Value, tables)
case *ast.ValueExpr, *ast.ParamMarkerExpr, *ast.PositionExpr:
// pass
default:
log.Infof("checkItem: %#v", e)
}
return
}
// getExprAliasName 获取列别名
func (s *masking) getExprAliasName(field *ast.SelectField) string {
expr := field.Expr
if expr == nil {
return ""
}
switch e := expr.(type) {
case *ast.ColumnNameExpr:
return e.Name.Name.String()
case *ast.BinaryOperationExpr, *ast.UnaryOperationExpr, *ast.FuncCallExpr,
*ast.FuncCastExpr,
*ast.AggregateFuncExpr, *ast.PatternInExpr, *ast.PatternLikeExpr,
*ast.PatternRegexpExpr, *ast.SubqueryExpr, *ast.CompareSubqueryExpr,
*ast.ExistsSubqueryExpr, *ast.IsNullExpr, *ast.IsTruthExpr,
*ast.BetweenExpr, *ast.CaseExpr, *ast.ParenthesesExpr,
*ast.RowExpr, *ast.ValuesExpr, *ast.VariableExpr,
*ast.ValueExpr, *ast.ParamMarkerExpr, *ast.PositionExpr:
return field.Text()
default:
log.Infof("getExprAliasName default: %#v", e)
return field.Text()
}
}
// checkFieldItem 检查字段是否存在并返回对应的字段信息. 当为列别名时可能引用了多个原始列
func (s *masking) checkFieldItem(name *ast.ColumnName, tables []*TableInfo) []*MaskingFieldInfo {
db := name.Schema.L
for _, t := range tables {
if name.Table.L != "" {
var tName string
if t.AsName != "" {
tName = t.AsName
} else {
tName = t.Name
}
if (db == "" || strings.EqualFold(t.Schema, db)) &&
(strings.EqualFold(tName, name.Table.L)) {
if len(t.Fields) == 0 {
return []*MaskingFieldInfo{
{
Schema: t.Schema,
Table: t.Name,
Field: name.Name.String(),
}}
}
result := make([]*MaskingFieldInfo, 0)
for index := range t.maskingFields {
field := t.maskingFields[index]
if strings.EqualFold(field.Alias, name.Name.L) {
result = append(result, &field)
}
}
if len(result) > 0 {
return result
}
for _, field := range t.Fields {
if strings.EqualFold(field.Field, name.Name.L) && !field.IsDeleted {
return []*MaskingFieldInfo{
{
Schema: t.Schema,
Table: t.Name,
Field: name.Name.String(),
Type: field.Type,
}}
}
}
}
} else {
if len(t.Fields) == 0 {
return []*MaskingFieldInfo{
{
Schema: t.Schema,
Table: t.Name,
Field: name.Name.String(),
}}
}
result := make([]*MaskingFieldInfo, 0)
for index := range t.maskingFields {
field := t.maskingFields[index]
if strings.EqualFold(field.Alias, name.Name.L) {
result = append(result, &field)
}
}
if len(result) > 0 {
return result
}
for _, field := range t.Fields {
if strings.EqualFold(field.Field, name.Name.L) && !field.IsDeleted {
return []*MaskingFieldInfo{
{
Schema: t.Schema,
Table: t.Name,
Field: name.Name.String(),
Type: field.Type,
}}
}
}
}
}
return nil
}
// checkFuncItem 检查函数的字段
func (s *masking) checkFuncItem(f *ast.FuncCallExpr, tables []*TableInfo) (fields []MaskingFieldInfo) {
for _, arg := range f.Args {
fields = append(fields, s.checkItem(arg, tables)...)
// if c:=s.checkColumnExpr(arg,tables);c!=nil{
// f.Args[index] = c
// }
}
return
}
// checkAggregateFuncItem 检查聚合函数的字段
func (s *masking) checkAggregateFuncItem(f *ast.AggregateFuncExpr, tables []*TableInfo) (fields []MaskingFieldInfo) {
for _, arg := range f.Args {
fields = append(fields, s.checkItem(arg, tables)...)
}
return
}
func Convert(schema, table string, fs []FieldInfo) []MaskingFieldInfo {
maskingFields := make([]MaskingFieldInfo, len(fs))
for index, f := range fs {
maskingFields[index] = MaskingFieldInfo{
Field: f.Field,
Type: f.Type,
Schema: schema,
Table: table,
}
}
return maskingFields
}
func (s *masking) appendErrorNo(number ErrorCode, values ...interface{}) {
// 不检查时退出
if !s.session.checkInceptionVariables(number) {
return
}
var level uint8
if v, ok := s.session.incLevel[number.String()]; ok {
level = v
} else {
level = GetErrorLevel(number)
}
if level > 0 {
if len(values) == 0 {
s.buf.WriteString(s.session.getErrorMessage(number))
} else {
s.buf.WriteString(fmt.Sprintf(s.session.getErrorMessage(number), values...))
}
s.buf.WriteString("\n")
}
}

Комментарий ( 0 )

Вы можете оставить комментарий после Вход в систему

1
https://gitlife.ru/oschina-mirror/hanchuanchuan-goInception.git
git@gitlife.ru:oschina-mirror/hanchuanchuan-goInception.git
oschina-mirror
hanchuanchuan-goInception
hanchuanchuan-goInception
v1.3.0