@@ -55,6 +55,23 @@ int8 activation(int8 i)
5555#define __builtin_expect (x )
5656#endif
5757
58+ #if defined(__opencl_c_integer_dot_product_input_4x8bit_packed )
59+ #define dp4 dot_4x8packed_ss_int
60+ #else
61+ #define dp4 emu_dot_4x8packed_ss_int
62+
63+ int emu_dot_4x8packed_ss_int (const uint a , const uint b )
64+ {
65+ const char4 a_c4 = as_char4 (a );
66+ const char4 b_c4 = as_char4 (b );
67+
68+ return a_c4 .x * b_c4 .x +
69+ a_c4 .y * b_c4 .y +
70+ a_c4 .z * b_c4 .z +
71+ a_c4 .w * b_c4 .w ;
72+ }
73+ #endif
74+
5875#if defined(cl_intel_subgroups ) && defined(cl_intel_subgroups_short ) && defined(cl_intel_subgroups_char )
5976
6077typedef global char * global_aligned_char_ptr __attribute__((align_value (4 )));
@@ -79,47 +96,14 @@ int emu_sub_group_i8_i8_matrix_mad_k32(int a, int8 b, int acc)
7996{
8097 int res = acc ;
8198
82- // TODO: this could use integer dot products instead?
83-
84- res = as_char4 (sub_group_broadcast (a , 0 )).x * as_char4 (b .s0 ).x + res ;
85- res = as_char4 (sub_group_broadcast (a , 0 )).y * as_char4 (b .s0 ).y + res ;
86- res = as_char4 (sub_group_broadcast (a , 0 )).z * as_char4 (b .s0 ).z + res ;
87- res = as_char4 (sub_group_broadcast (a , 0 )).w * as_char4 (b .s0 ).w + res ;
88-
89- res = as_char4 (sub_group_broadcast (a , 1 )).x * as_char4 (b .s1 ).x + res ;
90- res = as_char4 (sub_group_broadcast (a , 1 )).y * as_char4 (b .s1 ).y + res ;
91- res = as_char4 (sub_group_broadcast (a , 1 )).z * as_char4 (b .s1 ).z + res ;
92- res = as_char4 (sub_group_broadcast (a , 1 )).w * as_char4 (b .s1 ).w + res ;
93-
94- res = as_char4 (sub_group_broadcast (a , 2 )).x * as_char4 (b .s2 ).x + res ;
95- res = as_char4 (sub_group_broadcast (a , 2 )).y * as_char4 (b .s2 ).y + res ;
96- res = as_char4 (sub_group_broadcast (a , 2 )).z * as_char4 (b .s2 ).z + res ;
97- res = as_char4 (sub_group_broadcast (a , 2 )).w * as_char4 (b .s2 ).w + res ;
98-
99- res = as_char4 (sub_group_broadcast (a , 3 )).x * as_char4 (b .s3 ).x + res ;
100- res = as_char4 (sub_group_broadcast (a , 3 )).y * as_char4 (b .s3 ).y + res ;
101- res = as_char4 (sub_group_broadcast (a , 3 )).z * as_char4 (b .s3 ).z + res ;
102- res = as_char4 (sub_group_broadcast (a , 3 )).w * as_char4 (b .s3 ).w + res ;
103-
104- res = as_char4 (sub_group_broadcast (a , 4 )).x * as_char4 (b .s4 ).x + res ;
105- res = as_char4 (sub_group_broadcast (a , 4 )).y * as_char4 (b .s4 ).y + res ;
106- res = as_char4 (sub_group_broadcast (a , 4 )).z * as_char4 (b .s4 ).z + res ;
107- res = as_char4 (sub_group_broadcast (a , 4 )).w * as_char4 (b .s4 ).w + res ;
108-
109- res = as_char4 (sub_group_broadcast (a , 5 )).x * as_char4 (b .s5 ).x + res ;
110- res = as_char4 (sub_group_broadcast (a , 5 )).y * as_char4 (b .s5 ).y + res ;
111- res = as_char4 (sub_group_broadcast (a , 5 )).z * as_char4 (b .s5 ).z + res ;
112- res = as_char4 (sub_group_broadcast (a , 5 )).w * as_char4 (b .s5 ).w + res ;
113-
114- res = as_char4 (sub_group_broadcast (a , 6 )).x * as_char4 (b .s6 ).x + res ;
115- res = as_char4 (sub_group_broadcast (a , 6 )).y * as_char4 (b .s6 ).y + res ;
116- res = as_char4 (sub_group_broadcast (a , 6 )).z * as_char4 (b .s6 ).z + res ;
117- res = as_char4 (sub_group_broadcast (a , 6 )).w * as_char4 (b .s6 ).w + res ;
118-
119- res = as_char4 (sub_group_broadcast (a , 7 )).x * as_char4 (b .s7 ).x + res ;
120- res = as_char4 (sub_group_broadcast (a , 7 )).y * as_char4 (b .s7 ).y + res ;
121- res = as_char4 (sub_group_broadcast (a , 7 )).z * as_char4 (b .s7 ).z + res ;
122- res = as_char4 (sub_group_broadcast (a , 7 )).w * as_char4 (b .s7 ).w + res ;
99+ res = dp4 (sub_group_broadcast (a , 0 ), b .s0 ) + res ;
100+ res = dp4 (sub_group_broadcast (a , 1 ), b .s1 ) + res ;
101+ res = dp4 (sub_group_broadcast (a , 2 ), b .s2 ) + res ;
102+ res = dp4 (sub_group_broadcast (a , 3 ), b .s3 ) + res ;
103+ res = dp4 (sub_group_broadcast (a , 4 ), b .s4 ) + res ;
104+ res = dp4 (sub_group_broadcast (a , 5 ), b .s5 ) + res ;
105+ res = dp4 (sub_group_broadcast (a , 6 ), b .s6 ) + res ;
106+ res = dp4 (sub_group_broadcast (a , 7 ), b .s7 ) + res ;
123107
124108 return res ;
125109}
@@ -171,45 +155,14 @@ int emu_sub_group_i8_i8_matrix_mad_k32(short a, int8 b, int acc)
171155{
172156 float res = acc ;
173157
174- res = as_char2 (intel_sub_group_broadcast (a , 0 )).x * as_char4 (b .s0 ).x + res ;
175- res = as_char2 (intel_sub_group_broadcast (a , 0 )).y * as_char4 (b .s0 ).y + res ;
176- res = as_char2 (intel_sub_group_broadcast (a , 1 )).x * as_char4 (b .s0 ).z + res ;
177- res = as_char2 (intel_sub_group_broadcast (a , 1 )).y * as_char4 (b .s0 ).w + res ;
178-
179- res = as_char2 (intel_sub_group_broadcast (a , 2 )).x * as_char4 (b .s1 ).x + res ;
180- res = as_char2 (intel_sub_group_broadcast (a , 2 )).y * as_char4 (b .s1 ).y + res ;
181- res = as_char2 (intel_sub_group_broadcast (a , 3 )).x * as_char4 (b .s1 ).z + res ;
182- res = as_char2 (intel_sub_group_broadcast (a , 3 )).y * as_char4 (b .s1 ).w + res ;
183-
184- res = as_char2 (intel_sub_group_broadcast (a , 4 )).x * as_char4 (b .s2 ).x + res ;
185- res = as_char2 (intel_sub_group_broadcast (a , 4 )).y * as_char4 (b .s2 ).y + res ;
186- res = as_char2 (intel_sub_group_broadcast (a , 5 )).x * as_char4 (b .s2 ).z + res ;
187- res = as_char2 (intel_sub_group_broadcast (a , 5 )).y * as_char4 (b .s2 ).w + res ;
188-
189- res = as_char2 (intel_sub_group_broadcast (a , 6 )).x * as_char4 (b .s3 ).x + res ;
190- res = as_char2 (intel_sub_group_broadcast (a , 6 )).y * as_char4 (b .s3 ).y + res ;
191- res = as_char2 (intel_sub_group_broadcast (a , 7 )).x * as_char4 (b .s3 ).z + res ;
192- res = as_char2 (intel_sub_group_broadcast (a , 7 )).y * as_char4 (b .s3 ).w + res ;
193-
194- res = as_char2 (intel_sub_group_broadcast (a , 8 )).x * as_char4 (b .s4 ).x + res ;
195- res = as_char2 (intel_sub_group_broadcast (a , 8 )).y * as_char4 (b .s4 ).y + res ;
196- res = as_char2 (intel_sub_group_broadcast (a , 9 )).x * as_char4 (b .s4 ).z + res ;
197- res = as_char2 (intel_sub_group_broadcast (a , 9 )).y * as_char4 (b .s4 ).w + res ;
198-
199- res = as_char2 (intel_sub_group_broadcast (a , 10 )).x * as_char4 (b .s5 ).x + res ;
200- res = as_char2 (intel_sub_group_broadcast (a , 10 )).y * as_char4 (b .s5 ).y + res ;
201- res = as_char2 (intel_sub_group_broadcast (a , 11 )).x * as_char4 (b .s5 ).z + res ;
202- res = as_char2 (intel_sub_group_broadcast (a , 11 )).y * as_char4 (b .s5 ).w + res ;
203-
204- res = as_char2 (intel_sub_group_broadcast (a , 12 )).x * as_char4 (b .s6 ).x + res ;
205- res = as_char2 (intel_sub_group_broadcast (a , 12 )).y * as_char4 (b .s6 ).y + res ;
206- res = as_char2 (intel_sub_group_broadcast (a , 13 )).x * as_char4 (b .s6 ).z + res ;
207- res = as_char2 (intel_sub_group_broadcast (a , 13 )).y * as_char4 (b .s6 ).w + res ;
208-
209- res = as_char2 (intel_sub_group_broadcast (a , 14 )).x * as_char4 (b .s7 ).x + res ;
210- res = as_char2 (intel_sub_group_broadcast (a , 14 )).y * as_char4 (b .s7 ).y + res ;
211- res = as_char2 (intel_sub_group_broadcast (a , 15 )).x * as_char4 (b .s7 ).z + res ;
212- res = as_char2 (intel_sub_group_broadcast (a , 15 )).y * as_char4 (b .s7 ).w + res ;
158+ res = dp4 (as_uint ((short2 )(sub_group_broadcast (a , 0 ), sub_group_broadcast (a , 1 ))), b .s0 ) + res ;
159+ res = dp4 (as_uint ((short2 )(sub_group_broadcast (a , 2 ), sub_group_broadcast (a , 3 ))), b .s1 ) + res ;
160+ res = dp4 (as_uint ((short2 )(sub_group_broadcast (a , 4 ), sub_group_broadcast (a , 5 ))), b .s2 ) + res ;
161+ res = dp4 (as_uint ((short2 )(sub_group_broadcast (a , 6 ), sub_group_broadcast (a , 7 ))), b .s3 ) + res ;
162+ res = dp4 (as_uint ((short2 )(sub_group_broadcast (a , 8 ), sub_group_broadcast (a , 9 ))), b .s4 ) + res ;
163+ res = dp4 (as_uint ((short2 )(sub_group_broadcast (a , 10 ), sub_group_broadcast (a , 11 ))), b .s5 ) + res ;
164+ res = dp4 (as_uint ((short2 )(sub_group_broadcast (a , 12 ), sub_group_broadcast (a , 13 ))), b .s6 ) + res ;
165+ res = dp4 (as_uint ((short2 )(sub_group_broadcast (a , 14 ), sub_group_broadcast (a , 15 ))), b .s7 ) + res ;
213166
214167 return res ;
215168}
0 commit comments