Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 4 additions & 6 deletions tensorflow/lite/micro/kernels/xtensa/pooling.cc
Original file line number Diff line number Diff line change
Expand Up @@ -117,8 +117,12 @@ TfLiteStatus MaxEval(TfLiteContext* context, TfLiteNode* node) {
break;
}
case kTfLiteInt16: {
#if defined(HIFI5)
MaxEvalQuantizedHifi(context, node, params, op_data, input, output);
#else
MaxPoolingEvalQuantized<int16_t>(context, node, params, reference_op_data,
input, output);
#endif
break;
}
default: {
Expand Down Expand Up @@ -156,10 +160,4 @@ TFLMRegistration Register_MAX_POOL_2D() {
#endif
}

TFLMRegistration Register_AVERAGE_POOL_2D_INT16() {
return Register_AVERAGE_POOL_2D();
}

TFLMRegistration Register_MAX_POOL_2D_INT16() { return Register_MAX_POOL_2D(); }

} // namespace tflite
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,68 @@ TfLiteStatus MaxEvalInt8(TfLiteContext* context, TfLiteNode* node) {
return kTfLiteOk;
}

TfLiteStatus AverageEvalInt16(TfLiteContext* context, TfLiteNode* node) {
TFLITE_DCHECK(node->builtin_data != nullptr);
auto* params = reinterpret_cast<TfLitePoolParams*>(node->builtin_data);

TFLITE_DCHECK(node->user_data != nullptr);

const TfLiteEvalTensor* input =
micro::GetEvalInput(context, node, kPoolingInputTensor);
TfLiteEvalTensor* output =
micro::GetEvalOutput(context, node, kPoolingOutputTensor);

// Inputs and outputs share the same type, guaranteed by the converter.
switch (input->type) {
case kTfLiteInt16: {
const OpDataPooling* reference_op_data =
static_cast<const OpDataPooling*>(node->user_data);
AveragePoolingEvalQuantized<int16_t>(context, node, params,
reference_op_data, input, output);
break;
}
default: {
MicroPrintf("Input type %s is not currently supported",
TfLiteTypeGetName(input->type));
return kTfLiteError;
}
}
return kTfLiteOk;
}

TfLiteStatus MaxEvalInt16(TfLiteContext* context, TfLiteNode* node) {
TFLITE_DCHECK(node->builtin_data != nullptr);
auto* params = reinterpret_cast<TfLitePoolParams*>(node->builtin_data);

TFLITE_DCHECK(node->user_data != nullptr);

const TfLiteEvalTensor* input =
micro::GetEvalInput(context, node, kPoolingInputTensor);
TfLiteEvalTensor* output =
micro::GetEvalOutput(context, node, kPoolingOutputTensor);

switch (input->type) {
case kTfLiteInt16: {
#if defined(HIFI5)
auto* op_data = static_cast<const XtensaOpDataPooling*>(node->user_data);
MaxEvalQuantizedHifi(context, node, params, op_data, input, output);
#else
const OpDataPooling* reference_op_data =
static_cast<const OpDataPooling*>(node->user_data);
MaxPoolingEvalQuantized<int16_t>(context, node, params, reference_op_data,
input, output);
#endif
break;
}
default: {
MicroPrintf("Type %s not currently supported.",
TfLiteTypeGetName(input->type));
return kTfLiteError;
}
}
return kTfLiteOk;
}

} // namespace

