@@ -36,33 +36,35 @@ infiniStatus_t calculateAvgPool1d(const AvgPool1dInfo &info,
3636 const T *x) {
3737 const float inv_kernel = 1 .0f / static_cast <float >(info.kernel_size );
3838
39- #pragma omp parallel for collapse(2)
40- for (ptrdiff_t b = 0 ; b < ptrdiff_t (info.batch ); ++b) {
41- for (ptrdiff_t c = 0 ; c < ptrdiff_t (info.channels ); ++c) {
39+ #pragma omp parallel for
40+ for (ptrdiff_t bc = 0 ; bc < ptrdiff_t (info.batch * info.channels ); ++bc) {
4241
43- size_t y_base = b * info.y_stride_batch + c * info. y_stride_channel ;
44- size_t x_base = b * info.x_stride_batch + c * info. x_stride_channel ;
42+ ptrdiff_t b = bc / info.channels ;
43+ ptrdiff_t c = bc % info.channels ;
4544
46- for ( size_t ow = 0 ; ow < info.out_width ; ++ow) {
47- size_t y_offset = y_base + ow * info.y_stride_width ;
45+ size_t y_base = b * info.y_stride_batch + c * info. y_stride_channel ;
46+ size_t x_base = b * info. x_stride_batch + c * info.x_stride_channel ;
4847
49- long long start_w = static_cast < long long >( ow * info.stride ) - info. padding ;
50- long long end_w = start_w + info.kernel_size ;
48+ for ( size_t ow = 0 ; ow < info.out_width ; ++ow) {
49+ size_t y_offset = y_base + ow * info.y_stride_width ;
5150
52- long long valid_start = std::max ( 0LL , start_w) ;
53- long long valid_end = std::min ( static_cast < long long >( info.in_width ), end_w) ;
51+ long long start_w = static_cast < long long >(ow * info. stride ) - info. padding ;
52+ long long end_w = start_w + info.kernel_size ;
5453
55- float sum = 0 .0f ;
56- for (long long iw = valid_start; iw < valid_end; ++iw) {
57- size_t x_offset = x_base + iw * info.x_stride_width ;
58- sum += utils::cast<float >(x[x_offset]);
59- }
54+ long long valid_start = std::max (0LL , start_w);
55+ long long valid_end = std::min (static_cast <long long >(info.in_width ), end_w);
6056
61- const float avg = sum * inv_kernel;
62- y[y_offset] = utils::cast<T>(avg);
57+ float sum = 0 .0f ;
58+ for (long long iw = valid_start; iw < valid_end; ++iw) {
59+ size_t x_offset = x_base + iw * info.x_stride_width ;
60+ sum += utils::cast<float >(x[x_offset]);
6361 }
62+
63+ const float avg = sum * inv_kernel;
64+ y[y_offset] = utils::cast<T>(avg);
6465 }
6566 }
67+
6668 return INFINI_STATUS_SUCCESS;
6769}
6870
0 commit comments