Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@

import java.lang.reflect.Constructor;
import java.lang.reflect.InvocationTargetException;
import java.math.BigDecimal;
import java.math.BigInteger;
import java.nio.ByteBuffer;
import org.apache.avro.Schema;
import org.apache.avro.generic.GenericData;
Expand All @@ -29,6 +31,7 @@
import org.apache.parquet.io.api.Binary;
import org.apache.parquet.io.api.GroupConverter;
import org.apache.parquet.io.api.PrimitiveConverter;
import org.apache.parquet.schema.LogicalTypeAnnotation;
import org.apache.parquet.schema.PrimitiveStringifier;
import org.apache.parquet.schema.PrimitiveType;

Expand Down Expand Up @@ -339,4 +342,36 @@ public String convert(Binary binary) {
return stringifier.stringify(binary);
}
}

static final class FieldDecimalIntConverter extends AvroPrimitiveConverter {
private final int scale;

public FieldDecimalIntConverter(ParentValueContainer parent, PrimitiveType type) {
super(parent);
LogicalTypeAnnotation.DecimalLogicalTypeAnnotation decimalType =
(LogicalTypeAnnotation.DecimalLogicalTypeAnnotation) type.getLogicalTypeAnnotation();
this.scale = decimalType.getScale();
}

@Override
public void addInt(int value) {
parent.add(new BigDecimal(BigInteger.valueOf(value), scale));
}
}

