diff --git a/parquet-avro/src/main/java/org/apache/parquet/avro/AvroConverters.java b/parquet-avro/src/main/java/org/apache/parquet/avro/AvroConverters.java index f3e8c21483..e34cc9b0b2 100644 --- a/parquet-avro/src/main/java/org/apache/parquet/avro/AvroConverters.java +++ b/parquet-avro/src/main/java/org/apache/parquet/avro/AvroConverters.java @@ -20,6 +20,8 @@ import java.lang.reflect.Constructor; import java.lang.reflect.InvocationTargetException; +import java.math.BigDecimal; +import java.math.BigInteger; import java.nio.ByteBuffer; import org.apache.avro.Schema; import org.apache.avro.generic.GenericData; @@ -29,6 +31,7 @@ import org.apache.parquet.io.api.Binary; import org.apache.parquet.io.api.GroupConverter; import org.apache.parquet.io.api.PrimitiveConverter; +import org.apache.parquet.schema.LogicalTypeAnnotation; import org.apache.parquet.schema.PrimitiveStringifier; import org.apache.parquet.schema.PrimitiveType; @@ -339,4 +342,36 @@ public String convert(Binary binary) { return stringifier.stringify(binary); } } + + static final class FieldDecimalIntConverter extends AvroPrimitiveConverter { + private final int scale; + + public FieldDecimalIntConverter(ParentValueContainer parent, PrimitiveType type) { + super(parent); + LogicalTypeAnnotation.DecimalLogicalTypeAnnotation decimalType = + (LogicalTypeAnnotation.DecimalLogicalTypeAnnotation) type.getLogicalTypeAnnotation(); + this.scale = decimalType.getScale(); + } + + @Override + public void addInt(int value) { + parent.add(new BigDecimal(BigInteger.valueOf(value), scale)); + } + } + + static final class FieldDecimalLongConverter extends AvroPrimitiveConverter { + private final int scale; + + public FieldDecimalLongConverter(ParentValueContainer parent, PrimitiveType type) { + super(parent); + LogicalTypeAnnotation.DecimalLogicalTypeAnnotation decimalType = + (LogicalTypeAnnotation.DecimalLogicalTypeAnnotation) type.getLogicalTypeAnnotation(); + this.scale = decimalType.getScale(); + } + + @Override + public void addLong(long value) { + parent.add(new BigDecimal(BigInteger.valueOf(value), scale)); + } + } } diff --git a/parquet-avro/src/main/java/org/apache/parquet/avro/AvroRecordConverter.java b/parquet-avro/src/main/java/org/apache/parquet/avro/AvroRecordConverter.java index 340dc77220..66ffe64f6b 100644 --- a/parquet-avro/src/main/java/org/apache/parquet/avro/AvroRecordConverter.java +++ b/parquet-avro/src/main/java/org/apache/parquet/avro/AvroRecordConverter.java @@ -337,6 +337,14 @@ private static Converter newConverter( return newConverter(schema, type, model, null, setter, validator); } + private static boolean isDecimalType(Type type) { + if (!type.isPrimitive()) { + return false; + } + LogicalTypeAnnotation annotation = type.getLogicalTypeAnnotation(); + return annotation instanceof LogicalTypeAnnotation.DecimalLogicalTypeAnnotation; + } + private static Converter newConverter( Schema schema, Type type, @@ -359,6 +367,9 @@ private static Converter newConverter( case BOOLEAN: return new AvroConverters.FieldBooleanConverter(parent); case INT: + if (isDecimalType(type)) { + return new AvroConverters.FieldDecimalIntConverter(parent, type.asPrimitiveType()); + } Class intDatumClass = getDatumClass(conversion, knownClass, schema, model); if (intDatumClass == null) { return new AvroConverters.FieldIntegerConverter(parent); @@ -374,6 +385,9 @@ private static Converter newConverter( } return new AvroConverters.FieldIntegerConverter(parent); case LONG: + if (isDecimalType(type)) { + return new AvroConverters.FieldDecimalLongConverter(parent, type.asPrimitiveType()); + } return new AvroConverters.FieldLongConverter(parent); case FLOAT: return new AvroConverters.FieldFloatConverter(parent); diff --git a/parquet-avro/src/test/java/org/apache/parquet/avro/TestReadWrite.java b/parquet-avro/src/test/java/org/apache/parquet/avro/TestReadWrite.java index a8cb1214ac..4fb5b72b44 100644 --- a/parquet-avro/src/test/java/org/apache/parquet/avro/TestReadWrite.java +++ b/parquet-avro/src/test/java/org/apache/parquet/avro/TestReadWrite.java @@ -19,6 +19,9 @@ package org.apache.parquet.avro; import static org.apache.parquet.avro.AvroTestUtil.optional; +import static org.apache.parquet.schema.PrimitiveType.PrimitiveTypeName.INT32; +import static org.apache.parquet.schema.PrimitiveType.PrimitiveTypeName.INT64; +import static org.apache.parquet.schema.Type.Repetition.REQUIRED; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertNotNull; @@ -61,9 +64,12 @@ import org.apache.parquet.conf.ParquetConfiguration; import org.apache.parquet.conf.PlainParquetConfiguration; import org.apache.parquet.example.data.Group; +import org.apache.parquet.example.data.GroupFactory; +import org.apache.parquet.example.data.simple.SimpleGroupFactory; import org.apache.parquet.hadoop.ParquetReader; import org.apache.parquet.hadoop.ParquetWriter; import org.apache.parquet.hadoop.api.WriteSupport; +import org.apache.parquet.hadoop.example.ExampleParquetWriter; import org.apache.parquet.hadoop.example.GroupReadSupport; import org.apache.parquet.hadoop.util.HadoopCodecs; import org.apache.parquet.io.InputFile; @@ -71,7 +77,10 @@ import org.apache.parquet.io.LocalOutputFile; import org.apache.parquet.io.api.Binary; import org.apache.parquet.io.api.RecordConsumer; +import org.apache.parquet.schema.LogicalTypeAnnotation; +import org.apache.parquet.schema.MessageType; import org.apache.parquet.schema.MessageTypeParser; +import org.apache.parquet.schema.PrimitiveType; import org.junit.Assert; import org.junit.Rule; import org.junit.Test; @@ -400,6 +409,68 @@ public void testFixedDecimalValues() throws Exception { Assert.assertEquals("Content should match", expected, records); } + @Test + public void testDecimalIntegerValues() throws Exception { + + File file = temp.newFile("test_decimal_integer_values.parquet"); + file.delete(); + Path path = new Path(file.toString()); + + MessageType parquetSchema = new MessageType( + "test_decimal_integer_values", + new PrimitiveType(REQUIRED, INT32, "decimal_age") + .withLogicalTypeAnnotation(LogicalTypeAnnotation.decimalType(2, 5)), + new PrimitiveType(REQUIRED, INT64, "decimal_salary") + .withLogicalTypeAnnotation(LogicalTypeAnnotation.decimalType(1, 10))); + + try (ParquetWriter writer = + ExampleParquetWriter.builder(path).withType(parquetSchema).build()) { + + GroupFactory factory = new SimpleGroupFactory(parquetSchema); + + Group group1 = factory.newGroup(); + group1.add("decimal_age", 2534); + group1.add("decimal_salary", 234L); + writer.write(group1); + + Group group2 = factory.newGroup(); + group2.add("decimal_age", 4267); + group2.add("decimal_salary", 1203L); + writer.write(group2); + } + + GenericData decimalSupport = new GenericData(); + decimalSupport.addLogicalTypeConversion(new Conversions.DecimalConversion()); + + List records = Lists.newArrayList(); + try (ParquetReader reader = AvroParquetReader.builder(path) + .withDataModel(decimalSupport) + .build()) { + GenericRecord rec; + while ((rec = reader.read()) != null) { + records.add(rec); + } + } + + Assert.assertEquals("Should read 2 records", 2, records.size()); + + // INT32 values + Object firstAge = records.get(0).get("decimal_age"); + Object secondAge = records.get(1).get("decimal_age"); + + Assert.assertTrue("Should be BigDecimal, but is " + firstAge.getClass(), firstAge instanceof BigDecimal); + Assert.assertEquals("Should be 25.34, but is " + firstAge, new BigDecimal("25.34"), firstAge); + Assert.assertEquals("Should be 42.67, but is " + secondAge, new BigDecimal("42.67"), secondAge); + + // INT64 values + Object firstSalary = records.get(0).get("decimal_salary"); + Object secondSalary = records.get(1).get("decimal_salary"); + + Assert.assertTrue("Should be BigDecimal, but is " + firstSalary.getClass(), firstSalary instanceof BigDecimal); + Assert.assertEquals("Should be 23.4, but is " + firstSalary, new BigDecimal("23.4"), firstSalary); + Assert.assertEquals("Should be 120.3, but is " + secondSalary, new BigDecimal("120.3"), secondSalary); + } + @Test public void testAll() throws Exception { Schema schema =