From 9d86d6cea611c2ea066c59674ddf3713e6bf1b16 Mon Sep 17 00:00:00 2001
From: Vincent Danjean <Vincent.Danjean@ens-lyon.org>
Date: Fri, 2 Nov 2012 15:19:09 +0100
Subject: [PATCH] Use the hypergeometric distribution

This will give us a factor about 2.5

The r-code is improved :
- use exact values for afc until 1754 (max that can be
  computed with "long double")
- use per-thread static variables to allow parallel run
  in different threads
---
 CUtils/CUtils.xs          |  14 +++
 CUtils/c_sources/rhyper.c | 243 +++++++++++++++++++++++++++++++-------
 CUtils/c_sources/rhyper.h |   7 ++
 CUtils/c_sources/stats.c  |  19 ++-
 4 files changed, 229 insertions(+), 54 deletions(-)
 create mode 100644 CUtils/c_sources/rhyper.h

diff --git a/CUtils/CUtils.xs b/CUtils/CUtils.xs
index 5c4f2da..3c0e571 100644
--- a/CUtils/CUtils.xs
+++ b/CUtils/CUtils.xs
@@ -56,6 +56,20 @@ critchi(p, df)
 	double	p
 	int	df
 
+############################################################
+# rhyper.h
+############################################################
+
+int
+RHyper(n1, n2, k)
+        int n1
+	int n2
+	int k
+    CODE:
+        RETVAL=rhyper(n1, n2, k);
+    OUTPUT:
+        RETVAL
+
 ############################################################
 # double_permutation.h
 ############################################################
diff --git a/CUtils/c_sources/rhyper.c b/CUtils/c_sources/rhyper.c
index b758f06..c755373 100644
--- a/CUtils/c_sources/rhyper.c
+++ b/CUtils/c_sources/rhyper.c
@@ -1,4 +1,3 @@
-#if 0
 /*
  *  Mathlib : A C Library of Special Functions
  *  Copyright (C) 1998 Ross Ihaka
@@ -42,7 +41,12 @@
  *    where (m < 100 || ix <= 50) , see below.
  */
 
-#include "nmath.h"
+#include "mt19937ar.h"
+#include "debug.h"
+#include "stdio.h"
+#include "stdlib.h"
+#include <math.h>
+#include "rhyper.h"
 
 /* afc(i) :=  ln( i! )	[logarithm of the factorial i.
  *	   If (i > 7), use Stirling's approximation, otherwise use table lookup.
@@ -50,27 +54,53 @@
 
 static double afc(int i)
 {
-    const static double al[9] =
+    static int computed=10;
+    static double al[1756] =
     {
 	0.0,
-	0.0,/*ln(0!)=ln(1)*/
-	0.0,/*ln(1!)=ln(1)*/
-	0.69314718055994530941723212145817,/*ln(2) */
-	1.79175946922805500081247735838070,/*ln(6) */
-	3.17805383034794561964694160129705,/*ln(24)*/
-	4.78749174278204599424770093452324,
-	6.57925121201010099506017829290394,
-	8.52516136106541430016553103634712
-	/*, 10.60460290274525022841722740072165*/
+	0,/*ln(0!)*/
+	0,/*ln(1!)*/
+	0.693147180559945309,/*ln(2!)*/
+	1.791759469228055,/*ln(3!)*/
+	3.17805383034794562,/*ln(4!)*/
+	4.78749174278204599,/*ln(5!)*/
+	6.579251212010101,/*ln(6!)*/
+	8.5251613610654143,/*ln(7!)*/
+	10.6046029027452502,/*ln(8!)*/
+	12.8018274800814696,/*ln(9!)*/
+	15.1044125730755153,/*ln(10!)*/
     };
