diff --git a/.gitignore b/.gitignore index 259148f..859e9fc 100644 --- a/.gitignore +++ b/.gitignore @@ -30,3 +30,9 @@ *.exe *.out *.app + +# IDE Based files +*.idea + +# Build folder +build/ diff --git a/.gitmodules b/.gitmodules index 11c1984..3965bf6 100644 --- a/.gitmodules +++ b/.gitmodules @@ -1,3 +1,6 @@ [submodule "Dependencies/eigen"] path = Dependencies/eigen url = https://gitlab.com/libeigen/eigen +[submodule "Dependencies/Catch2"] + path = Dependencies/Catch2 + url = https://github.com/catchorg/Catch2.git diff --git a/Dependencies/Catch2 b/Dependencies/Catch2 new file mode 160000 index 0000000..a1faad9 --- /dev/null +++ b/Dependencies/Catch2 @@ -0,0 +1 @@ +Subproject commit a1faad9315ece8e7146e9d2263ceb3d42ea0619a diff --git a/tools/CMakeLists.txt b/tools/CMakeLists.txt index 1e09436..f3fecb7 100644 --- a/tools/CMakeLists.txt +++ b/tools/CMakeLists.txt @@ -1,6 +1,10 @@ file(GLOB_RECURSE NAM_SOURCES ../NAM/*.cpp ../NAM/*.c ../NAM*.h) +file(GLOB TEST_SOURCES test/*.cpp) + +set(CATCH_INSTALL_DOCS OFF CACHE BOOL "" FORCE) +set(CATCH_INSTALL_EXTRAS OFF CACHE BOOL "" FORCE) +add_subdirectory(${NAM_DEPS_PATH}/Catch2 Catch2 EXCLUDE_FROM_ALL) -# TODO: add loadmodel and run_tests to TOOLS? set(TOOLS benchmodel) add_custom_target(tools ALL @@ -12,11 +16,14 @@ include_directories(tools ${NAM_DEPS_PATH}/nlohmann) add_executable(loadmodel loadmodel.cpp ${NAM_SOURCES}) add_executable(benchmodel benchmodel.cpp ${NAM_SOURCES}) -add_executable(run_tests run_tests.cpp ${NAM_SOURCES}) +add_executable(run_tests run_tests.cpp ${TEST_SOURCES} ${NAM_SOURCES}) + +target_link_libraries(run_tests PRIVATE Catch2::Catch2WithMain) source_group(NAM ${CMAKE_CURRENT_SOURCE_DIR} FILES ${NAM_SOURCES}) target_compile_features(${TOOLS} PUBLIC cxx_std_20) +target_compile_features(run_tests PUBLIC cxx_std_20) set_target_properties(${TOOLS} PROPERTIES @@ -27,6 +34,7 @@ set_target_properties(${TOOLS} if (CMAKE_SYSTEM_NAME STREQUAL "Windows") target_compile_definitions(${TOOLS} PRIVATE NOMINMAX WIN32_LEAN_AND_MEAN) + target_compile_definitions(run_tests PRIVATE NOMINMAX WIN32_LEAN_AND_MEAN) endif() if (MSVC) @@ -34,15 +42,21 @@ if (MSVC) "$<$:/W4>" "$<$:/O2>" ) + target_compile_options(run_tests PRIVATE + "$<$:/W4>" + "$<$:/O2>" + ) else() target_compile_options(${TOOLS} PRIVATE -Wall -Wextra -Wpedantic -Wstrict-aliasing -Wunreachable-code -Weffc++ -Wno-unused-parameter "$<$:-Og;-ggdb;-Werror>" "$<$:-Ofast>" ) + target_compile_options(run_tests PRIVATE + -Wall -Wextra -Wpedantic -Wstrict-aliasing -Wunreachable-code -Weffc++ -Wno-unused-parameter + "$<$:-Og;-ggdb;-Werror>" + "$<$:-Ofast>" + ) endif() -# There's an error in eigen's -# /Users/steve/src/NeuralAmpModelerCore/Dependencies/eigen/Eigen/src/Core/products/GeneralBlockPanelKernel.h -# Don't let this break my build on debug: set_source_files_properties(../NAM/dsp.cpp PROPERTIES COMPILE_FLAGS "-Wno-error") diff --git a/tools/run_tests.cpp b/tools/run_tests.cpp index bee341a..417253d 100644 --- a/tools/run_tests.cpp +++ b/tools/run_tests.cpp @@ -1,40 +1,2 @@ -// Entry point for tests -// See the GitHub Action for a demo how to build and run tests. - -#include -#include "test/test_activations.cpp" -#include "test/test_dsp.cpp" -#include "test/test_get_dsp.cpp" -#include "test/test_wavenet.cpp" - -int main() -{ - std::cout << "Running tests..." << std::endl; - // TODO Automatically loop, catch exceptions, log results - - test_activations::TestFastTanh::test_core_function(); - test_activations::TestFastTanh::test_get_by_init(); - test_activations::TestFastTanh::test_get_by_str(); - - test_activations::TestLeakyReLU::test_core_function(); - test_activations::TestLeakyReLU::test_get_by_init(); - test_activations::TestLeakyReLU::test_get_by_str(); - - test_dsp::test_construct(); - test_dsp::test_get_input_level(); - test_dsp::test_get_output_level(); - test_dsp::test_has_input_level(); - test_dsp::test_has_output_level(); - test_dsp::test_set_input_level(); - test_dsp::test_set_output_level(); - - test_get_dsp::test_gets_input_level(); - test_get_dsp::test_gets_output_level(); - test_get_dsp::test_null_input_level(); - test_get_dsp::test_null_output_level(); - - test_wavenet::test_gated(); - - std::cout << "Success!" << std::endl; - return 0; -} \ No newline at end of file +#define CATCH_CONFIG_MAIN +#include \ No newline at end of file diff --git a/tools/test/test_activations.cpp b/tools/test/test_activations.cpp index 75c472b..c4dcbcc 100644 --- a/tools/test/test_activations.cpp +++ b/tools/test/test_activations.cpp @@ -1,122 +1,46 @@ -// Tests for activation functions -// -// Things you want ot test for: -// 1. That the core elementwise funciton is snapshot-correct. -// 2. The class that wraps the core function for an array of data -// 3. .cpp: that you have the singleton defined, and that it's in the unordered map to get by string - -#include -#include -#include - +#include #include "NAM/activations.h" -namespace test_activations -{ -// TODO get nonzero cases -class TestFastTanh -{ -public: - static void test_core_function() - { - auto TestCase = [](float input, float expectedOutput) { - float actualOutput = nam::activations::fast_tanh(input); - assert(actualOutput == expectedOutput); - }; - // A few snapshot tests - TestCase(0.0f, 0.0f); - // TestCase(1.0f, 1.0f); - // TestCase(-1.0f, -0.01f); - }; - - static void test_get_by_init() - { - auto a = nam::activations::ActivationLeakyReLU(); - _test_class(&a); +TEST_CASE("FastTanh core function", "[activations]") { + REQUIRE(nam::activations::fast_tanh(0.0f) == 0.0f); +} + +TEST_CASE("FastTanh get by init", "[activations]") { + auto a = nam::activations::ActivationLeakyReLU(); + std::vector inputs{0.0f}; + a.apply(inputs.data(), inputs.size()); + REQUIRE(inputs[0] == 0.0f); +} + +TEST_CASE("FastTanh get by string", "[activations]") { + auto a = nam::activations::Activation::get_activation("Fasttanh"); + std::vector inputs{0.0f}; + a->apply(inputs.data(), inputs.size()); + REQUIRE(inputs[0] == 0.0f); +} + +TEST_CASE("LeakyReLU core function", "[activations]") { + REQUIRE(nam::activations::leaky_relu(0.0f) == 0.0f); + REQUIRE(nam::activations::leaky_relu(1.0f) == 1.0f); + REQUIRE(nam::activations::leaky_relu(-1.0f) == -0.01f); +} + +TEST_CASE("LeakyReLU get by init", "[activations]") { + auto a = nam::activations::ActivationLeakyReLU(); + std::vector inputs{0.0f, 1.0f, -1.0f}; + std::vector expected{0.0f, 1.0f, -0.01f}; + a.apply(inputs.data(), inputs.size()); + for (size_t i = 0; i < inputs.size(); ++i) { + REQUIRE(inputs[i] == expected[i]); } - - // Get the singleton and test it - static void test_get_by_str() - { - const std::string name = "Fasttanh"; - auto a = nam::activations::Activation::get_activation(name); - _test_class(a); +} + +TEST_CASE("LeakyReLU get by string", "[activations]") { + auto a = nam::activations::Activation::get_activation("LeakyReLU"); + std::vector inputs{0.0f, 1.0f, -1.0f}; + std::vector expected{0.0f, 1.0f, -0.01f}; + a->apply(inputs.data(), inputs.size()); + for (size_t i = 0; i < inputs.size(); ++i) { + REQUIRE(inputs[i] == expected[i]); } - -private: - // Put the class through its paces - static void _test_class(nam::activations::Activation* a) - { - std::vector inputs, expectedOutputs; - - inputs.push_back(0.0f); - expectedOutputs.push_back(0.0f); - - // inputs.push_back(1.0f); - // expectedOutputs.push_back(1.0f); - - // inputs.push_back(-1.0f); - // expectedOutputs.push_back(-0.01f); - - a->apply(inputs.data(), (long)inputs.size()); - for (auto itActual = inputs.begin(), itExpected = expectedOutputs.begin(); itActual != inputs.end(); - ++itActual, ++itExpected) - { - assert(*itActual == *itExpected); - } - }; -}; - -class TestLeakyReLU -{ -public: - static void test_core_function() - { - auto TestCase = [](float input, float expectedOutput) { - float actualOutput = nam::activations::leaky_relu(input); - assert(actualOutput == expectedOutput); - }; - // A few snapshot tests - TestCase(0.0f, 0.0f); - TestCase(1.0f, 1.0f); - TestCase(-1.0f, -0.01f); - }; - - static void test_get_by_init() - { - auto a = nam::activations::ActivationLeakyReLU(); - _test_class(&a); - } - - // Get the singleton and test it - static void test_get_by_str() - { - const std::string name = "LeakyReLU"; - auto a = nam::activations::Activation::get_activation(name); - _test_class(a); - } - -private: - // Put the class through its paces - static void _test_class(nam::activations::Activation* a) - { - std::vector inputs, expectedOutputs; - - inputs.push_back(0.0f); - expectedOutputs.push_back(0.0f); - - inputs.push_back(1.0f); - expectedOutputs.push_back(1.0f); - - inputs.push_back(-1.0f); - expectedOutputs.push_back(-0.01f); - - a->apply(inputs.data(), (long)inputs.size()); - for (auto itActual = inputs.begin(), itExpected = expectedOutputs.begin(); itActual != inputs.end(); - ++itActual, ++itExpected) - { - assert(*itActual == *itExpected); - } - }; -}; -}; // namespace test_activations +} diff --git a/tools/test/test_dsp.cpp b/tools/test/test_dsp.cpp index bbdee63..2b4399c 100644 --- a/tools/test/test_dsp.cpp +++ b/tools/test/test_dsp.cpp @@ -1,66 +1,46 @@ -// Tests for dsp - +#include #include "NAM/dsp.h" -namespace test_dsp -{ -// Simplest test: can I construct something! -void test_construct() -{ +TEST_CASE("DSP construct", "[dsp]") { nam::DSP myDsp(48000.0); } -void test_get_input_level() -{ +TEST_CASE("DSP get input level", "[dsp]") { nam::DSP myDsp(48000.0); const double expected = 19.0; myDsp.SetInputLevel(expected); - assert(myDsp.HasInputLevel()); - const double actual = myDsp.GetInputLevel(); - - assert(actual == expected); + REQUIRE(myDsp.HasInputLevel()); + REQUIRE(myDsp.GetInputLevel() == expected); } -void test_get_output_level() -{ +TEST_CASE("DSP get output level", "[dsp]") { nam::DSP myDsp(48000.0); const double expected = 12.0; myDsp.SetOutputLevel(expected); - assert(myDsp.HasOutputLevel()); - const double actual = myDsp.GetOutputLevel(); - - assert(actual == expected); + REQUIRE(myDsp.HasOutputLevel()); + REQUIRE(myDsp.GetOutputLevel() == expected); } -// Test correct function of DSP::HasInputLevel() -void test_has_input_level() -{ +TEST_CASE("DSP has input level", "[dsp]") { nam::DSP myDsp(48000.0); - assert(!myDsp.HasInputLevel()); - + REQUIRE(!myDsp.HasInputLevel()); myDsp.SetInputLevel(19.0); - assert(myDsp.HasInputLevel()); + REQUIRE(myDsp.HasInputLevel()); } -void test_has_output_level() -{ +TEST_CASE("DSP has output level", "[dsp]") { nam::DSP myDsp(48000.0); - assert(!myDsp.HasOutputLevel()); - + REQUIRE(!myDsp.HasOutputLevel()); myDsp.SetOutputLevel(12.0); - assert(myDsp.HasOutputLevel()); + REQUIRE(myDsp.HasOutputLevel()); } -// Test correct function of DSP::HasInputLevel() -void test_set_input_level() -{ +TEST_CASE("DSP set input level", "[dsp]") { nam::DSP myDsp(48000.0); myDsp.SetInputLevel(19.0); } -void test_set_output_level() -{ +TEST_CASE("DSP set output level", "[dsp]") { nam::DSP myDsp(48000.0); myDsp.SetOutputLevel(19.0); } -}; // namespace test_dsp diff --git a/tools/test/test_get_dsp.cpp b/tools/test/test_get_dsp.cpp index b9f2d10..b879af8 100644 --- a/tools/test/test_get_dsp.cpp +++ b/tools/test/test_get_dsp.cpp @@ -1,14 +1,11 @@ -#include +#include #include #include #include - #include "json.hpp" - #include "NAM/get_dsp.h" -namespace test_get_dsp -{ +namespace { // Config const std::string basicConfigStr = R"({"version": "0.5.4", "metadata": {"date": {"year": 2024, "month": 10, "day": 9, "hour": 18, "minute": 44, "second": 41}, "loudness": -37.8406867980957, "gain": 0.13508800804658277, "name": "Test LSTM", "modeled_by": "Steve", "gear_type": "amp", "gear_make": "Darkglass Electronics", "gear_model": "Microtubes 900 v2", "tone_type": "clean", "input_level_dbu": 18.3, "output_level_dbu": 12.3, "training": {"settings": {"ignore_checks": false}, "data": {"latency": {"manual": null, "calibration": {"algorithm_version": 1, "delays": [-16], "safety_factor": 1, "recommended": -17, "warnings": {"matches_lookahead": false, "disagreement_too_high": false}}}, "checks": {"version": 3, "passed": true}}, "validation_esr": null}}, "architecture": "LSTM", "config": {"input_size": 1, "hidden_size": 3, "num_layers": 1}, "weights": [-0.21677088737487793, -0.6683622002601624, -0.2560940980911255, -0.3588429093360901, 0.17952610552310944, 0.19445613026618958, -0.01662646047770977, 0.5353694558143616, -0.2536540627479553, -0.5132213234901428, -0.020476307719945908, 0.08592455089092255, -0.6891753673553467, 0.3627359867095947, 0.008421811275184155, 0.3113192617893219, 0.14251480996608734, 0.07989779114723206, -0.18211324512958527, 0.7118963003158569, 0.41084015369415283, -0.6571938395500183, -0.13214066624641418, -0.2698603868484497, 0.49387243390083313, -0.3491725027561188, 0.6353667974472046, -0.5005152225494385, 0.2052856683731079, -0.4301638901233673, -0.15770092606544495, -0.7181791067123413, 0.056290093809366226, -0.49049463868141174, 0.6623441576957703, 0.09029324352741241, 0.34005245566368103, 0.16416560113430023, 0.15520110726356506, -0.4155678153038025, -0.36928507685661316, 0.3211132884025574, -0.6769840121269226, -0.1575538069009781, 0.05268515646457672, -0.4191459119319916, 0.599330484867096, 0.21518059074878693, -4.246325492858887, -3.315647840499878, -4.328850746154785, 4.496089458465576, 5.015639305114746, 3.6492037773132324, 0.14431169629096985, -0.6633821725845337, 0.11673200130462646, -0.1418764889240265, -0.4897872805595398, -0.8689419031143188, -0.06714004278182983, -0.4450395107269287, -0.02142983116209507, -0.15136894583702087, -2.775207996368408, -0.08681213855743408, 0.05702732503414154, 0.670292317867279, 0.31442636251449585, 0.30793967843055725], "sample_rate": 48000})"; @@ -42,41 +39,34 @@ nam::dspData _GetConfig(const std::string& configStr = basicConfigStr) return returnedConfig; } -void test_gets_input_level() -{ +} + +TEST_CASE("get_dsp gets input level", "[get_dsp]") { nam::dspData config = _GetConfig(); std::unique_ptr dsp = get_dsp(config); - assert(dsp->HasInputLevel()); + REQUIRE(dsp->HasInputLevel()); } -void test_gets_output_level() -{ + +TEST_CASE("get_dsp gets output level", "[get_dsp]") { nam::dspData config = _GetConfig(); std::unique_ptr dsp = get_dsp(config); - assert(dsp->HasOutputLevel()); + REQUIRE(dsp->HasOutputLevel()); } -void test_null_input_level() -{ - // Issue 129 +TEST_CASE("get_dsp null input level", "[get_dsp]") { const std::string configStr = R"({"version": "0.5.4", "metadata": {"date": {"year": 2024, "month": 10, "day": 9, "hour": 18, "minute": 44, "second": 41}, "loudness": -37.8406867980957, "gain": 0.13508800804658277, "name": "Test LSTM", "modeled_by": "Steve", "gear_type": "amp", "gear_make": "Darkglass Electronics", "gear_model": "Microtubes 900 v2", "tone_type": "clean", "input_level_dbu": null, "output_level_dbu": 12.3, "training": {"settings": {"ignore_checks": false}, "data": {"latency": {"manual": null, "calibration": {"algorithm_version": 1, "delays": [-16], "safety_factor": 1, "recommended": -17, "warnings": {"matches_lookahead": false, "disagreement_too_high": false}}}, "checks": {"version": 3, "passed": true}}, "validation_esr": null}}, "architecture": "LSTM", "config": {"input_size": 1, "hidden_size": 3, "num_layers": 1}, "weights": [-0.21677088737487793, -0.6683622002601624, -0.2560940980911255, -0.3588429093360901, 0.17952610552310944, 0.19445613026618958, -0.01662646047770977, 0.5353694558143616, -0.2536540627479553, -0.5132213234901428, -0.020476307719945908, 0.08592455089092255, -0.6891753673553467, 0.3627359867095947, 0.008421811275184155, 0.3113192617893219, 0.14251480996608734, 0.07989779114723206, -0.18211324512958527, 0.7118963003158569, 0.41084015369415283, -0.6571938395500183, -0.13214066624641418, -0.2698603868484497, 0.49387243390083313, -0.3491725027561188, 0.6353667974472046, -0.5005152225494385, 0.2052856683731079, -0.4301638901233673, -0.15770092606544495, -0.7181791067123413, 0.056290093809366226, -0.49049463868141174, 0.6623441576957703, 0.09029324352741241, 0.34005245566368103, 0.16416560113430023, 0.15520110726356506, -0.4155678153038025, -0.36928507685661316, 0.3211132884025574, -0.6769840121269226, -0.1575538069009781, 0.05268515646457672, -0.4191459119319916, 0.599330484867096, 0.21518059074878693, -4.246325492858887, -3.315647840499878, -4.328850746154785, 4.496089458465576, 5.015639305114746, 3.6492037773132324, 0.14431169629096985, -0.6633821725845337, 0.11673200130462646, -0.1418764889240265, -0.4897872805595398, -0.8689419031143188, -0.06714004278182983, -0.4450395107269287, -0.02142983116209507, -0.15136894583702087, -2.775207996368408, -0.08681213855743408, 0.05702732503414154, 0.670292317867279, 0.31442636251449585, 0.30793967843055725], "sample_rate": 48000})"; nam::dspData config = _GetConfig(configStr); - // The first part of this is that the following line doesn't fail: std::unique_ptr dsp = get_dsp(config); - - assert(!dsp->HasInputLevel()); - assert(dsp->HasOutputLevel()); + REQUIRE(!dsp->HasInputLevel()); + REQUIRE(dsp->HasOutputLevel()); } -void test_null_output_level() -{ - // Issue 129 +TEST_CASE("get_dsp null output level", "[get_dsp]") { const std::string configStr = R"({"version": "0.5.4", "metadata": {"date": {"year": 2024, "month": 10, "day": 9, "hour": 18, "minute": 44, "second": 41}, "loudness": -37.8406867980957, "gain": 0.13508800804658277, "name": "Test LSTM", "modeled_by": "Steve", "gear_type": "amp", "gear_make": "Darkglass Electronics", "gear_model": "Microtubes 900 v2", "tone_type": "clean", "input_level_dbu": 19.0, "output_level_dbu": null, "training": {"settings": {"ignore_checks": false}, "data": {"latency": {"manual": null, "calibration": {"algorithm_version": 1, "delays": [-16], "safety_factor": 1, "recommended": -17, "warnings": {"matches_lookahead": false, "disagreement_too_high": false}}}, "checks": {"version": 3, "passed": true}}, "validation_esr": null}}, "architecture": "LSTM", "config": {"input_size": 1, "hidden_size": 3, "num_layers": 1}, "weights": [-0.21677088737487793, -0.6683622002601624, -0.2560940980911255, -0.3588429093360901, 0.17952610552310944, 0.19445613026618958, -0.01662646047770977, 0.5353694558143616, -0.2536540627479553, -0.5132213234901428, -0.020476307719945908, 0.08592455089092255, -0.6891753673553467, 0.3627359867095947, 0.008421811275184155, 0.3113192617893219, 0.14251480996608734, 0.07989779114723206, -0.18211324512958527, 0.7118963003158569, 0.41084015369415283, -0.6571938395500183, -0.13214066624641418, -0.2698603868484497, 0.49387243390083313, -0.3491725027561188, 0.6353667974472046, -0.5005152225494385, 0.2052856683731079, -0.4301638901233673, -0.15770092606544495, -0.7181791067123413, 0.056290093809366226, -0.49049463868141174, 0.6623441576957703, 0.09029324352741241, 0.34005245566368103, 0.16416560113430023, 0.15520110726356506, -0.4155678153038025, -0.36928507685661316, 0.3211132884025574, -0.6769840121269226, -0.1575538069009781, 0.05268515646457672, -0.4191459119319916, 0.599330484867096, 0.21518059074878693, -4.246325492858887, -3.315647840499878, -4.328850746154785, 4.496089458465576, 5.015639305114746, 3.6492037773132324, 0.14431169629096985, -0.6633821725845337, 0.11673200130462646, -0.1418764889240265, -0.4897872805595398, -0.8689419031143188, -0.06714004278182983, -0.4450395107269287, -0.02142983116209507, -0.15136894583702087, -2.775207996368408, -0.08681213855743408, 0.05702732503414154, 0.670292317867279, 0.31442636251449585, 0.30793967843055725], "sample_rate": 48000})"; nam::dspData config = _GetConfig(configStr); - // The first part of this is that the following line doesn't fail: std::unique_ptr dsp = get_dsp(config); - assert(dsp->HasInputLevel()); - assert(!dsp->HasOutputLevel()); -} -}; // namespace test_get_dsp \ No newline at end of file + REQUIRE(dsp->HasInputLevel()); + REQUIRE(!dsp->HasOutputLevel()); +} \ No newline at end of file diff --git a/tools/test/test_wavenet.cpp b/tools/test/test_wavenet.cpp index 87390cd..0771ece 100644 --- a/tools/test/test_wavenet.cpp +++ b/tools/test/test_wavenet.cpp @@ -1,17 +1,8 @@ -// Tests for the WaveNet - +#include #include -#include -#include - #include "NAM/wavenet.h" -namespace test_wavenet -{ -void test_gated() -{ - // Assert correct nuemrics of the gating activation. - // Issue 101 +TEST_CASE("WaveNet gated activation", "[wavenet]") { const int conditionSize = 1; const int channels = 1; const int kernelSize = 1; @@ -20,20 +11,13 @@ void test_gated() const bool gated = true; auto layer = nam::wavenet::_Layer(conditionSize, channels, kernelSize, dilation, activation, gated); - // Conv, input mixin, 1x1 std::vector weights{ - // Conv (weight, bias) NOTE: 2 channels out bc gated, so shapes are (2,1,1), (2,) 1.0f, 1.0f, 0.0f, 0.0f, - // Input mixin (weight only: (2,1,1)) 1.0f, -1.0f, - // 1x1 (weight (1,1,1), bias (1,)) - // NOTE: Weights are (1,1) on conv, (1,-1), so the inputs sum on the upper channel and cancel on the lower. - // This should give us a nice zero if the input & condition are the same, so that'll sigmoid to 0.5 for the - // gate. 1.0f, 0.0f}; auto it = weights.begin(); layer.set_weights_(it); - assert(it == weights.end()); + REQUIRE(it == weights.end()); const long numFrames = 4; layer.SetMaxBufferSize(numFrames); @@ -47,30 +31,15 @@ void test_gated() const float signalValue = 0.25f; input.fill(signalValue); condition.fill(signalValue); - // So input & condition will sum to 0.5 on the top channel (-> ReLU), cancel to 0 on bottom (-> sigmoid) - headInput.setZero(); output.setZero(); layer.process_(input, condition, headInput, output, 0, 0, (int)numFrames); - // 0.25 + 0.25 -> 0.5 for conv & input mixin top channel - // (0 on bottom channel) - // Top ReLU -> preseves 0.5 - // Bottom sigmoid 0->0.5 - // Product is 0.25 - // 1x1 is unity - // Skip-connect -> 0.25 (input) + 0.25 (output) -> 0.5 output - // head output gets 0+0.25 = 0.25 const float expectedOutput = 0.5; const float expectedHeadInput = 0.25; - for (int i = 0; i < numFrames; i++) - { - const float actualOutput = output(0, i); - const float actualHeadInput = headInput(0, i); - // std::cout << actualOutput << std::endl; - assert(actualOutput == expectedOutput); - assert(actualHeadInput == expectedHeadInput); + for (int i = 0; i < numFrames; i++) { + REQUIRE(output(0, i) == expectedOutput); + REQUIRE(headInput(0, i) == expectedHeadInput); } -} -}; // namespace test_wavenet \ No newline at end of file +} \ No newline at end of file