Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions stan/math/fwd/fun/log_softmax.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,17 +32,17 @@ inline auto log_softmax(T&& x) {
*
* @tparam Vec Eigen vector with `fvar` scalar
* @param x vector to transform
* @return log softmax of the vector
* @throw std::domain_error if the input size is 0
* @return log softmax of the vector, or an empty result if the input is empty
*/
template <typename Vec, require_eigen_vector_vt<is_fvar, Vec>* = nullptr>
inline auto log_softmax(Vec&& x) {
using vec = std::decay_t<Vec>;
constexpr int Rows = vec::RowsAtCompileTime;
constexpr int Cols = vec::ColsAtCompileTime;
using T = typename value_type_t<Vec>::Scalar;
check_nonzero_size("log_softmax", "x", x);
using T = typename value_type_t<vec>::Scalar;
decltype(auto) x_ref = to_ref(std::forward<Vec>(x));
if (x_ref.size() == 0)
return Eigen::Matrix<fvar<T>, Rows, Cols>{};
const auto s = softmax(value_of(x_ref));
const auto d_in = x_ref.d();
const auto dot_sd = s.dot(d_in);
Expand Down
8 changes: 4 additions & 4 deletions stan/math/fwd/fun/softmax.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
#include <stan/math/prim/fun/Eigen.hpp>
#include <stan/math/fwd/core.hpp>
#include <stan/math/fwd/fun/value_of.hpp>
#include <stan/math/prim/err.hpp>
#include <stan/math/prim/fun/to_ref.hpp>
#include <stan/math/prim/fun/softmax.hpp>
#include <stan/math/prim/functor/apply_vector_unary.hpp>
Expand All @@ -30,18 +31,17 @@ inline auto softmax(T&& x) {
*
* @tparam Vec Eigen vector with `fvar` scalar
* @param x vector to transform
* @return softmax of the vector
* @return softmax of the vector, or an empty result if the input is empty
*/
template <typename Vec, require_eigen_vector_vt<is_fvar, Vec>* = nullptr>
inline auto softmax(Vec&& x) {
using vec = std::decay_t<Vec>;
constexpr int Rows = vec::RowsAtCompileTime;
constexpr int Cols = vec::ColsAtCompileTime;
using T = typename value_type_t<vec>::Scalar;
if (x.size() == 0) {
return Eigen::Matrix<fvar<T>, Rows, Cols>();
}
decltype(auto) x_ref = to_ref(std::forward<Vec>(x));
if (x_ref.size() == 0)
return Eigen::Matrix<fvar<T>, Rows, Cols>{};
const auto s = softmax(value_of(x_ref));
const auto d_in = x_ref.d();
const auto dot_sd = s.dot(d_in);
Expand Down
3 changes: 2 additions & 1 deletion stan/math/opencl/prim/log_softmax.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,8 @@ namespace math {
template <typename T,
require_all_kernel_expressions_and_none_scalar_t<T>* = nullptr>
inline matrix_cl<double> log_softmax(const T& a) {
check_nonzero_size("log_softmax (OpenCL)", "x", a);
if (a.size() == 0)
return matrix_cl<double>(a.rows(), a.cols());
return make_holder_cl([](auto&& x) { return x - log_sum_exp(x); }, to_ref(a));
}

Expand Down
6 changes: 3 additions & 3 deletions stan/math/opencl/prim/softmax.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
#include <stan/math/opencl/ref_type.hpp>
#include <stan/math/prim/meta.hpp>
#include <stan/math/prim/err/check_matching_sizes.hpp>
#include <stan/math/prim/err/check_nonzero_size.hpp>
#include <stan/math/prim/fun/to_ref.hpp>

namespace stan {
Expand All @@ -22,9 +23,8 @@ template <typename T,
require_all_kernel_expressions_and_none_scalar_t<T>* = nullptr>
inline matrix_cl<double> softmax(const T& a) {
check_vector("softmax (OpenCL)", "a", a);
if (a.size() == 0) {
return a;
}
if (a.size() == 0)
return matrix_cl<double>(a.rows(), a.cols());
matrix_cl<double> theta;
if constexpr (stan::internal::is_trivial_kg_expression<T>::value) {
matrix_cl<double> a_max = max_2d(a);
Expand Down
6 changes: 3 additions & 3 deletions stan/math/opencl/rev/softmax.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
#include <stan/math/opencl/prim/dot_product.hpp>
#include <stan/math/opencl/prim/softmax.hpp>
#include <stan/math/opencl/kernel_generator.hpp>
#include <stan/math/prim/err/check_nonzero_size.hpp>
#include <stan/math/rev/core.hpp>
#include <stan/math/rev/fun/value_of.hpp>

Expand All @@ -22,11 +23,10 @@ namespace math {
template <typename T,
require_all_kernel_expressions_and_none_scalar_t<T>* = nullptr>
inline var_value<matrix_cl<double>> softmax(const var_value<T>& A) {
if (A.size() == 0) {
return A;
}
return make_callback_var(
softmax(A.val()), [A](vari_value<matrix_cl<double>>& res) mutable {
if (res.val().size() == 0)
return;
A.adj() += elt_multiply(
res.val(), (res.adj() - dot_product(res.adj(), res.val())));
});
Expand Down
13 changes: 8 additions & 5 deletions stan/math/prim/fun/log_softmax.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,22 +37,25 @@ namespace math {
*
* @tparam Container type of input: an Eigen vector, `std::vector` of doubles,
* or nested container whose scalar type is arithmetic
* @param[in] x vector or container of vectors to transform
* @return log softmax of the input, preserving the container structure
* @throw std::domain_error if any input vector is empty
* @param x vector or container of vectors to transform
* @return log softmax of the input, preserving the container structure; an
* empty result if any input vector is empty
*/
template <typename Container, require_st_arithmetic<Container>* = nullptr,
require_container_t<Container>* = nullptr,
require_not_t<bool_constant<
is_eigen<std::decay_t<Container>>::value
&& !is_eigen_vector<std::decay_t<Container>>::value>>* = nullptr>
inline auto log_softmax(Container&& x) {
check_nonzero_size("log_softmax", "x", x);
return make_holder(
[](auto&& a) {
return apply_vector_unary<ref_type_t<Container>>::apply(
std::forward<decltype(a)>(a),
[](auto&& v) { return v.array() - log_sum_exp(v); });
[](auto&& v) -> plain_type_t<decltype(v)> {
if (v.size() == 0)
return v;
return (v.array() - log_sum_exp(v)).matrix();
});
},
to_ref(std::forward<Container>(x)));
}
Expand Down
55 changes: 25 additions & 30 deletions stan/math/prim/fun/softmax.hpp
Original file line number Diff line number Diff line change
@@ -1,19 +1,18 @@
#ifndef STAN_MATH_PRIM_FUN_SOFTMAX_HPP
#define STAN_MATH_PRIM_FUN_SOFTMAX_HPP

#include <stan/math/prim/meta.hpp>
#include <stan/math/prim/err.hpp>
#include <stan/math/prim/fun/Eigen.hpp>
#include <stan/math/prim/fun/to_ref.hpp>
#include <stan/math/prim/functor/apply_vector_unary.hpp>
#include <cmath>

namespace stan {
namespace math {

/**
* Return the softmax of the specified vector.
* Return the softmax of the specified vector, or of each vector in a container.
*
* <p>
* \f$
* \mbox{softmax}(y)
* = \frac{\exp(y)}
Expand All @@ -39,36 +38,32 @@ namespace math {
* \end{array}
* \f$
*
* @tparam Vec type of the input vector
* @param[in] v Vector to transform.
* @return Unit simplex result of the softmax transform of the vector.
* @tparam Container type of input: an Eigen vector, `std::vector` of doubles,
* or nested container whose scalar type is arithmetic
* @param x vector or container of vectors to transform
* @return softmax of the input, preserving the container structure; an empty
* result if any input vector is empty
*/
template <typename Vec,
require_eigen_vector_vt<std::is_arithmetic, Vec>* = nullptr>
inline plain_type_t<Vec> softmax(Vec&& v) {
if (v.size() == 0) {
return v;
}
decltype(auto) v_ref = to_ref(std::forward<Vec>(v));
const auto theta = (v_ref.array() - v_ref.maxCoeff()).exp();
return (theta / theta.sum()).matrix();
}

/**
* Return the softmax of each vector in an array.
*
* @tparam T `std::vector` whose scalar type is arithmetic
* @param[in] x Array of vectors to transform.
* @return Array of unit simplex results.
*/
template <typename T, require_std_vector_st<std::is_arithmetic, T>* = nullptr>
inline auto softmax(T&& x) {
return apply_vector_unary<T>::apply(std::forward<T>(x), [](auto&& v) {
return softmax(std::forward<decltype(v)>(v));
});
template <typename Container, require_st_arithmetic<Container>* = nullptr,
require_container_t<Container>* = nullptr,
require_not_t<bool_constant<
is_eigen<std::decay_t<Container>>::value
&& !is_eigen_vector<std::decay_t<Container>>::value>>* = nullptr>
inline auto softmax(Container&& x) {
return make_holder(
[](auto&& a) {
return apply_vector_unary<ref_type_t<Container>>::apply(
std::forward<decltype(a)>(a),
[](auto&& v) -> plain_type_t<decltype(v)> {
if (v.size() == 0)
return v;
const auto theta = (v.array() - v.maxCoeff()).exp();
return (theta / theta.sum()).matrix();
});
},
to_ref(std::forward<Container>(x)));
}

} // namespace math
} // namespace stan

#endif
7 changes: 3 additions & 4 deletions stan/math/rev/fun/log_softmax.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,13 @@ namespace math {
*
* @tparam T a `var_value` or Eigen vector/row_vector with `var` scalar
* @param x input
* @return log softmax of the input
* @throw std::domain_error if the input size is 0
* @return log softmax of the input, or an empty result if the input is empty
*/
template <typename T, require_rev_matrix_t<T>* = nullptr>
inline auto log_softmax(T&& x) {
check_nonzero_size("log_softmax", "x", x);
auto x_arena = to_arena(std::forward<T>(x));
if (x_arena.size() == 0)
return x_arena;
using return_t
= return_var_matrix_t<plain_type_t<decltype(x_arena.val())>, T>;
arena_t<return_t> res = log_softmax(x_arena.val());
Expand All @@ -42,7 +42,6 @@ inline auto log_softmax(T&& x) {
* @tparam T `std::vector` whose scalar type is `var`
* @param x array of vectors to transform
* @return array of log softmax results
* @throw std::domain_error if any element size is 0
*/
template <typename T, require_std_vector_st<is_var, T>* = nullptr>
inline auto log_softmax(T&& x) {
Expand Down
8 changes: 4 additions & 4 deletions stan/math/rev/fun/softmax.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
#include <stan/math/rev/core/reverse_pass_callback.hpp>
#include <stan/math/rev/core/arena_matrix.hpp>
#include <stan/math/rev/fun/to_arena.hpp>
#include <stan/math/prim/err.hpp>
#include <stan/math/prim/fun/to_ref.hpp>
#include <stan/math/prim/fun/softmax.hpp>
#include <stan/math/prim/functor/apply_vector_unary.hpp>
Expand All @@ -18,16 +19,15 @@ namespace math {
*
* @tparam T a `var_value` or Eigen vector/row_vector with `var` scalar
* @param x input
* @return softmax of the input
* @return softmax of the input, or an empty result if the input is empty
*/
template <typename T, require_rev_matrix_t<T>* = nullptr>
inline auto softmax(T&& x) {
auto x_arena = to_arena(std::forward<T>(x));
if (x_arena.size() == 0)
return x_arena;
using return_t
= return_var_matrix_t<plain_type_t<decltype(x_arena.val())>, T>;
if (x_arena.size() == 0) {
return x_arena;
}
arena_t<return_t> res = softmax(x_arena.val());
reverse_pass_callback([x_arena, res]() mutable {
x_arena.adj().array()
Expand Down
12 changes: 6 additions & 6 deletions test/unit/math/mix/fun/log_softmax_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
TEST(MathMixMatFun, logSoftmax) {
auto f = [](const auto& x) { return stan::math::log_softmax(x); };
// Column Vectors
Eigen::VectorXd x0(0); // error case
Eigen::VectorXd x0(0);
stan::test::expect_ad(f, x0);
stan::test::expect_ad_matvar(f, x0);

Expand Down Expand Up @@ -34,7 +34,7 @@ TEST(MathMixMatFun, logSoftmax) {
stan::test::expect_ad_matvar(f, x3c);

// Row Vectors
Eigen::RowVectorXd rx0(0); // error case
Eigen::RowVectorXd rx0(0);
stan::test::expect_ad(f, rx0);
stan::test::expect_ad_matvar(f, rx0);

Expand Down Expand Up @@ -64,7 +64,7 @@ TEST(MathMixMatFun, logSoftmax) {
stan::test::expect_ad_matvar(f, rx3c);

// std vectors
std::vector<double> stx0(0); // error case
std::vector<double> stx0(0);
stan::test::expect_ad(f, stx0);

std::vector<double> stx1{0};
Expand All @@ -83,23 +83,23 @@ TEST(MathMixMatFun, logSoftmax) {
stan::test::expect_ad(f, stx3c);

// Nested containers
std::vector<Eigen::VectorXd> stvx0{x0, x0}; // error case
std::vector<Eigen::VectorXd> stvx0{x0, x0};
stan::test::expect_ad(f, stvx0);
stan::test::expect_ad_matvar(f, stvx0);

std::vector<Eigen::VectorXd> stvx1{x1, x1};
stan::test::expect_ad(f, stvx1);
stan::test::expect_ad_matvar(f, stvx1);

std::vector<Eigen::RowVectorXd> strx0{rx0, rx0}; // error case
std::vector<Eigen::RowVectorXd> strx0{rx0, rx0};
stan::test::expect_ad(f, strx0);
stan::test::expect_ad_matvar(f, strx0);

std::vector<Eigen::RowVectorXd> strx1{rx1, rx1};
stan::test::expect_ad(f, strx1);
stan::test::expect_ad_matvar(f, strx1);

std::vector<std::vector<double>> ststx0{stx0, stx0}; // error case
std::vector<std::vector<double>> ststx0{stx0, stx0};
stan::test::expect_ad(f, ststx0);

std::vector<std::vector<double>> ststx1{stx1, stx1};
Expand Down
4 changes: 2 additions & 2 deletions test/unit/math/mix/fun/softmax_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ TEST(MathMixMatFun, softmax) {
expect_ad_matvar(f, rd2);

// Arrays of vectors (array[] vector and array[] row_vector)
std::vector<Eigen::VectorXd> stvx0{a, a}; // error case
std::vector<Eigen::VectorXd> stvx0{a, a};
stan::test::expect_ad(tols, f, stvx0);
expect_ad_matvar(f, stvx0);

Expand All @@ -81,7 +81,7 @@ TEST(MathMixMatFun, softmax) {
stan::test::expect_ad(tols, f, stvx2);
expect_ad_matvar(f, stvx2);

std::vector<Eigen::RowVectorXd> strx0{ra, ra}; // error case
std::vector<Eigen::RowVectorXd> strx0{ra, ra};
stan::test::expect_ad(tols, f, strx0);
expect_ad_matvar(f, strx0);

Expand Down
5 changes: 5 additions & 0 deletions test/unit/math/opencl/rev/log_softmax_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,11 @@ TEST(OpenCLLogSoftmax, prim_rev_size_1) {
stan::math::test::compare_cpu_opencl_prim_rev(log_softmax_functor, a);
}

TEST(OpenCLLogSoftmax, prim_rev_size_0) {
Eigen::VectorXd a(0);
EXPECT_EQ(0, stan::math::log_softmax(a).size());
}

TEST(OpenCLLogSoftmax, prim_rev_values_large) {
int N = 71;

Expand Down
6 changes: 2 additions & 4 deletions test/unit/math/opencl/rev/softmax_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,8 @@ TEST(OpenCLSoftmax, prim_rev_values_small) {
}

TEST(OpenCLSoftmax, prim_rev_size_0) {
int N = 0;

Eigen::VectorXd a(N);
stan::math::test::compare_cpu_opencl_prim_rev(softmax_functor, a);
Eigen::VectorXd a(0);
EXPECT_EQ(0, stan::math::softmax(a).size());
}

TEST(OpenCLSoftmax, prim_rev_values_large) {
Expand Down
7 changes: 3 additions & 4 deletions test/unit/math/prim/fun/log_softmax_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -84,9 +84,8 @@ TEST(MathMatrixPrimMat, log_softmax_neg_inf) {
EXPECT_FLOAT_EQ(2.0 - lse_finite, result[2]);
}

TEST(MathMatrixPrimMat, log_softmax_exception) {
TEST(MathMatrixPrimMat, log_softmax_empty) {
using stan::math::log_softmax;
stan::math::vector_d v0; // size == 0

EXPECT_THROW(log_softmax(v0), std::invalid_argument);
stan::math::vector_d v0;
EXPECT_EQ(0, log_softmax(v0).size());
}
6 changes: 6 additions & 0 deletions test/unit/math/prim/fun/softmax_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,12 @@ TEST(MathMatrixPrimMat, softmax_neg_inf) {
EXPECT_FLOAT_EQ(1.0, theta.sum());
}

TEST(MathMatrixPrimMat, softmax_empty) {
using stan::math::softmax;
Eigen::Matrix<double, Eigen::Dynamic, 1> v0; // size == 0
EXPECT_EQ(0, softmax(v0).size());
}

TEST(MathMatrixPrimMat, softmax_row_vector) {
using Eigen::Dynamic;
using Eigen::Matrix;
Expand Down
Loading