diff --git a/backends/xnnpack/runtime/XNNCompiler.cpp b/backends/xnnpack/runtime/XNNCompiler.cpp index e04aa78579f..9a307e95ea4 100644 --- a/backends/xnnpack/runtime/XNNCompiler.cpp +++ b/backends/xnnpack/runtime/XNNCompiler.cpp @@ -183,8 +183,7 @@ Result getConstantDataPtr( uint64_t constant_data_size, const NamedDataMap* named_data_map, std::vector& freeable_buffers, - XNNWeightsCache* weights_cache, - bool use_weight_cache) { + XNNWeightsCache* weights_cache) { if (buffer_idx) { if (!constant_data_ptr) { // TODO(T172265611): Remove constant_buffer in flatbuffer path after BC @@ -243,30 +242,30 @@ Result getConstantDataPtr( InvalidProgram, "Named key is null"); const std::string& data_name = constant_data_offset->named_key()->str(); - if (use_weight_cache) { - Result data_ptr = - weights_cache->load_unpacked_data(data_name); - if (!data_ptr.ok()) { - ET_LOG(Error, "Failed to load weights from cache"); - return data_ptr.error(); - } - return data_ptr.get(); - } else { - Result buffer = - named_data_map->get_data(data_name.c_str()); - if (!buffer.ok()) { - ET_LOG( - Error, - "Failed to get constant data for key %s from named_data_map. Error code: %u", - data_name.c_str(), - static_cast(buffer.error())); - return buffer.error(); - } - const uint8_t* data_ptr = - static_cast(buffer.get().data()); - freeable_buffers.push_back(std::move(buffer.get())); - return data_ptr; +#ifdef ENABLE_XNNPACK_WEIGHTS_CACHE + Result data_ptr = + weights_cache->load_unpacked_data(data_name); + if (!data_ptr.ok()) { + ET_LOG(Error, "Failed to load weights from cache"); + return data_ptr.error(); + } + return data_ptr.get(); +#else + Result buffer = + named_data_map->get_data(data_name.c_str()); + if (!buffer.ok()) { + ET_LOG( + Error, + "Failed to get constant data for key %s from named_data_map. Error code: %u", + data_name.c_str(), + static_cast(buffer.error())); + return buffer.error(); } + const uint8_t* data_ptr = + static_cast(buffer.get().data()); + freeable_buffers.push_back(std::move(buffer.get())); + return data_ptr; +#endif } } } @@ -281,8 +280,7 @@ Result getConstantDataPtr( uint64_t constant_data_size, const NamedDataMap* named_data_map, std::vector& freeable_buffers, - XNNWeightsCache* weights_cache, - bool use_weight_cache) { + XNNWeightsCache* weights_cache) { return getConstantDataPtr( tensor_value->constant_buffer_idx(), flatbuffer_graph, @@ -290,8 +288,7 @@ Result getConstantDataPtr( constant_data_size, named_data_map, freeable_buffers, - weights_cache, - use_weight_cache); + weights_cache); } /** @@ -311,8 +308,7 @@ Error defineTensor( CompileAllocator& allocator, const NamedDataMap* named_data_map, std::vector& freeable_buffers, - XNNWeightsCache* weights_cache, - bool use_weight_cache) { + XNNWeightsCache* weights_cache) { const fb_xnnpack::XNNTensorValue* tensor_value = nullptr; const fb_xnnpack::XNNQuantizedTensorValue* qtensor_value = nullptr; @@ -367,8 +363,7 @@ Error defineTensor( constant_data_size, named_data_map, freeable_buffers, - weights_cache, - use_weight_cache); + weights_cache); if (!buffer_result.ok()) { return buffer_result.error(); } @@ -524,8 +519,7 @@ Error defineTensor( constant_data_size, named_data_map, freeable_buffers, - weights_cache, - use_weight_cache); + weights_cache); if (!scale_result.ok()) { return scale_result.error(); } @@ -572,8 +566,7 @@ Error defineTensor( constant_data_size, named_data_map, freeable_buffers, - weights_cache, - use_weight_cache); + weights_cache); if (!scale_data_result.ok()) { return scale_data_result.error(); } @@ -2001,8 +1994,7 @@ ET_NODISCARD Error XNNCompiler::compileModel( XNNExecutor* executor, XNNWeightsCache* weights_cache, xnn_workspace_t workspace, - const NamedDataMap* named_data_map, - bool use_weight_cache) { + const NamedDataMap* named_data_map) { Result header = XNNHeader::Parse(buffer_pointer, num_bytes); const uint8_t* flatbuffer_data = nullptr; const uint8_t* constant_data = nullptr; @@ -2115,8 +2107,7 @@ ET_NODISCARD Error XNNCompiler::compileModel( compile_allocator, named_data_map, unpacked_buffers, - weights_cache, - use_weight_cache); + weights_cache); if (err != Error::Ok) { return err; @@ -2138,16 +2129,19 @@ ET_NODISCARD Error XNNCompiler::compileModel( xnn_runtime_t runtime_ptr = nullptr; + // XNNWeightsCache if weights cache is not enabled, then XNNWeightsCache + // just manages the unpacked weights until the runtime is created. +#ifdef ENABLE_XNNPACK_WEIGHTS_CACHE + ET_CHECK_OR_RETURN_ERROR( + unpacked_buffers.size() == 0, + Internal, + "Weight Cache is enabled, which means unpacked buffers should be owned by the cache"); + xnn_weights_cache_t weights_cache_ptr = + weights_cache->get_num_unpacked_data() > 0 ? weights_cache->get() + : nullptr; +#else xnn_weights_cache_t weights_cache_ptr = nullptr; - if (use_weight_cache) { - ET_CHECK_OR_RETURN_ERROR( - unpacked_buffers.size() == 0, - Internal, - "Weight Cache is enabled, which means unpacked buffers should be owned by the cache"); - weights_cache_ptr = weights_cache->get_num_unpacked_data() > 0 - ? weights_cache->get() - : nullptr; - } +#endif // NOLINTBEGIN(facebook-hte-NullableDereference) - weights cache is allowed to // be null @@ -2166,25 +2160,25 @@ ET_NODISCARD Error XNNCompiler::compileModel( "XNN Runtime creation failed with code: %s", xnn_status_to_string(status)); - std::vector packed_weights_names; - if (use_weight_cache) { - auto packed_weights_names_result = weights_cache->finalize_for_runtime(); - ET_CHECK_OR_RETURN_ERROR( - packed_weights_names_result.ok(), - Internal, - "Failed to finalize weights cache after creating the xnn runtime"); - packed_weights_names = std::move(packed_weights_names_result.get()); - } else { - for (auto& buffer : unpacked_buffers) { - buffer.Free(); - } +#ifdef ENABLE_XNNPACK_WEIGHTS_CACHE + auto packed_weights_names = weights_cache->finalize_for_runtime(); + ET_CHECK_OR_RETURN_ERROR( + packed_weights_names.ok(), + Internal, + "Failed to finalize weights cache after creating the xnn runtime") +#else + for (auto& buffer : unpacked_buffers) { + buffer.Free(); } + Result> packed_weights_names = + std::vector(); +#endif err = executor->initialize( // NOLINT: runtime_ptr is non-null runtime_ptr, std::move(input_ids), std::move(output_ids), - std::move(packed_weights_names)); + std::move(packed_weights_names.get())); return err; }; diff --git a/backends/xnnpack/runtime/XNNCompiler.h b/backends/xnnpack/runtime/XNNCompiler.h index 639df0438cb..bcc87351d7d 100644 --- a/backends/xnnpack/runtime/XNNCompiler.h +++ b/backends/xnnpack/runtime/XNNCompiler.h @@ -29,8 +29,7 @@ class XNNCompiler { XNNExecutor* executor, XNNWeightsCache* weights_cache, xnn_workspace_t workspace, - const NamedDataMap* named_data_map, - bool use_weight_cache); + const NamedDataMap* named_data_map); }; } // namespace delegate diff --git a/backends/xnnpack/runtime/XNNPACKBackend.cpp b/backends/xnnpack/runtime/XNNPACKBackend.cpp index c20fa985f46..366fb220fc4 100644 --- a/backends/xnnpack/runtime/XNNPACKBackend.cpp +++ b/backends/xnnpack/runtime/XNNPACKBackend.cpp @@ -115,8 +115,7 @@ class XnnpackBackend final executor, weights_cache_.get(), workspace_ptr, - named_data_map, - use_weight_cache); + named_data_map); // This backend does not need its processed data after compiling the model. processed->Free();