Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 56 additions & 23 deletions datafusion/proto-common/src/from_proto/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,12 @@ use arrow::datatypes::{
DataType, Field, IntervalDayTimeType, IntervalMonthDayNanoType, IntervalUnit, Schema,
TimeUnit, UnionFields, UnionMode, i256,
};
use arrow::ipc::{reader::read_record_batch, root_as_message};
use arrow::ipc::{
convert::fb_to_schema,
reader::{read_dictionary, read_record_batch},
root_as_message,
writer::{DictionaryTracker, IpcDataGenerator, IpcWriteOptions},
};

use datafusion_common::{
Column, ColumnStatistics, Constraint, Constraints, DFSchema, DFSchemaRef,
Expand Down Expand Up @@ -397,7 +402,7 @@ impl TryFrom<&protobuf::ScalarValue> for ScalarValue {
Value::Float32Value(v) => Self::Float32(Some(*v)),
Value::Float64Value(v) => Self::Float64(Some(*v)),
Value::Date32Value(v) => Self::Date32(Some(*v)),
// ScalarValue::List is serialized using arrow IPC format
// Nested ScalarValue types are serialized using arrow IPC format
Value::ListValue(v)
| Value::FixedSizeListValue(v)
| Value::LargeListValue(v)
Expand All @@ -414,55 +419,83 @@ impl TryFrom<&protobuf::ScalarValue> for ScalarValue {
schema_ref.try_into()?
} else {
return Err(Error::General(
"Invalid schema while deserializing ScalarValue::List"
"Invalid schema while deserializing nested ScalarValue"
.to_string(),
));
};

// IPC dictionary batch IDs are assigned when encoding the schema, but our protobuf
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I feel like somehow I missing something key -- this seems like a pretty massive overhead to create / decode a single value here.

That being said, it seems like the existing code is also doing the overhead, so maybe it is fine for now 🤔

I wonder if we could pull this logic for creating the schema into its own function to try and reduce the size of the overall method / make it easier to understand

// `Schema` doesn't preserve those IDs. Reconstruct them deterministically by
// round-tripping the schema through IPC.
let schema: Schema = {
let ipc_gen = IpcDataGenerator {};
let write_options = IpcWriteOptions::default();
let mut dict_tracker = DictionaryTracker::new(false);
let encoded_schema = ipc_gen.schema_to_bytes_with_dictionary_tracker(
&schema,
&mut dict_tracker,
&write_options,
);
let message =
root_as_message(encoded_schema.ipc_message.as_slice()).map_err(
|e| {
Error::General(format!(
"Error IPC schema message while deserializing nested ScalarValue: {e}"
))
},
)?;
let ipc_schema = message.header_as_schema().ok_or_else(|| {
Error::General(
"Unexpected message type deserializing nested ScalarValue schema"
.to_string(),
)
})?;
fb_to_schema(ipc_schema)
};

let message = root_as_message(ipc_message.as_slice()).map_err(|e| {
Error::General(format!(
"Error IPC message while deserializing ScalarValue::List: {e}"
"Error IPC message while deserializing nested ScalarValue: {e}"
))
})?;
let buffer = Buffer::from(arrow_data.as_slice());

let ipc_batch = message.header_as_record_batch().ok_or_else(|| {
Error::General(
"Unexpected message type deserializing ScalarValue::List"
"Unexpected message type deserializing nested ScalarValue"
.to_string(),
)
})?;

let dict_by_id: HashMap<i64,ArrayRef> = dictionaries.iter().map(|protobuf::scalar_nested_value::Dictionary { ipc_message, arrow_data }| {
let mut dict_by_id: HashMap<i64, ArrayRef> = HashMap::new();
for protobuf::scalar_nested_value::Dictionary {
ipc_message,
arrow_data,
} in dictionaries
{
let message = root_as_message(ipc_message.as_slice()).map_err(|e| {
Error::General(format!(
"Error IPC message while deserializing ScalarValue::List dictionary message: {e}"
"Error IPC message while deserializing nested ScalarValue dictionary message: {e}"
))
})?;
let buffer = Buffer::from(arrow_data.as_slice());

let dict_batch = message.header_as_dictionary_batch().ok_or_else(|| {
Error::General(
"Unexpected message type deserializing ScalarValue::List dictionary message"
"Unexpected message type deserializing nested ScalarValue dictionary message"
.to_string(),
)
})?;

let id = dict_batch.id();

let record_batch = read_record_batch(
read_dictionary(
&buffer,
dict_batch.data().unwrap(),
Arc::new(schema.clone()),
&Default::default(),
None,
dict_batch,
&schema,
&mut dict_by_id,
&message.version(),
)?;

let values: ArrayRef = Arc::clone(record_batch.column(0));

Ok((id, values))
}).collect::<datafusion_common::Result<HashMap<_, _>>>()?;
)
.map_err(|e| arrow_datafusion_err!(e))
.map_err(|e| e.context("Decoding nested ScalarValue dictionary"))?;
}

let record_batch = read_record_batch(
&buffer,
Expand All @@ -473,7 +506,7 @@ impl TryFrom<&protobuf::ScalarValue> for ScalarValue {
&message.version(),
)
.map_err(|e| arrow_datafusion_err!(e))
.map_err(|e| e.context("Decoding ScalarValue::List Value"))?;
.map_err(|e| e.context("Decoding nested ScalarValue value"))?;
let arr = record_batch.column(0);
match value {
Value::ListValue(_) => {
Expand Down
13 changes: 10 additions & 3 deletions datafusion/proto-common/src/to_proto/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1031,21 +1031,28 @@ fn create_proto_scalar<I, T: FnOnce(&I) -> protobuf::scalar_value::Value>(
Ok(protobuf::ScalarValue { value: Some(value) })
}

// ScalarValue::List / FixedSizeList / LargeList / Struct / Map are serialized using
// Nested ScalarValue types (List / FixedSizeList / LargeList / Struct / Map) are serialized using
// Arrow IPC messages as a single column RecordBatch
fn encode_scalar_nested_value(
arr: ArrayRef,
val: &ScalarValue,
) -> Result<protobuf::ScalarValue, Error> {
let batch = RecordBatch::try_from_iter(vec![("field_name", arr)]).map_err(|e| {
Error::General(format!(
"Error creating temporary batch while encoding ScalarValue::List: {e}"
"Error creating temporary batch while encoding nested ScalarValue: {e}"
))
})?;

let ipc_gen = IpcDataGenerator {};
let mut dict_tracker = DictionaryTracker::new(false);
let write_options = IpcWriteOptions::default();
// The IPC writer requires pre-allocated dictionary IDs (normally assigned when
// serializing the schema). Populate `dict_tracker` by encoding the schema first.
ipc_gen.schema_to_bytes_with_dictionary_tracker(
batch.schema().as_ref(),
&mut dict_tracker,
&write_options,
);
let mut compression_context = CompressionContext::default();
let (encoded_dictionaries, encoded_message) = ipc_gen
.encode(
Expand All @@ -1055,7 +1062,7 @@ fn encode_scalar_nested_value(
&mut compression_context,
)
.map_err(|e| {
Error::General(format!("Error encoding ScalarValue::List as IPC: {e}"))
Error::General(format!("Error encoding nested ScalarValue as IPC: {e}"))
})?;

let schema: protobuf::Schema = batch.schema().try_into()?;
Expand Down
19 changes: 19 additions & 0 deletions datafusion/proto/tests/cases/roundtrip_physical_plan.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2566,6 +2566,25 @@ fn custom_proto_converter_intercepts() -> Result<()> {
Ok(())
}

#[test]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

With the code change reverted, this test fails like

---- cases::roundtrip_physical_plan::roundtrip_call_null_scalar_struct_dict stdout ----
Error: Plan("General error: Error encoding ScalarValue::List as IPC: Ipc error: no dict id for field item")

fn roundtrip_call_null_scalar_struct_dict() -> Result<()> {
let data_type = DataType::Struct(Fields::from(vec![Field::new(
"item",
DataType::Dictionary(Box::new(DataType::UInt32), Box::new(DataType::Utf8)),
true,
)]));

let schema = Arc::new(Schema::new(vec![Field::new("a", data_type.clone(), true)]));
let scan = Arc::new(EmptyExec::new(Arc::clone(&schema)));
let scalar = lit(ScalarValue::try_from(data_type)?);
let filter = Arc::new(FilterExec::try_new(
Arc::new(BinaryExpr::new(scalar, Operator::Eq, col("a", &schema)?)),
scan,
)?);

roundtrip_test(filter)
}

/// Test that expression deduplication works during deserialization.
/// When the same expression Arc is serialized multiple times, it should be
/// deduplicated on deserialization (sharing the same Arc).
Expand Down