+    double compute(int n) {
+	static long double cur=3628800;
+	static int i=11;
+	static volatile int mutex=0;
+
+	while (__sync_lock_test_and_set(&mutex, 1)) {
+		/* Internal loop with only read to avoid cache line ping-pong
+		   on multi-processors */
+		while(mutex) {
+			/* spinlock */
+		}
+	}
+
+	for(; i<=n; i++) {
+		cur*=i;
+		al[i+1]=logl(cur);
+	}
+	computed=n;
+	__sync_lock_release(&mutex);
+	return al[i];
+    };
+
     double di, value;
 
     if (i < 0) {
-      MATHLIB_WARNING(("rhyper.c: afc(i), i=%d < 0 -- SHOULD NOT HAPPEN!\n"),
-		      i);
-      return -1;/* unreached (Wall) */
-    } else if (i <= 7) {
+      fprintf(stderr, "rhyper.c: afc(i), i=%d < 0 -- SHOULD NOT HAPPEN!\n", i);
+      exit(1);
+    } else if (i <= computed) {
 	value = al[i + 1];
+    } else if (i <= 1754) {
+	value = compute(i);
     } else {
 	di = i;
 	value = (di + 0.5) * log(di) - di + 0.08333333333333 / di
@@ -79,7 +109,21 @@ static double afc(int i)
     return value;
 }
 
-double rhyper(double nn1in, double nn2in, double kkin)
+#define imin2(a,b) ({ \
+	typeof(a) _a=(a); \
+	typeof(b) _b=(b); \
+	(_a < _b) ? _a : _b ;\
+})
+
+#define imax2(a,b) ({ \
+	typeof(a) _a=(a); \
+	typeof(b) _b=(b); \
+	(_a > _b) ? _a : _b ;\
+})
+
+#define unif_rand() genrand_real2()
+
+int rhyper(int nn1, int nn2, int kk)
 {
     const static double con = 57.56462733;
     const static double deltal = 0.0078;
@@ -88,45 +132,37 @@ double rhyper(double nn1in, double nn2in, double kkin)
 
     /* extern double afc(int); */
 
-    int nn1, nn2, kk;
     int i, ix;
-    Rboolean reject, setup1, setup2;
+    int reject, setup1, setup2;
 
     double e, f, g, p, r, t, u, v, y;
     double de, dg, dr, ds, dt, gl, gu, nk, nm, ub;
     double xk, xm, xn, y1, ym, yn, yk, alv;
 
     /* These should become `thread_local globals' : */
-    static int ks = -1;
-    static int n1s = -1, n2s = -1;
+    static __thread int ks = -1;
+    static __thread int n1s = -1, n2s = -1;
 
-    static int k, m;
-    static int minjx, maxjx, n1, n2;
+    static __thread int k, m;
+    static __thread int minjx, maxjx, n1, n2;
 
-    static double a, d, s, w;
-    static double tn, xl, xr, kl, kr, lamdl, lamdr, p1, p2, p3;
+    static __thread double a, d, s, w;
+    static __thread double tn, xl, xr, kl, kr, lamdl, lamdr, p1, p2, p3;
 
 
     /* check parameter validity */
 
-    if(!R_FINITE(nn1in) || !R_FINITE(nn2in) || !R_FINITE(kkin))
-	ML_ERR_return_NAN;
-
-    nn1 = (int) floor(nn1in+0.5);
-    nn2 = (int) floor(nn2in+0.5);
-    kk	= (int) floor(kkin +0.5);
-
     if (nn1 < 0 || nn2 < 0 || kk < 0 || kk > nn1 + nn2)
-	ML_ERR_return_NAN;
+	return -1;
 
     /* if new parameter values, initialize */
-    reject = TRUE;
+    reject = 1;
     if (nn1 != n1s || nn2 != n2s) {
-	setup1 = TRUE;	setup2 = TRUE;
+	setup1 = 1;	setup2 = 1;
     } else if (kk != ks) {
-	setup1 = FALSE;	setup2 = TRUE;
+	setup1 = 0;	setup2 = 1;
     } else {
-	setup1 = FALSE;	setup2 = FALSE;
+	setup1 = 0;	setup2 = 0;
     }
     if (setup1) {
 	n1s = nn1;
@@ -171,6 +207,11 @@ double rhyper(double nn1in, double nn2in, double kkin)
 	  if (nn1 > nn2)
 	    ix = kk - ix;
 	}
+	//debug("RHYPER: (%i, %i, %i)=%i", nn1, nn2, kk, ix);
+	assert(ix <= nn1);
+	assert(kk-ix <= nn2);
+	assert(ix <= kk);
+	assert(0 <= ix);
 	return ix;
 
     } else if (m - minjx < 10) { /* II: inverse transformation ---------- */
@@ -257,7 +298,7 @@ double rhyper(double nn1in, double nn2in, double kkin)
 		    f = f * i * (n2 - k + i) / (n1 - i + 1) / (k - i + 1);
 	    }
 	    if (v <= f) {
-		reject = FALSE;
+		reject = 0;
 	    }
 	} else {
 	    /* squeeze using upper and lower bounds */
@@ -289,7 +330,7 @@ double rhyper(double nn1in, double nn2in, double kkin)
 	    /* test against upper bound */
 	    alv = log(v);
 	    if (alv > ub) {
-		reject = TRUE;
+		reject = 1;
 	    } else {
 				/* test against lower bound */
 		dr = xm * (r * r * r * r);
@@ -306,16 +347,16 @@ double rhyper(double nn1in, double nn2in, double kkin)
 		    de /= (1.0 + e);
 		if (alv < ub - 0.25 * (dr + ds + dt + de)
 		    + (y + m) * (gl - gu) - deltal) {
-		    reject = FALSE;
+		    reject = 0;
 		}
 		else {
 		    /* * Stirling's formula to machine accuracy
 		     */
 		    if (alv <= (a - afc(ix) - afc(n1 - ix)
 				- afc(k - ix) - afc(n2 - k + ix))) {
-			reject = FALSE;
+			reject = 0;
 		    } else {
-			reject = TRUE;
+			reject = 1;
 		    }
 		}
 	    }
@@ -336,6 +377,124 @@ double rhyper(double nn1in, double nn2in, double kkin)
 	if (nn1 > nn2)
 	    ix = kk - ix;
     }
