Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
125 changes: 97 additions & 28 deletions stan/math/prim/prob/wiener5_lpdf.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -516,40 +516,109 @@ template <bool WrtLog = false, typename T_y, typename T_a, typename T_w,
inline auto wiener5_grad_w(const T_y& y, const T_a& a, const T_v& v,
const T_w& w, const T_sv& sv,
T_err log_err = log(1e-12)) noexcept {
const auto two_log_a = 2.0 * log(a);
const auto log_y_asq = log(y) - two_log_a;
const auto log_error_term = wiener5_compute_log_error_term(y, a, v, w, sv);
const auto one_m_w = 1.0 - w;
using ret_t = return_type_t<T_y, T_a, T_w, T_v, T_sv, T_err>;

const auto y_asq = y / square(a);
const auto q = 1.0 - w;
const auto sv_sqr = square(sv);
Comment on lines +521 to 523
const auto one_plus_svsqr_y = 1.0 + sv_sqr * y;
const auto density_part_one
= (v * a + sv_sqr * square(a) * one_m_w) / one_plus_svsqr_y;
const auto log_error = (log_err - log_error_term);

const auto n_terms_small_t
= wiener5_n_terms_small_t<GradientCalc::OFF, GradientCalc::ON>(y, a, w,
log_error);
const auto n_terms_large_t
const auto log_error_term = wiener5_compute_log_error_term(y, a, v, w, sv);
const auto log_error = log_err - log_error_term;

// d/dw of
//
// -2 log(a)
// - 0.5 log(1 + sv^2 y)
// + [-v^2 y + 2 a v (1-w) + a^2 (1-w)^2 sv^2]
// / [2 (1 + sv^2 y)].
const auto pref_grad_w = (-a * v - square(a) * sv_sqr * q) / one_plus_svsqr_y;

// Use the density branch decision. The derivative is the derivative of
// the same scalar value, so do not let the w-gradient path switch to a
// different underconverged representation.
const auto n_small_density
= wiener5_n_terms_small_t<true, GradientCalc::OFF>(y, a, w, log_error);
const auto n_large_density
= wiener5_density_large_reaction_time_terms(y, a, w, log_error);

const auto n_small_grad
= wiener5_n_terms_small_t<false, GradientCalc::ON>(y, a, w, log_error);
const auto n_large_grad
= wiener5_gradient_large_reaction_time_terms<GradientCalc::ON>(y, a, w,
log_error);
auto wiener_res = wiener5_log_sum_exp<GradientCalc::OFF, GradientCalc::ON>(
y, a, w, n_terms_small_t, n_terms_large_t);
auto&& result = wiener_res.first;
auto&& newsign = wiener_res.second;
const auto log_density = wiener5_density<GradientCalc::OFF>(
y, a, v, w, sv, log_err - log(fabs(density_part_one)));
if (2.0 * n_terms_small_t < n_terms_large_t) {
auto ans = -(density_part_one
- newsign
* exp(result - (log_density - log_error_term)
- 2.5 * log_y_asq - 0.5 * LOG_TWO - 0.5 * LOG_PI));
return WrtLog ? ans * exp(log_density) : ans;

const int n_small = static_cast<int>(
fmax(value_of_rec(n_small_density), value_of_rec(n_small_grad)));
const int n_large = static_cast<int>(
fmax(value_of_rec(n_large_density), value_of_rec(n_large_grad)));

ret_t series_grad_w = 0.0;

if (2.0 * n_small_density <= n_large_density) {
// Small-time representation.
//
// R_s = sum_{k=-K}^{K} z_k exp(-z_k^2 / (2 t*)),
// z_k = 1 - w + 2k.
//
// dR_s/dw = sum_k (z_k^2 / t* - 1)
// exp(-z_k^2 / (2 t*)).
ret_t max_log = NEGATIVE_INFTY;

for (int k = -n_small; k <= n_small; ++k) {
const double kd = static_cast<double>(k);
const auto z = q + 2.0 * kd;
const auto log_e = -square(z) / (2.0 * y_asq);
max_log = fmax(max_log, log_e);
}

ret_t raw = 0.0;
ret_t draw_dw = 0.0;

for (int k = -n_small; k <= n_small; ++k) {
const double kd = static_cast<double>(k);
const auto z = q + 2.0 * kd;
const auto e = exp(-square(z) / (2.0 * y_asq) - max_log);

raw += z * e;
draw_dw += (square(z) / y_asq - 1.0) * e;
}

series_grad_w = draw_dw / raw;
} else {
auto ans = -(
density_part_one
+ newsign
* exp(result - (log_density - log_error_term) + 2.0 * LOG_PI));
return WrtLog ? ans * exp(log_density) : ans;
// Large-time representation in the same upper-bound coordinate used
// by the density code: q = 1 - w.
//
// R_l = sum_{k=1}^{K}
// k sin(k pi q)
// exp(-(k^2 - 1) pi^2 t* / 2).
//
// dR_l/dw = -sum_{k=1}^{K}
// k^2 pi cos(k pi q)
// exp(-(k^2 - 1) pi^2 t* / 2).
ret_t raw = 0.0;
ret_t draw_dw = 0.0;

const auto half_pi2_y = 0.5 * square(pi()) * y_asq;

for (int k = 1; k <= n_large; ++k) {
const double kd = static_cast<double>(k);
const auto exp_term = exp(-(square(kd) - 1.0) * half_pi2_y);
const auto angle = kd * pi() * q;

raw += kd * sin(angle) * exp_term;
draw_dw += -square(kd) * pi() * cos(angle) * exp_term;
}

series_grad_w = draw_dw / raw;
}

const auto ans = pref_grad_w + series_grad_w;

if constexpr (WrtLog) {
return ans * wiener5_density<true>(y, a, v, w, sv, log_err);
} else {
return ans;
}
}

Expand Down
77 changes: 77 additions & 0 deletions test/unit/math/mix/prob/wiener_lpdf_ad_test.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
#include <stan/math/mix.hpp>
#include <test/unit/math/test_ad.hpp>
#include <gtest/gtest.h>

TEST(MathMixProbWienerLpdf, fiveParamWGradientExpectAd) {
auto f = [](const auto& w) {
return stan::math::wiener_lpdf(6.0, 10.0, 0.01, w, -3.0, 0.2);
};

stan::test::ad_tolerances tols;
tols.gradient_grad_ = 1e-5;
stan::test::expect_ad(tols, f, 0.1);
}

TEST(MathMixProbWienerLpdf, fiveParamZeroSvWGradientExpectAd) {
auto f = [](const auto& w) {
return stan::math::wiener_lpdf(6.0, 10.0, 0.01, w, -3.0, 0.0);
};

stan::test::ad_tolerances tols;
tols.gradient_grad_ = 1e-5;
stan::test::expect_ad(tols, f, 0.1);
}

TEST(MathMixProbWienerLpdf, fullParamWGradientExpectAd) {
auto f = [](const auto& w) {
return stan::math::wiener_lpdf(6.0, 10.0, 0.01, w, -3.0, 0.2, 0.1, 0.0);
};

stan::test::ad_tolerances tols;
tols.gradient_grad_ = 1e-5;
stan::test::expect_ad(tols, f, 0.1);
}

TEST(MathMixProbWienerLpdf, existingFullRowsWGradientExpectAd) {
struct Case {
const char* name;
double y;
double a;
double t0;
double w;
double v;
double sv;
double sw;
double st0;
};

const std::vector<Case> cases = {
{"row_0", 2.0, 2.0, 1e-9, 0.10, 2.0, 0.0, 0.00, 0.000},
{"row_1", 3.0, 2.0, 0.01, 0.50, 2.0, 0.2, 0.00, 0.000},
{"row_2", 4.0, 10.0, 0.01, 0.80, 4.0, 0.0, 0.10, 0.000},
{"row_3", 5.0, 4.0, 0.01, 0.70, 3.0, 0.0, 0.00, 0.007},
{"row_4", 6.0, 10.0, 0.01, 0.10, -3.0, 0.2, 0.10, 0.000},
{"row_5", 7.0, 1.0, 0.01, 0.90, 1.0, 0.2, 0.00, 0.007},
{"row_6", 8.0, 3.0, 0.01, 0.70, -1.0, 0.0, 0.10, 0.007},
{"row_7", 8.85, 1.7, 0.01, 0.92, -7.3, 0.7, 0.01, 0.009},
{"row_8", 8.9, 2.4, 0.01, 0.90, -4.9, 0.0, 0.00, 0.009},
{"row_9", 9.0, 11.0, 0.01, 0.12, 4.5, 0.7, 0.10, 0.009},
{"row_10", 1.0, 1.5, 0.10, 0.50, 3.0, 0.5, 0.20, 0.000},
};

stan::test::ad_tolerances tols;
tols.gradient_grad_ = 1e-4;

for (const auto& c : cases) {
SCOPED_TRACE(c.name);

auto f = [c](const auto& w) {
return stan::math::wiener_lpdf(c.y, c.a, c.t0, w, c.v, c.sv, c.sw, c.st0);
};

// The row sweep is intended to check the reverse-mode w adjoint against
// finite differences. Some full-Wiener rows are not stable enough for
// higher-order mixed-mode finite-difference checks
stan::test::expect_ad<true>(tols, f, c.w);
}
}
2 changes: 1 addition & 1 deletion test/unit/math/prim/prob/wiener_full_lpdf_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -320,7 +320,7 @@ TEST(mathPrimCorrectValues, wiener_lpdf) {
12.8617364931501, 1.12047317491985, 5.68799957241344};
std::vector<double> true_grad_w
= {5.67120184517318, -3.64396221090076, -38.7775057146792,
-14.1837930137393, -34.5869239580708, -10.4535345681946,
-14.1837930137393, 35.71918681520359, -10.4535345681946,
0.679597983582904, -9.93144540834201, 2.09117200953597,
-6.0858540417876, -3.74870310978083};
std::vector<double> true_grad_v = {
Expand Down
Loading