static final class FieldDecimalLongConverter extends AvroPrimitiveConverter {
private final int scale;

public FieldDecimalLongConverter(ParentValueContainer parent, PrimitiveType type) {
super(parent);
LogicalTypeAnnotation.DecimalLogicalTypeAnnotation decimalType =
(LogicalTypeAnnotation.DecimalLogicalTypeAnnotation) type.getLogicalTypeAnnotation();
this.scale = decimalType.getScale();
}

@Override
public void addLong(long value) {
parent.add(new BigDecimal(BigInteger.valueOf(value), scale));
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -337,6 +337,14 @@ private static Converter newConverter(
return newConverter(schema, type, model, null, setter, validator);
}

private static boolean isDecimalType(Type type) {
if (!type.isPrimitive()) {
return false;
}
LogicalTypeAnnotation annotation = type.getLogicalTypeAnnotation();
return annotation instanceof LogicalTypeAnnotation.DecimalLogicalTypeAnnotation;
}

private static Converter newConverter(
Schema schema,
Type type,
Expand All @@ -359,6 +367,9 @@ private static Converter newConverter(
case BOOLEAN:
return new AvroConverters.FieldBooleanConverter(parent);
case INT:
if (isDecimalType(type)) {
return new AvroConverters.FieldDecimalIntConverter(parent, type.asPrimitiveType());
}
Class<?> intDatumClass = getDatumClass(conversion, knownClass, schema, model);
if (intDatumClass == null) {
return new AvroConverters.FieldIntegerConverter(parent);
Expand All @@ -374,6 +385,9 @@ private static Converter newConverter(
}
return new AvroConverters.FieldIntegerConverter(parent);
case LONG:
if (isDecimalType(type)) {
return new AvroConverters.FieldDecimalLongConverter(parent, type.asPrimitiveType());
}
return new AvroConverters.FieldLongConverter(parent);
case FLOAT:
return new AvroConverters.FieldFloatConverter(parent);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,9 @@
package org.apache.parquet.avro;

import static org.apache.parquet.avro.AvroTestUtil.optional;
import static org.apache.parquet.schema.PrimitiveType.PrimitiveTypeName.INT32;
import static org.apache.parquet.schema.PrimitiveType.PrimitiveTypeName.INT64;
import static org.apache.parquet.schema.Type.Repetition.REQUIRED;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNotNull;

Expand Down Expand Up @@ -61,17 +64,23 @@
import org.apache.parquet.conf.ParquetConfiguration;
import org.apache.parquet.conf.PlainParquetConfiguration;
import org.apache.parquet.example.data.Group;
import org.apache.parquet.example.data.GroupFactory;
import org.apache.parquet.example.data.simple.SimpleGroupFactory;
import org.apache.parquet.hadoop.ParquetReader;
import org.apache.parquet.hadoop.ParquetWriter;
import org.apache.parquet.hadoop.api.WriteSupport;
import org.apache.parquet.hadoop.example.ExampleParquetWriter;
import org.apache.parquet.hadoop.example.GroupReadSupport;
import org.apache.parquet.hadoop.util.HadoopCodecs;
import org.apache.parquet.io.InputFile;
import org.apache.parquet.io.LocalInputFile;
import org.apache.parquet.io.LocalOutputFile;
import org.apache.parquet.io.api.Binary;
import org.apache.parquet.io.api.RecordConsumer;
import org.apache.parquet.schema.LogicalTypeAnnotation;
import org.apache.parquet.schema.MessageType;
import org.apache.parquet.schema.MessageTypeParser;
import org.apache.parquet.schema.PrimitiveType;
import org.junit.Assert;
import org.junit.Rule;
import org.junit.Test;
Expand Down Expand Up @@ -400,6 +409,68 @@ public void testFixedDecimalValues() throws Exception {
Assert.assertEquals("Content should match", expected, records);
}

@Test
public void testDecimalIntegerValues() throws Exception {

File file = temp.newFile("test_decimal_integer_values.parquet");
file.delete();
Path path = new Path(file.toString());

MessageType parquetSchema = new MessageType(
"test_decimal_integer_values",
new PrimitiveType(REQUIRED, INT32, "decimal_age")
.withLogicalTypeAnnotation(LogicalTypeAnnotation.decimalType(2, 5)),
new PrimitiveType(REQUIRED, INT64, "decimal_salary")
.withLogicalTypeAnnotation(LogicalTypeAnnotation.decimalType(1, 10)));

try (ParquetWriter<Group> writer =
ExampleParquetWriter.builder(path).withType(parquetSchema).build()) {

GroupFactory factory = new SimpleGroupFactory(parquetSchema);

Group group1 = factory.newGroup();
group1.add("decimal_age", 2534);
group1.add("decimal_salary", 234L);
writer.write(group1);

Group group2 = factory.newGroup();
group2.add("decimal_age", 4267);
group2.add("decimal_salary", 1203L);
writer.write(group2);
}

GenericData decimalSupport = new GenericData();
decimalSupport.addLogicalTypeConversion(new Conversions.DecimalConversion());

List<GenericRecord> records = Lists.newArrayList();
try (ParquetReader<GenericRecord> reader = AvroParquetReader.<GenericRecord>builder(path)
.withDataModel(decimalSupport)
.build()) {
GenericRecord rec;
while ((rec = reader.read()) != null) {
records.add(rec);
}
}

Assert.assertEquals("Should read 2 records", 2, records.size());

// INT32 values
Object firstAge = records.get(0).get("decimal_age");
Object secondAge = records.get(1).get("decimal_age");

Assert.assertTrue("Should be BigDecimal, but is " + firstAge.getClass(), firstAge instanceof BigDecimal);
Assert.assertEquals("Should be 25.34, but is " + firstAge, new BigDecimal("25.34"), firstAge);
Assert.assertEquals("Should be 42.67, but is " + secondAge, new BigDecimal("42.67"), secondAge);

// INT64 values
Object firstSalary = records.get(0).get("decimal_salary");
Object secondSalary = records.get(1).get("decimal_salary");

Assert.assertTrue("Should be BigDecimal, but is " + firstSalary.getClass(), firstSalary instanceof BigDecimal);
Assert.assertEquals("Should be 23.4, but is " + firstSalary, new BigDecimal("23.4"), firstSalary);
Assert.assertEquals("Should be 120.3, but is " + secondSalary, new BigDecimal("120.3"), secondSalary);
}

@Test
public void testAll() throws Exception {
Schema schema =
Expand Down