+    //debug("RHYPER: (%i, %i, %i)=%i", nn1, nn2, kk, ix);
+    assert(ix <= nn1);
+    assert(kk-ix <= nn2);
+    assert(ix <= kk);
+    assert(0 <= ix);
     return ix;
 }
+
+#if TEST_AFC
+static double origafc(int i)
+{
+    const static double al[9] =
+    {
+	0.0,
+	0.0,/*ln(0!)=ln(1)*/
+	0.0,/*ln(1!)=ln(1)*/
+	0.69314718055994530941723212145817,/*ln(2) */
+	1.79175946922805500081247735838070,/*ln(6) */
+	3.17805383034794561964694160129705,/*ln(24)*/
+	4.78749174278204599424770093452324,
+	6.57925121201010099506017829290394,
+	8.52516136106541430016553103634712
+	/*, 10.60460290274525022841722740072165*/
+    };
+    double di, value;
+
+    if (i < 0) {
+      fprintf(stderr, "rhyper.c: afc(i), i=%d < 0 -- SHOULD NOT HAPPEN!\n", i);
+      exit(1);
+    } else if (i <= 7) {
+	value = al[i + 1];
+    } else {
+	di = i;
+	value = (di + 0.5) * log(di) - di + 0.08333333333333 / di
+	    - 0.00277777777777 / di / di / di + 0.9189385332;
+    }
+    return value;
+}
+
+static double afc2(int n)
+{
+	static const double logpi=__builtin_log(M_PI);
+	
+	return n*log(n)-n+log(n*(1+4*n*(1+2*n)))/6+logpi/2;
+}
+
+static double afc3(int n)
+{
+	static const double logpi=__builtin_log(M_PI);
+	
+	return n*log(n)-n+log(1+1/(2*n)+1/(8*n*n))/6+log(2*n)/2+logpi/2;
+}
+
+static double afc4(int n)
+{
+	static const double logpi=__builtin_log(M_PI);
+	static const double log2=__builtin_log(2);
+	double logn=log(n);
+	
+	return n*logn-n+log(1+1/(2*n)+1/(8*n*n))/6+(logn+(logpi+log2))/2;
+}
+
+static double afc5(int n)
+{
+	static long long int cur=1;
+	static int i=1;
+
+	for(; i<=n; i++) {
+		cur*=i;
+	}
+	//printf(" %lli %i %i ", cur, i, n);
+	return log(cur);
+}
+
+static double afc6(int n)
+{
+	static long double cur=1;
+	static int i=1;
+
+	for(; i<=n; i++) {
+		cur*=i;
+	}
+	//printf(" %lli %i %i ", cur, i, n);
+	return logl(cur);
+}
+
+static double afc7(int n)
+{
+	static long double cur=1;
+	static int i=1;
+
+	for(; i<=n; i++) {
+		cur*=i;
+	}
+	//printf(" %lli %i %i ", cur, i, n);
+	printf("\t%.18Lg, /* ln(%i!) = ln(%.0Lf) */\n",logl(cur),n,cur);
+	return logl(cur);
+}
+
+static void compare(int k) {
+	int i;
+	printf("           %20s / %20s / %23s / %23s / %23s / %23sg / \n",
+		"ref=exact(long double)", "my", "orig-ref", "orig-my", "my-ref", "exact(double)-ref");
+	for (i=1; i<=k; i++) {
+		double ref=afc6(i);
+		double ref2=afc(i);
+		printf("log %4i! = %20.17lg / %20.17lg / %13.7lg / %13.7lg / %13.7lg / %13.7lg / %13.7lg / %13.7lg / %13.7lg\n",
+			i, ref, ref2, origafc(i)-ref, origafc(i)-ref2, afc(i)-ref, afc5(i)-ref, afc2(i)-ref, afc3(i)-ref, afc4(i)-ref);
+	}
+	for (i=1; i<=50; i++) {
+		afc7(i);
+	}
+	return;
+}
+
+int main(int argc, char**argv) {
+
+	compare(1755);
+	return 0;
+}
 #endif
