package component import ( "context" "errors" "fmt" glogger "gitea.party/public-messag-service/config/logger" "gorm.io/gorm/logger" "gorm.io/gorm/utils" "time" ) const ( _TraceStr = "\n\t%s\n\t[%.3fms] [rows:%v] %s" _TraceErrStr = "\n\t%s %s\n\t[%.3fms] [rows:%v] %s" _TraceWarnStr = "\n\t%s %s\n\t[%.3fms] [rows:%v] %s" ) type MysqlLogger struct { log glogger.Interface level logger.LogLevel slowThreshold time.Duration } func (l *MysqlLogger) SetLogger(log glogger.Interface) { l.log = log } func (l *MysqlLogger) LogMode(level logger.LogLevel) logger.Interface { l.level = level return l } func (l *MysqlLogger) Info(_ context.Context, msg string, data ...interface{}) { if l.level >= logger.Info { l.log.Info(msg, append([]interface{}{utils.FileWithLineNum()}, data...)...) } } func (l *MysqlLogger) Warn(_ context.Context, msg string, data ...interface{}) { if l.level >= logger.Warn { l.log.Warn(msg, append([]interface{}{utils.FileWithLineNum()}, data...)...) } } func (l *MysqlLogger) Error(_ context.Context, msg string, data ...interface{}) { if l.level >= logger.Error { l.log.Error(msg, append([]interface{}{utils.FileWithLineNum()}, data...)...) } } func (l *MysqlLogger) Trace(_ context.Context, begin time.Time, fc func() (string, int64), err error) { if l.level <= logger.Silent { return } elapsed := time.Since(begin) switch { case err != nil && l.level >= logger.Error && !errors.Is(err, logger.ErrRecordNotFound): sql, rows := fc() if rows == -1 { l.log.Error(_TraceErrStr, utils.FileWithLineNum(), err, float64(elapsed.Nanoseconds())/1e6, "-", sql) } else { l.log.Error(_TraceErrStr, utils.FileWithLineNum(), err, float64(elapsed.Nanoseconds())/1e6, rows, sql) } case elapsed > l.slowThreshold && l.slowThreshold != 0 && l.level >= logger.Warn: sql, rows := fc() slowLog := fmt.Sprintf("SLOW SQL >= %v", l.slowThreshold) if rows == -1 { l.log.Warn(_TraceWarnStr, utils.FileWithLineNum(), slowLog, float64(elapsed.Nanoseconds())/1e6, "-", sql) } else { l.log.Warn(_TraceWarnStr, utils.FileWithLineNum(), slowLog, float64(elapsed.Nanoseconds())/1e6, rows, sql) } case l.level == logger.Info: sql, rows := fc() if rows == -1 { l.log.Info(_TraceStr, utils.FileWithLineNum(), float64(elapsed.Nanoseconds())/1e6, "-", sql) } else { l.log.Info(_TraceStr, utils.FileWithLineNum(), float64(elapsed.Nanoseconds())/1e6, rows, sql) } } }