Skip to content

Commit 0ab6656

Browse files
issue/1031fxi avg_pool1d cross_entropy omp
1 parent 8581a2c commit 0ab6656

2 files changed

Lines changed: 21 additions & 19 deletions

File tree

src/infiniop/ops/avg_pool1d/cpu/avg_pool1d_cpu.cc

Lines changed: 20 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -36,33 +36,35 @@ infiniStatus_t calculateAvgPool1d(const AvgPool1dInfo &info,
3636
const T *x) {
3737
const float inv_kernel = 1.0f / static_cast<float>(info.kernel_size);
3838

39-
#pragma omp parallel for collapse(2)
40-
for (ptrdiff_t b = 0; b < ptrdiff_t(info.batch); ++b) {
41-
for (ptrdiff_t c = 0; c < ptrdiff_t(info.channels); ++c) {
39+
#pragma omp parallel for
40+
for (ptrdiff_t bc = 0; bc < ptrdiff_t(info.batch * info.channels); ++bc) {
4241

43-
size_t y_base = b * info.y_stride_batch + c * info.y_stride_channel;
44-
size_t x_base = b * info.x_stride_batch + c * info.x_stride_channel;
42+
ptrdiff_t b = bc / info.channels;
43+
ptrdiff_t c = bc % info.channels;
4544

46-
for (size_t ow = 0; ow < info.out_width; ++ow) {
47-
size_t y_offset = y_base + ow * info.y_stride_width;
45+
size_t y_base = b * info.y_stride_batch + c * info.y_stride_channel;
46+
size_t x_base = b * info.x_stride_batch + c * info.x_stride_channel;
4847

49-
long long start_w = static_cast<long long>(ow * info.stride) - info.padding;
50-
long long end_w = start_w + info.kernel_size;
48+
for (size_t ow = 0; ow < info.out_width; ++ow) {
49+
size_t y_offset = y_base + ow * info.y_stride_width;
5150

52-
long long valid_start = std::max(0LL, start_w);
53-
long long valid_end = std::min(static_cast<long long>(info.in_width), end_w);
51+
long long start_w = static_cast<long long>(ow * info.stride) - info.padding;
52+
long long end_w = start_w + info.kernel_size;
5453

55-
float sum = 0.0f;
56-
for (long long iw = valid_start; iw < valid_end; ++iw) {
57-
size_t x_offset = x_base + iw * info.x_stride_width;
58-
sum += utils::cast<float>(x[x_offset]);
59-
}
54+
long long valid_start = std::max(0LL, start_w);
55+
long long valid_end = std::min(static_cast<long long>(info.in_width), end_w);
6056

61-
const float avg = sum * inv_kernel;
62-
y[y_offset] = utils::cast<T>(avg);
57+
float sum = 0.0f;
58+
for (long long iw = valid_start; iw < valid_end; ++iw) {
59+
size_t x_offset = x_base + iw * info.x_stride_width;
60+
sum += utils::cast<float>(x[x_offset]);
6361
}
62+
63+
const float avg = sum * inv_kernel;
64+
y[y_offset] = utils::cast<T>(avg);
6465
}
6566
}
67+
6668
return INFINI_STATUS_SUCCESS;
6769
}
6870

src/infiniop/ops/cross_entropy/cpu/cross_entropy_cpu.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ infiniStatus_t cross_entropy_kernel(const CrossEntropyInfo *info,
4141
const Tidx *label = reinterpret_cast<const Tidx *>(target);
4242

4343
#pragma omp parallel for
44-
for (size_t i = 0; i < info->outer_size; ++i) {
44+
for (ptrdiff_t i = 0; i < info->outer_size; ++i) {
4545
const T *row = x + i * info->x_stride;
4646
Tidx idx = label[i];
4747

0 commit comments

Comments
 (0)