22#define STAN_MATH_PRIM_FUN_BINOMIAL_COEFFICIENT_LOG_HPP
33
44#include < stan/math/prim/meta.hpp>
5- #include < stan/math/prim/fun/inv.hpp>
5+ #include < stan/math/prim/err.hpp>
6+ #include < stan/math/prim/fun/constants.hpp>
7+ #include < stan/math/prim/fun/digamma.hpp>
8+ #include < stan/math/prim/fun/is_any_nan.hpp>
9+ #include < stan/math/prim/fun/log1p.hpp>
10+ #include < stan/math/prim/fun/lbeta.hpp>
611#include < stan/math/prim/fun/lgamma.hpp>
7- #include < stan/math/prim/fun/multiply_log .hpp>
12+ #include < stan/math/prim/fun/value_of .hpp>
813
914namespace stan {
1015namespace math {
@@ -13,22 +18,24 @@ namespace math {
1318 * Return the log of the binomial coefficient for the specified
1419 * arguments.
1520 *
16- * The binomial coefficient, \f${N \choose n }\f$, read "N choose n ", is
17- * defined for \f$0 \leq n \leq N \f$ by
21+ * The binomial coefficient, \f${n \choose k }\f$, read "n choose k ", is
22+ * defined for \f$0 \leq k \leq n \f$ by
1823 *
19- * \f${N \choose n } = \frac{N !}{n ! (N-n )!}\f$.
24+ * \f${n \choose k } = \frac{n !}{k ! (n-k )!}\f$.
2025 *
2126 * This function uses Gamma functions to define the log
22- * and generalize the arguments to continuous N and n.
27+ * and generalize the arguments to continuous n and k.
28+ *
29+ * \f$ \log {n \choose k}
30+ * = \log \ \Gamma(n+1) - \log \Gamma(k+1) - \log \Gamma(n-k+1)\f$.
2331 *
24- * \f$ \log {N \choose n}
25- * = \log \ \Gamma(N+1) - \log \Gamma(n+1) - \log \Gamma(N-n+1)\f$.
2632 *
2733 \f[
2834 \mbox{binomial\_coefficient\_log}(x, y) =
2935 \begin{cases}
30- \textrm{error} & \mbox{if } y > x \textrm{ or } y < 0\\
31- \ln\Gamma(x+1) & \mbox{if } 0\leq y \leq x \\
36+ \textrm{error} & \mbox{if } y > x + 1 \textrm{ or } y < -1 \textrm{ or } x
37+ < -1\\
38+ \ln\Gamma(x+1) & \mbox{if } -1 < y < x + 1 \\
3239 \quad -\ln\Gamma(y+1)& \\
3340 \quad -\ln\Gamma(x-y+1)& \\[6pt]
3441 \textrm{NaN} & \mbox{if } x = \textrm{NaN or } y = \textrm{NaN}
@@ -38,7 +45,8 @@ namespace math {
3845 \f[
3946 \frac{\partial\, \mbox{binomial\_coefficient\_log}(x, y)}{\partial x} =
4047 \begin{cases}
41- \textrm{error} & \mbox{if } y > x \textrm{ or } y < 0\\
48+ \textrm{error} & \mbox{if } y > x + 1 \textrm{ or } y < -1 \textrm{ or } x
49+ < -1\\
4250 \Psi(x+1) & \mbox{if } 0\leq y \leq x \\
4351 \quad -\Psi(x-y+1)& \\[6pt]
4452 \textrm{NaN} & \mbox{if } x = \textrm{NaN or } y = \textrm{NaN}
@@ -48,32 +56,95 @@ namespace math {
4856 \f[
4957 \frac{\partial\, \mbox{binomial\_coefficient\_log}(x, y)}{\partial y} =
5058 \begin{cases}
51- \textrm{error} & \mbox{if } y > x \textrm{ or } y < 0\\
59+ \textrm{error} & \mbox{if } y > x + 1 \textrm{ or } y < -1 \textrm{ or } x
60+ < -1\\
5261 -\Psi(y+1) & \mbox{if } 0\leq y \leq x \\
5362 \quad +\Psi(x-y+1)& \\[6pt]
5463 \textrm{NaN} & \mbox{if } x = \textrm{NaN or } y = \textrm{NaN}
5564 \end{cases}
5665 \f]
5766 *
58- * @tparam T_N type of the first argument
59- * @tparam T_n type of the second argument
60- * @param N total number of objects.
61- * @param n number of objects chosen.
62- * @return log (N choose n).
67+ * This function is numerically more stable than naive evaluation via lgamma.
68+ *
69+ * @tparam T_n type of the first argument
70+ * @tparam T_k type of the second argument
71+ *
72+ * @param n total number of objects.
73+ * @param k number of objects chosen.
74+ * @return log (n choose k).
6375 */
64- template <typename T_N, typename T_n>
65- inline return_type_t <T_N, T_n> binomial_coefficient_log (const T_N N,
66- const T_n n) {
67- const double CUTOFF = 1000 ;
68- if (N - n < CUTOFF) {
69- const T_N N_plus_1 = N + 1 ;
70- return lgamma (N_plus_1) - lgamma (n + 1 ) - lgamma (N_plus_1 - n);
76+
77+ template <typename T_n, typename T_k>
78+ inline return_type_t <T_n, T_k> binomial_coefficient_log (const T_n n,
79+ const T_k k) {
80+ using T_partials_return = partials_return_t <T_n, T_k>;
81+
82+ if (is_any_nan (n, k)) {
83+ return NOT_A_NUMBER;
84+ }
85+
86+ // Choosing the more stable of the symmetric branches
87+ if (n > -1 && k > value_of_rec (n) / 2.0 + 1e-8 ) {
88+ return binomial_coefficient_log (n, n - k);
89+ }
90+
91+ const T_partials_return n_dbl = value_of (n);
92+ const T_partials_return k_dbl = value_of (k);
93+ const T_partials_return n_plus_1 = n_dbl + 1 ;
94+ const T_partials_return n_plus_1_mk = n_plus_1 - k_dbl;
95+
96+ static const char * function = " binomial_coefficient_log" ;
97+ check_greater_or_equal (function, " first argument" , n, -1 );
98+ check_greater_or_equal (function, " second argument" , k, -1 );
99+ check_greater_or_equal (function, " (first argument - second argument + 1)" ,
100+ n_plus_1_mk, 0.0 );
101+
102+ operands_and_partials<T_n, T_k> ops_partials (n, k);
103+
104+ T_partials_return value;
105+ if (k_dbl == 0 ) {
106+ value = 0 ;
107+ } else if (n_plus_1 < lgamma_stirling_diff_useful) {
108+ value = lgamma (n_plus_1) - lgamma (k_dbl + 1 ) - lgamma (n_plus_1_mk);
71109 } else {
72- return_type_t <T_N, T_n> N_minus_n = N - n;
73- const double one_twelfth = inv (12 );
74- return multiply_log (n, N_minus_n) + multiply_log ((N + 0.5 ), N / N_minus_n)
75- + one_twelfth / N - n - one_twelfth / N_minus_n - lgamma (n + 1 );
110+ value = -lbeta (n_plus_1_mk, k_dbl + 1 ) - log1p (n_dbl);
76111 }
112+
113+ if (!is_constant_all<T_n, T_k>::value) {
114+ // Branching on all the edge cases.
115+ // In direct computation many of those would be NaN
116+ // But one-sided limits from within the domain exist, all of the below
117+ // follows from lim x->0 from above digamma(x) == -Inf
118+ //
119+ // Note that we have k < n / 2 (see the first branch in this function)
120+ // se we can ignore the n == k - 1 edge case.
121+ T_partials_return digamma_n_plus_1_mk = digamma (n_plus_1_mk);
122+
123+ if (!is_constant_all<T_n>::value) {
124+ if (n_dbl == -1.0 ) {
125+ if (k_dbl == 0 ) {
126+ ops_partials.edge1_ .partials_ [0 ] = 0 ;
127+ } else {
128+ ops_partials.edge1_ .partials_ [0 ] = NEGATIVE_INFTY;
129+ }
130+ } else {
131+ ops_partials.edge1_ .partials_ [0 ]
132+ = (digamma (n_plus_1) - digamma_n_plus_1_mk);
133+ }
134+ }
135+ if (!is_constant_all<T_k>::value) {
136+ if (k_dbl == 0 && n_dbl == -1.0 ) {
137+ ops_partials.edge2_ .partials_ [0 ] = NEGATIVE_INFTY;
138+ } else if (k_dbl == -1 ) {
139+ ops_partials.edge2_ .partials_ [0 ] = INFTY;
140+ } else {
141+ ops_partials.edge2_ .partials_ [0 ]
142+ = (digamma_n_plus_1_mk - digamma (k_dbl + 1 ));
143+ }
144+ }
145+ }
146+
147+ return ops_partials.build (value);
77148}
78149
79150} // namespace math
0 commit comments