@@ -76,6 +76,12 @@ static size_t findMinSubGroupSize(cl::Device& device)
7676 return 0 ;
7777}
7878
79+ static bool supportsSubgroupSize (cl::Device& device, size_t subgroupSize)
80+ {
81+ auto s = device.getInfo <CL_DEVICE_SUB_GROUP_SIZES_INTEL>();
82+ return std::find (std::begin (s), std::end (s), subgroupSize) != std::end (s);
83+ }
84+
7985static void setRoundRobin (cl::Kernel& kernel)
8086{
8187 constexpr cl_kernel_exec_info CL_KERNEL_EXEC_INFO_THREAD_ARBITRATION_POLICY_INTEL = 0x10025 ;
@@ -175,6 +181,23 @@ static float hw_time(cl::Event& event)
175181 return ns / 1e9f;
176182}
177183
184+ static cl::NDRange getRequiredLocalWorkSize (cl::Kernel& kernel, cl::CommandQueue queue)
185+ {
186+ // Note: This shouldn't be necessary, and the OpenCL implementation should
187+ // automatically choose the required local work-group size when the local
188+ // work-group size is `nullptr`. This is not working for some OpenCL
189+ // implementations, though, so we will just query and use the required local
190+ // work-group size explicitly.
191+ auto device = queue.getInfo <CL_QUEUE_DEVICE>();
192+ auto reqd_wgs = kernel.getWorkGroupInfo <CL_KERNEL_COMPILE_WORK_GROUP_SIZE>(device);
193+
194+ if (reqd_wgs[0 ] > 0 && reqd_wgs[1 ] > 0 && reqd_wgs[2 ] > 0 ) {
195+ return cl::NDRange (reqd_wgs[0 ], reqd_wgs[1 ], reqd_wgs[2 ]);
196+ }
197+
198+ return cl::NullRange;
199+ }
200+
178201static void bfloat16_naive (
179202 cl::Context& context, cl::Program& program, cl::CommandQueue& queue,
180203 cl::Buffer& C, cl::Buffer& A, cl::Buffer& B,
@@ -187,6 +210,8 @@ static void bfloat16_naive(
187210 if (kernel () == nullptr ) {
188211 printf (" unsupported.\n " );
189212 } else {
213+ const cl::NDRange localWorkSize = getRequiredLocalWorkSize (kernel, queue);
214+
190215 kernel.setArg (0 , C);
191216 kernel.setArg (1 , A);
192217 kernel.setArg (2 , B);
@@ -201,7 +226,7 @@ static void bfloat16_naive(
201226 cl::Event event;
202227 auto start = test_clock::now ();
203228 queue.enqueueNDRangeKernel (kernel, cl::NullRange,
204- cl::NDRange{N, M}, cl::NullRange , nullptr , &event);
229+ cl::NDRange{N, M}, localWorkSize , nullptr , &event);
205230 queue.finish ();
206231 auto end = test_clock::now ();
207232 std::chrono::duration<float > sw_time = end - start;
@@ -237,6 +262,8 @@ static void bfloat16_dpas_rowmajor(
237262 if (kernel () == nullptr ) {
238263 printf (" unsupported.\n " );
239264 } else {
265+ const cl::NDRange localWorkSize = getRequiredLocalWorkSize (kernel, queue);
266+
240267 kernel.setArg (0 , C);
241268 kernel.setArg (1 , A);
242269 kernel.setArg (2 , B);
@@ -251,7 +278,7 @@ static void bfloat16_dpas_rowmajor(
251278 cl::Event event;
252279 auto start = test_clock::now ();
253280 queue.enqueueNDRangeKernel (kernel, cl::NullRange,
254- cl::NDRange{N, M/tM}, cl::NullRange , nullptr , &event);
281+ cl::NDRange{N, M/tM}, localWorkSize , nullptr , &event);
255282 queue.finish ();
256283 auto end = test_clock::now ();
257284 std::chrono::duration<float > sw_time = end - start;
@@ -293,6 +320,8 @@ static void bfloat16_dpas_rowmajor_tiled(
293320 } else if (tN * NN > N) {
294321 printf (" N is too small.\n " );
295322 } else {
323+ const cl::NDRange localWorkSize = getRequiredLocalWorkSize (kernel, queue);
324+
296325 kernel.setArg (0 , C);
297326 kernel.setArg (1 , A);
298327 kernel.setArg (2 , B);
@@ -307,7 +336,7 @@ static void bfloat16_dpas_rowmajor_tiled(
307336 cl::Event event;
308337 auto start = test_clock::now ();
309338 queue.enqueueNDRangeKernel (kernel, cl::NullRange,
310- cl::NDRange{N/NN, M/tM/MM}, cl::NullRange , nullptr , &event);
339+ cl::NDRange{N/NN, M/tM/MM}, localWorkSize , nullptr , &event);
311340 queue.finish ();
312341 auto end = test_clock::now ();
313342 std::chrono::duration<float > sw_time = end - start;
@@ -343,6 +372,8 @@ static void bfloat16_dpas_vnni(
343372 if (kernel () == nullptr ) {
344373 printf (" unsupported.\n " );
345374 } else {
375+ const cl::NDRange localWorkSize = getRequiredLocalWorkSize (kernel, queue);
376+
346377 kernel.setArg (0 , C);
347378 kernel.setArg (1 , A);
348379 kernel.setArg (2 , B);
@@ -357,7 +388,7 @@ static void bfloat16_dpas_vnni(
357388 cl::Event event;
358389 auto start = test_clock::now ();
359390 queue.enqueueNDRangeKernel (kernel, cl::NullRange,
360- cl::NDRange{N, M/tM}, cl::NullRange , nullptr , &event);
391+ cl::NDRange{N, M/tM}, localWorkSize , nullptr , &event);
361392 queue.finish ();
362393 auto end = test_clock::now ();
363394 std::chrono::duration<float > sw_time = end - start;
@@ -399,6 +430,8 @@ static void bfloat16_dpas_vnni_tiled(
399430 } else if (tN * NN > N) {
400431 printf (" N is too small.\n " );
401432 } else {
433+ const cl::NDRange localWorkSize = getRequiredLocalWorkSize (kernel, queue);
434+
402435 kernel.setArg (0 , C);
403436 kernel.setArg (1 , A);
404437 kernel.setArg (2 , B);
@@ -413,7 +446,7 @@ static void bfloat16_dpas_vnni_tiled(
413446 cl::Event event;
414447 auto start = test_clock::now ();
415448 queue.enqueueNDRangeKernel (kernel, cl::NullRange,
416- cl::NDRange{N/NN, M/tM/MM}, cl::NullRange , nullptr , &event);
449+ cl::NDRange{N/NN, M/tM/MM}, localWorkSize , nullptr , &event);
417450 queue.finish ();
418451 auto end = test_clock::now ();
419452 std::chrono::duration<float > sw_time = end - start;
@@ -449,6 +482,8 @@ static void bfloat16_dpas_blockread_rowmajor(
449482 if (kernel () == nullptr ) {
450483 printf (" unsupported.\n " );
451484 } else {
485+ const cl::NDRange localWorkSize = getRequiredLocalWorkSize (kernel, queue);
486+
452487 kernel.setArg (0 , C);
453488 kernel.setArg (1 , A);
454489 kernel.setArg (2 , B);
@@ -466,7 +501,7 @@ static void bfloat16_dpas_blockread_rowmajor(
466501 cl::Event event;
467502 auto start = test_clock::now ();
468503 queue.enqueueNDRangeKernel (kernel, cl::NullRange,
469- cl::NDRange{N, M/tM}, cl::NullRange , nullptr , &event);
504+ cl::NDRange{N, M/tM}, localWorkSize , nullptr , &event);
470505 queue.finish ();
471506 auto end = test_clock::now ();
472507 std::chrono::duration<float > sw_time = end - start;
@@ -508,6 +543,8 @@ static void bfloat16_dpas_blockread_rowmajor_tiled(
508543 } else if (tN * NN > N) {
509544 printf (" N is too small.\n " );
510545 } else {
546+ const cl::NDRange localWorkSize = getRequiredLocalWorkSize (kernel, queue);
547+
511548 kernel.setArg (0 , C);
512549 kernel.setArg (1 , A);
513550 kernel.setArg (2 , B);
@@ -525,7 +562,7 @@ static void bfloat16_dpas_blockread_rowmajor_tiled(
525562 cl::Event event;
526563 auto start = test_clock::now ();
527564 queue.enqueueNDRangeKernel (kernel, cl::NullRange,
528- cl::NDRange{N/NN, M/tM/MM}, cl::NullRange , nullptr , &event);
565+ cl::NDRange{N/NN, M/tM/MM}, localWorkSize , nullptr , &event);
529566 queue.finish ();
530567 auto end = test_clock::now ();
531568 std::chrono::duration<float > sw_time = end - start;
@@ -561,6 +598,8 @@ static void bfloat16_dpas_blockread_vnni(
561598 if (kernel () == nullptr ) {
562599 printf (" unsupported.\n " );
563600 } else {
601+ const cl::NDRange localWorkSize = getRequiredLocalWorkSize (kernel, queue);
602+
564603 kernel.setArg (0 , C);
565604 kernel.setArg (1 , A);
566605 kernel.setArg (2 , B);
@@ -578,7 +617,7 @@ static void bfloat16_dpas_blockread_vnni(
578617 cl::Event event;
579618 auto start = test_clock::now ();
580619 queue.enqueueNDRangeKernel (kernel, cl::NullRange,
581- cl::NDRange{N, M/tM}, cl::NullRange , nullptr , &event);
620+ cl::NDRange{N, M/tM}, localWorkSize , nullptr , &event);
582621 queue.finish ();
583622 auto end = test_clock::now ();
584623 std::chrono::duration<float > sw_time = end - start;
@@ -620,6 +659,8 @@ static void bfloat16_dpas_blockread_vnni_tiled(
620659 } else if (tN * NN > N) {
621660 printf (" N is too small.\n " );
622661 } else {
662+ const cl::NDRange localWorkSize = getRequiredLocalWorkSize (kernel, queue);
663+
623664 kernel.setArg (0 , C);
624665 kernel.setArg (1 , A);
625666 kernel.setArg (2 , B);
@@ -637,7 +678,7 @@ static void bfloat16_dpas_blockread_vnni_tiled(
637678 cl::Event event;
638679 auto start = test_clock::now ();
639680 queue.enqueueNDRangeKernel (kernel, cl::NullRange,
640- cl::NDRange{N/NN, M/tM/MM}, cl::NullRange , nullptr , &event);
681+ cl::NDRange{N/NN, M/tM/MM}, localWorkSize , nullptr , &event);
641682 queue.finish ();
642683 auto end = test_clock::now ();
643684 std::chrono::duration<float > sw_time = end - start;
@@ -729,7 +770,7 @@ int main(int argc, char** argv)
729770
730771 auto minSubGroupSize = findMinSubGroupSize (device);
731772
732- bool has_simd8 = minSubGroupSize == 8 ;
773+ bool has_sg8 = supportsSubgroupSize (device, 8 ) ;
733774 bool emulate_tN8 = true ;
734775 bool emulate_tN16 = true ;
735776 if (!emulate && checkDeviceForExtension (device, " cl_intel_subgroup_matrix_multiply_accumulate" )) {
@@ -741,7 +782,7 @@ int main(int argc, char** argv)
741782 }
742783 }
743784
744- buildOptions += " -DHAS_SIMD8 =" + std::to_string (has_simd8 );
785+ buildOptions += " -DHAS_SG8 =" + std::to_string (has_sg8 );
745786 buildOptions += " -DEMULATE_tN8=" + std::to_string (emulate_tN8);
746787 buildOptions += " -DEMULATE_tN16=" + std::to_string (emulate_tN16);
747788
0 commit comments