Skip to content

Commit 6718843

Browse files
committed
test: add unit tests for BlockRefinementScheduler
12 tests covering set_timesteps, get_num_transfer_tokens, step logic (confidence-based commits, threshold behavior, editing, prompt masking, batched inputs, tuple output).
1 parent b3f6cb5 commit 6718843

1 file changed

Lines changed: 231 additions & 0 deletions

File tree

Lines changed: 231 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,231 @@
1+
import unittest
2+
3+
import torch
4+
5+
from diffusers import BlockRefinementScheduler
6+
7+
8+
class BlockRefinementSchedulerTest(unittest.TestCase):
9+
def test_set_timesteps(self):
10+
scheduler = BlockRefinementScheduler(block_length=32, num_inference_steps=8)
11+
scheduler.set_timesteps(8)
12+
self.assertEqual(scheduler.num_inference_steps, 8)
13+
self.assertEqual(len(scheduler.timesteps), 8)
14+
# Timesteps should count down
15+
self.assertEqual(scheduler.timesteps[0].item(), 7)
16+
self.assertEqual(scheduler.timesteps[-1].item(), 0)
17+
18+
def test_set_timesteps_invalid(self):
19+
scheduler = BlockRefinementScheduler()
20+
with self.assertRaises(ValueError):
21+
scheduler.set_timesteps(0)
22+
23+
def test_get_num_transfer_tokens_even(self):
24+
scheduler = BlockRefinementScheduler()
25+
schedule = scheduler.get_num_transfer_tokens(block_length=32, num_inference_steps=8)
26+
self.assertEqual(schedule.sum().item(), 32)
27+
self.assertEqual(len(schedule), 8)
28+
# 32 / 8 = 4 each, no remainder
29+
self.assertTrue((schedule == 4).all().item())
30+
31+
def test_get_num_transfer_tokens_remainder(self):
32+
scheduler = BlockRefinementScheduler()
33+
schedule = scheduler.get_num_transfer_tokens(block_length=10, num_inference_steps=3)
34+
self.assertEqual(schedule.sum().item(), 10)
35+
self.assertEqual(len(schedule), 3)
36+
# 10 / 3 = 3 base, 1 remainder -> [4, 3, 3]
37+
self.assertEqual(schedule[0].item(), 4)
38+
self.assertEqual(schedule[1].item(), 3)
39+
self.assertEqual(schedule[2].item(), 3)
40+
41+
def test_transfer_schedule_created_on_set_timesteps(self):
42+
scheduler = BlockRefinementScheduler(block_length=16)
43+
scheduler.set_timesteps(4)
44+
self.assertIsNotNone(scheduler._transfer_schedule)
45+
self.assertEqual(scheduler._transfer_schedule.sum().item(), 16)
46+
47+
def test_step_commits_tokens(self):
48+
"""Verify that step() commits mask tokens based on confidence."""
49+
scheduler = BlockRefinementScheduler(block_length=8)
50+
scheduler.set_timesteps(2)
51+
52+
batch_size, block_length = 1, 8
53+
mask_id = 99
54+
55+
# All positions are masked
56+
sample = torch.full((batch_size, block_length), mask_id, dtype=torch.long)
57+
sampled_tokens = torch.arange(block_length, dtype=torch.long).unsqueeze(0)
58+
# Confidence decreasing: first tokens are most confident
59+
sampled_probs = torch.tensor([[0.9, 0.8, 0.7, 0.6, 0.5, 0.4, 0.3, 0.2]])
60+
61+
out = scheduler.step(
62+
sampled_tokens=sampled_tokens,
63+
sampled_probs=sampled_probs,
64+
timestep=0,
65+
sample=sample,
66+
mask_token_id=mask_id,
67+
threshold=0.95,
68+
return_dict=True,
69+
)
70+
71+
# With 8 tokens and 2 steps, first step should commit 4 tokens
72+
committed = out.transfer_index[0].sum().item()
73+
self.assertEqual(committed, 4)
74+
# The 4 most confident (highest prob) should be committed
75+
self.assertTrue(out.transfer_index[0, 0].item())
76+
self.assertTrue(out.transfer_index[0, 1].item())
77+
self.assertTrue(out.transfer_index[0, 2].item())
78+
self.assertTrue(out.transfer_index[0, 3].item())
79+
80+
def test_step_threshold_commits_all_above(self):
81+
"""When enough tokens exceed threshold, commit all of them (not just num_to_transfer)."""
82+
scheduler = BlockRefinementScheduler(block_length=8)
83+
scheduler.set_timesteps(4) # 2 tokens per step
84+
85+
batch_size, block_length = 1, 8
86+
mask_id = 99
87+
88+
sample = torch.full((batch_size, block_length), mask_id, dtype=torch.long)
89+
sampled_tokens = torch.arange(block_length, dtype=torch.long).unsqueeze(0)
90+
# 5 tokens above threshold of 0.5
91+
sampled_probs = torch.tensor([[0.9, 0.8, 0.7, 0.6, 0.55, 0.1, 0.1, 0.1]])
92+
93+
out = scheduler.step(
94+
sampled_tokens=sampled_tokens,
95+
sampled_probs=sampled_probs,
96+
timestep=0,
97+
sample=sample,
98+
mask_token_id=mask_id,
99+
threshold=0.5,
100+
return_dict=True,
101+
)
102+
103+
# All 5 above threshold should be committed (more than num_to_transfer=2)
104+
committed = out.transfer_index[0].sum().item()
105+
self.assertEqual(committed, 5)
106+
107+
def test_step_no_editing_by_default(self):
108+
"""Without editing_threshold, no non-mask tokens should be changed."""
109+
scheduler = BlockRefinementScheduler(block_length=4)
110+
scheduler.set_timesteps(2)
111+
112+
sample = torch.tensor([[10, 20, 99, 99]], dtype=torch.long)
113+
sampled_tokens = torch.tensor([[50, 60, 70, 80]], dtype=torch.long)
114+
sampled_probs = torch.tensor([[0.99, 0.99, 0.99, 0.99]])
115+
116+
out = scheduler.step(
117+
sampled_tokens=sampled_tokens,
118+
sampled_probs=sampled_probs,
119+
timestep=0,
120+
sample=sample,
121+
mask_token_id=99,
122+
editing_threshold=None,
123+
return_dict=True,
124+
)
125+
126+
# Non-mask positions should not be edited
127+
self.assertFalse(out.editing_transfer_index.any().item())
128+
# Only mask positions should be committed
129+
self.assertFalse(out.transfer_index[0, 0].item())
130+
self.assertFalse(out.transfer_index[0, 1].item())
131+
132+
def test_step_editing_replaces_tokens(self):
133+
"""With editing_threshold, non-mask tokens with high confidence and different prediction get replaced."""
134+
scheduler = BlockRefinementScheduler(block_length=4)
135+
scheduler.set_timesteps(2)
136+
137+
sample = torch.tensor([[10, 20, 99, 99]], dtype=torch.long)
138+
# Token 0: model predicts 50 (different from 10) with high confidence
139+
# Token 1: model predicts 20 (same as current) — should NOT edit
140+
sampled_tokens = torch.tensor([[50, 20, 70, 80]], dtype=torch.long)
141+
sampled_probs = torch.tensor([[0.99, 0.99, 0.5, 0.5]])
142+
143+
out = scheduler.step(
144+
sampled_tokens=sampled_tokens,
145+
sampled_probs=sampled_probs,
146+
timestep=0,
147+
sample=sample,
148+
mask_token_id=99,
149+
editing_threshold=0.8,
150+
return_dict=True,
151+
)
152+
153+
# Token 0 should be edited (different prediction, high confidence)
154+
self.assertTrue(out.editing_transfer_index[0, 0].item())
155+
# Token 1 should NOT be edited (same prediction)
156+
self.assertFalse(out.editing_transfer_index[0, 1].item())
157+
# prev_sample should reflect the edit
158+
self.assertEqual(out.prev_sample[0, 0].item(), 50)
159+
160+
def test_step_prompt_mask_prevents_editing(self):
161+
"""Prompt positions should never be edited even with editing enabled."""
162+
scheduler = BlockRefinementScheduler(block_length=4)
163+
scheduler.set_timesteps(2)
164+
165+
sample = torch.tensor([[10, 20, 99, 99]], dtype=torch.long)
166+
sampled_tokens = torch.tensor([[50, 60, 70, 80]], dtype=torch.long)
167+
sampled_probs = torch.tensor([[0.99, 0.99, 0.99, 0.99]])
168+
prompt_mask = torch.tensor([True, True, False, False])
169+
170+
out = scheduler.step(
171+
sampled_tokens=sampled_tokens,
172+
sampled_probs=sampled_probs,
173+
timestep=0,
174+
sample=sample,
175+
mask_token_id=99,
176+
editing_threshold=0.5,
177+
prompt_mask=prompt_mask,
178+
return_dict=True,
179+
)
180+
181+
# Prompt positions should not be edited
182+
self.assertFalse(out.editing_transfer_index[0, 0].item())
183+
self.assertFalse(out.editing_transfer_index[0, 1].item())
184+
185+
def test_step_return_tuple(self):
186+
"""Verify tuple output when return_dict=False."""
187+
scheduler = BlockRefinementScheduler(block_length=4)
188+
scheduler.set_timesteps(2)
189+
190+
sample = torch.full((1, 4), 99, dtype=torch.long)
191+
sampled_tokens = torch.arange(4, dtype=torch.long).unsqueeze(0)
192+
sampled_probs = torch.ones(1, 4)
193+
194+
result = scheduler.step(
195+
sampled_tokens=sampled_tokens,
196+
sampled_probs=sampled_probs,
197+
timestep=0,
198+
sample=sample,
199+
mask_token_id=99,
200+
return_dict=False,
201+
)
202+
203+
self.assertIsInstance(result, tuple)
204+
self.assertEqual(len(result), 5)
205+
206+
def test_step_batched(self):
207+
"""Verify step works with batch_size > 1."""
208+
scheduler = BlockRefinementScheduler(block_length=4)
209+
scheduler.set_timesteps(2)
210+
211+
batch_size = 3
212+
mask_id = 99
213+
sample = torch.full((batch_size, 4), mask_id, dtype=torch.long)
214+
sampled_tokens = torch.arange(4, dtype=torch.long).unsqueeze(0).expand(batch_size, -1)
215+
sampled_probs = torch.rand(batch_size, 4)
216+
217+
out = scheduler.step(
218+
sampled_tokens=sampled_tokens,
219+
sampled_probs=sampled_probs,
220+
timestep=0,
221+
sample=sample,
222+
mask_token_id=mask_id,
223+
return_dict=True,
224+
)
225+
226+
self.assertEqual(out.prev_sample.shape, (batch_size, 4))
227+
self.assertEqual(out.transfer_index.shape, (batch_size, 4))
228+
229+
230+
if __name__ == "__main__":
231+
unittest.main()

0 commit comments

Comments
 (0)