Skip to content

Commit 4e31c01

Browse files
committed
fixes for CPU and more
1 parent 7f5764c commit 4e31c01

File tree

9 files changed

+207
-106
lines changed

9 files changed

+207
-106
lines changed

samples/20_matrixexperiments-bf16/main.cpp

Lines changed: 52 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,12 @@ static size_t findMinSubGroupSize(cl::Device& device)
7676
return 0;
7777
}
7878

79+
static bool supportsSubgroupSize(cl::Device& device, size_t subgroupSize)
80+
{
81+
auto s = device.getInfo<CL_DEVICE_SUB_GROUP_SIZES_INTEL>();
82+
return std::find(std::begin(s), std::end(s), subgroupSize) != std::end(s);
83+
}
84+
7985
static void setRoundRobin(cl::Kernel& kernel)
8086
{
8187
constexpr cl_kernel_exec_info CL_KERNEL_EXEC_INFO_THREAD_ARBITRATION_POLICY_INTEL = 0x10025;
@@ -175,6 +181,23 @@ static float hw_time(cl::Event& event)
175181
return ns / 1e9f;
176182
}
177183

184+
static cl::NDRange getRequiredLocalWorkSize(cl::Kernel& kernel, cl::CommandQueue queue)
185+
{
186+
// Note: This shouldn't be necessary, and the OpenCL implementation should
187+
// automatically choose the required local work-group size when the local
188+
// work-group size is `nullptr`. This is not working for some OpenCL
189+
// implementations, though, so we will just query and use the required local
190+
// work-group size explicitly.
191+
auto device = queue.getInfo<CL_QUEUE_DEVICE>();
192+
auto reqd_wgs = kernel.getWorkGroupInfo<CL_KERNEL_COMPILE_WORK_GROUP_SIZE>(device);
193+
194+
if (reqd_wgs[0] > 0 && reqd_wgs[1] > 0 && reqd_wgs[2] > 0) {
195+
return cl::NDRange(reqd_wgs[0], reqd_wgs[1], reqd_wgs[2]);
196+
}
197+
198+
return cl::NullRange;
199+
}
200+
178201
static void bfloat16_naive(
179202
cl::Context& context, cl::Program& program, cl::CommandQueue& queue,
180203
cl::Buffer& C, cl::Buffer& A, cl::Buffer& B,
@@ -187,6 +210,8 @@ static void bfloat16_naive(
187210
if (kernel() == nullptr) {
188211
printf("unsupported.\n");
189212
} else {
213+
const cl::NDRange localWorkSize = getRequiredLocalWorkSize(kernel, queue);
214+
190215
kernel.setArg(0, C);
191216
kernel.setArg(1, A);
192217
kernel.setArg(2, B);
@@ -201,7 +226,7 @@ static void bfloat16_naive(
201226
cl::Event event;
202227
auto start = test_clock::now();
203228
queue.enqueueNDRangeKernel(kernel, cl::NullRange,
204-
cl::NDRange{N, M}, cl::NullRange, nullptr, &event);
229+
cl::NDRange{N, M}, localWorkSize, nullptr, &event);
205230
queue.finish();
206231
auto end = test_clock::now();
207232
std::chrono::duration<float> sw_time = end - start;
@@ -237,6 +262,8 @@ static void bfloat16_dpas_rowmajor(
237262
if (kernel() == nullptr) {
238263
printf("unsupported.\n");
239264
} else {
265+
const cl::NDRange localWorkSize = getRequiredLocalWorkSize(kernel, queue);
266+
240267
kernel.setArg(0, C);
241268
kernel.setArg(1, A);
242269
kernel.setArg(2, B);
@@ -251,7 +278,7 @@ static void bfloat16_dpas_rowmajor(
251278
cl::Event event;
252279
auto start = test_clock::now();
253280
queue.enqueueNDRangeKernel(kernel, cl::NullRange,
254-
cl::NDRange{N, M/tM}, cl::NullRange, nullptr, &event);
281+
cl::NDRange{N, M/tM}, localWorkSize, nullptr, &event);
255282
queue.finish();
256283
auto end = test_clock::now();
257284
std::chrono::duration<float> sw_time = end - start;
@@ -293,6 +320,8 @@ static void bfloat16_dpas_rowmajor_tiled(
293320
} else if (tN * NN > N) {
294321
printf("N is too small.\n");
295322
} else {
323+
const cl::NDRange localWorkSize = getRequiredLocalWorkSize(kernel, queue);
324+
296325
kernel.setArg(0, C);
297326
kernel.setArg(1, A);
298327
kernel.setArg(2, B);
@@ -307,7 +336,7 @@ static void bfloat16_dpas_rowmajor_tiled(
307336
cl::Event event;
308337
auto start = test_clock::now();
309338
queue.enqueueNDRangeKernel(kernel, cl::NullRange,
310-
cl::NDRange{N/NN, M/tM/MM}, cl::NullRange, nullptr, &event);
339+
cl::NDRange{N/NN, M/tM/MM}, localWorkSize, nullptr, &event);
311340
queue.finish();
312341
auto end = test_clock::now();
313342
std::chrono::duration<float> sw_time = end - start;
@@ -343,6 +372,8 @@ static void bfloat16_dpas_vnni(
343372
if (kernel() == nullptr) {
344373
printf("unsupported.\n");
345374
} else {
375+
const cl::NDRange localWorkSize = getRequiredLocalWorkSize(kernel, queue);
376+
346377
kernel.setArg(0, C);
347378
kernel.setArg(1, A);
348379
kernel.setArg(2, B);
@@ -357,7 +388,7 @@ static void bfloat16_dpas_vnni(
357388
cl::Event event;
358389
auto start = test_clock::now();
359390
queue.enqueueNDRangeKernel(kernel, cl::NullRange,
360-
cl::NDRange{N, M/tM}, cl::NullRange, nullptr, &event);
391+
cl::NDRange{N, M/tM}, localWorkSize, nullptr, &event);
361392
queue.finish();
362393
auto end = test_clock::now();
363394
std::chrono::duration<float> sw_time = end - start;
@@ -399,6 +430,8 @@ static void bfloat16_dpas_vnni_tiled(
399430
} else if (tN * NN > N) {
400431
printf("N is too small.\n");
401432
} else {
433+
const cl::NDRange localWorkSize = getRequiredLocalWorkSize(kernel, queue);
434+
402435
kernel.setArg(0, C);
403436
kernel.setArg(1, A);
404437
kernel.setArg(2, B);
@@ -413,7 +446,7 @@ static void bfloat16_dpas_vnni_tiled(
413446
cl::Event event;
414447
auto start = test_clock::now();
415448
queue.enqueueNDRangeKernel(kernel, cl::NullRange,
416-
cl::NDRange{N/NN, M/tM/MM}, cl::NullRange, nullptr, &event);
449+
cl::NDRange{N/NN, M/tM/MM}, localWorkSize, nullptr, &event);
417450
queue.finish();
418451
auto end = test_clock::now();
419452
std::chrono::duration<float> sw_time = end - start;
@@ -449,6 +482,8 @@ static void bfloat16_dpas_blockread_rowmajor(
449482
if (kernel() == nullptr) {
450483
printf("unsupported.\n");
451484
} else {
485+
const cl::NDRange localWorkSize = getRequiredLocalWorkSize(kernel, queue);
486+
452487
kernel.setArg(0, C);
453488
kernel.setArg(1, A);
454489
kernel.setArg(2, B);
@@ -466,7 +501,7 @@ static void bfloat16_dpas_blockread_rowmajor(
466501
cl::Event event;
467502
auto start = test_clock::now();
468503
queue.enqueueNDRangeKernel(kernel, cl::NullRange,
469-
cl::NDRange{N, M/tM}, cl::NullRange, nullptr, &event);
504+
cl::NDRange{N, M/tM}, localWorkSize, nullptr, &event);
470505
queue.finish();
471506
auto end = test_clock::now();
472507
std::chrono::duration<float> sw_time = end - start;
@@ -508,6 +543,8 @@ static void bfloat16_dpas_blockread_rowmajor_tiled(
508543
} else if (tN * NN > N) {
509544
printf("N is too small.\n");
510545
} else {
546+
const cl::NDRange localWorkSize = getRequiredLocalWorkSize(kernel, queue);
547+
511548
kernel.setArg(0, C);
512549
kernel.setArg(1, A);
513550
kernel.setArg(2, B);
@@ -525,7 +562,7 @@ static void bfloat16_dpas_blockread_rowmajor_tiled(
525562
cl::Event event;
526563
auto start = test_clock::now();
527564
queue.enqueueNDRangeKernel(kernel, cl::NullRange,
528-
cl::NDRange{N/NN, M/tM/MM}, cl::NullRange, nullptr, &event);
565+
cl::NDRange{N/NN, M/tM/MM}, localWorkSize, nullptr, &event);
529566
queue.finish();
530567
auto end = test_clock::now();
531568
std::chrono::duration<float> sw_time = end - start;
@@ -561,6 +598,8 @@ static void bfloat16_dpas_blockread_vnni(
561598
if (kernel() == nullptr) {
562599
printf("unsupported.\n");
563600
} else {
601+
const cl::NDRange localWorkSize = getRequiredLocalWorkSize(kernel, queue);
602+
564603
kernel.setArg(0, C);
565604
kernel.setArg(1, A);
566605
kernel.setArg(2, B);
@@ -578,7 +617,7 @@ static void bfloat16_dpas_blockread_vnni(
578617
cl::Event event;
579618
auto start = test_clock::now();
580619
queue.enqueueNDRangeKernel(kernel, cl::NullRange,
581-
cl::NDRange{N, M/tM}, cl::NullRange, nullptr, &event);
620+
cl::NDRange{N, M/tM}, localWorkSize, nullptr, &event);
582621
queue.finish();
583622
auto end = test_clock::now();
584623
std::chrono::duration<float> sw_time = end - start;
@@ -620,6 +659,8 @@ static void bfloat16_dpas_blockread_vnni_tiled(
620659
} else if (tN * NN > N) {
621660
printf("N is too small.\n");
622661
} else {
662+
const cl::NDRange localWorkSize = getRequiredLocalWorkSize(kernel, queue);
663+
623664
kernel.setArg(0, C);
624665
kernel.setArg(1, A);
625666
kernel.setArg(2, B);
@@ -637,7 +678,7 @@ static void bfloat16_dpas_blockread_vnni_tiled(
637678
cl::Event event;
638679
auto start = test_clock::now();
639680
queue.enqueueNDRangeKernel(kernel, cl::NullRange,
640-
cl::NDRange{N/NN, M/tM/MM}, cl::NullRange, nullptr, &event);
681+
cl::NDRange{N/NN, M/tM/MM}, localWorkSize, nullptr, &event);
641682
queue.finish();
642683
auto end = test_clock::now();
643684
std::chrono::duration<float> sw_time = end - start;
@@ -729,7 +770,7 @@ int main(int argc, char** argv)
729770

730771
auto minSubGroupSize = findMinSubGroupSize(device);
731772

732-
bool has_simd8 = minSubGroupSize == 8;
773+
bool has_sg8 = supportsSubgroupSize(device, 8);
733774
bool emulate_tN8 = true;
734775
bool emulate_tN16 = true;
735776
if (!emulate && checkDeviceForExtension(device, "cl_intel_subgroup_matrix_multiply_accumulate")) {
@@ -741,7 +782,7 @@ int main(int argc, char** argv)
741782
}
742783
}
743784

744-
buildOptions += " -DHAS_SIMD8=" + std::to_string(has_simd8);
785+
buildOptions += " -DHAS_SG8=" + std::to_string(has_sg8);
745786
buildOptions += " -DEMULATE_tN8=" + std::to_string(emulate_tN8);
746787
buildOptions += " -DEMULATE_tN16=" + std::to_string(emulate_tN16);
747788

samples/20_matrixexperiments-bf16/matrix_helpers_bf16.cl

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -155,22 +155,22 @@ float emu_sub_group_bf16_bf16_matrix_mad_k16(short a, int8 b, float acc)
155155
{
156156
float res = acc;
157157

158-
res = fma(bf16_to_fp32(sub_group_broadcast(a, 0)), bf16_to_fp32(as_ushort2(b.s0).x), res);
159-
res = fma(bf16_to_fp32(sub_group_broadcast(a, 1)), bf16_to_fp32(as_ushort2(b.s0).y), res);
160-
res = fma(bf16_to_fp32(sub_group_broadcast(a, 2)), bf16_to_fp32(as_ushort2(b.s1).x), res);
161-
res = fma(bf16_to_fp32(sub_group_broadcast(a, 3)), bf16_to_fp32(as_ushort2(b.s1).y), res);
162-
res = fma(bf16_to_fp32(sub_group_broadcast(a, 4)), bf16_to_fp32(as_ushort2(b.s2).x), res);
163-
res = fma(bf16_to_fp32(sub_group_broadcast(a, 5)), bf16_to_fp32(as_ushort2(b.s2).y), res);
164-
res = fma(bf16_to_fp32(sub_group_broadcast(a, 6)), bf16_to_fp32(as_ushort2(b.s3).x), res);
165-
res = fma(bf16_to_fp32(sub_group_broadcast(a, 7)), bf16_to_fp32(as_ushort2(b.s3).y), res);
166-
res = fma(bf16_to_fp32(sub_group_broadcast(a, 8)), bf16_to_fp32(as_ushort2(b.s4).x), res);
167-
res = fma(bf16_to_fp32(sub_group_broadcast(a, 9)), bf16_to_fp32(as_ushort2(b.s4).y), res);
168-
res = fma(bf16_to_fp32(sub_group_broadcast(a, 10)), bf16_to_fp32(as_ushort2(b.s5).x), res);
169-
res = fma(bf16_to_fp32(sub_group_broadcast(a, 11)), bf16_to_fp32(as_ushort2(b.s5).y), res);
170-
res = fma(bf16_to_fp32(sub_group_broadcast(a, 12)), bf16_to_fp32(as_ushort2(b.s6).x), res);
171-
res = fma(bf16_to_fp32(sub_group_broadcast(a, 13)), bf16_to_fp32(as_ushort2(b.s6).y), res);
172-
res = fma(bf16_to_fp32(sub_group_broadcast(a, 14)), bf16_to_fp32(as_ushort2(b.s7).x), res);
173-
res = fma(bf16_to_fp32(sub_group_broadcast(a, 15)), bf16_to_fp32(as_ushort2(b.s7).y), res);
158+
res = fma(bf16_to_fp32(intel_sub_group_broadcast(a, 0)), bf16_to_fp32(as_ushort2(b.s0).x), res);
159+
res = fma(bf16_to_fp32(intel_sub_group_broadcast(a, 1)), bf16_to_fp32(as_ushort2(b.s0).y), res);
160+
res = fma(bf16_to_fp32(intel_sub_group_broadcast(a, 2)), bf16_to_fp32(as_ushort2(b.s1).x), res);
161+
res = fma(bf16_to_fp32(intel_sub_group_broadcast(a, 3)), bf16_to_fp32(as_ushort2(b.s1).y), res);
162+
res = fma(bf16_to_fp32(intel_sub_group_broadcast(a, 4)), bf16_to_fp32(as_ushort2(b.s2).x), res);
163+
res = fma(bf16_to_fp32(intel_sub_group_broadcast(a, 5)), bf16_to_fp32(as_ushort2(b.s2).y), res);
164+
res = fma(bf16_to_fp32(intel_sub_group_broadcast(a, 6)), bf16_to_fp32(as_ushort2(b.s3).x), res);
165+
res = fma(bf16_to_fp32(intel_sub_group_broadcast(a, 7)), bf16_to_fp32(as_ushort2(b.s3).y), res);
166+
res = fma(bf16_to_fp32(intel_sub_group_broadcast(a, 8)), bf16_to_fp32(as_ushort2(b.s4).x), res);
167+
res = fma(bf16_to_fp32(intel_sub_group_broadcast(a, 9)), bf16_to_fp32(as_ushort2(b.s4).y), res);
168+
res = fma(bf16_to_fp32(intel_sub_group_broadcast(a, 10)), bf16_to_fp32(as_ushort2(b.s5).x), res);
169+
res = fma(bf16_to_fp32(intel_sub_group_broadcast(a, 11)), bf16_to_fp32(as_ushort2(b.s5).y), res);
170+
res = fma(bf16_to_fp32(intel_sub_group_broadcast(a, 12)), bf16_to_fp32(as_ushort2(b.s6).x), res);
171+
res = fma(bf16_to_fp32(intel_sub_group_broadcast(a, 13)), bf16_to_fp32(as_ushort2(b.s6).y), res);
172+
res = fma(bf16_to_fp32(intel_sub_group_broadcast(a, 14)), bf16_to_fp32(as_ushort2(b.s7).x), res);
173+
res = fma(bf16_to_fp32(intel_sub_group_broadcast(a, 15)), bf16_to_fp32(as_ushort2(b.s7).y), res);
174174

175175
return res;
176176
}

samples/20_matrixexperiments-bf16/matrix_kernel_tiled_bf16.cl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ void HELPER_NAME(btile_load_packed, MM, NN)(global ushort* B, int tN, int N, int
6464
}
6565
}
6666

67-
#if HAS_SIMD8
67+
#if HAS_SG8
6868

6969
void HELPER_NAME(atile_prefetch_rowmajor_sg8, MM, NN)(global ushort* A, int tM, int K, int m, int prefetch_k)
7070
{
@@ -236,7 +236,7 @@ kernel void MM_KERNEL_NAME(bfloat16_dpas_vnni_tiled, 8, 8, MM, NN)(global float*
236236
}
237237
}
238238

239-
#endif // HAS_SIMD8
239+
#endif // HAS_SG8
240240

241241
void HELPER_NAME(atile_prefetch_rowmajor, MM, NN)(global ushort* A, int tM, int K, int m, int prefetch_k)
242242
{

samples/20_matrixexperiments-bf16/matrix_kernels_bf16.cl

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ kernel void bfloat16_naive(global float* C, global ushort* A, global ushort* B,
3838

3939
#if defined(cl_intel_subgroups) && defined(cl_intel_subgroups_short) && defined(cl_intel_required_subgroup_size)
4040

41-
#if HAS_SIMD8
41+
#if HAS_SG8
4242

4343
// rowmajor kernels:
4444

@@ -212,9 +212,9 @@ kernel void bfloat16_dpas_vnni_m8_n8(global float* C, global ushort* A, global u
212212
store_c_rowmajor_fp32_8rNc(C, sum, m, n, N);
213213
}
214214

215-
#endif // HAS_SIMD8
215+
#endif // HAS_SG8
216216

217-
// rowmajor krenels:
217+
// rowmajor kernels:
218218

219219
__attribute__((intel_reqd_sub_group_size(16))) __attribute__((reqd_work_group_size(16, 1, 1)))
220220
kernel void bfloat16_dpas_rowmajor_m1_n16(global float* C, global ushort* A, global ushort* B, int K)
@@ -224,7 +224,7 @@ kernel void bfloat16_dpas_rowmajor_m1_n16(global float* C, global ushort* A, glo
224224
const int tN = 16;
225225
const int N = get_global_size(0);
226226
const int m = get_group_id(1) * tM;
227-
const int n = get_group_id(0) * get_local_size(0);
227+
const int n = get_group_id(0) * tN;
228228

229229
float sum = 0;
230230
for (int k = 0; k < K; k += tK) {
@@ -245,7 +245,7 @@ kernel void bfloat16_dpas_rowmajor_m2_n16(global float* C, global ushort* A, glo
245245
const int tN = 16;
246246
const int N = get_global_size(0);
247247
const int m = get_group_id(1) * tM;
248-
const int n = get_group_id(0) * get_local_size(0);
248+
const int n = get_group_id(0) * tN;
249249

250250
float2 sum = 0;
251251
for (int k = 0; k < K; k += tK) {
@@ -266,7 +266,7 @@ kernel void bfloat16_dpas_rowmajor_m4_n16(global float* C, global ushort* A, glo
266266
const int tN = 16;
267267
const int N = get_global_size(0);
268268
const int m = get_group_id(1) * tM;
269-
const int n = get_group_id(0) * get_local_size(0);
269+
const int n = get_group_id(0) * tN;
270270

271271
float4 sum = 0;
272272
for (int k = 0; k < K; k += tK) {
@@ -287,7 +287,7 @@ kernel void bfloat16_dpas_rowmajor_m8_n16(global float* C, global ushort* A, glo
287287
const int tN = 16;
288288
const int N = get_global_size(0);
289289
const int m = get_group_id(1) * tM;
290-
const int n = get_group_id(0) * get_local_size(0);
290+
const int n = get_group_id(0) * tN;
291291

292292
float8 sum = 0;
293293
for (int k = 0; k < K; k += tK) {

0 commit comments

Comments
 (0)