diff --git a/netstringer.go b/netstringer.go index 40e5b76..5f449dc 100644 --- a/netstringer.go +++ b/netstringer.go @@ -7,41 +7,56 @@ import ( ) const ( - PARSE_LENGTH = iota - PARSE_SEPARATOR - PARSE_DATA - PARSE_END + parseLength = iota + parseBinLength + parseSeparator + parseData + parseEnd ) -const BUFFER_COUNT = 10 +const bufferCount = 10 //output buffered channels size + +type TextBinaryMsg struct { + Data []byte + TextLength int +} type NetStringDecoder struct { - parsedData []byte - length int - state int - DataOutput chan []byte - separatorSymbol, endSymbol byte - debugMode bool + parsedData []byte + length int + binLength int // extra property to support mixed text and binary messages + state int + DataOutput chan []byte + TextBinDataOutput chan TextBinaryMsg + separatorSymbol byte + separatorLengthSymbol byte + endSymbol byte + debugMode bool } // Caller receives the parsed parsedData through the output channel. func NewDecoder() NetStringDecoder { return NetStringDecoder{ - length: 0, - state: PARSE_LENGTH, - DataOutput: make(chan []byte, BUFFER_COUNT), - separatorSymbol: byte(':'), - endSymbol: byte(','), - debugMode: false, + length: 0, + state: parseLength, + DataOutput: make(chan []byte, bufferCount), + TextBinDataOutput: make(chan TextBinaryMsg, bufferCount), + separatorSymbol: byte(':'), + separatorLengthSymbol: byte(','), + endSymbol: byte(','), + debugMode: false, } +} +func (decoder *NetStringDecoder) SetEndSymbol(symbol byte) { + decoder.endSymbol = symbol } func (decoder *NetStringDecoder) SetDebugMode(mode bool) { decoder.debugMode = mode } -func (decoder NetStringDecoder) DebugLog(v ...interface{}) { +func (decoder NetStringDecoder) DebugLog(v ...any) { if decoder.debugMode { log.Println(v...) } @@ -49,12 +64,13 @@ func (decoder NetStringDecoder) DebugLog(v ...interface{}) { func (decoder *NetStringDecoder) reset() { decoder.length = 0 + decoder.binLength = 0 decoder.parsedData = []byte{} - decoder.state = PARSE_LENGTH + decoder.state = parseLength } func (decoder *NetStringDecoder) FeedData(data []byte) { - // New incoming parsedData packets are feeded into the decoder using this method. + // New incoming parsedData packets are fed into the decoder using this method. // Call this method every time we have a new set of parsedData. i := 0 for i < len(data) { @@ -64,13 +80,15 @@ func (decoder *NetStringDecoder) FeedData(data []byte) { func (decoder *NetStringDecoder) parse(i int, data []byte) int { switch decoder.state { - case PARSE_LENGTH: + case parseLength: i = decoder.parseLength(i, data) - case PARSE_SEPARATOR: + case parseBinLength: + i = decoder.parseBinLength(i, data) + case parseSeparator: i = decoder.parseSeparator(i, data) - case PARSE_DATA: + case parseData: i = decoder.parseData(i, data) - case PARSE_END: + case parseEnd: i = decoder.parseEnd(i, data) } return i @@ -83,23 +101,36 @@ func (decoder *NetStringDecoder) parseLength(i int, data []byte) int { decoder.length = (decoder.length * 10) + (int(symbol) - 48) i++ } else { - decoder.state = PARSE_SEPARATOR + decoder.state = parseSeparator } + return i +} +func (decoder *NetStringDecoder) parseBinLength(i int, data []byte) int { + symbol := data[i] + decoder.DebugLog("Parsing bin length, symbol =", string(symbol)) + if symbol >= '0' && symbol <= '9' { + decoder.binLength = (decoder.binLength * 10) + (int(symbol) - 48) + i++ + } else { + decoder.state = parseSeparator + } return i } func (decoder *NetStringDecoder) parseSeparator(i int, data []byte) int { decoder.DebugLog("Parsing separator, symbol =", string(data[i])) - if data[i] != decoder.separatorSymbol { - // Something is wrong with the parsedData. - // let's reset everything to start looking for next valid parsedData + switch data[i] { + case decoder.separatorSymbol: + decoder.length = decoder.length + decoder.binLength + decoder.state = parseData + case decoder.separatorLengthSymbol: + decoder.state = parseBinLength + default: + // Something is wrong with the parsedData. let's reset everything to start looking for next valid parsedData decoder.reset() - } else { - decoder.state = PARSE_DATA } - i++ - return i + return i + 1 } func (decoder *NetStringDecoder) parseData(i int, data []byte) int { @@ -109,7 +140,7 @@ func (decoder *NetStringDecoder) parseData(i int, data []byte) int { decoder.parsedData = append(decoder.parsedData, data[i:i+dataLength]...) decoder.length = decoder.length - dataLength if decoder.length == 0 { - decoder.state = PARSE_END + decoder.state = parseEnd } // We already parsed till i + dataLength return i + dataLength @@ -118,19 +149,24 @@ func (decoder *NetStringDecoder) parseData(i int, data []byte) int { func (decoder *NetStringDecoder) parseEnd(i int, data []byte) int { decoder.DebugLog("Parsing end.") symbol := data[i] + // Irrespective of what symbol we got we have to reset. + // Since we are looking for new data from now onwards. + defer decoder.reset() if symbol == decoder.endSymbol { // Symbol matches, that means this is valid data decoder.sendData(decoder.parsedData) + return i + 1 } - // Irrespective of what symbol we got we have to reset. - // Since we are looking for new data from now onwards. - decoder.reset() return i } func (decoder *NetStringDecoder) sendData(parsedData []byte) { decoder.DebugLog("Successfully parsed data: ", string(parsedData)) - decoder.DataOutput <- parsedData + if decoder.binLength == 0 { // netstring messages emits on DataOutput channel + decoder.DataOutput <- parsedData + } else { // text binary message emits on TextBinDataOutput channel + decoder.TextBinDataOutput <- TextBinaryMsg{parsedData, len(parsedData) - decoder.binLength} + } } func min(a, b int) int { diff --git a/netstringer_test.go b/netstringer_test.go index 74fb576..e7711d7 100644 --- a/netstringer_test.go +++ b/netstringer_test.go @@ -1,10 +1,12 @@ package netstringer import ( + "reflect" + "sync" "testing" ) -func TestNewNetStringDecoder(t *testing.T) { +func TestNetStringDecoder(t *testing.T) { decoder := NewDecoder() //decoder.SetDebugMode(true) @@ -27,19 +29,131 @@ func TestNewNetStringDecoder(t *testing.T) { "hello world!", } - go func(outputs []string, dataChannel chan []byte) { - for _, output := range outputs { - got := string(<-dataChannel) - if got != output { - t.Error("Got", got, "Expected", output) - } - } - }(expectedOutputs, decoder.DataOutput) + var wg sync.WaitGroup + wg.Add(1) + + // This will verify outputs in background as the decoder emits complete messages + verifyDataOutputsFromDecoder(t, &wg, expectedOutputs, decoder) + + for _, testInput := range testInputs { + decoder.FeedData([]byte(testInput)) + } + + close(decoder.DataOutput) + wg.Wait() +} + +func TestRipMsgDecoder(t *testing.T) { + decoder := NewDecoder() + decoder.SetEndSymbol(';') //Rip messages end with ; character + + testInputs := []string{ + "12:hello world!;", + "17:5:hello,6:world!,;", + "5:hello;6:world!;", + "12:How are you?;9:I am fine;12:this is cool;", + "12:hello", // Partial messages + " world!;", + } + expectedOutputs := []string{ + "hello world!", + "5:hello,6:world!,", + "hello", + "world!", + "How are you?", + "I am fine", + "this is cool", + "hello world!", + } + + var wg sync.WaitGroup + wg.Add(1) + + //this will verify outputs in background as the decoder emits complete messages + verifyDataOutputsFromDecoder(t, &wg, expectedOutputs, decoder) for _, testInput := range testInputs { decoder.FeedData([]byte(testInput)) } + close(decoder.DataOutput) + wg.Wait() +} + +func TestOgpTextBinaryMsgDecoder(t *testing.T) { + decoder := NewDecoder() + //decoder.SetDebugMode(true) + + strBytes := []byte{0x18, 0x2d, 0x44, 0x54} + str := string(strBytes) + testInputs := []string{ + "12:How are you?,", + "12,4:hello world!", + str, + ",12:hello", // Partial messages + " world!,", + } + expectedOutputs := []string{ + "How are you?", + "hello world!", + } + expectedTextBinOutputs := []TextBinaryMsg{ + {[]byte("hello world!" + str), 12}, + } + + var wg sync.WaitGroup + wg.Add(2) + + //verify decoded messages match expectedOutputs + verifyDataOutputsFromDecoder(t, &wg, expectedOutputs, decoder) + + //verify decoded text-binary output messages match textBinOutputs + go func(textBinOutputs []TextBinaryMsg, textBinDataChannel <-chan TextBinaryMsg) { + j := 0 + for { + select { + case msg := <-textBinDataChannel: + if j < len(textBinOutputs) { + if !reflect.DeepEqual(msg, textBinOutputs[j]) { + t.Error("Got", msg, "Expected", textBinOutputs[j]) + } + j++ + } else { + wg.Done() + return + } + } + } + }(expectedTextBinOutputs, decoder.TextBinDataOutput) + + for _, input := range testInputs { + decoder.FeedData([]byte(input)) + } + + close(decoder.DataOutput) + close(decoder.TextBinDataOutput) + wg.Wait() +} + +func verifyDataOutputsFromDecoder(t *testing.T, wg *sync.WaitGroup, expectedOutputs []string, decoder NetStringDecoder) { + go func(outputs []string, dataChannel <-chan []byte) { + i := 0 + for { + select { + case msg := <-dataChannel: + got := string(msg) + if i < len(outputs) { + if got != outputs[i] { + t.Error("Got", got, "Expected", outputs[i]) + } + i++ + } else { + wg.Done() + return + } + } + } + }(expectedOutputs, decoder.DataOutput) } func TestEncode(t *testing.T) { @@ -49,12 +163,12 @@ func TestEncode(t *testing.T) { } testCases := []TestCase{ - TestCase{Input: "hello world!", Expected: "12:hello world!,"}, - TestCase{Input: "5:hello,6:world!,", Expected: "17:5:hello,6:world!,,"}, - TestCase{Input: "hello", Expected: "5:hello,"}, - TestCase{Input: "world!", Expected: "6:world!,"}, - TestCase{Input: "How are you?", Expected: "12:How are you?,"}, - TestCase{Input: "I am fine", Expected: "9:I am fine,"}, + {Input: "hello world!", Expected: "12:hello world!,"}, + {Input: "5:hello,6:world!,", Expected: "17:5:hello,6:world!,,"}, + {Input: "hello", Expected: "5:hello,"}, + {Input: "world!", Expected: "6:world!,"}, + {Input: "How are you?", Expected: "12:How are you?,"}, + {Input: "I am fine", Expected: "9:I am fine,"}, } for _, testCase := range testCases { output := Encode([]byte(testCase.Input)) @@ -62,5 +176,4 @@ func TestEncode(t *testing.T) { t.Error("Got", string(output), "Expected", testCase.Expected) } } - }