Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
127 changes: 127 additions & 0 deletions NAM/container.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
#include <algorithm>
#include <cmath>
#include <stdexcept>
#include <sstream>

#include "container.h"
#include "get_dsp.h"
#include "model_config.h"

namespace nam
{
namespace container
{

// =============================================================================
// ContainerModel
// =============================================================================

ContainerModel::ContainerModel(std::vector<Submodel> submodels, const double expected_sample_rate)
: DSP(1, 1, expected_sample_rate)
, _submodels(std::move(submodels))
{
if (_submodels.empty())
throw std::runtime_error("ContainerModel: no submodels provided");

// Validate ordering and that final max_value covers 1.0
for (size_t i = 1; i < _submodels.size(); ++i)
{
if (_submodels[i].max_value <= _submodels[i - 1].max_value)
throw std::runtime_error("ContainerModel: submodels must be sorted by ascending max_value");
}
if (_submodels.back().max_value < 1.0)
throw std::runtime_error("ContainerModel: last submodel max_value must be >= 1.0");

// Validate all submodels have the same expected sample rate
for (const auto& sm : _submodels)
{
double sr = sm.model->GetExpectedSampleRate();
if (sr != expected_sample_rate && sr != NAM_UNKNOWN_EXPECTED_SAMPLE_RATE
&& expected_sample_rate != NAM_UNKNOWN_EXPECTED_SAMPLE_RATE)
{
std::stringstream ss;
ss << "ContainerModel: submodel sample rate mismatch (expected " << expected_sample_rate << ", got " << sr << ")";
throw std::runtime_error(ss.str());
}
}

// Default to full size (last submodel)
_active_index = _submodels.size() - 1;
}

void ContainerModel::process(NAM_SAMPLE** input, NAM_SAMPLE** output, const int num_frames)
{
_active_model().process(input, output, num_frames);
}

void ContainerModel::prewarm()
{
for (auto& sm : _submodels)
sm.model->prewarm();
}

void ContainerModel::Reset(const double sampleRate, const int maxBufferSize)
{
DSP::Reset(sampleRate, maxBufferSize);
for (auto& sm : _submodels)
sm.model->Reset(sampleRate, maxBufferSize);
}

void ContainerModel::SetSlimmableSize(const double val)
{
_active_index = _submodels.size() - 1;
for (size_t i = 0; i < _submodels.size(); ++i)
{
if (val < _submodels[i].max_value)
{
_active_index = i;
break;
}
}

const double sr = mHaveExternalSampleRate ? mExternalSampleRate : mExpectedSampleRate;
_active_model().ResetAndPrewarm(sr, GetMaxBufferSize());
}

// =============================================================================
// Config / factory
// =============================================================================

std::unique_ptr<DSP> ContainerConfig::create(std::vector<float> weights, double sampleRate)
{
(void)weights; // Container has no top-level weights

auto submodels_json = raw_config["submodels"];
if (!submodels_json.is_array() || submodels_json.empty())
throw std::runtime_error("SlimmableContainer: 'submodels' must be a non-empty array");

std::vector<Submodel> submodels;
submodels.reserve(submodels_json.size());

for (const auto& entry : submodels_json)
{
double max_val = entry.at("max_value").get<double>();
const auto& model_json = entry.at("model");

// Each submodel is a full NAM model spec (has architecture, config, weights, etc.)
auto dsp = get_dsp(model_json);

submodels.push_back({max_val, std::move(dsp)});
}

return std::make_unique<ContainerModel>(std::move(submodels), sampleRate);
}

std::unique_ptr<ModelConfig> create_config(const nlohmann::json& config, double sampleRate)
{
auto c = std::make_unique<ContainerConfig>();
c->raw_config = config;
c->sample_rate = sampleRate;
return c;
}

// Auto-register
static ConfigParserHelper _register_SlimmableContainer("SlimmableContainer", create_config);

} // namespace container
} // namespace nam
63 changes: 63 additions & 0 deletions NAM/container.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
#pragma once

