Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 52 additions & 0 deletions extension/llm/runner/test/test_text_llm_runner.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -497,4 +497,56 @@ TEST_F(RunnerTest, GenerateEmptyWithoutPrefillFails) {
EXPECT_EQ(err, Error::InvalidState);
}

// Test that TextTokenGenerator works correctly in non-kv-cache mode.
// Exercises the code path fixed by reserving capacity before from_blob:
// without reserve(), vector reallocation would invalidate the data pointer.
TEST_F(RunnerTest, NonKvCacheGenerateCompletesSuccessfully) {
auto tokenizer = createMockTokenizer();
auto text_decoder_runner = createMockTextDecoderRunner();

// In non-kv-cache mode, the input tensor should grow by 1 token each step.
// Verify data is readable each time (catches dangling pointers under ASan).
int step_count = 0;
ON_CALL(*text_decoder_runner, step)
.WillByDefault(
[&](executorch::extension::TensorPtr& tokens_tensor, int64_t) {
// Initial tokens = 4 (prompt 1,2,3 + prefill token 4).
// Each step appends one token before the next call.
int64_t expected_size = 4 + step_count;
EXPECT_EQ(tokens_tensor->size(1), expected_size);

// Read data to verify the pointer is still valid.
auto* data = tokens_tensor->const_data_ptr<int64_t>();
EXPECT_EQ(data[0], 1); // first prompt token
EXPECT_EQ(data[1], 2);
EXPECT_EQ(data[2], 3);
EXPECT_EQ(data[3], 4); // prefill token

step_count++;
return Result<executorch::aten::Tensor>(tensor);
});

Stats stats;
auto eos_ids = std::make_unique<std::unordered_set<uint64_t>>(
std::unordered_set<uint64_t>{100});
TextTokenGenerator generator(
tokenizer.get(),
text_decoder_runner.get(),
false, // use_kv_cache = false
std::move(eos_ids),
&stats);

// 4 tokens: prompt (1,2,3) + prefill token (4)
std::vector<uint64_t> tokens = {1, 2, 3, 4};
// Generate enough tokens that the vector would reallocate without reserve.
int32_t max_new_tokens = 20;

auto result = generator.generate(
tokens, 4, max_new_tokens, 0.0f, [](const std::string&) {});

EXPECT_TRUE(result.ok());
EXPECT_EQ(result.get(), max_new_tokens);
EXPECT_EQ(step_count, max_new_tokens);
}

} // namespace
16 changes: 14 additions & 2 deletions extension/llm/runner/text_token_generator.h
Original file line number Diff line number Diff line change
Expand Up @@ -77,11 +77,23 @@ class ET_EXPERIMENTAL TextTokenGenerator {
} else {
token_data = tokens;
token_shape = {1, static_cast<int>(tokens.size())};
// Prevent reallocation that would invalidate from_blob's data pointer.
token_data.reserve(token_data.size() + max_new_tokens);
}

// initialize tensor wrappers
// Create tensor wrapper. For non-kv-cache, use max capacity shape so
// numel_bound_ is large enough for subsequent resize_tensor_ptr calls,
// then resize down to the actual token count.
auto max_shape = use_kv_cache_
? token_shape
: std::vector<executorch::aten::SizesType>{
1, static_cast<int>(tokens.size() + max_new_tokens)};
auto tokens_managed = from_blob(
token_data.data(), token_shape, executorch::aten::ScalarType::Long);
token_data.data(), max_shape, executorch::aten::ScalarType::Long);
if (!use_kv_cache_) {
ET_CHECK_OK_OR_RETURN_ERROR(
resize_tensor_ptr(tokens_managed, token_shape));
}

should_stop_ = false;

Expand Down
Loading