From 17017b251622bf120fe4b87bc102747a50109cce Mon Sep 17 00:00:00 2001 From: Martin Holst Swende Date: Mon, 30 Jan 2023 12:43:12 -0500 Subject: [PATCH] log: better sanitation (#26556) --- log/format.go | 38 ++++++++++++++++++++++++++++++++------ log/format_test.go | 46 ++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 78 insertions(+), 6 deletions(-) diff --git a/log/format.go b/log/format.go index 613dc33be..42525ea6d 100644 --- a/log/format.go +++ b/log/format.go @@ -86,6 +86,7 @@ type TerminalStringer interface { // [DBUG] [May 16 20:58:45] remove route ns=haproxy addr=127.0.0.1:50002 func TerminalFormat(usecolor bool) Format { return FormatFunc(func(r *Record) []byte { + msg := escapeMessage(r.Msg) var color = 0 if usecolor { switch r.Lvl { @@ -122,19 +123,19 @@ func TerminalFormat(usecolor bool) Format { // Assemble and print the log heading if color > 0 { - fmt.Fprintf(b, "\x1b[%dm%s\x1b[0m[%s|%s]%s %s ", color, lvl, r.Time.Format(termTimeFormat), location, padding, r.Msg) + fmt.Fprintf(b, "\x1b[%dm%s\x1b[0m[%s|%s]%s %s ", color, lvl, r.Time.Format(termTimeFormat), location, padding, msg) } else { - fmt.Fprintf(b, "%s[%s|%s]%s %s ", lvl, r.Time.Format(termTimeFormat), location, padding, r.Msg) + fmt.Fprintf(b, "%s[%s|%s]%s %s ", lvl, r.Time.Format(termTimeFormat), location, padding, msg) } } else { if color > 0 { - fmt.Fprintf(b, "\x1b[%dm%s\x1b[0m[%s] %s ", color, lvl, r.Time.Format(termTimeFormat), r.Msg) + fmt.Fprintf(b, "\x1b[%dm%s\x1b[0m[%s] %s ", color, lvl, r.Time.Format(termTimeFormat), msg) } else { - fmt.Fprintf(b, "%s[%s] %s ", lvl, r.Time.Format(termTimeFormat), r.Msg) + fmt.Fprintf(b, "%s[%s] %s ", lvl, r.Time.Format(termTimeFormat), msg) } } // try to justify the log output for short messages - length := utf8.RuneCountInString(r.Msg) + length := utf8.RuneCountInString(msg) if len(r.Ctx) > 0 && length < termMsgJust { b.Write(bytes.Repeat([]byte{' '}, termMsgJust-length)) } @@ -167,6 +168,8 @@ func logfmt(buf *bytes.Buffer, ctx []interface{}, color int, term bool) { v := formatLogfmtValue(ctx[i+1], term) if !ok { k, v = errorKey, formatLogfmtValue(k, term) + } else { + k = escapeString(k) } // XXX: we should probably check that all of your key bytes aren't invalid @@ -471,7 +474,7 @@ func formatLogfmtBigInt(n *big.Int) string { func escapeString(s string) string { needsQuoting := false for _, r := range s { - // We quote everything below " (0x34) and above~ (0x7E), plus equal-sign + // We quote everything below " (0x22) and above~ (0x7E), plus equal-sign if r <= '"' || r > '~' || r == '=' { needsQuoting = true break @@ -482,3 +485,26 @@ func escapeString(s string) string { } return strconv.Quote(s) } + +// escapeMessage checks if the provided string needs escaping/quoting, similarly +// to escapeString. The difference is that this method is more lenient: it allows +// for spaces and linebreaks to occur without needing quoting. +func escapeMessage(s string) string { + needsQuoting := false + for _, r := range s { + // Carriage return and Line feed are ok + if r == 0xa || r == 0xd { + continue + } + // We quote everything below (0x20) and above~ (0x7E), + // plus equal-sign + if r < ' ' || r > '~' || r == '=' { + needsQuoting = true + break + } + } + if !needsQuoting { + return s + } + return strconv.Quote(s) +} diff --git a/log/format_test.go b/log/format_test.go index d7e0a9576..cfcfe8580 100644 --- a/log/format_test.go +++ b/log/format_test.go @@ -1,9 +1,11 @@ package log import ( + "fmt" "math" "math/big" "math/rand" + "strings" "testing" ) @@ -93,3 +95,47 @@ func BenchmarkPrettyUint64Logfmt(b *testing.B) { sink = FormatLogfmtUint64(rand.Uint64()) } } + +func TestSanitation(t *testing.T) { + msg := "\u001b[1G\u001b[K\u001b[1A" + msg2 := "\u001b \u0000" + msg3 := "NiceMessage" + msg4 := "Space Message" + msg5 := "Enter\nMessage" + + for i, tt := range []struct { + msg string + want string + }{ + { + msg: msg, + want: fmt.Sprintf("] %q %q=%q\n", msg, msg, msg), + }, + { + msg: msg2, + want: fmt.Sprintf("] %q %q=%q\n", msg2, msg2, msg2), + }, + { + msg: msg3, + want: fmt.Sprintf("] %s %s=%s\n", msg3, msg3, msg3), + }, + { + msg: msg4, + want: fmt.Sprintf("] %s %q=%q\n", msg4, msg4, msg4), + }, + { + msg: msg5, + want: fmt.Sprintf("] %s %q=%q\n", msg5, msg5, msg5), + }, + } { + var ( + logger = New() + out = new(strings.Builder) + ) + logger.SetHandler(LvlFilterHandler(LvlInfo, StreamHandler(out, TerminalFormat(false)))) + logger.Info(tt.msg, tt.msg, tt.msg) + if have := out.String()[24:]; tt.want != have { + t.Fatalf("test %d: want / have: \n%v\n%v", i, tt.want, have) + } + } +}