From 373c63e18d1f72dc9f050195b148612e5cace6d0 Mon Sep 17 00:00:00 2001 From: Charles Korn Date: Mon, 15 Dec 2025 16:04:47 +1100 Subject: [PATCH 1/3] Add limit --- config.go | 9 +++ misc_tests/max_marshalled_size_test.go | 77 ++++++++++++++++++++++++++ reflect_native.go | 1 + stream.go | 63 +++++++++++++++++---- stream_float.go | 2 + stream_int.go | 12 ++++ stream_str.go | 2 + 7 files changed, 155 insertions(+), 11 deletions(-) create mode 100644 misc_tests/max_marshalled_size_test.go diff --git a/config.go b/config.go index 2adcdc3b..3f437ec5 100644 --- a/config.go +++ b/config.go @@ -25,6 +25,13 @@ type Config struct { ValidateJsonRawMessage bool ObjectFieldMustBeSimpleString bool CaseSensitive bool + + // MaxMarshalledBytes limits the maximum size of the output. + // + // While it guarantees not to return more bytes than MaxMarshalledBytes, + // it does not guarantee that the internal buffer will be smaller than MaxMarshalledBytes. + // In most cases, the internal buffer may be larger by only a few bytes. + MaxMarshalledBytes uint64 } // API the public interface of this package. @@ -80,6 +87,7 @@ type frozenConfig struct { streamPool *sync.Pool iteratorPool *sync.Pool caseSensitive bool + maxMarshalledBytes uint64 } func (cfg *frozenConfig) initCache() { @@ -134,6 +142,7 @@ func (cfg Config) Froze() API { onlyTaggedField: cfg.OnlyTaggedField, disallowUnknownFields: cfg.DisallowUnknownFields, caseSensitive: cfg.CaseSensitive, + maxMarshalledBytes: cfg.MaxMarshalledBytes, } api.streamPool = &sync.Pool{ New: func() interface{} { diff --git a/misc_tests/max_marshalled_size_test.go b/misc_tests/max_marshalled_size_test.go new file mode 100644 index 00000000..ab81b7fc --- /dev/null +++ b/misc_tests/max_marshalled_size_test.go @@ -0,0 +1,77 @@ +package misc_tests + +import ( + "fmt" + "testing" + + "github.com/stretchr/testify/require" + + jsoniter "github.com/json-iterator/go" +) + +func TestMaxMarshalledSize(t *testing.T) { + testCases := []interface{}{ + nil, + "", + false, + 123, + 123.123, + []string{"foo", "bar"}, + map[string]int{"foo": 123}, + } + + for _, testCase := range testCases { + t.Run(fmt.Sprintf("%#v", testCase), func(t *testing.T) { + expectedBytes, err := jsoniter.Marshal(testCase) + require.NoError(t, err) + + expectedLength := uint64(len(expectedBytes)) + + expectSuccessfulMarshalling := func(t *testing.T, limit uint64) { + cfg := jsoniter.Config{ + MaxMarshalledBytes: limit, + } + + api := cfg.Froze() + actualBytes, err := api.Marshal(testCase) + require.NoError(t, err) + require.Equal(t, expectedBytes, actualBytes) + } + + expectFailedMarshalling := func(t *testing.T, limit uint64) { + cfg := jsoniter.Config{ + MaxMarshalledBytes: limit, + } + + api := cfg.Froze() + actualBytes, err := api.Marshal(testCase) + require.ErrorContains(t, err, fmt.Sprintf("marshalling produced a result over the configured limit of %d bytes", limit)) + require.Nil(t, actualBytes) + } + + t.Run("limit set to 0 (unlimited)", func(t *testing.T) { + expectSuccessfulMarshalling(t, 0) + }) + + t.Run("limit set to exact length of output", func(t *testing.T) { + expectSuccessfulMarshalling(t, expectedLength) + }) + + t.Run("limit set to just under length of output", func(t *testing.T) { + expectFailedMarshalling(t, expectedLength-1) + }) + + t.Run("limit set to well under length of output", func(t *testing.T) { + expectFailedMarshalling(t, 1) + }) + + t.Run("limit set to just over length of output", func(t *testing.T) { + expectSuccessfulMarshalling(t, expectedLength+1) + }) + + t.Run("limit set to well over length of output", func(t *testing.T) { + expectSuccessfulMarshalling(t, expectedLength+100) + }) + }) + } +} diff --git a/reflect_native.go b/reflect_native.go index f88722d1..b591b13c 100644 --- a/reflect_native.go +++ b/reflect_native.go @@ -444,6 +444,7 @@ func (codec *base64Codec) Encode(ptr unsafe.Pointer, stream *Stream) { buf := make([]byte, size) encoding.Encode(buf, src) stream.buf = append(stream.buf, buf...) + stream.enforceMaxBytes() } stream.writeByte('"') } diff --git a/stream.go b/stream.go index 23d8a3ad..9586194e 100644 --- a/stream.go +++ b/stream.go @@ -1,18 +1,24 @@ package jsoniter import ( + "fmt" "io" ) // stream is a io.Writer like object, with JSON specific write functions. // Error is not returned as return value, but stored as Error member on this stream instance. type Stream struct { - cfg *frozenConfig - out io.Writer - buf []byte - Error error - indention int - Attachment interface{} // open for customized encoder + cfg *frozenConfig + out io.Writer + buf []byte + Error error + indention int + Attachment interface{} // open for customized encoder + enforceMarshalledBytesLimit bool + + // Number of bytes remaining before marshalled size exceeds cfg.maxMarshalledBytes. + // This is tracked as an amount remaining to account for bytes already flushed in Write(). + marshalledBytesLimitRemaining uint64 } // NewStream create new stream instance. @@ -20,12 +26,15 @@ type Stream struct { // out can be nil if write to internal buffer. // bufSize is the initial size for the internal buffer in bytes. func NewStream(cfg API, out io.Writer, bufSize int) *Stream { + config := cfg.(*frozenConfig) return &Stream{ - cfg: cfg.(*frozenConfig), - out: out, - buf: make([]byte, 0, bufSize), - Error: nil, - indention: 0, + cfg: config, + out: out, + buf: make([]byte, 0, bufSize), + Error: nil, + indention: 0, + enforceMarshalledBytesLimit: config.maxMarshalledBytes > 0, + marshalledBytesLimitRemaining: config.maxMarshalledBytes, } } @@ -38,6 +47,7 @@ func (stream *Stream) Pool() StreamPool { func (stream *Stream) Reset(out io.Writer) { stream.out = out stream.buf = stream.buf[:0] + stream.marshalledBytesLimitRemaining = stream.cfg.maxMarshalledBytes } // Available returns how many bytes are unused in the buffer. @@ -66,9 +76,12 @@ func (stream *Stream) SetBuffer(buf []byte) { // why the write is short. func (stream *Stream) Write(p []byte) (nn int, err error) { stream.buf = append(stream.buf, p...) + stream.enforceMaxBytes() + if stream.out != nil { nn, err = stream.out.Write(stream.buf) stream.buf = stream.buf[nn:] + stream.marshalledBytesLimitRemaining -= uint64(nn) return } return len(p), nil @@ -77,22 +90,48 @@ func (stream *Stream) Write(p []byte) (nn int, err error) { // WriteByte writes a single byte. func (stream *Stream) writeByte(c byte) { stream.buf = append(stream.buf, c) + stream.enforceMaxBytes() } func (stream *Stream) writeTwoBytes(c1 byte, c2 byte) { stream.buf = append(stream.buf, c1, c2) + stream.enforceMaxBytes() } func (stream *Stream) writeThreeBytes(c1 byte, c2 byte, c3 byte) { stream.buf = append(stream.buf, c1, c2, c3) + stream.enforceMaxBytes() } func (stream *Stream) writeFourBytes(c1 byte, c2 byte, c3 byte, c4 byte) { stream.buf = append(stream.buf, c1, c2, c3, c4) + stream.enforceMaxBytes() } func (stream *Stream) writeFiveBytes(c1 byte, c2 byte, c3 byte, c4 byte, c5 byte) { stream.buf = append(stream.buf, c1, c2, c3, c4, c5) + stream.enforceMaxBytes() +} + +func (stream *Stream) enforceMaxBytes() { + if !stream.enforceMarshalledBytesLimit || stream.Error != nil { + return + } + + if uint64(len(stream.buf)) > stream.marshalledBytesLimitRemaining { + // Why do we do this rather than return an error? + // Most of the writing methods on Stream do not return an error, and introducing this would be a + // breaking change. + stream.Error = exceededMaxMarshalledBytesError{stream.cfg.maxMarshalledBytes} + } +} + +type exceededMaxMarshalledBytesError struct { + maxBytes uint64 +} + +func (err exceededMaxMarshalledBytesError) Error() string { + return fmt.Sprintf("marshalling produced a result over the configured limit of %d bytes", err.maxBytes) } // Flush writes any buffered data to the underlying io.Writer. @@ -117,6 +156,7 @@ func (stream *Stream) Flush() error { // WriteRaw write string out without quotes, just like []byte func (stream *Stream) WriteRaw(s string) { stream.buf = append(stream.buf, s...) + stream.enforceMaxBytes() } // WriteNil write null to stream @@ -207,4 +247,5 @@ func (stream *Stream) writeIndention(delta int) { for i := 0; i < toWrite; i++ { stream.buf = append(stream.buf, ' ') } + stream.enforceMaxBytes() } diff --git a/stream_float.go b/stream_float.go index eddd831b..dd826f8f 100644 --- a/stream_float.go +++ b/stream_float.go @@ -35,6 +35,7 @@ func (stream *Stream) WriteFloat32(val float32) { stream.buf = stream.buf[:n-1] } } + stream.enforceMaxBytes() } // WriteFloat32Lossy write float32 to stream with ONLY 6 digits precision although much much faster @@ -92,6 +93,7 @@ func (stream *Stream) WriteFloat64(val float64) { stream.buf = stream.buf[:n-1] } } + stream.enforceMaxBytes() } // WriteFloat64Lossy write float64 to stream with ONLY 6 digits precision although much much faster diff --git a/stream_int.go b/stream_int.go index d1059ee4..2935c9ef 100644 --- a/stream_int.go +++ b/stream_int.go @@ -32,6 +32,7 @@ func writeBuf(buf []byte, v uint32) []byte { // WriteUint8 write uint8 to stream func (stream *Stream) WriteUint8(val uint8) { stream.buf = writeFirstBuf(stream.buf, digits[val]) + stream.enforceMaxBytes() } // WriteInt8 write int8 to stream @@ -44,6 +45,7 @@ func (stream *Stream) WriteInt8(nval int8) { val = uint8(nval) } stream.buf = writeFirstBuf(stream.buf, digits[val]) + stream.enforceMaxBytes() } // WriteUint16 write uint16 to stream @@ -56,6 +58,7 @@ func (stream *Stream) WriteUint16(val uint16) { r1 := val - q1*1000 stream.buf = writeFirstBuf(stream.buf, digits[q1]) stream.buf = writeBuf(stream.buf, digits[r1]) + stream.enforceMaxBytes() return } @@ -76,6 +79,7 @@ func (stream *Stream) WriteUint32(val uint32) { q1 := val / 1000 if q1 == 0 { stream.buf = writeFirstBuf(stream.buf, digits[val]) + stream.enforceMaxBytes() return } r1 := val - q1*1000 @@ -83,6 +87,7 @@ func (stream *Stream) WriteUint32(val uint32) { if q2 == 0 { stream.buf = writeFirstBuf(stream.buf, digits[q1]) stream.buf = writeBuf(stream.buf, digits[r1]) + stream.enforceMaxBytes() return } r2 := q1 - q2*1000 @@ -96,6 +101,7 @@ func (stream *Stream) WriteUint32(val uint32) { } stream.buf = writeBuf(stream.buf, digits[r2]) stream.buf = writeBuf(stream.buf, digits[r1]) + stream.enforceMaxBytes() } // WriteInt32 write int32 to stream @@ -115,6 +121,7 @@ func (stream *Stream) WriteUint64(val uint64) { q1 := val / 1000 if q1 == 0 { stream.buf = writeFirstBuf(stream.buf, digits[val]) + stream.enforceMaxBytes() return } r1 := val - q1*1000 @@ -122,6 +129,7 @@ func (stream *Stream) WriteUint64(val uint64) { if q2 == 0 { stream.buf = writeFirstBuf(stream.buf, digits[q1]) stream.buf = writeBuf(stream.buf, digits[r1]) + stream.enforceMaxBytes() return } r2 := q1 - q2*1000 @@ -130,6 +138,7 @@ func (stream *Stream) WriteUint64(val uint64) { stream.buf = writeFirstBuf(stream.buf, digits[q2]) stream.buf = writeBuf(stream.buf, digits[r2]) stream.buf = writeBuf(stream.buf, digits[r1]) + stream.enforceMaxBytes() return } r3 := q2 - q3*1000 @@ -139,6 +148,7 @@ func (stream *Stream) WriteUint64(val uint64) { stream.buf = writeBuf(stream.buf, digits[r3]) stream.buf = writeBuf(stream.buf, digits[r2]) stream.buf = writeBuf(stream.buf, digits[r1]) + stream.enforceMaxBytes() return } r4 := q3 - q4*1000 @@ -149,6 +159,7 @@ func (stream *Stream) WriteUint64(val uint64) { stream.buf = writeBuf(stream.buf, digits[r3]) stream.buf = writeBuf(stream.buf, digits[r2]) stream.buf = writeBuf(stream.buf, digits[r1]) + stream.enforceMaxBytes() return } r5 := q4 - q5*1000 @@ -165,6 +176,7 @@ func (stream *Stream) WriteUint64(val uint64) { stream.buf = writeBuf(stream.buf, digits[r3]) stream.buf = writeBuf(stream.buf, digits[r2]) stream.buf = writeBuf(stream.buf, digits[r1]) + stream.enforceMaxBytes() } // WriteInt64 write int64 to stream diff --git a/stream_str.go b/stream_str.go index 54c2ba0b..d4d3a122 100644 --- a/stream_str.go +++ b/stream_str.go @@ -233,6 +233,7 @@ func (stream *Stream) WriteStringWithHTMLEscaped(s string) { } if i == valLen { stream.buf = append(stream.buf, '"') + stream.enforceMaxBytes() return } writeStringSlowPathWithHTMLEscaped(stream, i, s, valLen) @@ -323,6 +324,7 @@ func (stream *Stream) WriteString(s string) { } if i == valLen { stream.buf = append(stream.buf, '"') + stream.enforceMaxBytes() return } writeStringSlowPath(stream, i, s, valLen) From 1880ee20f5377e17a1df50b71a68b8be47cb3965 Mon Sep 17 00:00:00 2001 From: Charles Korn Date: Mon, 15 Dec 2025 16:14:00 +1100 Subject: [PATCH 2/3] Panic instead of using an error --- config.go | 15 ++++++++++++++- stream.go | 9 ++++++--- 2 files changed, 20 insertions(+), 4 deletions(-) diff --git a/config.go b/config.go index 3f437ec5..ecd3e385 100644 --- a/config.go +++ b/config.go @@ -302,9 +302,22 @@ func (cfg *frozenConfig) MarshalToString(v interface{}) (string, error) { return string(stream.Buffer()), nil } -func (cfg *frozenConfig) Marshal(v interface{}) ([]byte, error) { +func (cfg *frozenConfig) Marshal(v interface{}) (_ []byte, err error) { stream := cfg.BorrowStream(nil) defer cfg.ReturnStream(stream) + + defer func() { + // See Stream.enforceMaxBytes() for an explanation of this. + if r := recover(); r != nil { + if limitError, ok := r.(exceededMaxMarshalledBytesError); ok { + err = limitError + return + } + + panic(r) + } + }() + stream.WriteVal(v) if stream.Error != nil { return nil, stream.Error diff --git a/stream.go b/stream.go index 9586194e..6c502d4e 100644 --- a/stream.go +++ b/stream.go @@ -114,15 +114,18 @@ func (stream *Stream) writeFiveBytes(c1 byte, c2 byte, c3 byte, c4 byte, c5 byte } func (stream *Stream) enforceMaxBytes() { - if !stream.enforceMarshalledBytesLimit || stream.Error != nil { + if !stream.enforceMarshalledBytesLimit { return } if uint64(len(stream.buf)) > stream.marshalledBytesLimitRemaining { // Why do we do this rather than return an error? // Most of the writing methods on Stream do not return an error, and introducing this would be a - // breaking change. - stream.Error = exceededMaxMarshalledBytesError{stream.cfg.maxMarshalledBytes} + // breaking change for custom encoders. + // Furthermore, nothing checks if the stream has failed until the object has been completely written + // so if we don't panic here, we'd continue writing the rest of the object, negating the purpose of + // this limit. + panic(exceededMaxMarshalledBytesError{stream.cfg.maxMarshalledBytes}) } } From 8df468a862476225352d17b022656a4bd871d205 Mon Sep 17 00:00:00 2001 From: Charles Korn Date: Mon, 15 Dec 2025 16:41:29 +1100 Subject: [PATCH 3/3] Expose error type --- config.go | 2 +- stream.go | 10 +++++----- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/config.go b/config.go index ecd3e385..1f2b43a3 100644 --- a/config.go +++ b/config.go @@ -309,7 +309,7 @@ func (cfg *frozenConfig) Marshal(v interface{}) (_ []byte, err error) { defer func() { // See Stream.enforceMaxBytes() for an explanation of this. if r := recover(); r != nil { - if limitError, ok := r.(exceededMaxMarshalledBytesError); ok { + if limitError, ok := r.(ExceededMaxMarshalledBytesError); ok { err = limitError return } diff --git a/stream.go b/stream.go index 6c502d4e..a07a5130 100644 --- a/stream.go +++ b/stream.go @@ -125,16 +125,16 @@ func (stream *Stream) enforceMaxBytes() { // Furthermore, nothing checks if the stream has failed until the object has been completely written // so if we don't panic here, we'd continue writing the rest of the object, negating the purpose of // this limit. - panic(exceededMaxMarshalledBytesError{stream.cfg.maxMarshalledBytes}) + panic(ExceededMaxMarshalledBytesError{stream.cfg.maxMarshalledBytes}) } } -type exceededMaxMarshalledBytesError struct { - maxBytes uint64 +type ExceededMaxMarshalledBytesError struct { + MaxMarshalledBytes uint64 } -func (err exceededMaxMarshalledBytesError) Error() string { - return fmt.Sprintf("marshalling produced a result over the configured limit of %d bytes", err.maxBytes) +func (err ExceededMaxMarshalledBytesError) Error() string { + return fmt.Sprintf("marshalling produced a result over the configured limit of %d bytes", err.MaxMarshalledBytes) } // Flush writes any buffered data to the underlying io.Writer.