diff --git a/stan/math/mix/functor/laplace_likelihood.hpp b/stan/math/mix/functor/laplace_likelihood.hpp index 88c36163b8d..8bc9b0a4a95 100644 --- a/stan/math/mix/functor/laplace_likelihood.hpp +++ b/stan/math/mix/functor/laplace_likelihood.hpp @@ -5,6 +5,7 @@ #include #include #include +#include namespace stan { namespace math { @@ -16,6 +17,41 @@ namespace math { namespace laplace_likelihood { namespace internal { + +/** + * Type trait to detect if a likelihood functor `F` provides a custom + * `diff` method that computes the gradient and negative Hessian + * analytically, avoiding the cost of embedded reverse-mode autodiff. + * + * A functor with a custom `diff` method should provide: + * auto diff(theta, hessian_block_size, args...) const + * returning std::pair. + */ +template +struct has_custom_diff : std::false_type {}; + +template +struct has_custom_diff().diff( + std::declval(), 1))>> + : std::true_type {}; + +template +inline constexpr bool has_custom_diff_v = has_custom_diff::value; + +/** + * Type trait to detect if a likelihood functor `F` provides a custom + * `third_diff` method for the third derivative w.r.t. theta. + */ +template +struct has_custom_third_diff : std::false_type {}; + +template +struct has_custom_third_diff< + F, std::void_t().third_diff( + std::declval()))>> : std::true_type {}; + +template +inline constexpr bool has_custom_third_diff_v = has_custom_third_diff::value; /** * @tparam F A functor with `opertor()(Args&&...)` returning a scalar * @tparam Theta A class assignable to an Eigen vector type @@ -158,6 +194,8 @@ inline auto block_hessian(F&& f, Theta&& theta, * `theta` and `args...` * @note If `Args` contains \ref var types then their adjoints will be * calculated as a side effect. + * @note If `F` provides a custom `diff` method, it will be used instead + * of the generic autodiff path for better performance. * @tparam F A functor with `opertor()(Args&&...)` returning a scalar * @tparam Theta A class assignable to an Eigen vector type * @tparam Stream Type of stream for messages. @@ -174,33 +212,41 @@ template * = nullptr> inline auto diff(F&& f, Theta&& theta, const Eigen::Index hessian_block_size, Stream* msgs, Args&&... args) { - using Eigen::Dynamic; - using Eigen::Matrix; - const Eigen::Index theta_size = theta.size(); - auto theta_gradient = [&theta, &f, &msgs](auto&&... args) { - nested_rev_autodiff nested; - Matrix theta_var = theta; - var f_var = f(theta_var, args..., msgs); - grad(f_var.vi_); - return theta_var.adj().eval(); - }(args...); - if (hessian_block_size == 1) { - auto v = Eigen::VectorXd::Ones(theta_size); - Eigen::VectorXd hessian_v = Eigen::VectorXd::Zero(theta_size); - hessian_times_vector(f, hessian_v, std::forward(theta), std::move(v), - value_of(args)..., msgs); - Eigen::SparseMatrix hessian_theta(theta_size, theta_size); - hessian_theta.reserve(Eigen::VectorXi::Constant(theta_size, 1)); - for (Eigen::Index i = 0; i < theta_size; i++) { - hessian_theta.insert(i, i) = hessian_v(i); - } - return std::make_pair(std::move(theta_gradient), (-hessian_theta).eval()); + using F_t = std::decay_t; + if constexpr (has_custom_diff_v) { + // Use the functor's specialized analytic derivatives + return f.diff(std::forward(theta), hessian_block_size, + std::forward(args)...); } else { - return std::make_pair( - std::move(theta_gradient), - (-hessian_block_diag(f, std::forward(theta), hessian_block_size, - value_of(args)..., msgs)) - .eval()); + // Fall back to generic autodiff + using Eigen::Dynamic; + using Eigen::Matrix; + const Eigen::Index theta_size = theta.size(); + auto theta_gradient = [&theta, &f, &msgs](auto&&... args) { + nested_rev_autodiff nested; + Matrix theta_var = theta; + var f_var = f(theta_var, args..., msgs); + grad(f_var.vi_); + return theta_var.adj().eval(); + }(args...); + if (hessian_block_size == 1) { + auto v = Eigen::VectorXd::Ones(theta_size); + Eigen::VectorXd hessian_v = Eigen::VectorXd::Zero(theta_size); + hessian_times_vector(f, hessian_v, std::forward(theta), + std::move(v), value_of(args)..., msgs); + Eigen::SparseMatrix hessian_theta(theta_size, theta_size); + hessian_theta.reserve(Eigen::VectorXi::Constant(theta_size, 1)); + for (Eigen::Index i = 0; i < theta_size; i++) { + hessian_theta.insert(i, i) = hessian_v(i); + } + return std::make_pair(std::move(theta_gradient), (-hessian_theta).eval()); + } else { + return std::make_pair( + std::move(theta_gradient), + (-hessian_block_diag(f, std::forward(theta), + hessian_block_size, value_of(args)..., msgs)) + .eval()); + } } } @@ -208,6 +254,8 @@ inline auto diff(F&& f, Theta&& theta, const Eigen::Index hessian_block_size, * Compute third order derivative of `f` wrt `theta` and `args...` * @note If `Args` contains \ref var types then their adjoints will be * calculated as a side effect. + * @note If `F` provides a custom `third_diff` method, it will be used + * instead of the generic `fvar>` autodiff path. * @tparam F A functor with `opertor()(Args&&...)` returning a scalar * @tparam Theta A class assignable to an Eigen vector type * @tparam Stream Type of stream for messages. @@ -221,18 +269,26 @@ template * = nullptr> inline Eigen::VectorXd third_diff(F&& f, Theta&& theta, Stream&& msgs, Args&&... args) { - nested_rev_autodiff nested; - const Eigen::Index theta_size = theta.size(); - arena_t> theta_var - = std::forward(theta); - arena_t>, Eigen::Dynamic, 1>> theta_ffvar( - theta_size); - for (Eigen::Index i = 0; i < theta_size; ++i) { - theta_ffvar(i) = fvar>(fvar(theta_var(i), 1.0), 1.0); + using F_t = std::decay_t; + if constexpr (has_custom_third_diff_v) { + // Use the functor's specialized analytic third derivative + return f.third_diff(std::forward(theta), + std::forward(args)...); + } else { + // Fall back to generic fvar> autodiff + nested_rev_autodiff nested; + const Eigen::Index theta_size = theta.size(); + arena_t> theta_var + = std::forward(theta); + arena_t>, Eigen::Dynamic, 1>> theta_ffvar( + theta_size); + for (Eigen::Index i = 0; i < theta_size; ++i) { + theta_ffvar(i) = fvar>(fvar(theta_var(i), 1.0), 1.0); + } + fvar> ftheta_ffvar = f(theta_ffvar, args..., msgs); + grad(ftheta_ffvar.d_.d_.vi_); + return theta_var.adj().eval(); } - fvar> ftheta_ffvar = f(theta_ffvar, args..., msgs); - grad(ftheta_ffvar.d_.d_.vi_); - return theta_var.adj().eval(); } /** diff --git a/stan/math/mix/prob/laplace_marginal_neg_binomial_2_log_lpmf.hpp b/stan/math/mix/prob/laplace_marginal_neg_binomial_2_log_lpmf.hpp index 843f9d2e4dd..973c5b84b58 100644 --- a/stan/math/mix/prob/laplace_marginal_neg_binomial_2_log_lpmf.hpp +++ b/stan/math/mix/prob/laplace_marginal_neg_binomial_2_log_lpmf.hpp @@ -53,6 +53,115 @@ struct neg_binomial_2_log_likelihood { elt_multiply(multiply(n_per_group, eta), subtract(log_eta, lse)))); } + + /** + * Compute gradient and negative Hessian of the neg_binomial_2_log + * likelihood analytically, avoiding nested autodiff. + * + * @param theta Latent Gaussian variable (double). + * @param hessian_block_size Size of each diagonal block (typically 1). + * @param eta Dispersion parameter (scalar or 1-element vector). + * @param y Observed counts. + * @param y_index Group index for each observation. + * @param mean Mean offset for theta. + * @return pair of (gradient, negative Hessian) as (VectorXd, SparseMatrix). + */ + template + inline auto diff(const Eigen::VectorXd& theta, int hessian_block_size, + const Eigen::VectorXd& eta, const std::vector& y, + const std::vector& y_index, Mean&& mean) const { + const int theta_size = theta.size(); + const double eta_scalar = eta(0); + + Eigen::VectorXd sums = Eigen::VectorXd::Zero(theta_size); + Eigen::VectorXd n_samples = Eigen::VectorXd::Zero(theta_size); + for (size_t i = 0; i < y.size(); i++) { + n_samples(y_index[i] - 1) += 1; + sums(y_index[i] - 1) += y[i]; + } + + // theta + mean + Eigen::VectorXd theta_offset = add(theta, value_of(mean)); + + // exp(-theta_offset) + Eigen::VectorXd exp_neg_theta = exp(-theta_offset); + // sums + eta * n_samples + Eigen::VectorXd sums_plus_n_eta = sums + eta_scalar * n_samples; + // 1 + eta * exp(-theta_offset) + Eigen::VectorXd one_plus_exp + = Eigen::VectorXd::Ones(theta_size) + eta_scalar * exp_neg_theta; + + // Gradient: sums - (sums + eta * n) / (1 + eta * exp(-theta)) + Eigen::VectorXd gradient + = sums - sums_plus_n_eta.cwiseQuotient(one_plus_exp); + + // Negative Hessian diagonal: + // eta * (sums + eta * n) * exp(-theta) / (1 + eta * exp(-theta))^2 + Eigen::VectorXd hessian_diag + = eta_scalar + * sums_plus_n_eta.cwiseProduct(exp_neg_theta.cwiseQuotient( + one_plus_exp.cwiseProduct(one_plus_exp))); + + Eigen::SparseMatrix hessian(theta_size, theta_size); + hessian.reserve(Eigen::VectorXi::Constant(theta_size, hessian_block_size)); + for (int i = 0; i < theta_size; i++) { + hessian.insert(i, i) = hessian_diag(i); + } + + return std::make_pair(std::move(gradient), std::move(hessian)); + } + + /** + * Compute the third derivative of the neg_binomial_2_log likelihood + * w.r.t. theta analytically, avoiding fvar> autodiff. + * + * The third derivative is: + * d^3/dtheta^3 log p(y|theta,eta) = + * -(sums + eta*n) * eta * exp(theta) * (eta - exp(theta)) + * / (eta + exp(theta))^3 + * + * @param theta Latent Gaussian variable (double). + * @param eta Dispersion parameter (scalar or 1-element vector). + * @param y Observed counts. + * @param y_index Group index for each observation. + * @param mean Mean offset for theta. + * @return Third derivative as a VectorXd. + */ + template + inline Eigen::VectorXd third_diff(const Eigen::VectorXd& theta, + const Eigen::VectorXd& eta, + const std::vector& y, + const std::vector& y_index, + Mean&& mean) const { + const int theta_size = theta.size(); + const double eta_scalar = eta(0); + + Eigen::VectorXd sums = Eigen::VectorXd::Zero(theta_size); + Eigen::VectorXd n_samples = Eigen::VectorXd::Zero(theta_size); + for (size_t i = 0; i < y.size(); i++) { + n_samples(y_index[i] - 1) += 1; + sums(y_index[i] - 1) += y[i]; + } + + // theta + mean + Eigen::VectorXd theta_offset = add(theta, value_of(mean)); + + Eigen::VectorXd exp_theta = exp(theta_offset); + Eigen::VectorXd eta_vec = Eigen::VectorXd::Constant(theta_size, eta_scalar); + Eigen::VectorXd eta_plus_exp_theta = eta_vec + exp_theta; + + // -(sums + eta*n) * eta * exp(theta) * (eta - exp(theta)) + // / (eta + exp(theta))^3 + Eigen::VectorXd eta_plus_exp_theta_sq + = eta_plus_exp_theta.cwiseProduct(eta_plus_exp_theta); + Eigen::VectorXd eta_plus_exp_theta_cubed + = eta_plus_exp_theta_sq.cwiseProduct(eta_plus_exp_theta); + + return -((sums + eta_scalar * n_samples) * eta_scalar) + .cwiseProduct(exp_theta.cwiseProduct( + (eta_vec - exp_theta) + .cwiseQuotient(eta_plus_exp_theta_cubed))); + } }; /**