@@ -156,6 +156,113 @@ static size_t sizeof_dtype(const DTypeId dt) {
156156 };
157157};
158158
159+ static bool is_float (DTypeId t) {
160+ switch (t) {
161+ case DDPT ::DTypeId::FLOAT64 :
162+ case DDPT ::DTypeId::FLOAT32 :
163+ return true ;
164+ default :
165+ return false ;
166+ }
167+ }
168+
169+ static bool is_int (DTypeId t) {
170+ switch (t) {
171+ case DDPT ::DTypeId::INT64 :
172+ case DDPT ::DTypeId::INT32 :
173+ case DDPT ::DTypeId::INT16 :
174+ case DDPT ::DTypeId::INT8 :
175+ return true ;
176+ default :
177+ return false ;
178+ }
179+ }
180+
181+ static bool is_uint (DTypeId t) {
182+ switch (t) {
183+ case DDPT ::DTypeId::UINT64 :
184+ case DDPT ::DTypeId::UINT32 :
185+ case DDPT ::DTypeId::UINT16 :
186+ case DDPT ::DTypeId::UINT8 :
187+ case DDPT ::DTypeId::BOOL :
188+ return true ;
189+ default :
190+ return false ;
191+ }
192+ }
193+
194+ static size_t dtype_bitwidth (DTypeId t) {
195+ switch (t) {
196+ case DDPT ::DTypeId::FLOAT64 :
197+ case DDPT ::DTypeId::INT64 :
198+ case DDPT ::DTypeId::UINT64 :
199+ return 64 ;
200+ case DDPT ::DTypeId::FLOAT32 :
201+ case DDPT ::DTypeId::INT32 :
202+ case DDPT ::DTypeId::UINT32 :
203+ return 32 ;
204+ case DDPT ::DTypeId::INT16 :
205+ case DDPT ::DTypeId::UINT16 :
206+ return 16 ;
207+ case DDPT ::DTypeId::INT8 :
208+ case DDPT ::DTypeId::UINT8 :
209+ return 8 ;
210+ case DDPT ::DTypeId::BOOL :
211+ return 1 ;
212+ default :
213+ assert (!" Unknown DTypeId" );
214+ }
215+ }
216+
217+ static DTypeId get_float_dtype (size_t bitwidth) {
218+ switch (bitwidth) {
219+ case 64 :
220+ return DDPT ::DTypeId::FLOAT64 ;
221+ case 32 :
222+ return DDPT ::DTypeId::FLOAT32 ;
223+ default :
224+ assert (!" Unknown bitwidth" );
225+ }
226+ }
227+
228+ static DTypeId get_int_dtype (size_t bitwidth) {
229+ switch (bitwidth) {
230+ case 64 :
231+ return DDPT ::DTypeId::INT64 ;
232+ case 32 :
233+ return DDPT ::DTypeId::INT32 ;
234+ case 16 :
235+ return DDPT ::DTypeId::INT16 ;
236+ case 8 :
237+ return DDPT ::DTypeId::INT8 ;
238+ default :
239+ assert (!" Unknown bitwidth" );
240+ }
241+ }
242+
243+ static DTypeId promoted_dtype (DTypeId a, DTypeId b) {
244+ if ((is_float (a) && is_float (b)) || (is_int (a) && is_int (b)) ||
245+ (is_uint (a) && is_uint (b))) {
246+ return dtype_bitwidth (a) > dtype_bitwidth (b) ? a : b;
247+ }
248+ if (is_float (a) || is_float (b)) {
249+ return get_float_dtype (std::max (dtype_bitwidth (a), dtype_bitwidth (b)));
250+ }
251+ // mixed signed/unsigned int case
252+ size_t si_width, ui_width, max_width = 64 ;
253+ if (is_uint (a)) {
254+ ui_width = dtype_bitwidth (a);
255+ si_width = dtype_bitwidth (b);
256+ } else {
257+ ui_width = dtype_bitwidth (b);
258+ si_width = dtype_bitwidth (a);
259+ }
260+ if (ui_width < si_width) {
261+ return get_int_dtype (si_width);
262+ }
263+ return get_int_dtype (std::min (2 * ui_width, max_width));
264+ }
265+
159266using RedOpType = ReduceOpId;
160267
161268inline RedOpType red_op (const char *op) {
0 commit comments