Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions .github/workflows/build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,14 @@ jobs:
run: |
cmake .. -DCMAKE_BUILD_TYPE=Debug
cmake --build . -j4

- name: Build Tools (Inline GEMM)
working-directory: ${{github.workspace}}/build_inline
env:
CXX: clang++
run: |
cmake .. -DCMAKE_BUILD_TYPE=Debug -DCMAKE_CXX_FLAGS="-DNAM_USE_INLINE_GEMM"
cmake --build . -j4

- name: Run tests
working-directory: ${{github.workspace}}
Expand All @@ -42,3 +50,8 @@ jobs:
./build/tools/benchmodel ./example_models/wavenet.nam
./build/tools/benchmodel ./example_models/lstm.nam
./build/tools/render ./example_models/wavenet.nam ./example_audio/input.wav ./example_audio/output.wav
./build_inline/tools/run_tests
./build_inline/tools/benchmodel ./example_models/wavenet.nam
./build_inline/tools/benchmodel ./example_models/lstm.nam
./build_inline/tools/render ./example_models/wavenet.nam ./example_audio/input.wav ./example_audio/output.wav

15 changes: 13 additions & 2 deletions NAM/activations.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,10 @@ std::unordered_map<std::string, nam::activations::Activation::Ptr> nam::activati
{"PReLU", make_singleton_ptr(_PRELU)},
{"Softsign", make_singleton_ptr(_SOFTSIGN)}};

// Variables to hold previous instances of activations when we replace them with fast or LUT implementations
nam::activations::Activation::Ptr tanh_bak = nullptr;
nam::activations::Activation::Ptr sigmoid_bak = nullptr;
nam::activations::Activation::Ptr silu_bak = nullptr;

nam::activations::Activation::Ptr nam::activations::Activation::get_activation(const std::string name)
{
Expand Down Expand Up @@ -197,9 +199,14 @@ void nam::activations::Activation::enable_lut(std::string function_name, float m
fn = sigmoid;
sigmoid_bak = _activations["Sigmoid"];
}
else if (function_name == "SiLU")
{
fn = swish;
silu_bak = _activations["SiLU"];
}
else
{
throw std::runtime_error("Tried to enable LUT for a function other than Tanh or Sigmoid");
throw std::runtime_error("Tried to enable LUT for a function other than Tanh, Sigmoid, or SiLU");
}
_activations[function_name] = std::make_shared<FastLUTActivation>(min, max, n_points, fn);
}
Expand All @@ -214,8 +221,12 @@ void nam::activations::Activation::disable_lut(std::string function_name)
{
_activations["Sigmoid"] = sigmoid_bak;
}
else if (function_name == "SiLU")
{
_activations["SiLU"] = silu_bak;
}
else
{
throw std::runtime_error("Tried to disable LUT for a function other than Tanh or Sigmoid");
throw std::runtime_error("Tried to disable LUT for a function other than Tanh, Sigmoid, or SiLU");
}
}
43 changes: 18 additions & 25 deletions NAM/activations.h
Original file line number Diff line number Diff line change
Expand Up @@ -119,18 +119,12 @@ inline float swish(float x)

inline float hardswish(float x)
{
if (x <= -3.0)
{
return 0;
}
else if (x >= 3.0)
{
return x;
}
else
{
return x * (x + 3.0) / 6.0;
}
// Branchless implementation using clamp
// hardswish(x) = x * relu6(x + 3) / 6
// = x * clamp(x + 3, 0, 6) / 6
const float t = x + 3.0f;
const float clamped = t < 0.0f ? 0.0f : (t > 6.0f ? 6.0f : t);
Comment thread
sdatkinson marked this conversation as resolved.
return x * clamped * (1.0f / 6.0f);
}

inline float softsign(float x)
Expand All @@ -147,9 +141,18 @@ class Activation
Activation() = default;
virtual ~Activation() = default;
virtual void apply(Eigen::MatrixXf& matrix) { apply(matrix.data(), matrix.rows() * matrix.cols()); }
virtual void apply(Eigen::Block<Eigen::MatrixXf> block) { apply(block.data(), block.rows() * block.cols()); }
virtual void apply(Eigen::Block<Eigen::MatrixXf> block)
{
// Block must be contiguous in memory (outerStride == rows) for flat data() access.
// Non-contiguous blocks (e.g. topRows() of a wider matrix) would read/write wrong elements.
assert(block.outerStride() == block.rows());
apply(block.data(), block.rows() * block.cols());
}
virtual void apply(Eigen::Block<Eigen::MatrixXf, -1, -1, true> block)
{
// Inner-panel blocks (e.g. leftCols()) are always contiguous for column-major matrices,
// but assert anyway for safety.
assert(block.outerStride() == block.rows());
apply(block.data(), block.rows() * block.cols());
}
virtual void apply(float* data, long size) = 0;
Expand Down Expand Up @@ -244,9 +247,7 @@ class ActivationReLU : public Activation
void apply(float* data, long size) override
{
for (long pos = 0; pos < size; pos++)
{
data[pos] = relu(data[pos]);
}
}
};

Expand Down Expand Up @@ -336,9 +337,7 @@ class ActivationSigmoid : public Activation
void apply(float* data, long size) override
{
for (long pos = 0; pos < size; pos++)
{
data[pos] = sigmoid(data[pos]);
}
}
};

Expand All @@ -348,9 +347,7 @@ class ActivationSwish : public Activation
void apply(float* data, long size) override
{
for (long pos = 0; pos < size; pos++)
{
data[pos] = swish(data[pos]);
}
}
};

Expand All @@ -360,9 +357,7 @@ class ActivationHardSwish : public Activation
void apply(float* data, long size) override
{
for (long pos = 0; pos < size; pos++)
{
data[pos] = hardswish(data[pos]);
}
}
};

Expand All @@ -372,9 +367,7 @@ class ActivationSoftsign : public Activation
void apply(float* data, long size) override
{
for (long pos = 0; pos < size; pos++)
{
data[pos] = softsign(data[pos]);
}
}
};

Expand All @@ -400,8 +393,8 @@ class FastLUTActivation : public Activation
// Fast lookup with linear interpolation
inline float lookup(float x) const
{
// Clamp input to range
x = std::clamp(x, min_x_, max_x_);
// Clamp input to range (inline to avoid header dependency)
x = x < min_x_ ? min_x_ : (x > max_x_ ? max_x_ : x);

// Calculate float index
float f_idx = (x - min_x_) * inv_step_;
Expand Down
Loading