-
Notifications
You must be signed in to change notification settings - Fork 409
Expand file tree
/
Copy pathcodegen.py
More file actions
230 lines (201 loc) · 8.38 KB
/
codegen.py
File metadata and controls
230 lines (201 loc) · 8.38 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import atexit
import keyword
import linecache
import os
import re
import uuid
from typing import List, Callable, Union
from pyfory.resolver import NULL_FLAG, NOT_NULL_VALUE_FLAG
from pyfory.error import CompileError
_type_mapping = {
bool: ("write_bool", "read_bool", "write_nullable_pybool", "read_nullable_pybool"),
int: (
"write_varint64",
"read_varint64",
"write_nullable_pyint64",
"read_nullable_pyint64",
),
float: (
"write_double",
"read_double",
"write_nullable_pyfloat64",
"read_nullable_pyfloat64",
),
str: ("write_string", "read_string", "write_nullable_pystr", "read_nullable_pystr"),
"int8": ("write_int8", "read_int8", "write_nullable_int8", "read_nullable_int8"),
"int16": ("write_int16", "read_int16", "write_nullable_int16", "read_nullable_int16"),
"int32": ("write_varint32", "read_varint32", "write_nullable_int32", "read_nullable_int32"),
"int64": (
"write_varint64",
"read_varint64",
"write_nullable_pyint64",
"read_nullable_pyint64",
),
"float32": ("write_float32", "read_float32", "write_nullable_float32", "read_nullable_float32"),
"float64": (
"write_double",
"read_double",
"write_nullable_pyfloat64",
"read_nullable_pyfloat64",
),
"bfloat16": ("write_bfloat16", "read_bfloat16", "write_nullable_bfloat16", "read_nullable_bfloat16"),
}
def gen_write_nullable_basic_stmts(
buffer: str,
value: str,
type_: Union[type, str],
) -> List[str]:
methods = _type_mapping[type_]
from pyfory import ENABLE_FORY_CYTHON_SERIALIZATION
if ENABLE_FORY_CYTHON_SERIALIZATION:
return [f"{methods[2]}({buffer}, {value})"]
return [
f"if {value} is None:",
f" {buffer}.write_int8({NULL_FLAG})",
"else: ",
f" {buffer}.write_int8({NOT_NULL_VALUE_FLAG})",
f" {buffer}.{methods[0]}({value})",
]
def gen_read_nullable_basic_stmts(
buffer: str,
type_: Union[type, str],
set_action: Callable[[str], str],
) -> List[str]:
methods = _type_mapping[type_]
from pyfory import ENABLE_FORY_CYTHON_SERIALIZATION
if ENABLE_FORY_CYTHON_SERIALIZATION:
return [set_action(f"{methods[3]}({buffer})")]
read_value = f"{buffer}.{methods[1]}()"
return [
f"if {buffer}.read_int8() == {NULL_FLAG}:",
f" {set_action('None')}",
"else: ",
f" {set_action(read_value)}",
]
def _sanitize_function_name(name: str) -> str:
"""
Sanitize function names by replacing invalid characters with valid ones.
This is needed because function names with special characters like angle brackets
are not valid Python syntax.
"""
# 1) Replace every non‐identifier character with underscore
sanitized = re.sub(r"[^0-9A-Za-z_]", "_", name)
# 2) Prevent leading digit
if re.match(r"^\d", sanitized):
sanitized = "_" + sanitized
# 3) Avoid plain keywords
if keyword.iskeyword(sanitized):
sanitized = "_" + sanitized
return sanitized
def compile_function(
function_name: str,
params: List[str],
stmts: List[str],
context: dict,
):
from pyfory import ENABLE_FORY_CYTHON_SERIALIZATION
if ENABLE_FORY_CYTHON_SERIALIZATION:
from pyfory import serialization
context["write_nullable_pybool"] = serialization.write_nullable_pybool
context["read_nullable_pybool"] = serialization.read_nullable_pybool
context["write_nullable_int8"] = serialization.write_nullable_int8
context["read_nullable_int8"] = serialization.read_nullable_int8
context["write_nullable_int16"] = serialization.write_nullable_int16
context["read_nullable_int16"] = serialization.read_nullable_int16
context["write_nullable_int32"] = serialization.write_nullable_int32
context["read_nullable_int32"] = serialization.read_nullable_int32
context["write_nullable_pyint64"] = serialization.write_nullable_pyint64
context["read_nullable_pyint64"] = serialization.read_nullable_pyint64
context["write_nullable_float32"] = serialization.write_nullable_float32
context["read_nullable_float32"] = serialization.read_nullable_float32
context["write_nullable_pyfloat64"] = serialization.write_nullable_pyfloat64
context["read_nullable_pyfloat64"] = serialization.read_nullable_pyfloat64
context["write_nullable_pystr"] = serialization.write_nullable_pystr
context["read_nullable_pystr"] = serialization.read_nullable_pystr
context["write_nullable_bfloat16"] = serialization.write_nullable_bfloat16
context["read_nullable_bfloat16"] = serialization.read_nullable_bfloat16
stmts = [f"{ident(statement)}" for statement in stmts]
# Sanitize the function name to ensure it is valid Python syntax
sanitized_function_name = _sanitize_function_name(function_name)
stmts.insert(0, f"def {sanitized_function_name}({', '.join(params)}):")
stmts = [f"{statement} # line {idx + 1}" for idx, statement in enumerate(stmts)]
code = "\n".join(stmts)
filename = _generate_filename(function_name)
code_dir = _get_code_dir()
if code_dir:
filename = os.path.join(code_dir, filename)
with open(filename, "w") as f:
f.write(code)
f.flush()
if _delete_code_on_exit():
atexit.register(os.remove, filename)
try:
compiled = compile(code, filename, "exec")
except Exception as e:
raise CompileError(f"Failed to compile code:\n{code}") from e
exec(compiled, context, context)
# See https://stackoverflow.com/questions/64879414/how-does-attrs-fool-the-debugger-to-step-into-auto-generated-code # noqa: E501
# In order of debuggers like PDB being able to step through the code,
# we add a fake linecache entry.
linecache.cache[filename] = (
len(code),
None,
code.splitlines(True),
filename,
)
# Use the sanitized function name to retrieve the function from context
sanitized_function_name = _sanitize_function_name(function_name)
return code, context[sanitized_function_name]
# Based on https://github.com/python-attrs/attrs/blob/32fb12789e5cba4b2e71c09e47196b10763ddd7d/src/attr/_make.py#L1863 # noqa: E501
def _generate_filename(func_name):
"""
Create a "filename" suitable for a function being generated.
"""
# Sanitize the function name for filename
sanitized_name = _sanitize_function_name(func_name)
unique_id = uuid.uuid4()
extra = "0"
count = 1
while True:
filename = f"fory_generated_{sanitized_name}_{extra}.py"
# To handle concurrency we essentially "reserve" our spot in
# the linecache with a dummy line. The caller can then
# set this value correctly.
cache_line = (1, None, [str(unique_id)], filename)
if linecache.cache.setdefault(filename, cache_line) == cache_line:
return filename
# Looks like this spot is taken. Try again.
count += 1
extra = "{0}".format(count)
def _get_code_dir():
code_dir = os.environ.get("FORY_CODE_DIR")
if code_dir is not None and not os.path.exists(code_dir):
os.makedirs(code_dir)
return code_dir
def _delete_code_on_exit():
return os.environ.get("DELETE_CODE_ON_EXIT", "True").lower() in ("true", "1")
def ident_lines(lines: Union[List[str], str]):
is_str = type(lines) is str
if is_str:
lines = lines.split("\n")
lines = [ident(line) for line in lines]
return lines if not is_str else "\n".join(lines)
def ident(line: str):
assert type(line) is str, type(line)
return " " * 4 + line