diff --git a/NAM/convnet.cpp b/NAM/convnet.cpp index 98bf649..535fac6 100644 --- a/NAM/convnet.cpp +++ b/NAM/convnet.cpp @@ -8,6 +8,7 @@ #include #include "dsp.h" +#include "registry.h" #include "convnet.h" nam::convnet::BatchNorm::BatchNorm(const int dim, std::vector::iterator& weights) @@ -184,3 +185,20 @@ void nam::convnet::ConvNet::_rewind_buffers_() // Now we can do the rest of the rewind this->Buffer::_rewind_buffers_(); } + +// Factory +std::unique_ptr nam::convnet::Factory(const nlohmann::json& config, std::vector& weights, + const double expectedSampleRate) +{ + const int channels = config["channels"]; + const std::vector dilations = config["dilations"]; + const bool batchnorm = config["batchnorm"]; + const std::string activation = config["activation"]; + return std::make_unique( + channels, dilations, batchnorm, activation, weights, expectedSampleRate); +} + +namespace +{ +static nam::factory::Helper _register_ConvNet("ConvNet", nam::convnet::Factory); +} diff --git a/NAM/convnet.h b/NAM/convnet.h index 458cf67..34cbaa0 100644 --- a/NAM/convnet.h +++ b/NAM/convnet.h @@ -86,5 +86,10 @@ class ConvNet : public Buffer int mPrewarmSamples = 0; // Pre-compute during initialization int PrewarmSamples() override { return mPrewarmSamples; }; }; + +// Factory +std::unique_ptr Factory(const nlohmann::json& config, std::vector& weights, + const double expectedSampleRate); + }; // namespace convnet }; // namespace nam diff --git a/NAM/dsp.cpp b/NAM/dsp.cpp index ec5892b..f3b5c14 100644 --- a/NAM/dsp.cpp +++ b/NAM/dsp.cpp @@ -7,6 +7,7 @@ #include #include "dsp.h" +#include "registry.h" #define tanh_impl_ std::tanh // #define tanh_impl_ fast_tanh_ @@ -192,6 +193,15 @@ void nam::Linear::process(NAM_SAMPLE* input, NAM_SAMPLE* output, const int num_f nam::Buffer::_advance_input_buffer_(num_frames); } +// Factory +std::unique_ptr nam::linear::Factory(const nlohmann::json& config, std::vector& weights, + const double expectedSampleRate) +{ + const int receptive_field = config["receptive_field"]; + const bool bias = config["bias"]; + return std::make_unique(receptive_field, bias, weights, expectedSampleRate); +} + // NN modules ================================================================= void nam::Conv1D::set_weights_(std::vector::iterator& weights) diff --git a/NAM/dsp.h b/NAM/dsp.h index e8a6ef3..bf6c448 100644 --- a/NAM/dsp.h +++ b/NAM/dsp.h @@ -164,6 +164,12 @@ class Linear : public Buffer float _bias; }; +namespace linear +{ +std::unique_ptr Factory(const nlohmann::json& config, std::vector& weights, + const double expectedSampleRate); +} // namespace linear + // NN modules ================================================================= // TODO conv could take care of its own ring buffer. diff --git a/NAM/get_dsp.cpp b/NAM/get_dsp.cpp index f0ff312..99dd3a0 100644 --- a/NAM/get_dsp.cpp +++ b/NAM/get_dsp.cpp @@ -4,10 +4,12 @@ #include #include "dsp.h" +#include "registry.h" #include "json.hpp" #include "lstm.h" #include "convnet.h" #include "wavenet.h" +#include "get_dsp.h" namespace nam { @@ -102,12 +104,7 @@ std::unique_ptr get_dsp(const std::filesystem::path config_filename, dspDat returnedConfig.config = j["config"]; returnedConfig.metadata = j["metadata"]; returnedConfig.weights = weights; - if (j.find("sample_rate") != j.end()) - returnedConfig.expected_sample_rate = j["sample_rate"]; - else - { - returnedConfig.expected_sample_rate = -1.0; - } + returnedConfig.expected_sample_rate = nam::get_sample_rate_from_nam_file(j); /*Copy to a new dsp_config object for get_dsp below, since not sure if weights actually get modified as being non-const references on some @@ -152,47 +149,9 @@ std::unique_ptr get_dsp(dspData& conf) } const double expectedSampleRate = conf.expected_sample_rate; - std::unique_ptr out = nullptr; - if (architecture == "Linear") - { - const int receptive_field = config["receptive_field"]; - const bool _bias = config["bias"]; - out = std::make_unique(receptive_field, _bias, weights, expectedSampleRate); - } - else if (architecture == "ConvNet") - { - const int channels = config["channels"]; - const bool batchnorm = config["batchnorm"]; - std::vector dilations = config["dilations"]; - const std::string activation = config["activation"]; - out = std::make_unique(channels, dilations, batchnorm, activation, weights, expectedSampleRate); - } - else if (architecture == "LSTM") - { - const int num_layers = config["num_layers"]; - const int input_size = config["input_size"]; - const int hidden_size = config["hidden_size"]; - out = std::make_unique(num_layers, input_size, hidden_size, weights, expectedSampleRate); - } - else if (architecture == "WaveNet") - { - std::vector layer_array_params; - for (size_t i = 0; i < config["layers"].size(); i++) - { - nlohmann::json layer_config = config["layers"][i]; - layer_array_params.push_back( - wavenet::LayerArrayParams(layer_config["input_size"], layer_config["condition_size"], layer_config["head_size"], - layer_config["channels"], layer_config["kernel_size"], layer_config["dilations"], - layer_config["activation"], layer_config["gated"], layer_config["head_bias"])); - } - const bool with_head = !config["head"].is_null(); - const float head_scale = config["head_scale"]; - out = std::make_unique(layer_array_params, head_scale, with_head, weights, expectedSampleRate); - } - else - { - throw std::runtime_error("Unrecognized architecture"); - } + // Initialize using registry-based factory + std::unique_ptr out = + nam::factory::FactoryRegistry::instance().create(architecture, config, weights, expectedSampleRate); if (loudness.have) { out->SetLoudness(loudness.value); @@ -212,4 +171,13 @@ std::unique_ptr get_dsp(dspData& conf) return out; } + +double get_sample_rate_from_nam_file(const nlohmann::json& j) +{ + if (j.find("sample_rate") != j.end()) + return j["sample_rate"]; + else + return -1.0; +} + }; // namespace nam diff --git a/NAM/get_dsp.h b/NAM/get_dsp.h index 40bbbc9..41b3cd9 100644 --- a/NAM/get_dsp.h +++ b/NAM/get_dsp.h @@ -1,3 +1,5 @@ +#pragma once + #include #include "dsp.h" @@ -12,4 +14,8 @@ std::unique_ptr get_dsp(dspData& conf); // Get NAM from a provided .nam file path and store its configuration in the provided conf std::unique_ptr get_dsp(const std::filesystem::path config_filename, dspData& returnedConfig); + +// Get sample rate from a .nam file +// Returns -1 if not known (Really old .nam files) +double get_sample_rate_from_nam_file(const nlohmann::json& j); }; // namespace nam diff --git a/NAM/lstm.cpp b/NAM/lstm.cpp index 4d74c29..6fa33a2 100644 --- a/NAM/lstm.cpp +++ b/NAM/lstm.cpp @@ -1,7 +1,9 @@ #include #include #include +#include +#include "registry.h" #include "lstm.h" nam::lstm::LSTMCell::LSTMCell(const int input_size, const int hidden_size, std::vector::iterator& weights) @@ -102,3 +104,19 @@ float nam::lstm::LSTM::_process_sample(const float x) this->_layers[i].process_(this->_layers[i - 1].get_hidden_state()); return this->_head_weight.dot(this->_layers[this->_layers.size() - 1].get_hidden_state()) + this->_head_bias; } + +// Factory to instantiate from nlohmann json +std::unique_ptr nam::lstm::Factory(const nlohmann::json& config, std::vector& weights, + const double expectedSampleRate) +{ + const int num_layers = config["num_layers"]; + const int input_size = config["input_size"]; + const int hidden_size = config["hidden_size"]; + return std::make_unique(num_layers, input_size, hidden_size, weights, expectedSampleRate); +} + +// Register the factory +namespace +{ +static nam::factory::Helper _register_LSTM("LSTM", nam::lstm::Factory); +} diff --git a/NAM/lstm.h b/NAM/lstm.h index 7ee38e0..17d0ada 100644 --- a/NAM/lstm.h +++ b/NAM/lstm.h @@ -3,6 +3,7 @@ #include #include +#include #include @@ -69,5 +70,10 @@ class LSTM : public DSP // Since this is assumed to not be a parametric model, its shape should be (1,) Eigen::VectorXf _input; }; + +// Factory to instantiate from nlohmann json +std::unique_ptr Factory(const nlohmann::json& config, std::vector& weights, + const double expectedSampleRate); + }; // namespace lstm }; // namespace nam diff --git a/NAM/registry.h b/NAM/registry.h new file mode 100644 index 0000000..e3bc4e8 --- /dev/null +++ b/NAM/registry.h @@ -0,0 +1,63 @@ +#pragma once + +// Registry for DSP objects + +#include +#include +#include +#include + +#include "dsp.h" + +namespace nam +{ +namespace factory +{ +// TODO get rid of weights and expectedSampleRate +using FactoryFunction = std::function(const nlohmann::json&, std::vector&, const double)>; + +// Register factories for instantiating DSP objects +class FactoryRegistry +{ +public: + static FactoryRegistry& instance() + { + static FactoryRegistry inst; + return inst; + } + + void registerFactory(const std::string& key, FactoryFunction func) + { + // Assert that the key is not already registered + if (factories_.find(key) != factories_.end()) + { + throw std::runtime_error("Factory already registered for key: " + key); + } + factories_[key] = func; + } + + std::unique_ptr create(const std::string& name, const nlohmann::json& config, std::vector& weights, + const double expectedSampleRate) const + { + auto it = factories_.find(name); + if (it != factories_.end()) + { + return it->second(config, weights, expectedSampleRate); + } + throw std::runtime_error("Factory not found for name: " + name); + } + +private: + std::unordered_map factories_; +}; + +// Registration helper. Use this to register your factories. +struct Helper +{ + Helper(const std::string& name, FactoryFunction factory) + { + FactoryRegistry::instance().registerFactory(name, std::move(factory)); + } +}; +} // namespace factory +} // namespace nam diff --git a/NAM/version.h b/NAM/version.h index b1328ad..f746905 100644 --- a/NAM/version.h +++ b/NAM/version.h @@ -1,9 +1,6 @@ -#ifndef version_h -#define version_h +#pragma once // Make sure this matches NAM version in ../CMakeLists.txt! #define NEURAL_AMP_MODELER_DSP_VERSION_MAJOR 0 #define NEURAL_AMP_MODELER_DSP_VERSION_MINOR 3 #define NEURAL_AMP_MODELER_DSP_VERSION_PATCH 0 - -#endif diff --git a/NAM/wavenet.cpp b/NAM/wavenet.cpp index 4abff25..0d573d7 100644 --- a/NAM/wavenet.cpp +++ b/NAM/wavenet.cpp @@ -4,6 +4,7 @@ #include +#include "registry.h" #include "wavenet.h" nam::wavenet::_DilatedConv::_DilatedConv(const int in_channels, const int out_channels, const int kernel_size, @@ -397,3 +398,28 @@ void nam::wavenet::WaveNet::process(NAM_SAMPLE* input, NAM_SAMPLE* output, const // Finalize to prepare for the next call: this->_advance_buffers_(num_frames); } + +// Factory to instantiate from nlohmann json +std::unique_ptr nam::wavenet::Factory(const nlohmann::json& config, std::vector& weights, + const double expectedSampleRate) +{ + std::vector layer_array_params; + for (size_t i = 0; i < config["layers"].size(); i++) + { + nlohmann::json layer_config = config["layers"][i]; + layer_array_params.push_back(nam::wavenet::LayerArrayParams( + layer_config["input_size"], layer_config["condition_size"], layer_config["head_size"], layer_config["channels"], + layer_config["kernel_size"], layer_config["dilations"], layer_config["activation"], layer_config["gated"], + layer_config["head_bias"])); + } + const bool with_head = !config["head"].is_null(); + const float head_scale = config["head_scale"]; + return std::make_unique( + layer_array_params, head_scale, with_head, weights, expectedSampleRate); +} + +// Register the factory +namespace +{ +static nam::factory::Helper _register_WaveNet("WaveNet", nam::wavenet::Factory); +} diff --git a/NAM/wavenet.h b/NAM/wavenet.h index d2e74fb..12fbfea 100644 --- a/NAM/wavenet.h +++ b/NAM/wavenet.h @@ -220,5 +220,9 @@ class WaveNet : public DSP int mPrewarmSamples = 0; // Pre-compute during initialization int PrewarmSamples() override { return mPrewarmSamples; }; }; + +// Factory to instantiate from nlohmann json +std::unique_ptr Factory(const nlohmann::json& config, std::vector& weights, + const double expectedSampleRate); }; // namespace wavenet }; // namespace nam