diff --git a/CUtils/c_sources/rhyper.h b/CUtils/c_sources/rhyper.h
new file mode 100644
index 0000000..7481370
--- /dev/null
+++ b/CUtils/c_sources/rhyper.h
@@ -0,0 +1,7 @@
+#ifndef _RHYPER_H
+#define _RHYPER_H
+
+int rhyper(int nn1in, int nn2in, int kkin);
+
+#endif
+
diff --git a/CUtils/c_sources/stats.c b/CUtils/c_sources/stats.c
index 8b3ec14..910781e 100644
--- a/CUtils/c_sources/stats.c
+++ b/CUtils/c_sources/stats.c
@@ -3,7 +3,7 @@
 #include "stats.h"
 #include "fisher.h"
 #include "chisq.h"
-#include "myrand.h"
+#include "rhyper.h"
 #include <stdio.h>
 #include <stdlib.h>
 #include <gsl/gsl_cdf.h>
@@ -240,17 +240,12 @@ void random_clades(int nb_nodes, const struct cc *nodes,
 	bzero(clades, nb_nodes*sizeof(struct cc));
 	int c;
 	for(c=0; c<nb_nodes; c++) {
-		int i;
 		int s=nodes[c].cases+nodes[c].controls;
-		for(i=0; i<s; i++) {
-			int alea=myrand(cases+controls);
-			if (alea < cases) {
-				cases--;
-				clades[c].cases++;
-			} else {
-				controls--;
-				clades[c].controls++;
-				}
-		}
+		int ncases=rhyper(cases, controls, s);
+		//debug("clades: (%i, %i, %i)=%i", cases, controls, s, ncases);
+		clades[c].cases = ncases;
+		clades[c].controls = s-ncases;
+		cases -= ncases;
+		controls -= (s-ncases);
 	}
 }
-- 
GitLab