diff --git a/stan/math/prim/prob/wiener5_lpdf.hpp b/stan/math/prim/prob/wiener5_lpdf.hpp index bd4c8beecfa..b810caafd91 100644 --- a/stan/math/prim/prob/wiener5_lpdf.hpp +++ b/stan/math/prim/prob/wiener5_lpdf.hpp @@ -516,40 +516,109 @@ template ; + + const auto y_asq = y / square(a); + const auto q = 1.0 - w; const auto sv_sqr = square(sv); const auto one_plus_svsqr_y = 1.0 + sv_sqr * y; - const auto density_part_one - = (v * a + sv_sqr * square(a) * one_m_w) / one_plus_svsqr_y; - const auto log_error = (log_err - log_error_term); - const auto n_terms_small_t - = wiener5_n_terms_small_t(y, a, w, - log_error); - const auto n_terms_large_t + const auto log_error_term = wiener5_compute_log_error_term(y, a, v, w, sv); + const auto log_error = log_err - log_error_term; + + // d/dw of + // + // -2 log(a) + // - 0.5 log(1 + sv^2 y) + // + [-v^2 y + 2 a v (1-w) + a^2 (1-w)^2 sv^2] + // / [2 (1 + sv^2 y)]. + const auto pref_grad_w = (-a * v - square(a) * sv_sqr * q) / one_plus_svsqr_y; + + // Use the density branch decision. The derivative is the derivative of + // the same scalar value, so do not let the w-gradient path switch to a + // different underconverged representation. + const auto n_small_density + = wiener5_n_terms_small_t(y, a, w, log_error); + const auto n_large_density + = wiener5_density_large_reaction_time_terms(y, a, w, log_error); + + const auto n_small_grad + = wiener5_n_terms_small_t(y, a, w, log_error); + const auto n_large_grad = wiener5_gradient_large_reaction_time_terms(y, a, w, log_error); - auto wiener_res = wiener5_log_sum_exp( - y, a, w, n_terms_small_t, n_terms_large_t); - auto&& result = wiener_res.first; - auto&& newsign = wiener_res.second; - const auto log_density = wiener5_density( - y, a, v, w, sv, log_err - log(fabs(density_part_one))); - if (2.0 * n_terms_small_t < n_terms_large_t) { - auto ans = -(density_part_one - - newsign - * exp(result - (log_density - log_error_term) - - 2.5 * log_y_asq - 0.5 * LOG_TWO - 0.5 * LOG_PI)); - return WrtLog ? ans * exp(log_density) : ans; + + const int n_small = static_cast( + fmax(value_of_rec(n_small_density), value_of_rec(n_small_grad))); + const int n_large = static_cast( + fmax(value_of_rec(n_large_density), value_of_rec(n_large_grad))); + + ret_t series_grad_w = 0.0; + + if (2.0 * n_small_density <= n_large_density) { + // Small-time representation. + // + // R_s = sum_{k=-K}^{K} z_k exp(-z_k^2 / (2 t*)), + // z_k = 1 - w + 2k. + // + // dR_s/dw = sum_k (z_k^2 / t* - 1) + // exp(-z_k^2 / (2 t*)). + ret_t max_log = NEGATIVE_INFTY; + + for (int k = -n_small; k <= n_small; ++k) { + const double kd = static_cast(k); + const auto z = q + 2.0 * kd; + const auto log_e = -square(z) / (2.0 * y_asq); + max_log = fmax(max_log, log_e); + } + + ret_t raw = 0.0; + ret_t draw_dw = 0.0; + + for (int k = -n_small; k <= n_small; ++k) { + const double kd = static_cast(k); + const auto z = q + 2.0 * kd; + const auto e = exp(-square(z) / (2.0 * y_asq) - max_log); + + raw += z * e; + draw_dw += (square(z) / y_asq - 1.0) * e; + } + + series_grad_w = draw_dw / raw; } else { - auto ans = -( - density_part_one - + newsign - * exp(result - (log_density - log_error_term) + 2.0 * LOG_PI)); - return WrtLog ? ans * exp(log_density) : ans; + // Large-time representation in the same upper-bound coordinate used + // by the density code: q = 1 - w. + // + // R_l = sum_{k=1}^{K} + // k sin(k pi q) + // exp(-(k^2 - 1) pi^2 t* / 2). + // + // dR_l/dw = -sum_{k=1}^{K} + // k^2 pi cos(k pi q) + // exp(-(k^2 - 1) pi^2 t* / 2). + ret_t raw = 0.0; + ret_t draw_dw = 0.0; + + const auto half_pi2_y = 0.5 * square(pi()) * y_asq; + + for (int k = 1; k <= n_large; ++k) { + const double kd = static_cast(k); + const auto exp_term = exp(-(square(kd) - 1.0) * half_pi2_y); + const auto angle = kd * pi() * q; + + raw += kd * sin(angle) * exp_term; + draw_dw += -square(kd) * pi() * cos(angle) * exp_term; + } + + series_grad_w = draw_dw / raw; + } + + const auto ans = pref_grad_w + series_grad_w; + + if constexpr (WrtLog) { + return ans * wiener5_density(y, a, v, w, sv, log_err); + } else { + return ans; } } diff --git a/test/unit/math/mix/prob/wiener_lpdf_ad_test.cpp b/test/unit/math/mix/prob/wiener_lpdf_ad_test.cpp new file mode 100644 index 00000000000..d561e96df4c --- /dev/null +++ b/test/unit/math/mix/prob/wiener_lpdf_ad_test.cpp @@ -0,0 +1,77 @@ +#include +#include +#include + +TEST(MathMixProbWienerLpdf, fiveParamWGradientExpectAd) { + auto f = [](const auto& w) { + return stan::math::wiener_lpdf(6.0, 10.0, 0.01, w, -3.0, 0.2); + }; + + stan::test::ad_tolerances tols; + tols.gradient_grad_ = 1e-5; + stan::test::expect_ad(tols, f, 0.1); +} + +TEST(MathMixProbWienerLpdf, fiveParamZeroSvWGradientExpectAd) { + auto f = [](const auto& w) { + return stan::math::wiener_lpdf(6.0, 10.0, 0.01, w, -3.0, 0.0); + }; + + stan::test::ad_tolerances tols; + tols.gradient_grad_ = 1e-5; + stan::test::expect_ad(tols, f, 0.1); +} + +TEST(MathMixProbWienerLpdf, fullParamWGradientExpectAd) { + auto f = [](const auto& w) { + return stan::math::wiener_lpdf(6.0, 10.0, 0.01, w, -3.0, 0.2, 0.1, 0.0); + }; + + stan::test::ad_tolerances tols; + tols.gradient_grad_ = 1e-5; + stan::test::expect_ad(tols, f, 0.1); +} + +TEST(MathMixProbWienerLpdf, existingFullRowsWGradientExpectAd) { + struct Case { + const char* name; + double y; + double a; + double t0; + double w; + double v; + double sv; + double sw; + double st0; + }; + + const std::vector cases = { + {"row_0", 2.0, 2.0, 1e-9, 0.10, 2.0, 0.0, 0.00, 0.000}, + {"row_1", 3.0, 2.0, 0.01, 0.50, 2.0, 0.2, 0.00, 0.000}, + {"row_2", 4.0, 10.0, 0.01, 0.80, 4.0, 0.0, 0.10, 0.000}, + {"row_3", 5.0, 4.0, 0.01, 0.70, 3.0, 0.0, 0.00, 0.007}, + {"row_4", 6.0, 10.0, 0.01, 0.10, -3.0, 0.2, 0.10, 0.000}, + {"row_5", 7.0, 1.0, 0.01, 0.90, 1.0, 0.2, 0.00, 0.007}, + {"row_6", 8.0, 3.0, 0.01, 0.70, -1.0, 0.0, 0.10, 0.007}, + {"row_7", 8.85, 1.7, 0.01, 0.92, -7.3, 0.7, 0.01, 0.009}, + {"row_8", 8.9, 2.4, 0.01, 0.90, -4.9, 0.0, 0.00, 0.009}, + {"row_9", 9.0, 11.0, 0.01, 0.12, 4.5, 0.7, 0.10, 0.009}, + {"row_10", 1.0, 1.5, 0.10, 0.50, 3.0, 0.5, 0.20, 0.000}, + }; + + stan::test::ad_tolerances tols; + tols.gradient_grad_ = 1e-4; + + for (const auto& c : cases) { + SCOPED_TRACE(c.name); + + auto f = [c](const auto& w) { + return stan::math::wiener_lpdf(c.y, c.a, c.t0, w, c.v, c.sv, c.sw, c.st0); + }; + + // The row sweep is intended to check the reverse-mode w adjoint against + // finite differences. Some full-Wiener rows are not stable enough for + // higher-order mixed-mode finite-difference checks + stan::test::expect_ad(tols, f, c.w); + } +} diff --git a/test/unit/math/prim/prob/wiener_full_lpdf_test.cpp b/test/unit/math/prim/prob/wiener_full_lpdf_test.cpp index 8dedb448428..2a6a19a5484 100644 --- a/test/unit/math/prim/prob/wiener_full_lpdf_test.cpp +++ b/test/unit/math/prim/prob/wiener_full_lpdf_test.cpp @@ -320,7 +320,7 @@ TEST(mathPrimCorrectValues, wiener_lpdf) { 12.8617364931501, 1.12047317491985, 5.68799957241344}; std::vector true_grad_w = {5.67120184517318, -3.64396221090076, -38.7775057146792, - -14.1837930137393, -34.5869239580708, -10.4535345681946, + -14.1837930137393, 35.71918681520359, -10.4535345681946, 0.679597983582904, -9.93144540834201, 2.09117200953597, -6.0858540417876, -3.74870310978083}; std::vector true_grad_v = {