Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions NAM/convnet.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ namespace convnet
class BatchNorm
{
public:
BatchNorm(){};
BatchNorm() {};
BatchNorm(const int dim, std::vector<float>::iterator& weights);
void process_(Eigen::MatrixXf& input, const long i_start, const long i_end) const;

Expand All @@ -39,7 +39,7 @@ class BatchNorm
class ConvNetBlock
{
public:
ConvNetBlock(){};
ConvNetBlock() {};
void set_weights_(const int in_channels, const int out_channels, const int _dilation, const bool batchnorm,
const std::string activation, std::vector<float>::iterator& weights);
void process_(const Eigen::MatrixXf& input, Eigen::MatrixXf& output, const long i_start, const long i_end) const;
Expand All @@ -55,7 +55,7 @@ class ConvNetBlock
class _Head
{
public:
_Head(){};
_Head() {};
_Head(const int channels, std::vector<float>::iterator& weights);
void process_(const Eigen::MatrixXf& input, Eigen::VectorXf& output, const long i_start, const long i_end) const;

Expand Down
15 changes: 8 additions & 7 deletions NAM/wavenet.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ class _Layer
, _input_mixin(condition_size, gated ? 2 * channels : channels, false)
, _1x1(channels, channels, true)
, _activation(activations::Activation::get_activation(activation))
, _gated(gated){};
, _gated(gated) {};
void set_weights_(std::vector<float>::iterator& weights);
// :param `input`: from previous layer
// :param `output`: to next layer
Expand Down Expand Up @@ -170,30 +170,31 @@ class WaveNet : public DSP
WaveNet(const std::vector<LayerArrayParams>& layer_array_params, const float head_scale, const bool with_head,
std::vector<float> weights, const double expected_sample_rate = -1.0);
~WaveNet() = default;

void process(NAM_SAMPLE* input, NAM_SAMPLE* output, const int num_frames) override;
void set_weights_(std::vector<float>& weights);

protected:
// Element-wise arrays:
Eigen::MatrixXf _condition;
// Fill in the "condition" array that's fed into the various parts of the net.
virtual void _set_condition_array(NAM_SAMPLE* input, const int num_frames);

private:
long _num_frames;
std::vector<_LayerArray> _layer_arrays;
// Their outputs
std::vector<Eigen::MatrixXf> _layer_array_outputs;
// Head _head;

// Element-wise arrays:
Eigen::MatrixXf _condition;
// One more than total layer arrays
std::vector<Eigen::MatrixXf> _head_arrays;
float _head_scale;
Eigen::MatrixXf _head_output;

void _advance_buffers_(const int num_frames);
void _prepare_for_frames_(const long num_frames);
void process(NAM_SAMPLE* input, NAM_SAMPLE* output, const int num_frames) override;

virtual int _get_condition_dim() const { return 1; };
// Fill in the "condition" array that's fed into the various parts of the net.
virtual void _set_condition_array(NAM_SAMPLE* input, const int num_frames);
// Ensure that all buffer arrays are the right size for this num_frames
void _set_num_frames_(const long num_frames);

Expand Down
Loading