diff --git a/backends/xnnpack/runtime/XNNCompiler.cpp b/backends/xnnpack/runtime/XNNCompiler.cpp index 4881844ac6d..dc7c638e4b8 100644 --- a/backends/xnnpack/runtime/XNNCompiler.cpp +++ b/backends/xnnpack/runtime/XNNCompiler.cpp @@ -14,6 +14,7 @@ #include #include #include +#include #include #pragma clang diagnostic ignored "-Wmissing-prototypes" @@ -1876,19 +1877,27 @@ ET_NODISCARD Error XNNCompiler::compileModel( // Invalid ids do not need to be remapped remapped_ids.emplace(XNN_INVALID_VALUE_ID, XNN_INVALID_VALUE_ID); - // If weight cache is not on we hold onto all the unpacked buffers - // and we free them at the end + // Buffers loaded from the named data map. After xnn_create_runtime, + // buffers consumed by packing operators are freed; the rest are moved + // into the executor to keep them alive for non-packing operators. std::vector unpacked_buffers; + // Maps xvalue index to unpacked_buffers index for values whose data was + // loaded from the named data map. Used to selectively retain buffers that + // are still referenced at runtime (non-packing operators). + std::unordered_map named_data_buffer_map; + // External Ids for inputs and outputs std::vector input_ids; std::vector output_ids; Error err = Error::Ok; - for (auto value : *flatbuffer_graph->xvalues()) { + auto xvalues = flatbuffer_graph->xvalues(); + for (uint32_t i = 0; i < xvalues->size(); i++) { + size_t prev_buffers = unpacked_buffers.size(); err = defineTensor( subgraph.get(), remapped_ids, - value, + xvalues->Get(i), flatbuffer_graph, constant_data, input_ids, @@ -1901,6 +1910,10 @@ ET_NODISCARD Error XNNCompiler::compileModel( if (err != Error::Ok) { return err; } + + if (unpacked_buffers.size() > prev_buffers) { + named_data_buffer_map[i] = prev_buffers; + } } for (auto node : *flatbuffer_graph->xnodes()) { @@ -1956,8 +1969,55 @@ ET_NODISCARD Error XNNCompiler::compileModel( Internal, "Failed to finalize weights cache after creating the xnn runtime") #else - for (auto& buffer : unpacked_buffers) { - buffer.Free(); + + // Operators like convolution and fully-connected pack weights during + // load, so those buffers can be freed. Other operators (PreLU) retain + // raw pointers to the original constant data, so those buffers need + // to remain alive. + if (!named_data_buffer_map.empty()) { + // Collect xvalue indices whose data was packed by create_runtime. + // These are the filter and bias inputs of weight-packing operators. + std::unordered_set packed_value_indices; + for (auto node : *flatbuffer_graph->xnodes()) { + auto type = node->xnode_union_type(); + switch (type) { + case fb_xnnpack::XNodeUnion::XNNFullyConnected: { + auto n = node->xnode_union_as_XNNFullyConnected(); + packed_value_indices.insert(n->filter_id()); + packed_value_indices.insert(n->bias_id()); + break; + } + case fb_xnnpack::XNodeUnion::XNNConv2d: { + auto n = node->xnode_union_as_XNNConv2d(); + packed_value_indices.insert(n->filter_id()); + packed_value_indices.insert(n->bias_id()); + break; + } + case fb_xnnpack::XNodeUnion::XNNDepthwiseConv2d: { + auto n = node->xnode_union_as_XNNDepthwiseConv2d(); + packed_value_indices.insert(n->filter_id()); + packed_value_indices.insert(n->bias_id()); + break; + } + case fb_xnnpack::XNodeUnion::XNNConvTranspose2d: { + auto n = node->xnode_union_as_XNNConvTranspose2d(); + packed_value_indices.insert(n->filter_id()); + packed_value_indices.insert(n->bias_id()); + break; + } + default: + break; + } + } + + for (auto& [value_idx, buffer_idx] : named_data_buffer_map) { + if (packed_value_indices.count(value_idx)) { + unpacked_buffers[buffer_idx].Free(); + } else { + executor->unpacked_buffers_.push_back( + std::move(unpacked_buffers[buffer_idx])); + } + } } Result> packed_weights_names = std::vector(); diff --git a/backends/xnnpack/runtime/XNNExecutor.h b/backends/xnnpack/runtime/XNNExecutor.h index 6c07771b02a..3cde5747f58 100644 --- a/backends/xnnpack/runtime/XNNExecutor.h +++ b/backends/xnnpack/runtime/XNNExecutor.h @@ -14,6 +14,7 @@ #include #include #include +#include #include #include @@ -26,6 +27,10 @@ namespace delegate { class XNNExecutor { private: + // For XNN constant data that isn't packed (PreLU weights, for example), + // we need to hold onto the buffers to keep them alive. + std::vector unpacked_buffers_; + std::unique_ptr runtime_{ nullptr, &xnn_delete_runtime}; diff --git a/backends/xnnpack/test/ops/test_prelu.py b/backends/xnnpack/test/ops/test_prelu.py index 47b2851278c..a239603a893 100644 --- a/backends/xnnpack/test/ops/test_prelu.py +++ b/backends/xnnpack/test/ops/test_prelu.py @@ -4,10 +4,15 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +import os +import tempfile import unittest +from pathlib import Path import torch +from executorch.backends.test.harness.stages import StageType from executorch.backends.xnnpack.test.tester import Tester +from executorch.runtime import Runtime, Verification class TestPrelu(unittest.TestCase): @@ -48,3 +53,34 @@ def test_fp32_prelu(self): module = self.PReLU() inputs = (torch.randn(1, 5, 3, 2),) self._test_prelu(module, inputs) + + def test_fp32_prelu_file_load(self): + """ + Make sure that PreLU doesn't free its weight buffer after load. It's a weird + op that doesn't copy or pack its data, so we need to hold onto the buffer. + Run specifically from a file to exercise the path. + """ + module = self.PReLU() + module.eval() + x = torch.randn(1, 5, 3, 2) + expected = module(x) + + tester = Tester(module, (x,)) + tester.export() + tester.to_edge_transform_and_lower() + tester.check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) + tester.to_executorch() + tester.serialize() + + buf = tester.stages[StageType.SERIALIZE].artifact + fd, pte_path = tempfile.mkstemp(suffix=".pte") + try: + os.write(fd, buf) + os.close(fd) + rt = Runtime.get() + program = rt.load_program(Path(pte_path), verification=Verification.Minimal) + method = program.load_method("forward") + actual = method.execute((x,))[0] + self.assertTrue(torch.allclose(expected, actual, atol=1e-5)) + finally: + os.unlink(pte_path) diff --git a/backends/xnnpack/xnnpack_preprocess.py b/backends/xnnpack/xnnpack_preprocess.py index cdceb8a90a1..2cf6b88cc46 100644 --- a/backends/xnnpack/xnnpack_preprocess.py +++ b/backends/xnnpack/xnnpack_preprocess.py @@ -140,6 +140,7 @@ def preprocess( passes.append(ConvertToLinearPass) passes = passes if len(passes) > 0 else None + # XNNPACK Delegate Specific Passes ep = XNNPACKPassManager(ep, passes=passes).transform() graph_module = ep.graph_module