Skip to content

Commit 0ad10b5

Browse files
committed
Optimize naive Conv4D with flat contiguous buffers
1 parent ebcf79f commit 0ad10b5

1 file changed

Lines changed: 90 additions & 92 deletions

File tree

include/layers/ConvLayer.hpp

Lines changed: 90 additions & 92 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
#pragma once
22
#include <cmath>
3+
#include <cstddef>
34
#include <stdexcept>
45
#include <vector>
56

@@ -34,7 +35,7 @@ class ConvolutionalLayer : public Layer {
3435
dilations_ = 0;
3536
}
3637
ConvolutionalLayer(size_t step, size_t pads, size_t dilations,
37-
const Tensor& kernel, const Tensor& bias = Tensor(),
38+
const Tensor &kernel, const Tensor &bias = Tensor(),
3839
size_t group = 1, bool useLegacyImpl = false)
3940
: Layer(kConvolution),
4041
kernel_(std::make_shared<Tensor>(kernel)),
@@ -73,10 +74,10 @@ class ConvolutionalLayer : public Layer {
7374
return useLegacyImpl_;
7475
}
7576

76-
void run(const std::vector<Tensor>& input,
77-
std::vector<Tensor>& output) override;
78-
void run(const std::vector<Tensor>& input, std::vector<Tensor>& output,
79-
const RuntimeOptions& options) override;
77+
void run(const std::vector<Tensor> &input,
78+
std::vector<Tensor> &output) override;
79+
void run(const std::vector<Tensor> &input, std::vector<Tensor> &output,
80+
const RuntimeOptions &options) override;
8081
#ifdef ENABLE_STATISTIC_WEIGHTS
8182
Tensor get_weights() override {
8283
return *kernel_;
@@ -100,7 +101,7 @@ class ConvImpl : public LayerImpl<ValueType> {
100101
ConvImpl() = delete;
101102
ConvImpl(size_t stride, size_t pads, size_t dilations, int input_width,
102103
int input_height, int input_flow, size_t input_size,
103-
const std::vector<ValueType>& bias)
104+
const std::vector<ValueType> &bias)
104105
: input_width_(input_width),
105106
input_height_(input_height),
106107
input_flow_(input_flow),
@@ -110,10 +111,10 @@ class ConvImpl : public LayerImpl<ValueType> {
110111
input_size_(input_size),
111112
bias_(bias) {}
112113

113-
ConvImpl(const ConvImpl& c) = default;
114+
ConvImpl(const ConvImpl &c) = default;
114115

115116
[[nodiscard]] std::vector<ValueType> run(
116-
const std::vector<ValueType>& input) const override {
117+
const std::vector<ValueType> &input) const override {
117118
return input;
118119
}
119120

@@ -184,8 +185,8 @@ class ConvImpl : public LayerImpl<ValueType> {
184185

185186
// NCHW -> NCHW only
186187
template <typename ValueType>
187-
void Conv4D(const Tensor& input, const Tensor& kernel_, const Tensor& bias_,
188-
Tensor& output, size_t stride_, size_t pads_, size_t group_,
188+
void Conv4D(const Tensor &input, const Tensor &kernel_, const Tensor &bias_,
189+
Tensor &output, size_t stride_, size_t pads_, size_t group_,
189190
size_t dilations_, ParBackend backend = ParBackend::kSeq) {
190191
size_t batch_size = input.get_shape()[0];
191192
size_t in_channels = input.get_shape()[1];
@@ -212,106 +213,103 @@ void Conv4D(const Tensor& input, const Tensor& kernel_, const Tensor& bias_,
212213
size_t out_width =
213214
ComputeConvOutputDim(in_width, kernel_width, stride_, pads_, dilations_);
214215

215-
std::vector<std::vector<std::vector<std::vector<ValueType>>>> padded_input(
216-
batch_size,
217-
std::vector<std::vector<std::vector<ValueType>>>(
218-
in_height + 2 * pads_,
219-
std::vector<std::vector<ValueType>>(
220-
in_width + 2 * pads_, std::vector<ValueType>(in_channels, 0))));
221-
222216
parallel::Options options;
223217
options.backend = backend;
224218

225-
parallel::parallel_for(batch_size, [&](size_t b) {
226-
for (size_t h = 0; h < in_height; ++h) {
227-
for (size_t w = 0; w < in_width; ++w) {
228-
for (size_t c = 0; c < in_channels; ++c) {
229-
padded_input[b][h + pads_][w + pads_][c] =
230-
input.get<ValueType>({b, c, h, w});
231-
}
232-
}
233-
}
234-
}, options);
235-
236-
size_t dilated_kernel_height = (kernel_height - 1) * dilations_ + 1;
237-
size_t dilated_kernel_width = (kernel_width - 1) * dilations_ + 1;
238-
239-
std::vector<std::vector<std::vector<std::vector<ValueType>>>> dil_kernel(
240-
out_channels, std::vector<std::vector<std::vector<ValueType>>>(
241-
kernel_in_channels,
242-
std::vector<std::vector<ValueType>>(
243-
dilated_kernel_height,
244-
std::vector<ValueType>(dilated_kernel_width, 0))));
245-
246-
parallel::parallel_for(out_channels, [&](size_t oc) {
247-
for (size_t ic = 0; ic < kernel_in_channels; ++ic) {
248-
for (size_t kh = 0; kh < kernel_height; ++kh) {
249-
for (size_t kw = 0; kw < kernel_width; ++kw) {
250-
dil_kernel[oc][ic][kh * dilations_][kw * dilations_] =
251-
kernel_.get<ValueType>({oc, ic, kh, kw});
252-
}
253-
}
254-
}
255-
}, options);
219+
const auto &input_data = *input.as<ValueType>();
220+
const auto &kernel_data = *kernel_.as<ValueType>();
221+
const std::vector<ValueType> *bias_data = nullptr;
222+
if (!bias_.empty()) {
223+
bias_data = bias_.as<ValueType>();
224+
}
256225

257-
std::vector<std::vector<std::vector<std::vector<ValueType>>>> output_tensor(
258-
batch_size,
259-
std::vector<std::vector<std::vector<ValueType>>>(
260-
out_channels, std::vector<std::vector<ValueType>>(
261-
out_height, std::vector<ValueType>(out_width, 0))));
226+
const size_t input_channel_stride = in_height * in_width;
227+
const size_t input_batch_stride = in_channels * input_channel_stride;
228+
const size_t kernel_channel_stride = kernel_height * kernel_width;
229+
const size_t kernel_output_stride =
230+
kernel_in_channels * kernel_channel_stride;
231+
const size_t output_channel_stride = out_height * out_width;
232+
const size_t output_batch_stride = out_channels * output_channel_stride;
233+
const size_t in_channels_per_group = in_channels / group_;
234+
const size_t out_channels_per_group = out_channels / group_;
235+
const bool collapsed_kernel = dilations_ == 0;
262236

237+
Shape output_shape({batch_size, out_channels, out_height, out_width});
238+
std::vector<ValueType> flat_output(output_shape.count(), 0);
263239
size_t total_work = batch_size * out_channels;
264240
parallel::parallel_for(total_work, [&](size_t idx) {
265241
size_t b = idx / out_channels;
266242
size_t oc = idx % out_channels;
243+
size_t input_batch_base = b * input_batch_stride;
244+
size_t output_base = b * output_batch_stride + oc * output_channel_stride;
245+
size_t group = (group_ > 1) ? oc / out_channels_per_group : 0;
246+
size_t group_start_channel = group * in_channels_per_group;
247+
size_t group_end_channel = group_start_channel + in_channels_per_group;
248+
size_t kernel_oc_base = oc * kernel_output_stride;
249+
ValueType bias_value = ValueType{};
250+
if (bias_data != nullptr && oc < bias_data->size()) {
251+
bias_value = (*bias_data)[oc];
252+
}
267253

268254
for (size_t oh = 0; oh < out_height; ++oh) {
255+
std::ptrdiff_t input_h_base = static_cast<std::ptrdiff_t>(oh * stride_) -
256+
static_cast<std::ptrdiff_t>(pads_);
269257
for (size_t ow = 0; ow < out_width; ++ow) {
270-
ValueType value = 0;
271-
size_t h_start = oh * stride_;
272-
size_t w_start = ow * stride_;
273-
274-
size_t group = (group_ > 1) ? oc / (out_channels / group_) : 0;
275-
size_t group_start_channel = group * (in_channels / group_);
276-
size_t group_end_channel = (group + 1) * (in_channels / group_);
258+
ValueType value = bias_value;
259+
std::ptrdiff_t input_w_base =
260+
static_cast<std::ptrdiff_t>(ow * stride_) -
261+
static_cast<std::ptrdiff_t>(pads_);
262+
size_t output_idx = output_base + oh * out_width + ow;
277263

278264
for (size_t ic = group_start_channel; ic < group_end_channel; ++ic) {
279265
size_t kernel_ic = ic - group_start_channel;
280-
281-
for (size_t kh = 0; kh < dilated_kernel_height; ++kh) {
282-
for (size_t kw = 0; kw < dilated_kernel_width; ++kw) {
283-
size_t h_index = h_start + kh;
284-
size_t w_index = w_start + kw;
285-
286-
if (h_index < padded_input[b].size() &&
287-
w_index < padded_input[b][h_index].size()) {
288-
value += padded_input[b][h_index][w_index][ic] *
289-
dil_kernel[oc][kernel_ic][kh][kw];
290-
}
266+
size_t input_channel_base =
267+
input_batch_base + ic * input_channel_stride;
268+
size_t kernel_ic_base =
269+
kernel_oc_base + kernel_ic * kernel_channel_stride;
270+
271+
if (collapsed_kernel) {
272+
if (input_h_base >= 0 &&
273+
input_h_base < static_cast<std::ptrdiff_t>(in_height) &&
274+
input_w_base >= 0 &&
275+
input_w_base < static_cast<std::ptrdiff_t>(in_width)) {
276+
size_t input_idx = input_channel_base +
277+
static_cast<size_t>(input_h_base) * in_width +
278+
static_cast<size_t>(input_w_base);
279+
size_t kernel_idx = kernel_ic_base + kernel_channel_stride - 1;
280+
value += input_data[input_idx] * kernel_data[kernel_idx];
291281
}
282+
continue;
292283
}
293-
}
294284

295-
if (!bias_.empty() && oc < bias_.get_shape()[0]) {
296-
value += bias_.get<ValueType>({oc});
297-
}
285+
for (size_t kh = 0; kh < kernel_height; ++kh) {
286+
std::ptrdiff_t input_h =
287+
input_h_base + static_cast<std::ptrdiff_t>(kh * dilations_);
288+
if (input_h < 0 ||
289+
input_h >= static_cast<std::ptrdiff_t>(in_height)) {
290+
continue;
291+
}
298292

299-
output_tensor[b][oc][oh][ow] = value;
300-
}
301-
}
302-
}, options);
293+
size_t input_row_base =
294+
input_channel_base + static_cast<size_t>(input_h) * in_width;
295+
size_t kernel_row_base = kernel_ic_base + kh * kernel_width;
303296

304-
Shape output_shape({batch_size, out_channels, out_height, out_width});
305-
std::vector<ValueType> flat_output(batch_size * out_channels * out_height *
306-
out_width);
297+
for (size_t kw = 0; kw < kernel_width; ++kw) {
298+
std::ptrdiff_t input_w =
299+
input_w_base + static_cast<std::ptrdiff_t>(kw * dilations_);
300+
if (input_w < 0 ||
301+
input_w >= static_cast<std::ptrdiff_t>(in_width)) {
302+
continue;
303+
}
307304

308-
parallel::parallel_for(batch_size, [&](size_t b) {
309-
size_t base_idx = b * out_channels * out_height * out_width;
310-
for (size_t oc = 0; oc < out_channels; ++oc) {
311-
for (size_t h = 0; h < out_height; ++h) {
312-
for (size_t w = 0; w < out_width; ++w) {
313-
flat_output[base_idx++] = output_tensor[b][oc][h][w];
305+
value +=
306+
input_data[input_row_base + static_cast<size_t>(input_w)] *
307+
kernel_data[kernel_row_base + kw];
308+
}
309+
}
314310
}
311+
312+
flat_output[output_idx] = value;
315313
}
316314
}
317315
}, options);
@@ -320,8 +318,8 @@ void Conv4D(const Tensor& input, const Tensor& kernel_, const Tensor& bias_,
320318
}
321319

322320
template <typename ValueType>
323-
void DepthwiseConv4D(const Tensor& input, const Tensor& kernel_,
324-
const Tensor& bias_, Tensor& output, size_t stride_,
321+
void DepthwiseConv4D(const Tensor &input, const Tensor &kernel_,
322+
const Tensor &bias_, Tensor &output, size_t stride_,
325323
size_t pads_, size_t dilations_,
326324
ParBackend backend = ParBackend::kSeq) {
327325
size_t batch_size = input.get_shape()[0];
@@ -388,8 +386,8 @@ void DepthwiseConv4D(const Tensor& input, const Tensor& kernel_,
388386

389387
// NCHW -> NCHW only (Legacy version)
390388
template <typename ValueType>
391-
void Conv4D_Legacy(const Tensor& input, const Tensor& kernel_,
392-
const Tensor& bias_, Tensor& output, size_t stride_,
389+
void Conv4D_Legacy(const Tensor &input, const Tensor &kernel_,
390+
const Tensor &bias_, Tensor &output, size_t stride_,
393391
size_t pads_, size_t dilations_,
394392
ParBackend backend = ParBackend::kSeq) {
395393
size_t batch_size = input.get_shape()[0];

0 commit comments

Comments
 (0)