diff --git a/extension/llm/runner/test/test_text_llm_runner.cpp b/extension/llm/runner/test/test_text_llm_runner.cpp index e7dbb762dd9..7e9468a5f01 100644 --- a/extension/llm/runner/test/test_text_llm_runner.cpp +++ b/extension/llm/runner/test/test_text_llm_runner.cpp @@ -497,4 +497,56 @@ TEST_F(RunnerTest, GenerateEmptyWithoutPrefillFails) { EXPECT_EQ(err, Error::InvalidState); } +// Test that TextTokenGenerator works correctly in non-kv-cache mode. +// Exercises the code path fixed by reserving capacity before from_blob: +// without reserve(), vector reallocation would invalidate the data pointer. +TEST_F(RunnerTest, NonKvCacheGenerateCompletesSuccessfully) { + auto tokenizer = createMockTokenizer(); + auto text_decoder_runner = createMockTextDecoderRunner(); + + // In non-kv-cache mode, the input tensor should grow by 1 token each step. + // Verify data is readable each time (catches dangling pointers under ASan). + int step_count = 0; + ON_CALL(*text_decoder_runner, step) + .WillByDefault( + [&](executorch::extension::TensorPtr& tokens_tensor, int64_t) { + // Initial tokens = 4 (prompt 1,2,3 + prefill token 4). + // Each step appends one token before the next call. + int64_t expected_size = 4 + step_count; + EXPECT_EQ(tokens_tensor->size(1), expected_size); + + // Read data to verify the pointer is still valid. + auto* data = tokens_tensor->const_data_ptr(); + EXPECT_EQ(data[0], 1); // first prompt token + EXPECT_EQ(data[1], 2); + EXPECT_EQ(data[2], 3); + EXPECT_EQ(data[3], 4); // prefill token + + step_count++; + return Result(tensor); + }); + + Stats stats; + auto eos_ids = std::make_unique>( + std::unordered_set{100}); + TextTokenGenerator generator( + tokenizer.get(), + text_decoder_runner.get(), + false, // use_kv_cache = false + std::move(eos_ids), + &stats); + + // 4 tokens: prompt (1,2,3) + prefill token (4) + std::vector tokens = {1, 2, 3, 4}; + // Generate enough tokens that the vector would reallocate without reserve. + int32_t max_new_tokens = 20; + + auto result = generator.generate( + tokens, 4, max_new_tokens, 0.0f, [](const std::string&) {}); + + EXPECT_TRUE(result.ok()); + EXPECT_EQ(result.get(), max_new_tokens); + EXPECT_EQ(step_count, max_new_tokens); +} + } // namespace diff --git a/extension/llm/runner/text_token_generator.h b/extension/llm/runner/text_token_generator.h index 128de05d1d9..4bbc91a01c7 100644 --- a/extension/llm/runner/text_token_generator.h +++ b/extension/llm/runner/text_token_generator.h @@ -77,11 +77,23 @@ class ET_EXPERIMENTAL TextTokenGenerator { } else { token_data = tokens; token_shape = {1, static_cast(tokens.size())}; + // Prevent reallocation that would invalidate from_blob's data pointer. + token_data.reserve(token_data.size() + max_new_tokens); } - // initialize tensor wrappers + // Create tensor wrapper. For non-kv-cache, use max capacity shape so + // numel_bound_ is large enough for subsequent resize_tensor_ptr calls, + // then resize down to the actual token count. + auto max_shape = use_kv_cache_ + ? token_shape + : std::vector{ + 1, static_cast(tokens.size() + max_new_tokens)}; auto tokens_managed = from_blob( - token_data.data(), token_shape, executorch::aten::ScalarType::Long); + token_data.data(), max_shape, executorch::aten::ScalarType::Long); + if (!use_kv_cache_) { + ET_CHECK_OK_OR_RETURN_ERROR( + resize_tensor_ptr(tokens_managed, token_shape)); + } should_stop_ = false;