diff --git a/activemq-client/src/main/java/org/apache/activemq/command/WireFormatInfo.java b/activemq-client/src/main/java/org/apache/activemq/command/WireFormatInfo.java index f1b4a3650ba..6c7ae8e4afc 100644 --- a/activemq-client/src/main/java/org/apache/activemq/command/WireFormatInfo.java +++ b/activemq-client/src/main/java/org/apache/activemq/command/WireFormatInfo.java @@ -39,10 +39,15 @@ public class WireFormatInfo implements Command, MarshallAware { public static final byte DATA_STRUCTURE_TYPE = CommandTypes.WIREFORMAT_INFO; - private static final int MAX_PROPERTY_SIZE = 1024 * 4; - private static final byte MAGIC[] = new byte[] {'A', 'c', 't', 'i', 'v', 'e', 'M', 'Q'}; - - protected byte magic[] = MAGIC; + // Max number of properties allowed in the map is 64 + static final int MAX_PROPERTY_SIZE = 64; + // Used to validate property values that allocate buffers, limit to 512 bytes + static final int MAX_PROPERTY_BUFFER_SIZE = 512; + // Do not allow any nested collections in properties + static final int MAX_PROPERTY_DEPTH = 0; + private static final byte[] MAGIC = new byte[] {'A', 'c', 't', 'i', 'v', 'e', 'M', 'Q'}; + + protected byte[] magic = MAGIC; protected int version; protected ByteSequence marshalledProperties; @@ -136,7 +141,7 @@ public Object getProperty(String name) throws IOException { if (marshalledProperties == null) { return null; } - properties = unmarsallProperties(marshalledProperties); + properties = unmarshalProperties(marshalledProperties); } return properties.get(name); } @@ -147,7 +152,7 @@ public Map getProperties() throws IOException { if (marshalledProperties == null) { return Collections.EMPTY_MAP; } - properties = unmarsallProperties(marshalledProperties); + properties = unmarshalProperties(marshalledProperties); } return Collections.unmodifiableMap(properties); } @@ -167,14 +172,15 @@ protected void lazyCreateProperties() throws IOException { if (marshalledProperties == null) { properties = new HashMap(); } else { - properties = unmarsallProperties(marshalledProperties); + properties = unmarshalProperties(marshalledProperties); marshalledProperties = null; } } } - private Map unmarsallProperties(ByteSequence marshalledProperties) throws IOException { - return MarshallingSupport.unmarshalPrimitiveMap(new DataInputStream(new ByteArrayInputStream(marshalledProperties)), MAX_PROPERTY_SIZE); + private Map unmarshalProperties(ByteSequence marshalledProperties) throws IOException { + return MarshallingSupport.unmarshalPrimitiveMap(new DataInputStream(new ByteArrayInputStream(marshalledProperties)), + MAX_PROPERTY_SIZE, MAX_PROPERTY_BUFFER_SIZE, MAX_PROPERTY_DEPTH); } @Override diff --git a/activemq-client/src/main/java/org/apache/activemq/util/MarshallingSupport.java b/activemq-client/src/main/java/org/apache/activemq/util/MarshallingSupport.java index 54bd8e4e7c8..bf61f307c36 100644 --- a/activemq-client/src/main/java/org/apache/activemq/util/MarshallingSupport.java +++ b/activemq-client/src/main/java/org/apache/activemq/util/MarshallingSupport.java @@ -67,15 +67,11 @@ public static void marshalPrimitiveMap(Map map, DataOutputStream } public static Map unmarshalPrimitiveMap(DataInputStream in) throws IOException { - return unmarshalPrimitiveMap(in, Integer.MAX_VALUE); + return unmarshalPrimitiveMap(in, Integer.MAX_VALUE, Integer.MAX_VALUE, Integer.MAX_VALUE); } - public static Map unmarshalPrimitiveMap(DataInputStream in, boolean force) throws IOException { - return unmarshalPrimitiveMap(in, Integer.MAX_VALUE, force); - } - - public static Map unmarshalPrimitiveMap(DataInputStream in, int maxPropertySize) throws IOException { - return unmarshalPrimitiveMap(in, maxPropertySize, false); + public static Map unmarshalPrimitiveMap(DataInputStream in, int maxPropertySize, int maxBufferSize, int maxDepth) throws IOException { + return unmarshalPrimitiveMap(in, false, maxPropertySize, maxBufferSize, maxDepth, 0); } /** @@ -84,7 +80,11 @@ public static Map unmarshalPrimitiveMap(DataInputStream in, int * @throws IOException * @throws IOException */ - public static Map unmarshalPrimitiveMap(DataInputStream in, int maxPropertySize, boolean force) throws IOException { + private static Map unmarshalPrimitiveMap(DataInputStream in, boolean force, int maxPropertySize, int maxBufferSize, + int maxDepth, int currentDepth) throws IOException { + // increment after validation, so future calls get the incremented depth + validateDepth(maxDepth, currentDepth++); + int size = in.readInt(); if (size > maxPropertySize) { throw new IOException("Primitive map is larger than the allowed size: " + size); @@ -92,10 +92,11 @@ public static Map unmarshalPrimitiveMap(DataInputStream in, int if (size < 0) { return null; } else { - Map rc = new HashMap(size); + // Size here was already validated above + Map rc = new HashMap<>(size); for (int i = 0; i < size; i++) { String name = in.readUTF(); - rc.put(name, unmarshalPrimitive(in, force)); + rc.put(name, unmarshalPrimitive(in, force, maxPropertySize, maxBufferSize, maxDepth, currentDepth)); } return rc; } @@ -108,15 +109,20 @@ public static void marshalPrimitiveList(List list, DataOutputStream out) } } - public static List unmarshalPrimitiveList(DataInputStream in) throws IOException { - return unmarshalPrimitiveList(in, false); + public static List unmarshalPrimitiveList(DataInputStream in, int maxPropertySize, int maxBufferSize, + int maxDepth) throws IOException { + return unmarshalPrimitiveList(in, false, maxPropertySize, maxBufferSize, maxDepth, 0); } - public static List unmarshalPrimitiveList(DataInputStream in, boolean force) throws IOException { - int size = in.readInt(); - List answer = new ArrayList(size); + private static List unmarshalPrimitiveList(DataInputStream in, boolean force, int maxPropertySize, int maxBufferSize, + int maxDepth, int currentDepth) throws IOException { + // increment after validation, so future calls get the incremented depth + validateDepth(maxDepth, currentDepth++); + + int size = validateBufferSize(maxBufferSize, in.readInt()); + List answer = new ArrayList<>(size); while (size-- > 0) { - answer.add(unmarshalPrimitive(in, force)); + answer.add(unmarshalPrimitive(in, force, maxPropertySize, maxBufferSize, maxDepth, currentDepth)); } return answer; } @@ -158,10 +164,11 @@ public static void marshalPrimitive(DataOutputStream out, Object value) throws I } public static Object unmarshalPrimitive(DataInputStream in) throws IOException { - return unmarshalPrimitive(in, false); + return unmarshalPrimitive(in, false, Integer.MAX_VALUE, Integer.MAX_VALUE, Integer.MAX_VALUE, 0); } - public static Object unmarshalPrimitive(DataInputStream in, boolean force) throws IOException { + private static Object unmarshalPrimitive(DataInputStream in, boolean force, int maxPropertySize, int maxBufferSize, + int maxDepth, int currentDepth) throws IOException { Object value = null; byte type = in.readByte(); switch (type) { @@ -190,29 +197,29 @@ public static Object unmarshalPrimitive(DataInputStream in, boolean force) throw value = in.readDouble(); break; case BYTE_ARRAY_TYPE: - value = new byte[in.readInt()]; + value = new byte[validateBufferSize(maxBufferSize, in.readInt())]; in.readFully((byte[])value); break; case STRING_TYPE: if (force) { value = in.readUTF(); } else { - value = readUTF(in, in.readUnsignedShort()); + value = readUTF(in, maxBufferSize, in.readUnsignedShort()); } break; case BIG_STRING_TYPE: { if (force) { - value = readUTF8(in); + value = readUTF8(in, maxBufferSize); } else { - value = readUTF(in, in.readInt()); + value = readUTF(in, maxBufferSize, in.readInt()); } break; } case MAP_TYPE: - value = unmarshalPrimitiveMap(in, true); + value = unmarshalPrimitiveMap(in, true, maxPropertySize, maxBufferSize, maxDepth, currentDepth); break; case LIST_TYPE: - value = unmarshalPrimitiveList(in, true); + value = unmarshalPrimitiveList(in, true, maxPropertySize, maxBufferSize, maxDepth, currentDepth); break; case NULL: value = null; @@ -223,8 +230,9 @@ public static Object unmarshalPrimitive(DataInputStream in, boolean force) throw return value; } - public static UTF8Buffer readUTF(DataInputStream in, int length) throws IOException { - byte data[] = new byte[length]; + public static UTF8Buffer readUTF(DataInputStream in, int maxLength, int length) throws IOException { + validateBufferSize(maxLength, length); + byte[] data = new byte[length]; in.readFully(data); return new UTF8Buffer(data); } @@ -350,7 +358,11 @@ public static int writeUTFBytesToBuffer(String str, long count, } public static String readUTF8(DataInput dataIn) throws IOException { - int utflen = dataIn.readInt(); + return readUTF8(dataIn, Integer.MAX_VALUE); + } + + public static String readUTF8(DataInput dataIn, int maxBufferSize) throws IOException { + int utflen = validateBufferSize(maxBufferSize, dataIn.readInt()); if (utflen > -1) { byte bytearr[] = new byte[utflen]; char chararr[] = new char[utflen]; @@ -419,4 +431,17 @@ public static String truncate64(String text) { } return text; } + + private static int validateBufferSize(int maxSize, int size) throws IOException { + if (size > maxSize) { + throw new IOException("Max buffer size: " + maxSize + " exceeded, size: " + size); + } + return size; + } + + private static void validateDepth(int maxDepth, int currentDepth) throws IOException { + if (currentDepth > maxDepth) { + throw new IOException("Max unmarshaling depth: " + maxDepth + " exceeded, depth: " + currentDepth); + } + } } diff --git a/activemq-client/src/test/java/org/apache/activemq/command/WireFormatInfoTest.java b/activemq-client/src/test/java/org/apache/activemq/command/WireFormatInfoTest.java new file mode 100644 index 00000000000..88571a247a5 --- /dev/null +++ b/activemq-client/src/test/java/org/apache/activemq/command/WireFormatInfoTest.java @@ -0,0 +1,180 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.activemq.command; + +import static org.junit.Assert.assertArrayEquals; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertNull; +import static org.junit.Assert.assertTrue; +import static org.junit.Assert.fail; + +import java.io.IOException; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.stream.Collectors; +import org.apache.activemq.openwire.OpenWireFormat; +import org.apache.activemq.util.ByteSequence; +import org.junit.Test; + +public class WireFormatInfoTest { + + private final OpenWireFormat wireFormat = new OpenWireFormat(); + + @Test + public void testWireFormatLimits() throws Exception { + WireFormatInfo wfi = new WireFormatInfo(); + byte[] bytes = "a".repeat(WireFormatInfo.MAX_PROPERTY_BUFFER_SIZE).getBytes(); + + // We allow up to 64 properties with 512 bytes + for (int i = 0; i < WireFormatInfo.MAX_PROPERTY_SIZE; i++) { + wfi.setProperty("prop" + i, bytes); + } + + // Marshal/unmarshal so we have only the marshaled properties stored + wfi = marshal(wfi); + assertMarshaled(wfi); + + // Trigger unmarshal and verify it works + assertEquals(WireFormatInfo.MAX_PROPERTY_SIZE, wfi.getProperties().size()); + assertArrayEquals(bytes, (byte[])wfi.getProperties().get("prop0")); + } + + @Test + public void testWireFormatInfoTooManyProperties() throws Exception { + WireFormatInfo wfi = new WireFormatInfo(); + + // We allow up to 64 properties, so add too many + for (int i = 0; i < WireFormatInfo.MAX_PROPERTY_SIZE + 1; i++) { + wfi.setProperty("prop" + i, "test"); + } + + // Marshal/unmarshal so we have only the marshaled properties stored + wfi = marshal(wfi); + assertMarshaled(wfi); + + try { + // Trigger unmarshaling the properties. This should throw an error because + // We have too many properties in the map + wfi.getProperties(); + fail("should have exception"); + } catch (IOException e) { + assertTrue(e.getMessage().contains("map is larger than the allowed size")); + } + } + + @Test + public void testWireFormatInfoTooLargeBytes() throws Exception { + testWireFormatInfoValidation("a".repeat(WireFormatInfo.MAX_PROPERTY_BUFFER_SIZE + 1).getBytes(), + "Max buffer size: 512 exceeded, size: 513"); + } + + @Test + public void testWireFormatInfoTooLargeString() throws Exception { + testWireFormatInfoValidation("a".repeat(WireFormatInfo.MAX_PROPERTY_BUFFER_SIZE + 1), + "Max buffer size: 512 exceeded, size: 513"); + } + + @Test + public void testWireFormatInfoTooLargeList() throws Exception { + // buffer is too large but we also don't allow lists + testWireFormatInfoValidation("a".repeat(WireFormatInfo.MAX_PROPERTY_BUFFER_SIZE + 1).chars() + .mapToObj(c -> (char) c) + .collect(Collectors.toList()), "Max unmarshaling depth: 0 exceeded"); + } + + // Make sure nested collections not allowed + @Test + public void testWireFormatInfoNestedList() throws Exception { + // we don't allow any lists at all + testWireFormatInfoValidation(List.of("blocked"), "Max unmarshaling depth: 0 exceeded"); + + // also check nested + List> list = List.of(List.of("blocked")); + testWireFormatInfoValidation(list, "Max unmarshaling depth: 0 exceeded"); + } + + @Test + public void testWireFormatInfoTooLargeMap() throws Exception { + Map map = new HashMap<>(); + for (int i = 0; i < WireFormatInfo.MAX_PROPERTY_BUFFER_SIZE + 1; i++) { + map.put("prop" + i, true); + } + // too large of a map but we also block maps in general + testWireFormatInfoValidation(map, "Max unmarshaling depth: 0 exceeded"); + } + + @Test + public void testWireFormatInfoNestedMap() throws Exception { + Map> map = Map.of("blocked", Map.of("a", "b")); + testWireFormatInfoValidation(map, "Max unmarshaling depth: 0 exceeded"); + } + + @Test + public void testWireFormatInfoNestedMapOfList() throws Exception { + Map> map = Map.of("blocked", List.of("test")); + testWireFormatInfoValidation(map, "Max unmarshaling depth: 0 exceeded"); + } + + @Test + public void testWireFormatInfoNestedListOfMap() throws Exception { + List> list = List.of(Map.of("blocked", "blocked")); + testWireFormatInfoValidation(list, "Max unmarshaling depth: 0 exceeded"); + } + + // test for MarshallingSupport.BIG_STRING_TYPE + // Anythin over Short.MAX_VALUE / 4 will be a big string type + @Test + public void testWireFormatInfoTooLargeBigString() throws Exception { + testWireFormatInfoValidation("a".repeat((Short.MAX_VALUE / 4) + 1), + "Max buffer size: 512 exceeded, size: 8192"); + } + + private void testWireFormatInfoValidation(Object buffer, String error) throws Exception { + WireFormatInfo wfi = new WireFormatInfo(); + + // We allow up to 1024 buffer size, add an extra byte + wfi.setProperty("prop", buffer); + + // Marshal/unmarshal so we have only the marshaled properties stored + wfi = marshal(wfi); + assertMarshaled(wfi); + + try { + // Trigger unmarshaling the properties. This should throw an error because + // the buffer is too large + wfi.getProperties(); + fail("should have exception"); + } catch (IOException e) { + assertTrue(e.getMessage().contains(error)); + } + } + + // Marshal/unmarshal so we have only the marshaled properties stored + private WireFormatInfo marshal(WireFormatInfo wfi) throws IOException { + ByteSequence s = wireFormat.marshal(wfi); + return (WireFormatInfo) wireFormat.unmarshal(s); + } + + private static void assertMarshaled(WireFormatInfo wfi) { + assertNull(wfi.properties); + assertNotNull(wfi.marshalledProperties); + } + +} diff --git a/activemq-client/src/test/java/org/apache/activemq/util/MarshallingSupportTest.java b/activemq-client/src/test/java/org/apache/activemq/util/MarshallingSupportTest.java index 03cad553b0c..59a5e602d61 100644 --- a/activemq-client/src/test/java/org/apache/activemq/util/MarshallingSupportTest.java +++ b/activemq-client/src/test/java/org/apache/activemq/util/MarshallingSupportTest.java @@ -16,8 +16,20 @@ */ package org.apache.activemq.util; +import static org.apache.activemq.util.MarshallingSupport.unmarshalPrimitiveList; +import static org.apache.activemq.util.MarshallingSupport.unmarshalPrimitiveMap; import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertTrue; +import static org.junit.Assert.fail; +import java.io.DataInputStream; +import java.io.DataOutputStream; +import java.io.IOException; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; import java.util.Properties; import org.junit.Test; @@ -42,4 +54,87 @@ public void testPropertiesToString() throws Exception { Properties props2 = MarshallingSupport.stringToProperties(str); assertEquals(props, props2); } + + @Test + public void testUnmarshalPrimitiveMap() throws Exception { + ByteArrayOutputStream bytesOut = new ByteArrayOutputStream(); + DataOutputStream dataOut = new DataOutputStream(bytesOut); + + Map testMap = new HashMap<>(); + testMap.put("key1", "value1"); + // nested map with nested collection, should require depth of 2 being allowed + testMap.put("key2", Map.of("string","string", "nested", + List.of("one"))); + + MarshallingSupport.marshalPrimitiveMap(testMap, dataOut); + DataInputStream dataIn = new DataInputStream(new ByteArrayInputStream(bytesOut.toByteArray())); + + //allowed + assertEquals(2, unmarshalPrimitiveMap(dataIn, 100, 1024, 10).size()); + + // too many properties + dataIn.reset(); + assertException(() -> unmarshalPrimitiveMap(dataIn, 1, 1024, 10), + "map is larger than the allowed size"); + + // buffers too large + dataIn.reset(); + assertException(() -> unmarshalPrimitiveMap(dataIn, 100, 2, 10), + "Max buffer size: 2 exceeded, size: 6"); + + // max depth violated + dataIn.reset(); + assertException(() -> unmarshalPrimitiveMap(dataIn, 100, 1024, 1), + "Max unmarshaling depth: 1 exceeded, depth: 2"); + } + + @Test + public void testUnmarshalPrimitiveList() throws Exception { + ByteArrayOutputStream bytesOut = new ByteArrayOutputStream(); + DataOutputStream dataOut = new DataOutputStream(bytesOut); + + List testList = new ArrayList<>(); + testList.add("value"); + // nested list with nested collection, should require depth of 2 being allowed + testList.add(List.of(List.of("one"))); + testList.add(Map.of("one","two","three","four")); + + MarshallingSupport.marshalPrimitiveList(testList, dataOut); + DataInputStream dataIn = new DataInputStream(new ByteArrayInputStream(bytesOut.toByteArray())); + + //allowed + assertEquals(3, unmarshalPrimitiveList(dataIn, 100, 1024, 2).size()); + + // Max property size applies to maps, so this will check the nested map in the list + // and should fail because the map has 2 properties + dataIn.reset(); + assertException(() -> unmarshalPrimitiveList(dataIn, 1, 1024, 2), + "map is larger than the allowed size"); + + // buffers too large + dataIn.reset(); + assertException(() -> unmarshalPrimitiveList(dataIn, 100, 2, 2), + "Max buffer size: 2 exceeded, size: 3"); + + // max depth violated + dataIn.reset(); + assertException(() -> unmarshalPrimitiveList(dataIn, 100, 1024, 1), + "Max unmarshaling depth: 1 exceeded, depth: 2"); + } + + private void assertException(ThrowableRunnable runnable, String error) throws Exception { + try { + // Trigger unmarshaling the properties. This should throw an error because + // the buffer is too large + runnable.run(); + fail("should have IO exception"); + } catch (IOException e) { + System.out.println(e.getMessage()); + assertTrue(e.getMessage().contains(error)); + } + } + + private interface ThrowableRunnable { + void run() throws T; + } }