-
Notifications
You must be signed in to change notification settings - Fork 409
Expand file tree
/
Copy pathbfloat16_array.py
More file actions
104 lines (87 loc) · 3.33 KB
/
bfloat16_array.py
File metadata and controls
104 lines (87 loc) · 3.33 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import array
from pyfory.bfloat16 import bfloat16
from pyfory.utils import is_little_endian
class BFloat16Array:
def __init__(self, values=None):
if values is None:
self._data = array.array("H")
elif isinstance(values, BFloat16Array):
self._data = array.array("H", values._data)
elif isinstance(values, array.array) and values.typecode == "H":
self._data = array.array("H", values)
else:
self._data = array.array(
"H",
(v.to_bits() if isinstance(v, bfloat16) else bfloat16(v).to_bits() for v in values),
)
def __len__(self):
return len(self._data)
def __getitem__(self, index):
return bfloat16.from_bits(self._data[index])
def __setitem__(self, index, value):
if isinstance(value, bfloat16):
self._data[index] = value.to_bits()
else:
self._data[index] = bfloat16(value).to_bits()
def __iter__(self):
for bits in self._data:
yield bfloat16.from_bits(bits)
def __repr__(self):
return f"BFloat16Array([{', '.join(str(bf16) for bf16 in self)}])"
def __eq__(self, other):
if not isinstance(other, BFloat16Array):
return False
return self._data == other._data
def append(self, value):
if isinstance(value, bfloat16):
self._data.append(value.to_bits())
else:
self._data.append(bfloat16(value).to_bits())
def extend(self, values):
if isinstance(values, BFloat16Array):
self._data.extend(values._data)
return
for value in values:
self.append(value)
@property
def itemsize(self):
return 2
def tobytes(self):
if is_little_endian:
return self._data.tobytes()
data = array.array("H", self._data)
data.byteswap()
return data.tobytes()
def to_bits_array(self):
return array.array("H", self._data)
@classmethod
def from_bits_array(cls, values):
arr = cls()
arr._data = array.array("H", values)
return arr
@classmethod
def frombytes(cls, data):
if len(data) % 2 != 0:
raise ValueError("bfloat16 byte payload length must be a multiple of 2")
arr = cls()
arr._data = array.array("H")
arr._data.frombytes(data)
if not is_little_endian:
arr._data.byteswap()
return arr