-
Notifications
You must be signed in to change notification settings - Fork 5
Expand file tree
/
Copy pathexample.py
More file actions
149 lines (125 loc) · 4.44 KB
/
example.py
File metadata and controls
149 lines (125 loc) · 4.44 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
import hashlib
from rustformlang.rustformlang import DFA
from constrained_diffusion.constrain_utils import compile_lex_map
from constrained_diffusion.eval.dllm.model import load_model
from constrained_diffusion.eval.dllm.datasets.generic import Instance, extract_code
from rustformlang.cfg import CFG
class CppInstance(Instance):
"""
Represents a single instance in a dataset.
All instances must have a unique field "instance_id".
"""
def __init__(self, prompt: str):
"""
Initializes the CppInstance.
All instances must have a unique field "instance_id".
"""
self._prompt = prompt
def instance_id(self) -> str:
"""
Returns the unique identifier for the instance.
This is used to identify instances across datasets.
"""
# Use the MD5 hash of the prompt to create a unique instance ID
return hashlib.md5(self._prompt.encode("utf-8")).hexdigest()
def user_prompt_content(self) -> str:
"""
Returns the user prompt content for the instance.
"""
return self._prompt
def assistant_start_line(self) -> str:
"""
Returns a string that indicates the start of the assistant's response inside the code block
i.e. function foo() {\n
Default is an empty string, meaning the assistant's response starts immediately
"""
return ""
def language_short_name(self) -> str:
"""
Returns the short name of the instance's language.
This is used to indicate the language inside the code block to the assistant
i.e. ```typescript --> language short name is "typescript"
"""
return "cpp"
def extract_result(self, s: str) -> str:
"""
Extracts the result from the assistant's response.
This is used to evaluate the instance's response.
The string s is the model output including ```language_short_name()\n + assistant_start_line()
Default just extracts the code block from the response.
"""
return extract_code(s, self.language_short_name(), 0)
def system_message_content(self) -> str:
"""
Returns the system message content for the dataset.
"""
return "You are an expert C++ programmer. Write a C++ function that solves the problem given by the user.\nMake sure to include all necessary headers and use standard C++ libraries."
def language_lex_subtokens(
self,
) -> tuple[CFG, dict[str, str | DFA], dict[str, set[str]]]:
"""
Returns the grammar, lex map and subtokens for the dataset.
"""
from constrained_diffusion.cfgs.cpp import cpp_grammar
return cpp_grammar()
def prelex(self) -> str | None:
"""
Returns the prelex for the dataset.
Usually its None
"""
return "\x02\x03"
def strip_chars(self):
"""
Returns the characters to strip between lexed tokens
Defaults to any whitespace
"""
return None
def main():
device = "cuda"
model_name = "GSAI-ML/LLaDA-8B-Instruct"
eval_model = load_model(model_name)
model, tokenizer = eval_model.model(device), eval_model.tokenizer(device)
instance = CppInstance(
prompt="Write a C++ function that calculates the factorial of a number."
)
diffusion_steps = 256
generate_tokens = 256
temperature = 0.2
timeout = 300
trace = True
lang, orig_lex_map, subtokens = instance.language_lex_subtokens()
lex_map = compile_lex_map(orig_lex_map, subtokens)
(
prompt,
code,
code_raw,
extracted,
timed_out,
resamples,
autocompletion_raw,
autocompletion,
time_taken_autocompletion,
) = eval_model.generate_constrained(
instance,
model,
tokenizer,
steps=diffusion_steps,
gen_length=generate_tokens,
temperature=temperature,
lang=lang,
lex_map=lex_map,
subtokens=subtokens,
prelex=instance.prelex(),
timeout=timeout,
trace=trace,
orig_lex_map=orig_lex_map,
alg="low_confidence",
additional_stuff=None,
)
print("----------- Prompt ----------------")
print(prompt)
print("Took {:.2f} seconds to generate.".format(time_taken_autocompletion))
print("----------- Code ------------------")
print(autocompletion or extracted)
if __name__ == "__main__":
main()