11#pragma once
22#include < cmath>
3+ #include < cstddef>
34#include < stdexcept>
45#include < vector>
56
@@ -34,7 +35,7 @@ class ConvolutionalLayer : public Layer {
3435 dilations_ = 0 ;
3536 }
3637 ConvolutionalLayer (size_t step, size_t pads, size_t dilations,
37- const Tensor& kernel, const Tensor& bias = Tensor(),
38+ const Tensor & kernel, const Tensor & bias = Tensor(),
3839 size_t group = 1 , bool useLegacyImpl = false )
3940 : Layer(kConvolution ),
4041 kernel_ (std::make_shared<Tensor>(kernel)),
@@ -73,10 +74,10 @@ class ConvolutionalLayer : public Layer {
7374 return useLegacyImpl_;
7475 }
7576
76- void run (const std::vector<Tensor>& input,
77- std::vector<Tensor>& output) override ;
78- void run (const std::vector<Tensor>& input, std::vector<Tensor>& output,
79- const RuntimeOptions& options) override ;
77+ void run (const std::vector<Tensor> & input,
78+ std::vector<Tensor> & output) override ;
79+ void run (const std::vector<Tensor> & input, std::vector<Tensor> & output,
80+ const RuntimeOptions & options) override ;
8081#ifdef ENABLE_STATISTIC_WEIGHTS
8182 Tensor get_weights () override {
8283 return *kernel_;
@@ -100,7 +101,7 @@ class ConvImpl : public LayerImpl<ValueType> {
100101 ConvImpl () = delete ;
101102 ConvImpl (size_t stride, size_t pads, size_t dilations, int input_width,
102103 int input_height, int input_flow, size_t input_size,
103- const std::vector<ValueType>& bias)
104+ const std::vector<ValueType> & bias)
104105 : input_width_(input_width),
105106 input_height_ (input_height),
106107 input_flow_(input_flow),
@@ -110,10 +111,10 @@ class ConvImpl : public LayerImpl<ValueType> {
110111 input_size_(input_size),
111112 bias_(bias) {}
112113
113- ConvImpl (const ConvImpl& c) = default;
114+ ConvImpl (const ConvImpl & c) = default;
114115
115116 [[nodiscard]] std::vector<ValueType> run (
116- const std::vector<ValueType>& input) const override {
117+ const std::vector<ValueType> & input) const override {
117118 return input;
118119 }
119120
@@ -184,8 +185,8 @@ class ConvImpl : public LayerImpl<ValueType> {
184185
185186// NCHW -> NCHW only
186187template <typename ValueType>
187- void Conv4D (const Tensor& input, const Tensor& kernel_, const Tensor& bias_,
188- Tensor& output, size_t stride_, size_t pads_, size_t group_,
188+ void Conv4D (const Tensor & input, const Tensor & kernel_, const Tensor & bias_,
189+ Tensor & output, size_t stride_, size_t pads_, size_t group_,
189190 size_t dilations_, ParBackend backend = ParBackend::kSeq ) {
190191 size_t batch_size = input.get_shape ()[0 ];
191192 size_t in_channels = input.get_shape ()[1 ];
@@ -212,106 +213,103 @@ void Conv4D(const Tensor& input, const Tensor& kernel_, const Tensor& bias_,
212213 size_t out_width =
213214 ComputeConvOutputDim (in_width, kernel_width, stride_, pads_, dilations_);
214215
215- std::vector<std::vector<std::vector<std::vector<ValueType>>>> padded_input (
216- batch_size,
217- std::vector<std::vector<std::vector<ValueType>>>(
218- in_height + 2 * pads_,
219- std::vector<std::vector<ValueType>>(
220- in_width + 2 * pads_, std::vector<ValueType>(in_channels, 0 ))));
221-
222216 parallel::Options options;
223217 options.backend = backend;
224218
225- parallel::parallel_for (batch_size, [&](size_t b) {
226- for (size_t h = 0 ; h < in_height; ++h) {
227- for (size_t w = 0 ; w < in_width; ++w) {
228- for (size_t c = 0 ; c < in_channels; ++c) {
229- padded_input[b][h + pads_][w + pads_][c] =
230- input.get <ValueType>({b, c, h, w});
231- }
232- }
233- }
234- }, options);
235-
236- size_t dilated_kernel_height = (kernel_height - 1 ) * dilations_ + 1 ;
237- size_t dilated_kernel_width = (kernel_width - 1 ) * dilations_ + 1 ;
238-
239- std::vector<std::vector<std::vector<std::vector<ValueType>>>> dil_kernel (
240- out_channels, std::vector<std::vector<std::vector<ValueType>>>(
241- kernel_in_channels,
242- std::vector<std::vector<ValueType>>(
243- dilated_kernel_height,
244- std::vector<ValueType>(dilated_kernel_width, 0 ))));
245-
246- parallel::parallel_for (out_channels, [&](size_t oc) {
247- for (size_t ic = 0 ; ic < kernel_in_channels; ++ic) {
248- for (size_t kh = 0 ; kh < kernel_height; ++kh) {
249- for (size_t kw = 0 ; kw < kernel_width; ++kw) {
250- dil_kernel[oc][ic][kh * dilations_][kw * dilations_] =
251- kernel_.get <ValueType>({oc, ic, kh, kw});
252- }
253- }
254- }
255- }, options);
219+ const auto &input_data = *input.as <ValueType>();
220+ const auto &kernel_data = *kernel_.as <ValueType>();
221+ const std::vector<ValueType> *bias_data = nullptr ;
222+ if (!bias_.empty ()) {
223+ bias_data = bias_.as <ValueType>();
224+ }
256225
257- std::vector<std::vector<std::vector<std::vector<ValueType>>>> output_tensor (
258- batch_size,
259- std::vector<std::vector<std::vector<ValueType>>>(
260- out_channels, std::vector<std::vector<ValueType>>(
261- out_height, std::vector<ValueType>(out_width, 0 ))));
226+ const size_t input_channel_stride = in_height * in_width;
227+ const size_t input_batch_stride = in_channels * input_channel_stride;
228+ const size_t kernel_channel_stride = kernel_height * kernel_width;
229+ const size_t kernel_output_stride =
230+ kernel_in_channels * kernel_channel_stride;
231+ const size_t output_channel_stride = out_height * out_width;
232+ const size_t output_batch_stride = out_channels * output_channel_stride;
233+ const size_t in_channels_per_group = in_channels / group_;
234+ const size_t out_channels_per_group = out_channels / group_;
235+ const bool collapsed_kernel = dilations_ == 0 ;
262236
237+ Shape output_shape ({batch_size, out_channels, out_height, out_width});
238+ std::vector<ValueType> flat_output (output_shape.count (), 0 );
263239 size_t total_work = batch_size * out_channels;
264240 parallel::parallel_for (total_work, [&](size_t idx) {
265241 size_t b = idx / out_channels;
266242 size_t oc = idx % out_channels;
243+ size_t input_batch_base = b * input_batch_stride;
244+ size_t output_base = b * output_batch_stride + oc * output_channel_stride;
245+ size_t group = (group_ > 1 ) ? oc / out_channels_per_group : 0 ;
246+ size_t group_start_channel = group * in_channels_per_group;
247+ size_t group_end_channel = group_start_channel + in_channels_per_group;
248+ size_t kernel_oc_base = oc * kernel_output_stride;
249+ ValueType bias_value = ValueType{};
250+ if (bias_data != nullptr && oc < bias_data->size ()) {
251+ bias_value = (*bias_data)[oc];
252+ }
267253
268254 for (size_t oh = 0 ; oh < out_height; ++oh) {
255+ std::ptrdiff_t input_h_base = static_cast <std::ptrdiff_t >(oh * stride_) -
256+ static_cast <std::ptrdiff_t >(pads_);
269257 for (size_t ow = 0 ; ow < out_width; ++ow) {
270- ValueType value = 0 ;
271- size_t h_start = oh * stride_;
272- size_t w_start = ow * stride_;
273-
274- size_t group = (group_ > 1 ) ? oc / (out_channels / group_) : 0 ;
275- size_t group_start_channel = group * (in_channels / group_);
276- size_t group_end_channel = (group + 1 ) * (in_channels / group_);
258+ ValueType value = bias_value;
259+ std::ptrdiff_t input_w_base =
260+ static_cast <std::ptrdiff_t >(ow * stride_) -
261+ static_cast <std::ptrdiff_t >(pads_);
262+ size_t output_idx = output_base + oh * out_width + ow;
277263
278264 for (size_t ic = group_start_channel; ic < group_end_channel; ++ic) {
279265 size_t kernel_ic = ic - group_start_channel;
280-
281- for (size_t kh = 0 ; kh < dilated_kernel_height; ++kh) {
282- for (size_t kw = 0 ; kw < dilated_kernel_width; ++kw) {
283- size_t h_index = h_start + kh;
284- size_t w_index = w_start + kw;
285-
286- if (h_index < padded_input[b].size () &&
287- w_index < padded_input[b][h_index].size ()) {
288- value += padded_input[b][h_index][w_index][ic] *
289- dil_kernel[oc][kernel_ic][kh][kw];
290- }
266+ size_t input_channel_base =
267+ input_batch_base + ic * input_channel_stride;
268+ size_t kernel_ic_base =
269+ kernel_oc_base + kernel_ic * kernel_channel_stride;
270+
271+ if (collapsed_kernel) {
272+ if (input_h_base >= 0 &&
273+ input_h_base < static_cast <std::ptrdiff_t >(in_height) &&
274+ input_w_base >= 0 &&
275+ input_w_base < static_cast <std::ptrdiff_t >(in_width)) {
276+ size_t input_idx = input_channel_base +
277+ static_cast <size_t >(input_h_base) * in_width +
278+ static_cast <size_t >(input_w_base);
279+ size_t kernel_idx = kernel_ic_base + kernel_channel_stride - 1 ;
280+ value += input_data[input_idx] * kernel_data[kernel_idx];
291281 }
282+ continue ;
292283 }
293- }
294284
295- if (!bias_.empty () && oc < bias_.get_shape ()[0 ]) {
296- value += bias_.get <ValueType>({oc});
297- }
285+ for (size_t kh = 0 ; kh < kernel_height; ++kh) {
286+ std::ptrdiff_t input_h =
287+ input_h_base + static_cast <std::ptrdiff_t >(kh * dilations_);
288+ if (input_h < 0 ||
289+ input_h >= static_cast <std::ptrdiff_t >(in_height)) {
290+ continue ;
291+ }
298292
299- output_tensor[b][oc][oh][ow] = value;
300- }
301- }
302- }, options);
293+ size_t input_row_base =
294+ input_channel_base + static_cast <size_t >(input_h) * in_width;
295+ size_t kernel_row_base = kernel_ic_base + kh * kernel_width;
303296
304- Shape output_shape ({batch_size, out_channels, out_height, out_width});
305- std::vector<ValueType> flat_output (batch_size * out_channels * out_height *
306- out_width);
297+ for (size_t kw = 0 ; kw < kernel_width; ++kw) {
298+ std::ptrdiff_t input_w =
299+ input_w_base + static_cast <std::ptrdiff_t >(kw * dilations_);
300+ if (input_w < 0 ||
301+ input_w >= static_cast <std::ptrdiff_t >(in_width)) {
302+ continue ;
303+ }
307304
308- parallel::parallel_for (batch_size, [&](size_t b) {
309- size_t base_idx = b * out_channels * out_height * out_width;
310- for (size_t oc = 0 ; oc < out_channels; ++oc) {
311- for (size_t h = 0 ; h < out_height; ++h) {
312- for (size_t w = 0 ; w < out_width; ++w) {
313- flat_output[base_idx++] = output_tensor[b][oc][h][w];
305+ value +=
306+ input_data[input_row_base + static_cast <size_t >(input_w)] *
307+ kernel_data[kernel_row_base + kw];
308+ }
309+ }
314310 }
311+
312+ flat_output[output_idx] = value;
315313 }
316314 }
317315 }, options);
@@ -320,8 +318,8 @@ void Conv4D(const Tensor& input, const Tensor& kernel_, const Tensor& bias_,
320318}
321319
322320template <typename ValueType>
323- void DepthwiseConv4D (const Tensor& input, const Tensor& kernel_,
324- const Tensor& bias_, Tensor& output, size_t stride_,
321+ void DepthwiseConv4D (const Tensor & input, const Tensor & kernel_,
322+ const Tensor & bias_, Tensor & output, size_t stride_,
325323 size_t pads_, size_t dilations_,
326324 ParBackend backend = ParBackend::kSeq ) {
327325 size_t batch_size = input.get_shape ()[0 ];
@@ -388,8 +386,8 @@ void DepthwiseConv4D(const Tensor& input, const Tensor& kernel_,
388386
389387// NCHW -> NCHW only (Legacy version)
390388template <typename ValueType>
391- void Conv4D_Legacy (const Tensor& input, const Tensor& kernel_,
392- const Tensor& bias_, Tensor& output, size_t stride_,
389+ void Conv4D_Legacy (const Tensor & input, const Tensor & kernel_,
390+ const Tensor & bias_, Tensor & output, size_t stride_,
393391 size_t pads_, size_t dilations_,
394392 ParBackend backend = ParBackend::kSeq ) {
395393 size_t batch_size = input.get_shape ()[0 ];
0 commit comments