#if defined(HIFI5)
Expand All @@ -111,10 +173,11 @@ TfLiteStatus AveragePrepareHifi(TfLiteContext* context, TfLiteNode* node) {
TfLiteTensor* input =
micro_context->AllocateTempInputTensor(node, kPoolingInputTensor);

// Hifi5 implementation only works with int8.
if (input->type == kTfLiteInt8) {
const RuntimeShape& input_shape = GetTensorShape(input);
TfLiteTensor* output =
micro_context->AllocateTempInputTensor(node, kPoolingOutputTensor);
micro_context->AllocateTempOutputTensor(node, kPoolingOutputTensor);
const RuntimeShape& output_shape = GetTensorShape(output);
micro_context->DeallocateTempTfLiteTensor(output);

Expand Down Expand Up @@ -155,8 +218,6 @@ TfLiteStatus AverageEvalQuantizedHifi(TfLiteContext* context,
const XtensaOpDataPooling* data,
const TfLiteEvalTensor* input,
TfLiteEvalTensor* output) {
TFLITE_DCHECK(input->type == kTfLiteInt8);

const RuntimeShape& input_shape = tflite::micro::GetTensorShape(input);
const RuntimeShape& output_shape = tflite::micro::GetTensorShape(output);
const int batches = MatchingDim(input_shape, 0, output_shape, 0);
Expand All @@ -169,31 +230,42 @@ TfLiteStatus AverageEvalQuantizedHifi(TfLiteContext* context,
void* p_scratch = static_cast<void*>(
context->GetScratchBuffer(context, data->scratch_tensor_index));

const int8_t* inp_data_ptr = tflite::micro::GetTensorData<int8_t>(input);
int8_t* out_data_ptr = tflite::micro::GetTensorData<int8_t>(output);

for (int batch = 0; batch < batches; ++batch) {
TF_LITE_ENSURE_EQ(
context,
xa_nn_avgpool_8(
&out_data_ptr[output_height * output_width * depth * batch],
const_cast<int8_t*>(
&inp_data_ptr[output_height * output_width * depth * batch]),
input_height, input_width, depth, params->filter_height,
params->filter_width, params->stride_width, params->stride_height,
data->reference_op_data.padding.width,
data->reference_op_data.padding.height, output_height, output_width,
0, 0, p_scratch),
0);
}
switch (input->type) {
case kTfLiteInt8: {
const int8_t* inp_data_ptr = tflite::micro::GetTensorData<int8_t>(input);
int8_t* out_data_ptr = tflite::micro::GetTensorData<int8_t>(output);

for (int batch = 0; batch < batches; ++batch) {
TF_LITE_ENSURE_EQ(
context,
xa_nn_avgpool_8(
&out_data_ptr[output_height * output_width * depth * batch],
const_cast<int8_t*>(&inp_data_ptr[output_height * output_width *
depth * batch]),
input_height, input_width, depth, params->filter_height,
params->filter_width, params->stride_width,
params->stride_height, data->reference_op_data.padding.width,
data->reference_op_data.padding.height, output_height,
output_width, 0, 0, p_scratch),
0);
}

const int out_length = batches * output_height * output_width * depth;
TF_LITE_ENSURE_EQ(context,
xa_nn_vec_activation_min_max_8_8(
out_data_ptr, out_data_ptr,
data->reference_op_data.activation_min,
data->reference_op_data.activation_max, out_length),
0);

const int out_length = batches * output_height * output_width * depth;
TF_LITE_ENSURE_EQ(
context,
xa_nn_vec_activation_min_max_8_8(
out_data_ptr, out_data_ptr, data->reference_op_data.activation_min,
data->reference_op_data.activation_max, out_length),
0);
break;
}
default: {
MicroPrintf("Input type %s is not currently supported",
TfLiteTypeGetName(input->type));
return kTfLiteError;
}
}

return kTfLiteOk;
}
Expand All @@ -206,7 +278,7 @@ TfLiteStatus MaxPrepareHifi(TfLiteContext* context, TfLiteNode* node) {
TfLiteTensor* input =
micro_context->AllocateTempInputTensor(node, kPoolingInputTensor);

if (input->type == kTfLiteInt8) {
if (input->type == kTfLiteInt8 || input->type == kTfLiteInt16) {
auto* params = reinterpret_cast<TfLitePoolParams*>(node->builtin_data);
auto* data = static_cast<XtensaOpDataPooling*>(node->user_data);

Expand All @@ -222,14 +294,41 @@ TfLiteStatus MaxPrepareHifi(TfLiteContext* context, TfLiteNode* node) {
const int output_height = output_shape.Dims(1);
const int output_width = output_shape.Dims(2);

int required_scratch = xa_nn_maxpool_getsize(
depth, PREC_8, PREC_8, input_height, input_width, params->filter_height,
params->filter_width,
params->stride_width, // x_stride,
params->stride_height, // y_stride,
data->reference_op_data.padding.width, // x_padding,
data->reference_op_data.padding.height, // y_padding,
output_height, output_width, 0 /* NHWC inpput */, 0 /* NHWC output */);
int required_scratch = 0;

switch (input->type) {
case kTfLiteInt8: {
required_scratch = xa_nn_maxpool_getsize(
depth, PREC_8, PREC_8, input_height, input_width,
params->filter_height, params->filter_width,
params->stride_width, // x_stride,
params->stride_height, // y_stride,
data->reference_op_data.padding.width, // x_padding,
data->reference_op_data.padding.height, // y_padding,
output_height, output_width, 0 /* NHWC inpput */,
0 /* NHWC output */);

break;
}
case kTfLiteInt16: {
required_scratch = xa_nn_maxpool_getsize(
depth, PREC_16, PREC_16, input_height, input_width,
params->filter_height, params->filter_width,
params->stride_width, // x_stride,
params->stride_height, // y_stride,
data->reference_op_data.padding.width, // x_padding,
data->reference_op_data.padding.height, // y_padding,
output_height, output_width, 0 /* NHWC inpput */,
0 /* NHWC output */);

break;
}
default: {
MicroPrintf("Input type %s is not currently supported",
TfLiteTypeGetName(input->type));
return kTfLiteError;
}
}

if (required_scratch <= 0) {
MicroPrintf("Maxpool: xa_nn_maxpool_getsize failed");
Expand Down Expand Up @@ -261,32 +360,72 @@ TfLiteStatus MaxEvalQuantizedHifi(TfLiteContext* context, TfLiteNode* node,
void* p_scratch = static_cast<void*>(
context->GetScratchBuffer(context, data->scratch_tensor_index));

const int8_t* inp_data_ptr = tflite::micro::GetTensorData<int8_t>(input);
int8_t* out_data_ptr = tflite::micro::GetTensorData<int8_t>(output);

for (int batch = 0; batch < batches; ++batch) {
TF_LITE_ENSURE_EQ(
context,
xa_nn_maxpool_8(
&out_data_ptr[output_height * output_width * depth * batch],
const_cast<int8_t*>(
&inp_data_ptr[output_height * output_width * depth * batch]),
input_height, input_width, depth, params->filter_height,
params->filter_width, params->stride_width, params->stride_height,
data->reference_op_data.padding.width,
data->reference_op_data.padding.height, output_height, output_width,
0, 0, p_scratch),
0);
switch (input->type) {
case kTfLiteInt8: {
const int8_t* inp_data_ptr = tflite::micro::GetTensorData<int8_t>(input);
int8_t* out_data_ptr = tflite::micro::GetTensorData<int8_t>(output);

for (int batch = 0; batch < batches; ++batch) {
TF_LITE_ENSURE_EQ(
context,
xa_nn_maxpool_8(
&out_data_ptr[output_height * output_width * depth * batch],
const_cast<int8_t*>(&inp_data_ptr[output_height * output_width *
depth * batch]),
input_height, input_width, depth, params->filter_height,
params->filter_width, params->stride_width,
params->stride_height, data->reference_op_data.padding.width,
data->reference_op_data.padding.height, output_height,
output_width, 0, 0, p_scratch),
0);
}

const int out_length = batches * output_height * output_width * depth;
TF_LITE_ENSURE_EQ(context,
xa_nn_vec_activation_min_max_8_8(
out_data_ptr, out_data_ptr,
data->reference_op_data.activation_min,
data->reference_op_data.activation_max, out_length),
0);
break;
}
case kTfLiteInt16: {
const int16_t* inp_data_ptr =
tflite::micro::GetTensorData<int16_t>(input);
int16_t* out_data_ptr = tflite::micro::GetTensorData<int16_t>(output);

for (int batch = 0; batch < batches; ++batch) {
TF_LITE_ENSURE_EQ(
context,
xa_nn_maxpool_16(
&out_data_ptr[output_height * output_width * depth * batch],
const_cast<int16_t*>(
&inp_data_ptr[output_height * output_width * depth *
batch]),
input_height, input_width, depth, params->filter_height,
params->filter_width, params->stride_width,
params->stride_height, data->reference_op_data.padding.width,
data->reference_op_data.padding.height, output_height,
output_width, 0, 0, p_scratch),
0);
}

const int out_length = batches * output_height * output_width * depth;
TF_LITE_ENSURE_EQ(context,
xa_nn_vec_activation_min_max_16_16(
out_data_ptr, out_data_ptr,
data->reference_op_data.activation_min,
data->reference_op_data.activation_max, out_length),
0);
break;
}
default: {
MicroPrintf("Input type %s is not currently supported",
TfLiteTypeGetName(input->type));
return kTfLiteError;
}
}

const int out_length = batches * output_height * output_width * depth;
TF_LITE_ENSURE_EQ(
context,
xa_nn_vec_activation_min_max_8_8(
out_data_ptr, out_data_ptr, data->reference_op_data.activation_min,
data->reference_op_data.activation_max, out_length),
0);

return kTfLiteOk;
}

Expand Down Expand Up @@ -335,4 +474,24 @@ TFLMRegistration Register_MAX_POOL_2D_INT8() {
#endif
}

TFLMRegistration Register_AVERAGE_POOL_2D_INT16() {
#if defined(HIFI5)
return tflite::micro::RegisterOp(XtensaPoolingInit, AveragePrepareHifi,
AverageEvalInt16);
#else
return tflite::micro::RegisterOp(XtensaPoolingInit, PoolingPrepare,
AverageEvalInt16);
#endif
}

TFLMRegistration Register_MAX_POOL_2D_INT16() {
#if defined(HIFI5)
return tflite::micro::RegisterOp(XtensaPoolingInit, MaxPrepareHifi,
MaxEvalInt16);
#else
return tflite::micro::RegisterOp(XtensaPoolingInit, PoolingPrepare,
MaxEvalInt16);
#endif
}

} // namespace tflite
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ ifeq ($(OPTIMIZED_KERNEL_DIR), xtensa)
$(TENSORFLOW_ROOT)tensorflow/lite/micro/kernels/xtensa/fully_connected_vision.cc \
$(TENSORFLOW_ROOT)tensorflow/lite/micro/kernels/xtensa/lstm_eval_hifi.cc \
$(TENSORFLOW_ROOT)tensorflow/lite/micro/kernels/xtensa/pad_vision.cc \
$(TENSORFLOW_ROOT)tensorflow/lite/micro/kernels/xtensa/pooling_int8.cc \
$(TENSORFLOW_ROOT)tensorflow/lite/micro/kernels/xtensa/pooling_hifi.cc \
$(TENSORFLOW_ROOT)tensorflow/lite/micro/kernels/xtensa/pooling_vision.cc \
$(TENSORFLOW_ROOT)tensorflow/lite/micro/kernels/xtensa/reduce_vision.cc \
$(TENSORFLOW_ROOT)tensorflow/lite/micro/kernels/xtensa/reshape_vision.cc \
Expand Down
Loading