#include <memory>
#include <stdexcept>
#include <vector>

#include "dsp.h"
#include "model_config.h"
#include "slimmable.h"

namespace nam
{
namespace container
{

struct Submodel
{
double max_value;
std::unique_ptr<DSP> model;
};

/// \brief A container model that holds multiple submodels at different sizes
///
/// SetSlimmableSize selects the active submodel based on the max_value thresholds.
/// Each submodel covers values up to (but not including) its max_value.
/// The last submodel is the fallback for values at or above the last threshold.
class ContainerModel : public DSP, public SlimmableModel
{
public:
/// \brief Constructor
/// \param submodels Vector of submodels sorted by max_value ascending
/// \param expected_sample_rate Expected sample rate in Hz
ContainerModel(std::vector<Submodel> submodels, const double expected_sample_rate);

void process(NAM_SAMPLE** input, NAM_SAMPLE** output, const int num_frames) override;
void prewarm() override;
void Reset(const double sampleRate, const int maxBufferSize) override;
void SetSlimmableSize(const double val) override;

protected:
int PrewarmSamples() override { return 0; }

private:
std::vector<Submodel> _submodels;
size_t _active_index = 0;

DSP& _active_model() { return *_submodels[_active_index].model; }
};

// Config / registration

struct ContainerConfig : public ModelConfig
{
nlohmann::json raw_config;
double sample_rate;

std::unique_ptr<DSP> create(std::vector<float> weights, double sampleRate) override;
};

std::unique_ptr<ModelConfig> create_config(const nlohmann::json& config, double sampleRate);

} // namespace container
} // namespace nam
2 changes: 1 addition & 1 deletion NAM/get_dsp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ void verify_config_version(const std::string versionStr)
<< currentVersion.toString();
throw std::runtime_error(ss.str());
}
else if (version.major == 0 && version.minor == 6 && version.patch > 0)
else if (currentVersion < version)
{
std::cerr << "Model config is a partially-supported version " << versionStr
<< ". The latest fully-supported version is " << currentVersion.toString()
Expand Down
2 changes: 1 addition & 1 deletion NAM/get_dsp.h
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ Version ParseVersion(const std::string& versionStr);

void verify_config_version(const std::string versionStr);

const std::string LATEST_FULLY_SUPPORTED_NAM_FILE_VERSION = "0.6.0";
const std::string LATEST_FULLY_SUPPORTED_NAM_FILE_VERSION = "0.7.0";
const std::string EARLIEST_SUPPORTED_NAM_FILE_VERSION = "0.5.0";

/// \brief Get NAM from a .nam file at the provided location
Expand Down
21 changes: 21 additions & 0 deletions NAM/slimmable.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
#pragma once

namespace nam
{

/// \brief Interface for models that support dynamic size reduction
///
/// Models implementing this interface can reduce their computational cost
/// at the expense of quality. The interpretation of the size parameter is
/// model-specific (e.g., selecting a sub-model, pruning channels, etc.).
class SlimmableModel
{
public:
virtual ~SlimmableModel() = default;

/// \brief Set the slimmable size of the model
/// \param val Value between 0.0 (minimum size) and 1.0 (maximum size)
virtual void SetSlimmableSize(const double val) = 0;
};

} // namespace nam
2 changes: 2 additions & 0 deletions docs/nam_file_version.rst
Original file line number Diff line number Diff line change
Expand Up @@ -31,3 +31,5 @@ The following table shows which versions of NeuralAmpModelerCore support which m
- 0.5.3
* - 0.4.0
- 0.6.0
* - 0.4.1
- 0.7.0
1 change: 1 addition & 0 deletions example_models/slimmable_container.nam

Large diffs are not rendered by default.

57 changes: 52 additions & 5 deletions tools/render.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,16 @@
#include <cmath>
#include <cstdint>
#include <cstring>
#include <cstdlib>
#include <filesystem>
#include <fstream>
#include <iostream>
#include <string>
#include <vector>

#include "NAM/dsp.h"
#include "NAM/get_dsp.h"
#include "NAM/slimmable.h"
#include "wav.h"

namespace
Expand Down Expand Up @@ -60,15 +63,47 @@ bool SaveWavFloat32(const char* fileName, const float* samples, size_t numSample

int main(int argc, char* argv[])
{
if (argc < 3 || argc > 4)
// Parse optional --slim <value> from the arguments
double slimValue = -1.0;
bool hasSlim = false;
std::vector<char*> positionalArgs;
positionalArgs.push_back(argv[0]);

for (int i = 1; i < argc; i++)
{
std::string arg(argv[i]);
if (arg == "--slim")
{
if (i + 1 >= argc)
{
std::cerr << "Error: --slim requires a value between 0.0 and 1.0\n";
return 1;
}
char* end = nullptr;
slimValue = std::strtod(argv[i + 1], &end);
if (end == argv[i + 1] || *end != '\0' || slimValue < 0.0 || slimValue > 1.0)
{
std::cerr << "Error: --slim value must be a number between 0.0 and 1.0\n";
return 1;
}
hasSlim = true;
i++; // skip the value
}
else
{
positionalArgs.push_back(argv[i]);
}
}

if (positionalArgs.size() < 3 || positionalArgs.size() > 4)
{
std::cerr << "Usage: render <model.nam> <input.wav> [output.wav]\n";
std::cerr << "Usage: render [--slim <0.0-1.0>] <model.nam> <input.wav> [output.wav]\n";
return 1;
}

const char* modelPath = argv[1];
const char* inputPath = argv[2];
const char* outputPath = (argc >= 4) ? argv[3] : "output.wav";
const char* modelPath = positionalArgs[1];
const char* inputPath = positionalArgs[2];
const char* outputPath = (positionalArgs.size() >= 4) ? positionalArgs[3] : "output.wav";

std::cerr << "Loading model [" << modelPath << "]\n";
auto model = nam::get_dsp(std::filesystem::path(modelPath));
Expand All @@ -79,6 +114,18 @@ int main(int argc, char* argv[])
}
std::cerr << "Model loaded successfully\n";

if (hasSlim)
{
auto* slimmable = dynamic_cast<nam::SlimmableModel*>(model.get());
if (!slimmable)
{
std::cerr << "Error: --slim requires a model that implements the SlimmableModel interface\n";
return 1;
}
std::cerr << "Setting slimmable size to " << slimValue << "\n";
slimmable->SetSlimmableSize(slimValue);
}

std::vector<float> inputAudio;
double inputSampleRate = 0.0;
auto loadResult = dsp::wav::Load(inputPath, inputAudio, inputSampleRate);
Expand Down
20 changes: 20 additions & 0 deletions tools/run_tests.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@
#include "test/test_wavenet_configurable_gating.cpp"
#include "test/test_noncontiguous_blocks.cpp"
#include "test/test_extensible.cpp"
#include "test/test_container.cpp"
#include "test/test_render_slim.cpp"

int main()
{
Expand Down Expand Up @@ -268,6 +270,24 @@ int main()
// Extensibility: external architecture registration and get_dsp (issue #230)
test_extensible::run_extensibility_tests();

// Container / SlimmableContainer tests
test_container::test_container_loads_from_json();
test_container::test_container_processes_audio();
test_container::test_container_slimmable_selects_submodel();
test_container::test_container_boundary_values();
test_container::test_container_empty_submodels_throws();
test_container::test_container_last_max_value_must_cover_one();
test_container::test_container_unsorted_submodels_throws();
test_container::test_container_sample_rate_mismatch_throws();
test_container::test_container_load_from_file();
test_container::test_container_default_is_max_size();

// Render --slim tests
test_render_slim::test_slim_changes_output();
test_render_slim::test_slim_rejects_non_slimmable();
test_render_slim::test_slim_boundary_values();
test_render_slim::test_slim_applied_before_processing();

std::cout << "Success!" << std::endl;
#ifdef ADDASSERT
std::cerr << "===============================================================" << std::endl;
Expand Down
Loading
Loading