From 7418491dfe132b0bef77225d8417d29af8ac6002 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jo=C3=A3o=20Felipe=20Santos?= Date: Wed, 11 Mar 2026 20:14:06 -0700 Subject: [PATCH 1/3] First draft of SlimmableWavenet --- NAM/slimmable_wavenet.cpp | 465 ++++++++++++++++++++++++ NAM/slimmable_wavenet.h | 76 ++++ example_models/slimmable_wavenet.nam | 501 ++++++++++++++++++++++++++ tools/run_tests.cpp | 13 + tools/test/test_slimmable_wavenet.cpp | 328 +++++++++++++++++ 5 files changed, 1383 insertions(+) create mode 100644 NAM/slimmable_wavenet.cpp create mode 100644 NAM/slimmable_wavenet.h create mode 100644 example_models/slimmable_wavenet.nam create mode 100644 tools/test/test_slimmable_wavenet.cpp diff --git a/NAM/slimmable_wavenet.cpp b/NAM/slimmable_wavenet.cpp new file mode 100644 index 0000000..e31e306 --- /dev/null +++ b/NAM/slimmable_wavenet.cpp @@ -0,0 +1,465 @@ +#include "slimmable_wavenet.h" +#include "get_dsp.h" + +#include +#include +#include + +namespace nam +{ +namespace slimmable_wavenet +{ + +namespace +{ + +// ============================================================================ +// Weight extraction helpers (groups=1 only) +// ============================================================================ + +// Extract Conv1x1 weight subset: take first slim_out rows, first slim_in cols. +// Full layout (groups=1): row-major out*in, then optional bias(out). +void extract_conv1x1(std::vector::const_iterator& src, int full_in, int full_out, int slim_in, int slim_out, + bool bias, std::vector& dst) +{ + for (int i = 0; i < full_out; i++) + { + for (int j = 0; j < full_in; j++) + { + float w = *(src++); + if (i < slim_out && j < slim_in) + dst.push_back(w); + } + } + if (bias) + { + for (int i = 0; i < full_out; i++) + { + float b = *(src++); + if (i < slim_out) + dst.push_back(b); + } + } +} + +// Extract Conv1D weight subset: take first slim_out output channels, first slim_in input channels. +// Full layout (groups=1): for each out i, for each in j, for each kernel tap k: weight. Then bias(out). +void extract_conv1d(std::vector::const_iterator& src, int full_in, int full_out, int slim_in, int slim_out, + int kernel_size, std::vector& dst) +{ + for (int i = 0; i < full_out; i++) + { + for (int j = 0; j < full_in; j++) + { + for (int k = 0; k < kernel_size; k++) + { + float w = *(src++); + if (i < slim_out && j < slim_in) + dst.push_back(w); + } + } + } + // Bias is always present for conv in WaveNet layers + for (int i = 0; i < full_out; i++) + { + float b = *(src++); + if (i < slim_out) + dst.push_back(b); + } +} + +// Copy n weights unchanged +void copy_weights(std::vector::const_iterator& src, int n, std::vector& dst) +{ + for (int i = 0; i < n; i++) + dst.push_back(*(src++)); +} + +// Compute slim bottleneck from original params and target channels +int compute_slim_bottleneck(const wavenet::LayerArrayParams& p, int new_channels) +{ + if (!p.layer1x1_params.active) + return new_channels; // bottleneck must equal channels when layer1x1 inactive + return std::max(1, p.bottleneck * new_channels / p.channels); +} + +// Validate that all convolution groups are 1 +void validate_groups(const wavenet::LayerArrayParams& p) +{ + if (p.groups_input != 1) + throw std::runtime_error("SlimmableWavenet: groups_input > 1 not supported"); + if (p.groups_input_mixin != 1) + throw std::runtime_error("SlimmableWavenet: groups_input_mixin > 1 not supported"); + if (p.layer1x1_params.active && p.layer1x1_params.groups != 1) + throw std::runtime_error("SlimmableWavenet: layer1x1 groups > 1 not supported"); + if (p.head1x1_params.active && p.head1x1_params.groups != 1) + throw std::runtime_error("SlimmableWavenet: head1x1 groups > 1 not supported"); +} + +// Map ratio [0,1] to a channel count from allowed_channels. +// Matches Python: idx = min(floor(ratio * len), len - 1) +int ratio_to_channels(double ratio, const std::vector& allowed) +{ + int idx = std::min((int)std::floor(ratio * (double)allowed.size()), (int)allowed.size() - 1); + return allowed[idx]; +} + +// ============================================================================ +// Extract slimmed weights by walking the full weight vector in set_weights_ +// order, using typed LayerArrayParams for dimensions. +// ============================================================================ + +std::vector extract_slimmed_weights(const std::vector& original_params, + const std::vector& full_weights, + const std::vector& new_channels_per_array) +{ + std::vector slim; + auto src = full_weights.cbegin(); + const int num_arrays = (int)original_params.size(); + + for (int arr = 0; arr < num_arrays; arr++) + { + const auto& p = original_params[arr]; + validate_groups(p); + + const int full_ch = p.channels; + const int full_bn = p.bottleneck; + const int num_layers = (int)p.dilations.size(); + const int slim_ch = new_channels_per_array[arr]; + + const int slim_bn = compute_slim_bottleneck(p, slim_ch); + + // Input size: first array keeps original, others get previous array's target channels + const int slim_input_size = (arr == 0) ? p.input_size : new_channels_per_array[arr - 1]; + // Head size: intermediate arrays must match next array's channels; last keeps original + const int slim_head_size = (arr < num_arrays - 1) ? new_channels_per_array[arr + 1] : p.head_size; + + const int full_head_out = p.head1x1_params.active ? p.head1x1_params.out_channels : full_bn; + const int slim_head_out = p.head1x1_params.active ? p.head1x1_params.out_channels : slim_bn; + + // ---- rechannel: Conv1x1(input_size -> channels, no bias) ---- + extract_conv1x1(src, p.input_size, full_ch, slim_input_size, slim_ch, false, slim); + + // ---- Per layer ---- + for (int l = 0; l < num_layers; l++) + { + const bool gated = p.gating_modes[l] != wavenet::GatingMode::NONE; + const int full_bg = gated ? 2 * full_bn : full_bn; + const int slim_bg = gated ? 2 * slim_bn : slim_bn; + + // conv: Conv1D(channels -> B_g, K, bias=true) + extract_conv1d(src, full_ch, full_bg, slim_ch, slim_bg, p.kernel_size, slim); + + // input_mixin: Conv1x1(condition_size -> B_g, no bias) + extract_conv1x1(src, p.condition_size, full_bg, p.condition_size, slim_bg, false, slim); + + // layer1x1 (optional): Conv1x1(B -> C, bias=true) + if (p.layer1x1_params.active) + extract_conv1x1(src, full_bn, full_ch, slim_bn, slim_ch, true, slim); + + // head1x1 (optional): Conv1x1(B -> head1x1_out, bias=true) + if (p.head1x1_params.active) + extract_conv1x1(src, full_bn, p.head1x1_params.out_channels, slim_bn, p.head1x1_params.out_channels, true, + slim); + + // ---- FiLM objects (8, in set_weights_ order) ---- + + // conv_pre_film: FiLM(condition_size -> channels) + if (p.conv_pre_film_params.active) + { + int full_out = (p.conv_pre_film_params.shift ? 2 : 1) * full_ch; + int slim_out = (p.conv_pre_film_params.shift ? 2 : 1) * slim_ch; + extract_conv1x1(src, p.condition_size, full_out, p.condition_size, slim_out, true, slim); + } + + // conv_post_film: FiLM(condition_size -> B_g) + if (p.conv_post_film_params.active) + { + int full_out = (p.conv_post_film_params.shift ? 2 : 1) * full_bg; + int slim_out = (p.conv_post_film_params.shift ? 2 : 1) * slim_bg; + extract_conv1x1(src, p.condition_size, full_out, p.condition_size, slim_out, true, slim); + } + + // input_mixin_pre_film: FiLM(condition_size -> condition_size) -- unchanged + if (p.input_mixin_pre_film_params.active) + { + int dim = (p.input_mixin_pre_film_params.shift ? 2 : 1) * p.condition_size; + copy_weights(src, p.condition_size * dim + dim, slim); + } + + // input_mixin_post_film: FiLM(condition_size -> B_g) + if (p.input_mixin_post_film_params.active) + { + int full_out = (p.input_mixin_post_film_params.shift ? 2 : 1) * full_bg; + int slim_out = (p.input_mixin_post_film_params.shift ? 2 : 1) * slim_bg; + extract_conv1x1(src, p.condition_size, full_out, p.condition_size, slim_out, true, slim); + } + + // activation_pre_film: FiLM(condition_size -> B_g) + if (p.activation_pre_film_params.active) + { + int full_out = (p.activation_pre_film_params.shift ? 2 : 1) * full_bg; + int slim_out = (p.activation_pre_film_params.shift ? 2 : 1) * slim_bg; + extract_conv1x1(src, p.condition_size, full_out, p.condition_size, slim_out, true, slim); + } + + // activation_post_film: FiLM(condition_size -> B) + if (p.activation_post_film_params.active) + { + int full_out = (p.activation_post_film_params.shift ? 2 : 1) * full_bn; + int slim_out = (p.activation_post_film_params.shift ? 2 : 1) * slim_bn; + extract_conv1x1(src, p.condition_size, full_out, p.condition_size, slim_out, true, slim); + } + + // layer1x1_post_film: FiLM(condition_size -> C) + if (p._layer1x1_post_film_params.active && p.layer1x1_params.active) + { + int full_out = (p._layer1x1_post_film_params.shift ? 2 : 1) * full_ch; + int slim_out = (p._layer1x1_post_film_params.shift ? 2 : 1) * slim_ch; + extract_conv1x1(src, p.condition_size, full_out, p.condition_size, slim_out, true, slim); + } + + // head1x1_post_film: FiLM(condition_size -> head1x1_out) -- unchanged + if (p.head1x1_post_film_params.active && p.head1x1_params.active) + { + int dim = (p.head1x1_post_film_params.shift ? 2 : 1) * p.head1x1_params.out_channels; + copy_weights(src, p.condition_size * dim + dim, slim); + } + } + + // ---- head_rechannel: Conv1x1(head_output_size -> head_size, bias=head_bias) ---- + extract_conv1x1(src, full_head_out, p.head_size, slim_head_out, slim_head_size, p.head_bias, slim); + } + + // head_scale: 1 float, copy as-is + slim.push_back(*(src++)); + + return slim; +} + +// ============================================================================ +// Build modified LayerArrayParams with per-array channel counts +// ============================================================================ + +std::vector modify_params_for_channels( + const std::vector& original_params, const std::vector& new_channels_per_array) +{ + std::vector modified; + const int num_arrays = (int)original_params.size(); + + for (int i = 0; i < num_arrays; i++) + { + const auto& p = original_params[i]; + const int new_ch = new_channels_per_array[i]; + + int new_bottleneck = compute_slim_bottleneck(p, new_ch); + int new_input_size = (i == 0) ? p.input_size : new_channels_per_array[i - 1]; + int new_head_size = (i < num_arrays - 1) ? new_channels_per_array[i + 1] : p.head_size; + + modified.push_back(wavenet::LayerArrayParams( + new_input_size, p.condition_size, new_head_size, new_ch, new_bottleneck, p.kernel_size, + std::vector(p.dilations), std::vector(p.activation_configs), + std::vector(p.gating_modes), p.head_bias, p.groups_input, p.groups_input_mixin, + p.layer1x1_params, p.head1x1_params, + std::vector(p.secondary_activation_configs), p.conv_pre_film_params, + p.conv_post_film_params, p.input_mixin_pre_film_params, p.input_mixin_post_film_params, + p.activation_pre_film_params, p.activation_post_film_params, p._layer1x1_post_film_params, + p.head1x1_post_film_params)); + } + + return modified; +} + +// Check if all per-array channels match full (no slimming needed) +bool is_full_size(const std::vector& params, const std::vector& channels) +{ + for (size_t i = 0; i < params.size(); i++) + { + if (channels[i] != params[i].channels) + return false; + } + return true; +} + +} // anonymous namespace + +// ============================================================================ +// SlimmableWavenet +// ============================================================================ + +SlimmableWavenet::SlimmableWavenet(std::vector original_params, + std::vector> per_array_allowed_channels, int in_channels, + float head_scale, bool with_head, nlohmann::json condition_dsp_json, + std::vector full_weights, double expected_sample_rate) +: DSP(in_channels, original_params.back().head_size, expected_sample_rate) +, _original_params(std::move(original_params)) +, _per_array_allowed_channels(std::move(per_array_allowed_channels)) +, _in_channels(in_channels) +, _head_scale(head_scale) +, _with_head(with_head) +, _condition_dsp_json(std::move(condition_dsp_json)) +, _full_weights(std::move(full_weights)) +{ + if (_per_array_allowed_channels.size() != _original_params.size()) + throw std::runtime_error("SlimmableWavenet: per_array_allowed_channels size must match number of layer arrays"); + + // Validate: at least one array must be slimmable + bool any_slimmable = false; + for (size_t i = 0; i < _per_array_allowed_channels.size(); i++) + { + const auto& allowed = _per_array_allowed_channels[i]; + if (!allowed.empty()) + { + any_slimmable = true; + // Validate sorted + for (size_t j = 1; j < allowed.size(); j++) + { + if (allowed[j] <= allowed[j - 1]) + throw std::runtime_error("SlimmableWavenet: allowed_channels must be sorted ascending"); + } + // Validate last entry matches full channel count + if (allowed.back() != _original_params[i].channels) + throw std::runtime_error( + "SlimmableWavenet: last allowed_channels entry must equal the full channel count for that array"); + } + } + if (!any_slimmable) + throw std::runtime_error("SlimmableWavenet: at least one layer array must have allowed_channels"); + + // Build with full channel counts as default (ratio=1.0) + std::vector full_channels(_original_params.size()); + for (size_t i = 0; i < _original_params.size(); i++) + full_channels[i] = _original_params[i].channels; + _rebuild_model(full_channels); +} + +void SlimmableWavenet::_rebuild_model(const std::vector& target_channels) +{ + if (target_channels == _current_channels && _active_model) + return; + + std::vector weights; + std::vector modified_params; + const std::vector* params_ptr; + + if (is_full_size(_original_params, target_channels)) + { + weights = _full_weights; + params_ptr = &_original_params; + } + else + { + weights = extract_slimmed_weights(_original_params, _full_weights, target_channels); + modified_params = modify_params_for_channels(_original_params, target_channels); + params_ptr = &modified_params; + } + + // Rebuild condition_dsp if present (WaveNet takes ownership each time) + std::unique_ptr condition_dsp; + if (!_condition_dsp_json.is_null()) + condition_dsp = get_dsp(_condition_dsp_json); + + double sampleRate = _current_sample_rate > 0 ? _current_sample_rate : GetExpectedSampleRate(); + _active_model = std::make_unique(_in_channels, *params_ptr, _head_scale, _with_head, + std::move(weights), std::move(condition_dsp), sampleRate); + _current_channels = target_channels; + + if (_current_buffer_size > 0) + _active_model->Reset(_current_sample_rate, _current_buffer_size); +} + +void SlimmableWavenet::process(NAM_SAMPLE** input, NAM_SAMPLE** output, const int num_frames) +{ + if (_active_model) + _active_model->process(input, output, num_frames); +} + +void SlimmableWavenet::prewarm() +{ + if (_active_model) + _active_model->prewarm(); +} + +void SlimmableWavenet::Reset(const double sampleRate, const int maxBufferSize) +{ + _current_sample_rate = sampleRate; + _current_buffer_size = maxBufferSize; + if (_active_model) + _active_model->Reset(sampleRate, maxBufferSize); +} + +void SlimmableWavenet::SetSlimmableSize(const double val) +{ + const size_t num_arrays = _original_params.size(); + std::vector target(num_arrays); + + for (size_t i = 0; i < num_arrays; i++) + { + const auto& allowed = _per_array_allowed_channels[i]; + if (allowed.empty()) + target[i] = _original_params[i].channels; // Non-slimmable: keep full + else + target[i] = ratio_to_channels(val, allowed); + } + + _rebuild_model(target); +} + +// ============================================================================ +// Config / factory / registration +// ============================================================================ + +std::unique_ptr SlimmableWavenetConfig::create(std::vector weights, double sampleRate) +{ + // Parse the WaveNet model config into typed params + nlohmann::json model_json = raw_config["model"]; + auto wc = wavenet::parse_config_json(model_json, sampleRate); + + // Extract per-array allowed_channels from slimmable config fields + const auto& layers_json = model_json["layers"]; + std::vector> per_array_allowed; + for (size_t i = 0; i < layers_json.size(); i++) + { + const auto& lc = layers_json[i]; + std::vector allowed; + if (lc.find("slimmable") != lc.end() && lc["slimmable"].is_object()) + { + const auto& slim_cfg = lc["slimmable"]; + const std::string method = slim_cfg.value("method", ""); + if (method != "slice_channels_uniform") + throw std::runtime_error("SlimmableWavenet: unsupported slimmable method '" + method + "'"); + if (slim_cfg.find("kwargs") != slim_cfg.end() && slim_cfg["kwargs"].find("allowed_channels") != slim_cfg["kwargs"].end()) + { + for (const auto& ch : slim_cfg["kwargs"]["allowed_channels"]) + allowed.push_back(ch.get()); + } + } + per_array_allowed.push_back(std::move(allowed)); + } + + // Extract condition_dsp JSON for future rebuilds + nlohmann::json condition_dsp_json = nullptr; + if (model_json.find("condition_dsp") != model_json.end() && !model_json["condition_dsp"].is_null()) + condition_dsp_json = model_json["condition_dsp"]; + + return std::make_unique(std::move(wc.layer_array_params), std::move(per_array_allowed), + wc.in_channels, wc.head_scale, wc.with_head, + std::move(condition_dsp_json), std::move(weights), sampleRate); +} + +std::unique_ptr create_config(const nlohmann::json& config, double sampleRate) +{ + auto sc = std::make_unique(); + sc->raw_config = config; + sc->sample_rate = sampleRate; + return sc; +} + +// Auto-register with the config parser registry +namespace +{ +static ConfigParserHelper _register_SlimmableWavenet("SlimmableWavenet", nam::slimmable_wavenet::create_config); +} + +} // namespace slimmable_wavenet +} // namespace nam diff --git a/NAM/slimmable_wavenet.h b/NAM/slimmable_wavenet.h new file mode 100644 index 0000000..5fb28f4 --- /dev/null +++ b/NAM/slimmable_wavenet.h @@ -0,0 +1,76 @@ +#pragma once + +#include +#include + +#include "dsp.h" +#include "json.hpp" +#include "model_config.h" +#include "slimmable.h" +#include "wavenet.h" + +namespace nam +{ +namespace slimmable_wavenet +{ + +/// \brief A WaveNet model that supports per-layer-array dynamic channel reduction +/// +/// Stores the full WaveNet LayerArrayParams and weights. Each layer array has its +/// own allowed_channels list (from the "slimmable" config field). On SetSlimmableSize(), +/// maps the ratio to a channel count per array, extracts a weight subset, builds +/// modified LayerArrayParams, and reconstructs the WaveNet. +class SlimmableWavenet : public DSP, public SlimmableModel +{ +public: + /// \param original_params Full-size LayerArrayParams from parse_config_json + /// \param per_array_allowed_channels Per-array sorted allowed channel counts (empty = non-slimmable) + /// \param in_channels WaveNet input channels + /// \param head_scale WaveNet head scale + /// \param with_head WaveNet head flag + /// \param condition_dsp_json JSON for rebuilding condition_dsp (nullptr if none) + /// \param full_weights Full weight vector for the max-channel model + /// \param expected_sample_rate Expected sample rate + SlimmableWavenet(std::vector original_params, + std::vector> per_array_allowed_channels, int in_channels, float head_scale, + bool with_head, nlohmann::json condition_dsp_json, std::vector full_weights, + double expected_sample_rate); + + void process(NAM_SAMPLE** input, NAM_SAMPLE** output, const int num_frames) override; + void prewarm() override; + void Reset(const double sampleRate, const int maxBufferSize) override; + void SetSlimmableSize(const double val) override; + +protected: + int PrewarmSamples() override { return 0; } + +private: + std::vector _original_params; + std::vector> _per_array_allowed_channels; + int _in_channels; + float _head_scale; + bool _with_head; + nlohmann::json _condition_dsp_json; + std::vector _full_weights; + std::unique_ptr _active_model; + std::vector _current_channels; + int _current_buffer_size = 0; + double _current_sample_rate = 0.0; + + void _rebuild_model(const std::vector& target_channels); +}; + +// Config / registration + +struct SlimmableWavenetConfig : public ModelConfig +{ + nlohmann::json raw_config; + double sample_rate; + + std::unique_ptr create(std::vector weights, double sampleRate) override; +}; + +std::unique_ptr create_config(const nlohmann::json& config, double sampleRate); + +} // namespace slimmable_wavenet +} // namespace nam diff --git a/example_models/slimmable_wavenet.nam b/example_models/slimmable_wavenet.nam new file mode 100644 index 0000000..19eafd1 --- /dev/null +++ b/example_models/slimmable_wavenet.nam @@ -0,0 +1,501 @@ +{ + "version": "0.7.0", + "metadata": {}, + "architecture": "SlimmableWavenet", + "config": { + "model": { + "layers": [ + { + "input_size": 1, + "condition_size": 1, + "head_size": 1, + "channels": 3, + "kernel_size": 3, + "dilations": [ + 1, + 2, + 4, + 8, + 16, + 32, + 64, + 128, + 256, + 512 + ], + "activation": "ReLU", + "gated": false, + "head_bias": false, + "slimmable": { + "method": "slice_channels_uniform", + "kwargs": { + "allowed_channels": [1, 2, 3] + } + } + } + ], + "head": null, + "head_scale": 0.02 + } + }, + "weights": [ + 0.2788535969157675, + -0.9499784895546661, + -0.4499413632617615, + -0.5535785237023545, + 0.4729424283280248, + 0.3533989748458226, + 0.7843591354096908, + -0.8261223347411677, + -0.15615636062945915, + -0.9404055611238593, + -0.5627240503927933, + 0.010710576206724776, + -0.9469280606322728, + -0.602324698626703, + 0.2997688755590464, + 0.08988296120643335, + -0.5591187559186066, + 0.17853136775181744, + 0.6188609133556533, + -0.987002480643878, + 0.6116385036656158, + 0.3962787899764537, + -0.31949896696401625, + -0.6890410003764369, + 0.9144261444135624, + -0.32681090977474647, + -0.8145083132397042, + -0.806567246333072, + 0.6949887326949196, + 0.20745206273378214, + 0.6142565465487604, + 0.45946357338763577, + 0.07245618290940148, + 0.9462315279587412, + -0.24293124558329304, + 0.104081262546454, + 0.6588093285059897, + 0.2370395047284921, + 0.7234138006215545, + 0.15470429051352408, + 0.40914367242984695, + -0.9083512326886756, + -0.5442034486969063, + -0.42122407279578566, + -0.840416046152745, + -0.5344182272779396, + -0.7979971411805418, + -0.44405279377981577, + 0.27136888852880037, + -0.2703356420598315, + -0.2596380657662347, + -0.5809859384570246, + -0.4660443559017733, + 0.873309175424988, + 0.2960707704931871, + 0.21826201133397638, + -0.657722703603806, + 0.45825359590069836, + -0.6731950124761432, + -0.24108911648470444, + 0.9790467012731905, + 0.2799995197081857, + 0.11389948754929247, + 0.3692285019797492, + 0.6857038403796192, + 0.5519998230924896, + -0.5419038560717913, + -0.9357995121919245, + -0.36909390388183616, + -0.46451824804859454, + -0.5780343128273471, + 0.8858194286701089, + 0.7527352529453377, + -0.37064423840304417, + 0.3108773305897601, + -0.20873619787867148, + 0.829095179481087, + -0.0822962948252024, + -0.4702396670038951, + -0.5067449846120331, + 0.12273626832630158, + -0.47451678295412947, + 0.16917198044708104, + 0.795645767204954, + -0.20119898971920547, + -0.5613584816854333, + 0.9950752129902205, + 0.01905258735292903, + -0.8181811756524122, + -0.9057672491505309, + -0.7807017392986817, + 0.2548920834061801, + 0.5841587287259282, + -0.15568006640063192, + -0.8729445876960857, + -0.23676142698692648, + 0.9922427604801936, + 0.058228690198274036, + 0.9421567552272363, + 0.7215594044689961, + -0.9770379561143607, + 0.4414436387203893, + 0.36342073805314956, + 0.07394066081759032, + -0.46634962009491443, + 0.2819235971596161, + -0.7768956528082471, + -0.13046949866179003, + -0.09255258734158711, + 0.9076318550421603, + 0.7517058807563881, + -0.4732218984978185, + 0.0011722261005966406, + -0.6426962389397373, + 0.825255678689641, + 0.7410371396735338, + -0.4031104171027342, + 0.2778989897320103, + 0.21794042287634463, + -0.6943214629007304, + 0.5250216001503025, + 0.07875806023925147, + 0.5572529572611165, + 0.06070734439035497, + -0.998856207744113, + -0.3516878859906538, + -0.9610465152283354, + 0.8581972325292342, + 0.7574437556463685, + 0.6633310587223589, + -0.38497174919467714, + -0.8841496670116249, + 0.7560191984080811, + 0.8938988905959881, + -0.8286930958642424, + -0.02801907336677245, + -0.8615749630632328, + 0.5212043305144631, + 0.5316688586139755, + -0.7432170710004744, + -0.04943524380253739, + 0.0996071869898878, + -0.4698867421198818, + 0.7448660821705149, + -0.15372411959822618, + -0.5764035891158359, + 0.07859217755891668, + 0.45986213817995236, + -0.5976978732206082, + -0.3765674173982101, + 0.9902987133217893, + 0.299756115278907, + -0.12379983217099189, + 0.035151682071181245, + -0.7579916082634686, + -0.5506053259368853, + -0.32382887570508934, + 0.17661743691446663, + -0.539770534806846, + -0.559565231096881, + -0.8580138279819349, + 0.2622059145401978, + -0.5421164323776912, + 0.8108400260122559, + 0.719270800507493, + -0.8582853002226931, + -0.5239907312620096, + 0.33795555659256116, + -0.5715263852591228, + -0.73537630254995, + 0.871028481161342, + 0.14208618665056894, + -0.05465794737641172, + 0.5692388485815068, + 0.6149939955332868, + -0.6191801712762446, + -0.8061383715423533, + -0.13789763518724496, + -0.15284275396015845, + -0.06595066392665005, + 0.4581516989197012, + 0.34672909458660306, + 0.9683304227319323, + -0.8031642576960822, + -0.19475743579546245, + -0.3213947892100737, + 0.7233450727055821, + -0.5026873321594287, + -0.619582183118377, + -0.1027729043337362, + -0.15623672033119163, + -0.44290971066611906, + -0.500387104235799, + 0.8465311985520256, + -0.11373850989308609, + 0.7226982095236612, + 0.10065062489969612, + -0.8988233409502375, + 0.9985649368254532, + 0.6720551701599038, + 0.9379925145695025, + 0.8527339660162552, + 0.6973914688286109, + -0.667377778792172, + -0.02871774909856306, + -0.5725054016016367, + -0.19791941490109477, + -0.8827292000556421, + -0.2420537620461678, + 0.9706176875594519, + -0.4695938836556961, + 0.5681412038971387, + -0.08998326532171341, + -0.15398502801967417, + 0.9146352817193464, + 0.9908453789854277, + 0.11153664681123643, + 0.436816550592652, + -0.6904063494518717, + -0.4065843490108716, + 0.9374187299383177, + 0.15836058163251243, + 0.08439040274854848, + 0.4959511207581282, + -0.8856694541850338, + 0.16835518891794243, + 0.005700765839027122, + 0.7054397840965707, + -0.6851345441210335, + 0.9215578065489007, + -0.8397770695188262, + -0.6283500780385536, + 0.19007021290005532, + 0.3504251072081803, + -0.5295922099981376, + -0.7602267721057516, + 0.780574628258875, + -0.5075693044227503, + 0.1890383070668824, + 0.23876302066420618, + -0.16155016932825506, + 0.16734457858244944, + 0.04556543106391775, + 0.8694125154728545, + -0.5914816011529271, + 0.4323836015788296, + -0.522628094768308, + -0.208428306417491, + 0.34338044591994255, + -0.40000584040247555, + -0.36764560745629193, + 0.5037289848288042, + -0.8549137710136854, + -0.08342895476282775, + 0.9969088817088847, + 0.9921928957101889, + -0.853478557800734, + -0.5736913754659192, + -0.4695991704991973, + 0.8665187559874181, + 0.7617283473728791, + 0.7585404849690855, + -0.2609458225222321, + -0.6845063352855361, + 0.667489909279614, + 0.4070798501747419, + 0.22335553145190024, + 0.9744661272630086, + 0.3079526354214652, + -0.9843537856956841, + 0.6342082702309233, + -0.4012424956000442, + 0.32677742993215464, + 0.8778600078542078, + -0.7314177712132646, + -0.7691426591617956, + -0.7859280445811647, + 0.10644728176963181, + -0.45530357537036736, + 0.20965965406044784, + 0.43522437427759586, + -0.5928053753450941, + 0.26847591777015944, + -0.4720321967391812, + -0.02293629570124689, + 0.8106729821586465, + 0.6922074265897109, + -0.8154030645745332, + -0.15284845487254728, + -0.44663955205549666, + -0.9929086218244354, + 0.5422384460392542, + 0.27422675460275925, + -0.4760894751313036, + 0.48246181669586163, + 0.10336084225278253, + -0.14462616203864131, + -0.9806606007833201, + -0.8495122798524659, + 0.7662127866002859, + 0.8078571431197863, + 0.09118057841104465, + 0.6691900397720334, + 0.16501913297958803, + -0.7038124288650347, + -0.7451089614357225, + -0.38348330013973264, + 0.79796297748518, + 0.5922446097760834, + 0.7214051640018055, + 0.7978492730529492, + -0.5798469233204919, + -0.5009405215541511, + -0.7944127566564287, + 0.5602324837428854, + 0.7682694029020178, + -0.1872452203357664, + 0.2413230203014256, + -0.6908933233355907, + 0.8597620313873489, + 0.7292113924399279, + 0.9524120658619257, + 0.6215434398807937, + 0.7628324093266488, + -0.9504272762036226, + 0.4731289435101642, + -0.33562906410714266, + 0.8616317720966511, + 0.6044702778742779, + 0.7281280567505588, + 0.621498633148778, + -0.46638858081105594, + 0.5747490182709423, + -0.7838087471940858, + 0.7443335658121795, + 0.7171865026755633, + -0.5551325649086711, + 0.6331732111938579, + -0.07939353064211585, + -0.38961826532279886, + 0.5906909983057236, + -0.5448090251844593, + -0.952671130597097, + -0.6137404233445827, + -0.34347609760458697, + 0.7287058840605727, + 0.9337782080967223, + -0.4417500145562572, + 0.2829634772152554, + -0.20064323127987826, + 0.9622993743965202, + 0.07243146495744379, + 0.8784742806494314, + -0.7693164962971448, + 0.9408012220444559, + -0.6428643676550727, + 0.9250686315231109, + -0.46906727495406275, + -0.7831949055705778, + -0.1308724828707113, + 0.4570901213054086, + -0.37264537161001754, + 0.21241770661228654, + 0.022846119338956195, + -0.22960913331054567, + 0.15317608699319907, + -0.4905549877228361, + 0.4175705676683412, + -0.9966174435627411, + 0.8511503309981654, + 0.0769039941855838, + 0.438859998289691, + 0.48390015567895306, + 0.34125700886599897, + -0.27155705643747163, + -0.8600523777473796, + 0.32847536982254466, + -0.3395999279148072, + -0.37216870988328066, + 0.696030559012671, + 0.43950852602790036, + -0.39935546357747165, + -0.3814306755826935, + -0.1832141827615663, + -0.19519922588455074, + -0.40868959494810597, + -0.7454244018816936, + -0.1591073324541834, + 0.880727341460366, + 0.35463589054546585, + 0.8056110914651653, + 0.23102983190276105, + -0.39810025086886935, + 0.09587442627139642, + -0.999188120605425, + -0.4261725662621456, + -0.1402237000203308, + 0.15996956239136395, + 0.30941124740614323, + -0.07002361950597158, + -0.11568040139038516, + -0.5725971980217994, + -0.053627628181347475, + 0.8023616516565084, + 0.5920495202535605, + -0.6606172076038905, + -0.8304089265497565, + 0.030904019830432894, + 0.26588171153159146, + -0.3296234891803982, + 0.6368469290733285, + 0.5022762750814644, + 0.3455913411143341, + -0.5507186680054235, + -0.6017401345468467, + -0.9511492245463473, + -0.5103149118432997, + -0.04972731156238974, + 0.6994753892494638, + -0.8543435416308618, + -0.1711179780045613, + 0.2595307614754274, + -0.6111295265205814, + 0.39270850098100984, + -0.011245661979126131, + -0.5120311208431223, + 0.31211602222356816, + -0.9889103637239365, + 0.5019289532369458, + 0.5400923771480501, + -0.7868254068729221, + -0.14970761211453176, + -0.6482266365869367, + 0.9159320845590795, + 0.03591550088748163, + -0.8995632297187182, + -0.5016034406800567, + 0.6966726947033195, + -0.08707634905965489, + 0.602833203444529, + 0.3351554651727062, + 0.975784906132896, + 0.1909046369388394, + 0.9000792168863119, + 0.782851851620874, + 0.2253046455235257, + 0.4385479225519342, + 0.00955632964880393, + 0.6611383394428301, + 0.09574390122165677, + 0.7944162064665243, + 0.4873108843191698, + -0.05065112635389335, + -0.4816169030699613, + -0.5055205249806809, + 0.27532287355231255, + 0.5316273685943309 + ], + "sample_rate": 48000 +} \ No newline at end of file diff --git a/tools/run_tests.cpp b/tools/run_tests.cpp index d01f893..82697d7 100644 --- a/tools/run_tests.cpp +++ b/tools/run_tests.cpp @@ -30,6 +30,7 @@ #include "test/test_extensible.cpp" #include "test/test_container.cpp" #include "test/test_render_slim.cpp" +#include "test/test_slimmable_wavenet.cpp" int main() { @@ -288,6 +289,18 @@ int main() test_render_slim::test_slim_boundary_values(); test_render_slim::test_slim_applied_before_processing(); + // SlimmableWavenet tests + test_slimmable_wavenet::test_loads_from_file(); + test_slimmable_wavenet::test_implements_slimmable(); + test_slimmable_wavenet::test_processes_audio(); + test_slimmable_wavenet::test_slimming_changes_output(); + test_slimmable_wavenet::test_boundary_values(); + test_slimmable_wavenet::test_default_is_max_size(); + test_slimmable_wavenet::test_ratio_mapping(); + test_slimmable_wavenet::test_from_json(); + test_slimmable_wavenet::test_no_slimmable_layers_throws(); + test_slimmable_wavenet::test_unsupported_method_throws(); + std::cout << "Success!" << std::endl; #ifdef ADDASSERT std::cerr << "===============================================================" << std::endl; diff --git a/tools/test/test_slimmable_wavenet.cpp b/tools/test/test_slimmable_wavenet.cpp new file mode 100644 index 0000000..ffbb0d3 --- /dev/null +++ b/tools/test/test_slimmable_wavenet.cpp @@ -0,0 +1,328 @@ +#include +#include +#include +#include +#include +#include +#include + +#include "json.hpp" + +#include "NAM/dsp.h" +#include "NAM/get_dsp.h" +#include "NAM/slimmable.h" + +namespace test_slimmable_wavenet +{ + +// Helper: load a .nam file as JSON +nlohmann::json load_nam_json(const std::string& path) +{ + std::ifstream f(path); + if (!f.is_open()) + throw std::runtime_error("Cannot open " + path); + nlohmann::json j; + f >> j; + return j; +} + +// Helper: process audio and verify finite output +void process_and_verify(nam::DSP* dsp, int num_buffers, int buffer_size) +{ + const double sample_rate = dsp->GetExpectedSampleRate() > 0 ? dsp->GetExpectedSampleRate() : 48000.0; + dsp->Reset(sample_rate, buffer_size); + + std::vector input(buffer_size); + std::vector output(buffer_size); + NAM_SAMPLE* in_ptr = input.data(); + NAM_SAMPLE* out_ptr = output.data(); + + for (int buf = 0; buf < num_buffers; buf++) + { + for (int i = 0; i < buffer_size; i++) + input[i] = (NAM_SAMPLE)(0.1 * ((buf * buffer_size + i) % 100) / 100.0); + + dsp->process(&in_ptr, &out_ptr, buffer_size); + + for (int i = 0; i < buffer_size; i++) + assert(std::isfinite(output[i])); + } +} + +// ===================================================================== +// Tests +// ===================================================================== + +void test_loads_from_file() +{ + std::cout << " test_slimmable_wavenet_loads_from_file" << std::endl; + + std::filesystem::path path("example_models/slimmable_wavenet.nam"); + auto dsp = nam::get_dsp(path); + assert(dsp != nullptr); +} + +void test_implements_slimmable() +{ + std::cout << " test_slimmable_wavenet_implements_slimmable" << std::endl; + + auto dsp = nam::get_dsp(std::filesystem::path("example_models/slimmable_wavenet.nam")); + assert(dsp != nullptr); + + auto* slimmable = dynamic_cast(dsp.get()); + assert(slimmable != nullptr); +} + +void test_processes_audio() +{ + std::cout << " test_slimmable_wavenet_processes_audio" << std::endl; + + auto dsp = nam::get_dsp(std::filesystem::path("example_models/slimmable_wavenet.nam")); + assert(dsp != nullptr); + process_and_verify(dsp.get(), 3, 64); +} + +void test_slimming_changes_output() +{ + std::cout << " test_slimmable_wavenet_slimming_changes_output" << std::endl; + + auto dsp = nam::get_dsp(std::filesystem::path("example_models/slimmable_wavenet.nam")); + assert(dsp != nullptr); + + const double sample_rate = dsp->GetExpectedSampleRate() > 0 ? dsp->GetExpectedSampleRate() : 48000.0; + const int buffer_size = 64; + dsp->Reset(sample_rate, buffer_size); + + std::vector input(buffer_size, 0.1); + std::vector out_small(buffer_size); + std::vector out_large(buffer_size); + NAM_SAMPLE* in_ptr = input.data(); + NAM_SAMPLE* out_ptr; + + auto* slimmable = dynamic_cast(dsp.get()); + assert(slimmable != nullptr); + + // Process at minimum size (ratio 0.0 -> allowed_channels[0] = 1) + slimmable->SetSlimmableSize(0.0); + dsp->Reset(sample_rate, buffer_size); + out_ptr = out_small.data(); + dsp->process(&in_ptr, &out_ptr, buffer_size); + + // Process at maximum size (ratio 1.0 -> allowed_channels[2] = 3) + slimmable->SetSlimmableSize(1.0); + dsp->Reset(sample_rate, buffer_size); + out_ptr = out_large.data(); + dsp->process(&in_ptr, &out_ptr, buffer_size); + + // Outputs should differ since different channel counts are used + bool any_different = false; + for (int i = 0; i < buffer_size; i++) + { + if (std::abs(out_small[i] - out_large[i]) > 1e-6) + { + any_different = true; + break; + } + } + assert(any_different); +} + +void test_boundary_values() +{ + std::cout << " test_slimmable_wavenet_boundary_values" << std::endl; + + auto dsp = nam::get_dsp(std::filesystem::path("example_models/slimmable_wavenet.nam")); + assert(dsp != nullptr); + + const double sample_rate = dsp->GetExpectedSampleRate() > 0 ? dsp->GetExpectedSampleRate() : 48000.0; + const int buffer_size = 64; + dsp->Reset(sample_rate, buffer_size); + + std::vector input(buffer_size, 0.05); + std::vector output(buffer_size); + NAM_SAMPLE* in_ptr = input.data(); + NAM_SAMPLE* out_ptr = output.data(); + + auto* slimmable = dynamic_cast(dsp.get()); + assert(slimmable != nullptr); + + // Test at various ratio values — with 3 allowed channels [1,2,3]: + // ratio_to_channels: idx = min(floor(ratio * 3), 2) + // 0.0 -> idx=0 -> 1ch, 0.33 -> idx=0 -> 1ch, 0.34 -> idx=1 -> 2ch, + // 0.5 -> idx=1 -> 2ch, 0.66 -> idx=1 -> 2ch, 0.67 -> idx=2 -> 3ch, 1.0 -> idx=2 -> 3ch + double values[] = {0.0, 0.25, 0.33, 0.34, 0.5, 0.66, 0.67, 0.75, 1.0}; + for (double val : values) + { + slimmable->SetSlimmableSize(val); + dsp->Reset(sample_rate, buffer_size); + dsp->process(&in_ptr, &out_ptr, buffer_size); + for (int i = 0; i < buffer_size; i++) + assert(std::isfinite(output[i])); + } +} + +void test_default_is_max_size() +{ + std::cout << " test_slimmable_wavenet_default_is_max_size" << std::endl; + + auto dsp = nam::get_dsp(std::filesystem::path("example_models/slimmable_wavenet.nam")); + assert(dsp != nullptr); + + const double sample_rate = dsp->GetExpectedSampleRate() > 0 ? dsp->GetExpectedSampleRate() : 48000.0; + const int buffer_size = 64; + dsp->Reset(sample_rate, buffer_size); + + std::vector input(buffer_size, 0.1); + std::vector out_default(buffer_size); + std::vector out_max(buffer_size); + NAM_SAMPLE* in_ptr = input.data(); + NAM_SAMPLE* out_ptr; + + // Process with default (should be max size = 3 channels) + out_ptr = out_default.data(); + dsp->process(&in_ptr, &out_ptr, buffer_size); + + // Explicitly set to max + auto* slimmable = dynamic_cast(dsp.get()); + assert(slimmable != nullptr); + slimmable->SetSlimmableSize(1.0); + dsp->Reset(sample_rate, buffer_size); + out_ptr = out_max.data(); + dsp->process(&in_ptr, &out_ptr, buffer_size); + + // Both should produce the same output + for (int i = 0; i < buffer_size; i++) + assert(std::abs(out_default[i] - out_max[i]) < 1e-6); +} + +void test_ratio_mapping() +{ + std::cout << " test_slimmable_wavenet_ratio_mapping" << std::endl; + + // With allowed_channels [1, 2, 3] (len=3): + // idx = min(floor(ratio * 3), 2) + // ratio < 1/3 -> idx=0 -> 1ch + // 1/3 <= ratio < 2/3 -> idx=1 -> 2ch + // 2/3 <= ratio -> idx=2 -> 3ch + + auto dsp = nam::get_dsp(std::filesystem::path("example_models/slimmable_wavenet.nam")); + assert(dsp != nullptr); + + const double sample_rate = dsp->GetExpectedSampleRate() > 0 ? dsp->GetExpectedSampleRate() : 48000.0; + const int buffer_size = 64; + + std::vector input(buffer_size, 0.1); + std::vector out_a(buffer_size); + std::vector out_b(buffer_size); + NAM_SAMPLE* in_ptr = input.data(); + NAM_SAMPLE* out_ptr; + + auto* slimmable = dynamic_cast(dsp.get()); + assert(slimmable != nullptr); + + // 0.32 -> floor(0.32*3)=0 -> 1ch + slimmable->SetSlimmableSize(0.32); + dsp->Reset(sample_rate, buffer_size); + out_ptr = out_a.data(); + dsp->process(&in_ptr, &out_ptr, buffer_size); + + // 0.34 -> floor(0.34*3)=1 -> 2ch (different from 1ch) + slimmable->SetSlimmableSize(0.34); + dsp->Reset(sample_rate, buffer_size); + out_ptr = out_b.data(); + dsp->process(&in_ptr, &out_ptr, buffer_size); + + // These should differ (different channel counts: 1 vs 2) + bool any_different = false; + for (int i = 0; i < buffer_size; i++) + { + if (std::abs(out_a[i] - out_b[i]) > 1e-6) + { + any_different = true; + break; + } + } + assert(any_different); +} + +void test_from_json() +{ + std::cout << " test_slimmable_wavenet_from_json" << std::endl; + + // Build a SlimmableWavenet JSON from an existing WaveNet + auto wavenet_json = load_nam_json("example_models/wavenet_3ch.nam"); + + nlohmann::json j; + j["version"] = "0.7.0"; + j["architecture"] = "SlimmableWavenet"; + + // Copy the WaveNet config and add slimmable field to the first layer + j["config"]["model"] = wavenet_json["config"]; + j["config"]["model"]["layers"][0]["slimmable"] = { + {"method", "slice_channels_uniform"}, + {"kwargs", {{"allowed_channels", {2, 3}}}} + }; + j["weights"] = wavenet_json["weights"]; + j["sample_rate"] = wavenet_json["sample_rate"]; + + auto dsp = nam::get_dsp(j); + assert(dsp != nullptr); + process_and_verify(dsp.get(), 3, 64); +} + +void test_no_slimmable_layers_throws() +{ + std::cout << " test_slimmable_wavenet_no_slimmable_layers_throws" << std::endl; + + auto wavenet_json = load_nam_json("example_models/wavenet_3ch.nam"); + + nlohmann::json j; + j["version"] = "0.7.0"; + j["architecture"] = "SlimmableWavenet"; + j["config"]["model"] = wavenet_json["config"]; + // No slimmable field on any layer -> all allowed_channels empty -> should throw + j["weights"] = wavenet_json["weights"]; + j["sample_rate"] = wavenet_json["sample_rate"]; + + bool threw = false; + try + { + auto dsp = nam::get_dsp(j); + } + catch (const std::runtime_error&) + { + threw = true; + } + assert(threw); +} + +void test_unsupported_method_throws() +{ + std::cout << " test_slimmable_wavenet_unsupported_method_throws" << std::endl; + + auto wavenet_json = load_nam_json("example_models/wavenet_3ch.nam"); + + nlohmann::json j; + j["version"] = "0.7.0"; + j["architecture"] = "SlimmableWavenet"; + j["config"]["model"] = wavenet_json["config"]; + j["config"]["model"]["layers"][0]["slimmable"] = { + {"method", "some_future_method"}, + {"kwargs", {{"allowed_channels", {2, 3}}}} + }; + j["weights"] = wavenet_json["weights"]; + j["sample_rate"] = wavenet_json["sample_rate"]; + + bool threw = false; + try + { + auto dsp = nam::get_dsp(j); + } + catch (const std::runtime_error&) + { + threw = true; + } + assert(threw); +} + +} // namespace test_slimmable_wavenet From 9161a296377904019d476fab851b55190b6a46c0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jo=C3=A3o=20Felipe=20Santos?= Date: Thu, 12 Mar 2026 09:54:21 -0700 Subject: [PATCH 2/3] Added test comparing small model to slimmed model --- tools/run_tests.cpp | 1 + tools/test/test_slimmable_wavenet.cpp | 195 ++++++++++++++++++++++++++ 2 files changed, 196 insertions(+) diff --git a/tools/run_tests.cpp b/tools/run_tests.cpp index 82697d7..8bfd954 100644 --- a/tools/run_tests.cpp +++ b/tools/run_tests.cpp @@ -300,6 +300,7 @@ int main() test_slimmable_wavenet::test_from_json(); test_slimmable_wavenet::test_no_slimmable_layers_throws(); test_slimmable_wavenet::test_unsupported_method_throws(); + test_slimmable_wavenet::test_slimmed_matches_small_model(); std::cout << "Success!" << std::endl; #ifdef ADDASSERT diff --git a/tools/test/test_slimmable_wavenet.cpp b/tools/test/test_slimmable_wavenet.cpp index ffbb0d3..dded0b4 100644 --- a/tools/test/test_slimmable_wavenet.cpp +++ b/tools/test/test_slimmable_wavenet.cpp @@ -325,4 +325,199 @@ void test_unsupported_method_throws() assert(threw); } +void test_slimmed_matches_small_model() +{ + std::cout << " test_slimmable_wavenet_slimmed_matches_small_model" << std::endl; + + // Build a minimal WaveNet config: 1 layer array, 2 layers (dilations [1,2]), + // kernel_size=3, no gating, no layer1x1, no head1x1, no FiLM, Tanh activation. + // bottleneck = channels (required when layer1x1 is inactive). + const int small_ch = 2; + const int large_ch = 4; + const int kernel_size = 3; + const int input_size = 1; + const int condition_size = 1; + const int head_size = 1; + const int num_layers = 2; + + auto make_layer_config = [&](int channels) -> nlohmann::json { + nlohmann::json lc; + lc["input_size"] = input_size; + lc["condition_size"] = condition_size; + lc["head_size"] = head_size; + lc["channels"] = channels; + lc["kernel_size"] = kernel_size; + lc["dilations"] = {1, 2}; + lc["activation"] = "Tanh"; + lc["head_bias"] = false; + // Disable layer1x1 so bottleneck == channels (simplest config) + lc["layer1x1"] = {{"active", false}, {"groups", 1}}; + return lc; + }; + + // --- Generate deterministic weights for the small (2ch) model --- + // Weight layout for 1 array, no gating, no layer1x1, no head1x1, no FiLM: + // rechannel: Conv1x1(input_size -> ch, no bias) = input_size * ch + // per layer: + // conv: Conv1D(ch -> ch, K, bias) = ch * ch * K + ch + // input_mixin: Conv1x1(condition_size -> ch, no bias) = condition_size * ch + // head_rechannel: Conv1x1(ch -> head_size, no bias) = ch * head_size + // head_scale: 1 + auto count_weights = [&](int ch) { + int n = input_size * ch; // rechannel + for (int l = 0; l < num_layers; l++) + { + n += ch * ch * kernel_size + ch; // conv + n += condition_size * ch; // input_mixin + } + n += ch * head_size; // head_rechannel + n += 1; // head_scale + return n; + }; + + const int small_weight_count = count_weights(small_ch); + std::vector small_weights(small_weight_count); + // Fill with a deterministic pattern (small non-zero values) + for (int i = 0; i < small_weight_count; i++) + small_weights[i] = 0.01f * ((i % 17) - 8); // values in [-0.08, 0.08] + + // --- Embed small weights into large weight vector --- + // Walk both weight layouts in parallel: for each matrix, place small weights + // in the top-left corner and fill the rest with arbitrary filler. + std::vector large_weights; + auto small_it = small_weights.cbegin(); + + // Helper: embed Conv1x1(small_in, small_out) into Conv1x1(full_in, full_out) + auto embed_conv1x1 = [](std::vector::const_iterator& src, int small_in, int small_out, int full_in, + int full_out, bool bias, std::vector& dst) { + for (int i = 0; i < full_out; i++) + for (int j = 0; j < full_in; j++) + { + if (i < small_out && j < small_in) + dst.push_back(*(src++)); + else + dst.push_back(0.02f); + } + if (bias) + for (int i = 0; i < full_out; i++) + { + if (i < small_out) + dst.push_back(*(src++)); + else + dst.push_back(0.02f); + } + }; + + // Helper: embed Conv1D(small_in, small_out) into Conv1D(full_in, full_out) + auto embed_conv1d = [](std::vector::const_iterator& src, int small_in, int small_out, int full_in, + int full_out, int ks, std::vector& dst) { + for (int i = 0; i < full_out; i++) + for (int j = 0; j < full_in; j++) + for (int k = 0; k < ks; k++) + { + if (i < small_out && j < small_in) + dst.push_back(*(src++)); + else + dst.push_back(0.02f); + } + // bias + for (int i = 0; i < full_out; i++) + { + if (i < small_out) + dst.push_back(*(src++)); + else + dst.push_back(0.02f); + } + }; + + // rechannel: Conv1x1(input_size -> ch, no bias) + embed_conv1x1(small_it, input_size, small_ch, input_size, large_ch, false, large_weights); + // per layer + for (int l = 0; l < num_layers; l++) + { + // conv: Conv1D(ch -> ch, K, bias) + embed_conv1d(small_it, small_ch, small_ch, large_ch, large_ch, kernel_size, large_weights); + // input_mixin: Conv1x1(condition_size -> ch, no bias) + embed_conv1x1(small_it, condition_size, small_ch, condition_size, large_ch, false, large_weights); + } + // head_rechannel: Conv1x1(ch -> head_size, no bias) + embed_conv1x1(small_it, small_ch, head_size, large_ch, head_size, false, large_weights); + // head_scale + large_weights.push_back(*(small_it++)); + + assert(small_it == small_weights.cend()); + assert((int)large_weights.size() == count_weights(large_ch)); + + // --- Build the 2ch WaveNet (non-slimmable) --- + nlohmann::json small_json; + small_json["version"] = "0.7.0"; + small_json["architecture"] = "WaveNet"; + small_json["config"]["layers"] = nlohmann::json::array({make_layer_config(small_ch)}); + small_json["config"]["head_scale"] = 1.0; + small_json["weights"] = small_weights; + small_json["sample_rate"] = 48000; + + auto small_dsp = nam::get_dsp(small_json); + assert(small_dsp != nullptr); + + // --- Build the 4ch SlimmableWavenet --- + nlohmann::json large_json; + large_json["version"] = "0.7.0"; + large_json["architecture"] = "SlimmableWavenet"; + auto large_layer_config = make_layer_config(large_ch); + large_layer_config["slimmable"] = {{"method", "slice_channels_uniform"}, + {"kwargs", {{"allowed_channels", {small_ch, large_ch}}}}}; + large_json["config"]["model"]["layers"] = nlohmann::json::array({large_layer_config}); + large_json["config"]["model"]["head_scale"] = 1.0; + large_json["weights"] = large_weights; + large_json["sample_rate"] = 48000; + + auto large_dsp = nam::get_dsp(large_json); + assert(large_dsp != nullptr); + + // Slim the large model down to match the small model + auto* slimmable = dynamic_cast(large_dsp.get()); + assert(slimmable != nullptr); + // ratio 0.0 -> idx = floor(0.0 * 2) = 0 -> allowed_channels[0] = small_ch + slimmable->SetSlimmableSize(0.0); + + // --- Process audio through both and compare --- + const double sample_rate = 48000.0; + const int buffer_size = 64; + const int num_buffers = 5; // process enough buffers to exercise the dilated convolutions + + small_dsp->Reset(sample_rate, buffer_size); + large_dsp->Reset(sample_rate, buffer_size); + + for (int buf = 0; buf < num_buffers; buf++) + { + std::vector input(buffer_size); + for (int i = 0; i < buffer_size; i++) + input[i] = (NAM_SAMPLE)(0.1 * std::sin(0.1 * (buf * buffer_size + i))); + + std::vector out_small(buffer_size); + std::vector out_large(buffer_size); + NAM_SAMPLE* in_ptr = input.data(); + NAM_SAMPLE* out_ptr; + + out_ptr = out_small.data(); + small_dsp->process(&in_ptr, &out_ptr, buffer_size); + + out_ptr = out_large.data(); + large_dsp->process(&in_ptr, &out_ptr, buffer_size); + + for (int i = 0; i < buffer_size; i++) + { + assert(std::isfinite(out_small[i])); + assert(std::isfinite(out_large[i])); + if (std::abs(out_small[i] - out_large[i]) > 1e-6) + { + std::cerr << " MISMATCH at buffer " << buf << " sample " << i << ": small=" << out_small[i] + << " slimmed=" << out_large[i] << " diff=" << std::abs(out_small[i] - out_large[i]) << std::endl; + assert(false); + } + } + } +} + } // namespace test_slimmable_wavenet From d513ef18d6419bbcd18ba25271f9e3e7c29a096b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jo=C3=A3o=20Felipe=20Santos?= Date: Thu, 12 Mar 2026 10:15:17 -0700 Subject: [PATCH 3/3] Fixed tests to not print anything --- tools/test/test_render_slim.cpp | 9 ---- tools/test/test_slimmable_wavenet.cpp | 65 +++++++++------------------ 2 files changed, 21 insertions(+), 53 deletions(-) diff --git a/tools/test/test_render_slim.cpp b/tools/test/test_render_slim.cpp index 9b3c7ee..3f9513a 100644 --- a/tools/test/test_render_slim.cpp +++ b/tools/test/test_render_slim.cpp @@ -1,12 +1,7 @@ #include #include -#include #include -#include -#include #include -#include -#include #include #include "NAM/dsp.h" @@ -19,7 +14,6 @@ namespace test_render_slim // Test that --slim with a SlimmableContainer model changes the output void test_slim_changes_output() { - std::cout << " test_slim_changes_output" << std::endl; // Load the slimmable container model auto model = nam::get_dsp(std::filesystem::path("example_models/slimmable_container.nam")); @@ -65,7 +59,6 @@ void test_slim_changes_output() // Test that --slim rejects non-slimmable models void test_slim_rejects_non_slimmable() { - std::cout << " test_slim_rejects_non_slimmable" << std::endl; // Load a regular (non-container) model auto model = nam::get_dsp(std::filesystem::path("example_models/lstm.nam")); @@ -79,7 +72,6 @@ void test_slim_rejects_non_slimmable() // Test that --slim with boundary values produces finite output void test_slim_boundary_values() { - std::cout << " test_slim_boundary_values" << std::endl; auto model = nam::get_dsp(std::filesystem::path("example_models/slimmable_container.nam")); assert(model != nullptr); @@ -109,7 +101,6 @@ void test_slim_boundary_values() // Test that SetSlimmableSize is called before processing (simulates --slim flow) void test_slim_applied_before_processing() { - std::cout << " test_slim_applied_before_processing" << std::endl; auto model = nam::get_dsp(std::filesystem::path("example_models/slimmable_container.nam")); assert(model != nullptr); diff --git a/tools/test/test_slimmable_wavenet.cpp b/tools/test/test_slimmable_wavenet.cpp index dded0b4..44b9b5f 100644 --- a/tools/test/test_slimmable_wavenet.cpp +++ b/tools/test/test_slimmable_wavenet.cpp @@ -2,7 +2,6 @@ #include #include #include -#include #include #include @@ -55,7 +54,7 @@ void process_and_verify(nam::DSP* dsp, int num_buffers, int buffer_size) void test_loads_from_file() { - std::cout << " test_slimmable_wavenet_loads_from_file" << std::endl; + std::filesystem::path path("example_models/slimmable_wavenet.nam"); auto dsp = nam::get_dsp(path); @@ -64,7 +63,7 @@ void test_loads_from_file() void test_implements_slimmable() { - std::cout << " test_slimmable_wavenet_implements_slimmable" << std::endl; + auto dsp = nam::get_dsp(std::filesystem::path("example_models/slimmable_wavenet.nam")); assert(dsp != nullptr); @@ -75,7 +74,7 @@ void test_implements_slimmable() void test_processes_audio() { - std::cout << " test_slimmable_wavenet_processes_audio" << std::endl; + auto dsp = nam::get_dsp(std::filesystem::path("example_models/slimmable_wavenet.nam")); assert(dsp != nullptr); @@ -84,7 +83,7 @@ void test_processes_audio() void test_slimming_changes_output() { - std::cout << " test_slimmable_wavenet_slimming_changes_output" << std::endl; + auto dsp = nam::get_dsp(std::filesystem::path("example_models/slimmable_wavenet.nam")); assert(dsp != nullptr); @@ -129,7 +128,7 @@ void test_slimming_changes_output() void test_boundary_values() { - std::cout << " test_slimmable_wavenet_boundary_values" << std::endl; + auto dsp = nam::get_dsp(std::filesystem::path("example_models/slimmable_wavenet.nam")); assert(dsp != nullptr); @@ -163,7 +162,7 @@ void test_boundary_values() void test_default_is_max_size() { - std::cout << " test_slimmable_wavenet_default_is_max_size" << std::endl; + auto dsp = nam::get_dsp(std::filesystem::path("example_models/slimmable_wavenet.nam")); assert(dsp != nullptr); @@ -197,7 +196,7 @@ void test_default_is_max_size() void test_ratio_mapping() { - std::cout << " test_slimmable_wavenet_ratio_mapping" << std::endl; + // With allowed_channels [1, 2, 3] (len=3): // idx = min(floor(ratio * 3), 2) @@ -247,24 +246,9 @@ void test_ratio_mapping() void test_from_json() { - std::cout << " test_slimmable_wavenet_from_json" << std::endl; - - // Build a SlimmableWavenet JSON from an existing WaveNet - auto wavenet_json = load_nam_json("example_models/wavenet_3ch.nam"); - nlohmann::json j; - j["version"] = "0.7.0"; - j["architecture"] = "SlimmableWavenet"; - - // Copy the WaveNet config and add slimmable field to the first layer - j["config"]["model"] = wavenet_json["config"]; - j["config"]["model"]["layers"][0]["slimmable"] = { - {"method", "slice_channels_uniform"}, - {"kwargs", {{"allowed_channels", {2, 3}}}} - }; - j["weights"] = wavenet_json["weights"]; - j["sample_rate"] = wavenet_json["sample_rate"]; + auto j = load_nam_json("example_models/slimmable_wavenet.nam"); auto dsp = nam::get_dsp(j); assert(dsp != nullptr); process_and_verify(dsp.get(), 3, 64); @@ -272,9 +256,9 @@ void test_from_json() void test_no_slimmable_layers_throws() { - std::cout << " test_slimmable_wavenet_no_slimmable_layers_throws" << std::endl; - auto wavenet_json = load_nam_json("example_models/wavenet_3ch.nam"); + + auto wavenet_json = load_nam_json("example_models/wavenet.nam"); nlohmann::json j; j["version"] = "0.7.0"; @@ -298,18 +282,16 @@ void test_no_slimmable_layers_throws() void test_unsupported_method_throws() { - std::cout << " test_slimmable_wavenet_unsupported_method_throws" << std::endl; - auto wavenet_json = load_nam_json("example_models/wavenet_3ch.nam"); + + auto wavenet_json = load_nam_json("example_models/wavenet.nam"); nlohmann::json j; j["version"] = "0.7.0"; j["architecture"] = "SlimmableWavenet"; j["config"]["model"] = wavenet_json["config"]; j["config"]["model"]["layers"][0]["slimmable"] = { - {"method", "some_future_method"}, - {"kwargs", {{"allowed_channels", {2, 3}}}} - }; + {"method", "some_future_method"}, {"kwargs", {{"allowed_channels", {2, 3}}}}}; j["weights"] = wavenet_json["weights"]; j["sample_rate"] = wavenet_json["sample_rate"]; @@ -327,7 +309,7 @@ void test_unsupported_method_throws() void test_slimmed_matches_small_model() { - std::cout << " test_slimmable_wavenet_slimmed_matches_small_model" << std::endl; + // Build a minimal WaveNet config: 1 layer array, 2 layers (dilations [1,2]), // kernel_size=3, no gating, no layer1x1, no head1x1, no FiLM, Tanh activation. @@ -368,10 +350,10 @@ void test_slimmed_matches_small_model() for (int l = 0; l < num_layers; l++) { n += ch * ch * kernel_size + ch; // conv - n += condition_size * ch; // input_mixin + n += condition_size * ch; // input_mixin } n += ch * head_size; // head_rechannel - n += 1; // head_scale + n += 1; // head_scale return n; }; @@ -389,7 +371,7 @@ void test_slimmed_matches_small_model() // Helper: embed Conv1x1(small_in, small_out) into Conv1x1(full_in, full_out) auto embed_conv1x1 = [](std::vector::const_iterator& src, int small_in, int small_out, int full_in, - int full_out, bool bias, std::vector& dst) { + int full_out, bool bias, std::vector& dst) { for (int i = 0; i < full_out; i++) for (int j = 0; j < full_in; j++) { @@ -410,7 +392,7 @@ void test_slimmed_matches_small_model() // Helper: embed Conv1D(small_in, small_out) into Conv1D(full_in, full_out) auto embed_conv1d = [](std::vector::const_iterator& src, int small_in, int small_out, int full_in, - int full_out, int ks, std::vector& dst) { + int full_out, int ks, std::vector& dst) { for (int i = 0; i < full_out; i++) for (int j = 0; j < full_in; j++) for (int k = 0; k < ks; k++) @@ -465,8 +447,8 @@ void test_slimmed_matches_small_model() large_json["version"] = "0.7.0"; large_json["architecture"] = "SlimmableWavenet"; auto large_layer_config = make_layer_config(large_ch); - large_layer_config["slimmable"] = {{"method", "slice_channels_uniform"}, - {"kwargs", {{"allowed_channels", {small_ch, large_ch}}}}}; + large_layer_config["slimmable"] = { + {"method", "slice_channels_uniform"}, {"kwargs", {{"allowed_channels", {small_ch, large_ch}}}}}; large_json["config"]["model"]["layers"] = nlohmann::json::array({large_layer_config}); large_json["config"]["model"]["head_scale"] = 1.0; large_json["weights"] = large_weights; @@ -510,12 +492,7 @@ void test_slimmed_matches_small_model() { assert(std::isfinite(out_small[i])); assert(std::isfinite(out_large[i])); - if (std::abs(out_small[i] - out_large[i]) > 1e-6) - { - std::cerr << " MISMATCH at buffer " << buf << " sample " << i << ": small=" << out_small[i] - << " slimmed=" << out_large[i] << " diff=" << std::abs(out_small[i] - out_large[i]) << std::endl; - assert(false); - } + assert(std::abs(out_small[i] - out_large[i]) <= 1e-6); } } }