Skip to content

Commit ada8a4a

Browse files
committed
switch to kernels that use integer dot products
1 parent 4e31c01 commit ada8a4a

File tree

1 file changed

+33
-80
lines changed

1 file changed

+33
-80
lines changed

samples/20_matrixexperiments-i8/matrix_helpers_i8.cl

Lines changed: 33 additions & 80 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,23 @@ int8 activation(int8 i)
5555
#define __builtin_expect(x)
5656
#endif
5757

58+
#if defined(__opencl_c_integer_dot_product_input_4x8bit_packed)
59+
#define dp4 dot_4x8packed_ss_int
60+
#else
61+
#define dp4 emu_dot_4x8packed_ss_int
62+
63+
int emu_dot_4x8packed_ss_int(const uint a, const uint b)
64+
{
65+
const char4 a_c4 = as_char4(a);
66+
const char4 b_c4 = as_char4(b);
67+
68+
return a_c4.x * b_c4.x +
69+
a_c4.y * b_c4.y +
70+
a_c4.z * b_c4.z +
71+
a_c4.w * b_c4.w;
72+
}
73+
#endif
74+
5875
#if defined(cl_intel_subgroups) && defined(cl_intel_subgroups_short) && defined(cl_intel_subgroups_char)
5976

6077
typedef global char* global_aligned_char_ptr __attribute__((align_value(4)));
@@ -79,47 +96,14 @@ int emu_sub_group_i8_i8_matrix_mad_k32(int a, int8 b, int acc)
7996
{
8097
int res = acc;
8198

82-
// TODO: this could use integer dot products instead?
83-
84-
res = as_char4(sub_group_broadcast(a, 0)).x * as_char4(b.s0).x + res;
85-
res = as_char4(sub_group_broadcast(a, 0)).y * as_char4(b.s0).y + res;
86-
res = as_char4(sub_group_broadcast(a, 0)).z * as_char4(b.s0).z + res;
87-
res = as_char4(sub_group_broadcast(a, 0)).w * as_char4(b.s0).w + res;
88-
89-
res = as_char4(sub_group_broadcast(a, 1)).x * as_char4(b.s1).x + res;
90-
res = as_char4(sub_group_broadcast(a, 1)).y * as_char4(b.s1).y + res;
91-
res = as_char4(sub_group_broadcast(a, 1)).z * as_char4(b.s1).z + res;
92-
res = as_char4(sub_group_broadcast(a, 1)).w * as_char4(b.s1).w + res;
93-
94-
res = as_char4(sub_group_broadcast(a, 2)).x * as_char4(b.s2).x + res;
95-
res = as_char4(sub_group_broadcast(a, 2)).y * as_char4(b.s2).y + res;
96-
res = as_char4(sub_group_broadcast(a, 2)).z * as_char4(b.s2).z + res;
97-
res = as_char4(sub_group_broadcast(a, 2)).w * as_char4(b.s2).w + res;
98-
99-
res = as_char4(sub_group_broadcast(a, 3)).x * as_char4(b.s3).x + res;
100-
res = as_char4(sub_group_broadcast(a, 3)).y * as_char4(b.s3).y + res;
101-
res = as_char4(sub_group_broadcast(a, 3)).z * as_char4(b.s3).z + res;
102-
res = as_char4(sub_group_broadcast(a, 3)).w * as_char4(b.s3).w + res;
103-
104-
res = as_char4(sub_group_broadcast(a, 4)).x * as_char4(b.s4).x + res;
105-
res = as_char4(sub_group_broadcast(a, 4)).y * as_char4(b.s4).y + res;
106-
res = as_char4(sub_group_broadcast(a, 4)).z * as_char4(b.s4).z + res;
107-
res = as_char4(sub_group_broadcast(a, 4)).w * as_char4(b.s4).w + res;
108-
109-
res = as_char4(sub_group_broadcast(a, 5)).x * as_char4(b.s5).x + res;
110-
res = as_char4(sub_group_broadcast(a, 5)).y * as_char4(b.s5).y + res;
111-
res = as_char4(sub_group_broadcast(a, 5)).z * as_char4(b.s5).z + res;
112-
res = as_char4(sub_group_broadcast(a, 5)).w * as_char4(b.s5).w + res;
113-
114-
res = as_char4(sub_group_broadcast(a, 6)).x * as_char4(b.s6).x + res;
115-
res = as_char4(sub_group_broadcast(a, 6)).y * as_char4(b.s6).y + res;
116-
res = as_char4(sub_group_broadcast(a, 6)).z * as_char4(b.s6).z + res;
117-
res = as_char4(sub_group_broadcast(a, 6)).w * as_char4(b.s6).w + res;
118-
119-
res = as_char4(sub_group_broadcast(a, 7)).x * as_char4(b.s7).x + res;
120-
res = as_char4(sub_group_broadcast(a, 7)).y * as_char4(b.s7).y + res;
121-
res = as_char4(sub_group_broadcast(a, 7)).z * as_char4(b.s7).z + res;
122-
res = as_char4(sub_group_broadcast(a, 7)).w * as_char4(b.s7).w + res;
99+
res = dp4(sub_group_broadcast(a, 0), b.s0) + res;
100+
res = dp4(sub_group_broadcast(a, 1), b.s1) + res;
101+
res = dp4(sub_group_broadcast(a, 2), b.s2) + res;
102+
res = dp4(sub_group_broadcast(a, 3), b.s3) + res;
103+
res = dp4(sub_group_broadcast(a, 4), b.s4) + res;
104+
res = dp4(sub_group_broadcast(a, 5), b.s5) + res;
105+
res = dp4(sub_group_broadcast(a, 6), b.s6) + res;
106+
res = dp4(sub_group_broadcast(a, 7), b.s7) + res;
123107

124108
return res;
125109
}
@@ -171,45 +155,14 @@ int emu_sub_group_i8_i8_matrix_mad_k32(short a, int8 b, int acc)
171155
{
172156
float res = acc;
173157

174-
res = as_char2(intel_sub_group_broadcast(a, 0)).x * as_char4(b.s0).x + res;
175-
res = as_char2(intel_sub_group_broadcast(a, 0)).y * as_char4(b.s0).y + res;
176-
res = as_char2(intel_sub_group_broadcast(a, 1)).x * as_char4(b.s0).z + res;
177-
res = as_char2(intel_sub_group_broadcast(a, 1)).y * as_char4(b.s0).w + res;
178-
179-
res = as_char2(intel_sub_group_broadcast(a, 2)).x * as_char4(b.s1).x + res;
180-
res = as_char2(intel_sub_group_broadcast(a, 2)).y * as_char4(b.s1).y + res;
181-
res = as_char2(intel_sub_group_broadcast(a, 3)).x * as_char4(b.s1).z + res;
182-
res = as_char2(intel_sub_group_broadcast(a, 3)).y * as_char4(b.s1).w + res;
183-
184-
res = as_char2(intel_sub_group_broadcast(a, 4)).x * as_char4(b.s2).x + res;
185-
res = as_char2(intel_sub_group_broadcast(a, 4)).y * as_char4(b.s2).y + res;
186-
res = as_char2(intel_sub_group_broadcast(a, 5)).x * as_char4(b.s2).z + res;
187-
res = as_char2(intel_sub_group_broadcast(a, 5)).y * as_char4(b.s2).w + res;
188-
189-
res = as_char2(intel_sub_group_broadcast(a, 6)).x * as_char4(b.s3).x + res;
190-
res = as_char2(intel_sub_group_broadcast(a, 6)).y * as_char4(b.s3).y + res;
191-
res = as_char2(intel_sub_group_broadcast(a, 7)).x * as_char4(b.s3).z + res;
192-
res = as_char2(intel_sub_group_broadcast(a, 7)).y * as_char4(b.s3).w + res;
193-
194-
res = as_char2(intel_sub_group_broadcast(a, 8)).x * as_char4(b.s4).x + res;
195-
res = as_char2(intel_sub_group_broadcast(a, 8)).y * as_char4(b.s4).y + res;
196-
res = as_char2(intel_sub_group_broadcast(a, 9)).x * as_char4(b.s4).z + res;
197-
res = as_char2(intel_sub_group_broadcast(a, 9)).y * as_char4(b.s4).w + res;
198-
199-
res = as_char2(intel_sub_group_broadcast(a, 10)).x * as_char4(b.s5).x + res;
200-
res = as_char2(intel_sub_group_broadcast(a, 10)).y * as_char4(b.s5).y + res;
201-
res = as_char2(intel_sub_group_broadcast(a, 11)).x * as_char4(b.s5).z + res;
202-
res = as_char2(intel_sub_group_broadcast(a, 11)).y * as_char4(b.s5).w + res;
203-
204-
res = as_char2(intel_sub_group_broadcast(a, 12)).x * as_char4(b.s6).x + res;
205-
res = as_char2(intel_sub_group_broadcast(a, 12)).y * as_char4(b.s6).y + res;
206-
res = as_char2(intel_sub_group_broadcast(a, 13)).x * as_char4(b.s6).z + res;
207-
res = as_char2(intel_sub_group_broadcast(a, 13)).y * as_char4(b.s6).w + res;
208-
209-
res = as_char2(intel_sub_group_broadcast(a, 14)).x * as_char4(b.s7).x + res;
210-
res = as_char2(intel_sub_group_broadcast(a, 14)).y * as_char4(b.s7).y + res;
211-
res = as_char2(intel_sub_group_broadcast(a, 15)).x * as_char4(b.s7).z + res;
212-
res = as_char2(intel_sub_group_broadcast(a, 15)).y * as_char4(b.s7).w + res;
158+
res = dp4(as_uint((short2)(sub_group_broadcast(a, 0), sub_group_broadcast(a, 1))), b.s0) + res;
159+
res = dp4(as_uint((short2)(sub_group_broadcast(a, 2), sub_group_broadcast(a, 3))), b.s1) + res;
160+
res = dp4(as_uint((short2)(sub_group_broadcast(a, 4), sub_group_broadcast(a, 5))), b.s2) + res;
161+
res = dp4(as_uint((short2)(sub_group_broadcast(a, 6), sub_group_broadcast(a, 7))), b.s3) + res;
162+
res = dp4(as_uint((short2)(sub_group_broadcast(a, 8), sub_group_broadcast(a, 9))), b.s4) + res;
163+
res = dp4(as_uint((short2)(sub_group_broadcast(a, 10), sub_group_broadcast(a, 11))), b.s5) + res;
164+
res = dp4(as_uint((short2)(sub_group_broadcast(a, 12), sub_group_broadcast(a, 13))), b.s6) + res;
165+
res = dp4(as_uint((short2)(sub_group_broadcast(a, 14), sub_group_broadcast(a, 15))), b.s7) + res;
213166

214167
return res;
215168
}

0 commit comments

Comments
 (0)