Commit 063bb72e authored by Raphael Rieu-Helft's avatar Raphael Rieu-Helft

powm: extraction

parent 57c34b95
......@@ -201,7 +201,7 @@ struct __lsld32_result lsld32(uint32_t x, uint32_t cnt);
syntax val add_mod "%1 + %2" prec 4 4 3
syntax val sub_mod "%1 - %2" prec 4 4 3
syntax val minus_mod "-%2" prec 2 1
syntax val minus_mod "-%1" prec 2 1
syntax val mul_mod "%1 * %2" prec 3 3 2
syntax val div2by1
......@@ -422,7 +422,7 @@ static inline struct __lsld64_result lsld64(uint64_t x, uint64_t cnt)
syntax val add_mod "%1 + %2" prec 4 4 3
syntax val sub_mod "%1 - %2" prec 4 4 3
syntax val minus_mod "-%2" prec 2 1
syntax val minus_mod "-%1" prec 2 1
syntax val mul_mod "%1 * %2" prec 3 3 2
syntax val lsl "%1 << %2" prec 5 5 2
......
......@@ -26,7 +26,7 @@ clean:
why3:
make -C ../..
MLWFILES= $(addsuffix .mlw, sqrtrem sqrt toom logical div mul sub add compare util)
MLWFILES= $(addsuffix .mlw, powm sqrtrem sqrt toom logical div mul sub add compare util)
cfiles: why3
mkdir -p build
......@@ -38,9 +38,13 @@ build/sqrtinit.h: sqrtinit.ml
mkdir -p build
ocaml sqrtinit.ml > build/sqrtinit.h
extract: cfiles build/sqrtinit.h
build/binverttab.h: binverttab.ml
mkdir -p build
ocaml binverttab.ml > build/binverttab.h
extract: cfiles build/sqrtinit.h build/binverttab.h
CFILES = build/uint64gmp.c build/fxp.c build/sqrt.c build/sqrt1.c build/toom.c build/div.c build/logical.c build/mul.c build/sub.c build/add.c build/compare.c build/util.c build/int32.c
CFILES = build/uint64gmp.c build/powm.c build/fxp.c build/sqrt.c build/sqrt1.c build/toom.c build/div.c build/logical.c build/mul.c build/sub.c build/add.c build/compare.c build/util.c build/int32.c
tests: extract
gcc $(CFLAGS) tests.c $(CFILES) -Irandom -lgmp -o build/tests
......@@ -60,7 +64,7 @@ build/gmp%bench: extract
build/minigmp%bench: extract
gcc $(CFLAGS) -DTEST_MINIGMP -DTEST_`echo $* | tr [:lower:] [:upper:]` tests.c $(CFILES) -Iinclude -Imini-gmp -Irandom -o $@
alltests: tests build/why3addbench build/why3mulbench build/why3toombench build/why3divbench build/gmpaddbench build/gmpmulbench build/gmptoombench build/gmpdivbench build/minigmpaddbench build/minigmpmulbench build/minigmpdivbench build/minigmptoombench
alltests: tests build/why3addbench build/why3mulbench build/why3toombench build/why3divbench build/gmpaddbench build/gmpmulbench build/gmptoombench build/gmpdivbench build/minigmpaddbench build/minigmpmulbench build/minigmpdivbench build/minigmptoombench build/gmppowmbench build/why3powmbench
data: alltests
mkdir -p bench
......@@ -68,10 +72,12 @@ data: alltests
./build/why3mulbench > bench/why3mul
./build/why3toombench > bench/why3toom
./build/why3divbench > bench/why3div
./build/why3powmbench > bench/why3powm
./build/gmpaddbench > bench/gmpadd
./build/gmpmulbench > bench/gmpmul
./build/gmptoombench > bench/gmptoom
./build/gmpdivbench > bench/gmpdiv
./build/gmppowmbench > bench/gmppowm
./build/minigmpaddbench > bench/minigmpadd
./build/minigmpmulbench > bench/minigmpmul
./build/minigmptoombench > bench/minigmptoom
......
open Printf
let inverse x =
let rec loop t t' r r' =
if r' = 0 then
if t < 0 then t + 256 else t
else
let q = r / r' in
loop t' (t - q * t') r' (r - q * r')
in
loop 0 1 256 x
let () =
printf "/* binverttab[i] is the multiplicative inverse of 2*i+1 mod 256,\
\n ie. (binverttab[i] * (2*i+1)) %% 256 == 1 */\n";
printf "const unsigned char binverttab[128] = {\n";
for i1 = 0 to 15 do
printf " ";
for i2 = 0 to 7 do
let i = i1 * 8 + i2 in
let inv = inverse (2*i+1) in
assert (((2*i+1) * inv) mod 256 = 1);
printf "0x%02x," (inverse (2*i+1));
done;
printf "\n";
done;
printf "};\n%!"
......@@ -50,6 +50,9 @@ void wmpn_tdiv_qr_in_place (wmp_ptr, wmp_srcptr, wmp_size_t, wmp_srcptr, wmp_siz
wmp_size_t wmpn_sqrtrem (wmp_ptr, wmp_ptr, wmp_srcptr, wmp_size_t);
void wmpn_powm (wmp_ptr, wmp_srcptr, wmp_size_t, wmp_srcptr, wmp_size_t,
wmp_srcptr, wmp_size_t, wmp_ptr);
#ifdef __cplusplus
}
#endif
......
......@@ -10,5 +10,7 @@ all:
# ./gmpdivplot &
# ./gmpmulplot &
# ./gmpaddplot &
./toomrelative
./minitoomrelative
./toomrelative &
./minitoomrelative &
./powmrelative
......@@ -459,7 +459,7 @@ module Powm
inv
(* TODO rewrite this with array literal once they exist *)
let win_size [@extraction:inline] (eb:int32) : int32
let win_size [@extraction:c_inline] (eb:int32) : int32
ensures { 0 <= result <= 10 }
ensures { eb > 0 -> result > 0 }
= if eb = 0 then 0
......@@ -644,8 +644,8 @@ module Powm
= mod (2 * d + res) 2 = mod res 2 = res };
lps % 2
let getbits [@extraction:inline] (p: t) (ghost pn: int32) (bi:int32)
(nbits:int32) : limb
let getbits [@extraction:c_inline] (p: t) (ghost pn: int32) (bi:int32)
(nbits:int32) : limb
requires { 1 <= nbits < 64 }
requires { 0 <= bi }
requires { 1 <= pn }
......
......@@ -13,12 +13,15 @@
#define TEST_DIV
#define TEST_SQRT1
#define TEST_SQRTREM
#define TEST_POWM
#endif
#ifdef TEST_MINIGMP
#include "mini-gmp.c"
#else
#include <gmp.h>
extern void __gmpn_powm (mp_ptr, mp_srcptr, mp_size_t, mp_srcptr, mp_size_t,
mp_srcptr, mp_size_t, mp_ptr);
#endif
#ifdef TEST_LIB
......@@ -37,6 +40,7 @@ extern wmp_limb_t sqrt1(wmp_ptr, wmp_limb_t);
#include "build/toom.h"
#include "build/sqrt1.h"
#include "build/sqrt.h"
#include "build/powm.h"
#endif
#include "mt19937-64.c"
......@@ -64,9 +68,18 @@ void init_valid (mp_ptr ap, mp_ptr bp, mp_size_t an, mp_size_t bn) {
return;
}
void init_valid_1(mp_ptr ap, mp_size_t an) {
for (int i = 0; i < an; i++)
ap[i] = genrand64_int64();
while (ap[an-1]<2)
ap[an-1] = genrand64_int64();
return;
}
int main () {
mp_ptr ap, bp, rp, refp, rq, rr, refq, refr;
mp_size_t max_n, max_add, max_mul, max_toom, max_div, max_sqrt, an, bn, rn, cn;
mp_ptr ap, bp, rp, refp, rq, rr, refq, refr, ep, rep, mp, tp;
mp_size_t max_n, max_add, max_mul, max_toom, max_div, max_sqrt, max_powm,
an, bn, rn, cn;
int nb, nb_iter;
double elapsed;
#ifdef BENCH
......@@ -95,6 +108,7 @@ int main () {
max_toom = 95;
max_div = 20;
max_sqrt = 95;
max_powm = 50;
ap = TMP_ALLOC_LIMBS (max_n + 1);
bp = TMP_ALLOC_LIMBS (max_n + 1);
/* nap = TMP_ALLOC_LIMBS (max_n + 1); */
......@@ -105,6 +119,9 @@ int main () {
rr = TMP_ALLOC_LIMBS (max_n + 1);
refq = TMP_ALLOC_LIMBS (max_n + 1);
refr = TMP_ALLOC_LIMBS (max_n + 1);
tp = TMP_ALLOC_LIMBS(2 * max_n);
ep = TMP_ALLOC_LIMBS(max_n + 1);
mp = TMP_ALLOC_LIMBS(max_n + 1);
#ifdef TEST_ADD
#ifdef BENCH
......@@ -142,7 +159,7 @@ int main () {
}
elapsed = elapsed / (nb * nb_iter);
#ifdef BENCH
printf ("%d %d %g\n", an, bn, elapsed);
printf ("%ld %ld %g\n", an, bn, elapsed);
if (an==bn)
printf ("\n"); //for gnuplot
#endif
......@@ -210,7 +227,7 @@ int main () {
}
elapsed = elapsed / (nb * nb_iter);
#ifdef BENCH
printf ("%d %d %g\n", an, bn, elapsed);
printf ("%ld %ld %g\n", an, bn, elapsed);
if (an==bn)
printf ("\n"); //for gnuplot
#endif
......@@ -271,7 +288,7 @@ int main () {
}
elapsed = elapsed / (nb * nb_iter);
#ifdef BENCH
printf ("%d %d %g\n", an, bn, elapsed);
printf ("%ld %ld %g\n", an, bn, elapsed);
if (an==bn)
printf ("\n"); //for gnuplot
#endif
......@@ -336,7 +353,7 @@ int main () {
}
elapsed = elapsed / (nb * nb_iter);
#ifdef BENCH
printf ("%d %d %g\n", an, bn, elapsed);
printf ("%ld %ld %g\n", an, bn, elapsed);
if (an==bn)
printf ("\n"); //for gnuplot
#endif
......@@ -504,6 +521,75 @@ int main () {
printf ("sqrtrem ok\n");
#endif
#endif
#ifdef TEST_POWM
#ifdef TEST_MINIGMP
printf ("powm not available in mini-GMP\n");
goto skip;
#endif
#ifdef BENCH
printf ("#an bn t(µs)\n");
#endif
for (an = 2; an <= max_powm; an += 5)
{
for (rn = 1; rn <= max_powm; rn += 5)
{
elapsed = 0;
nb_iter = 1;
for (int iter = 0; iter != nb_iter; ++iter) {
init_valid_1 (ap, an);
init_valid_1 (ep, an);
init_valid_1 (tp, 2 * rn);
init_valid_1 (mp, rn);
mp[0] |= 1;
nb = 150 / an;
#ifdef BENCH
gettimeofday(&begin, NULL);
for (int i = 0; i != nb; ++i)
{
#endif
#ifdef TEST_GMP
__gmpn_powm(refr, ap, an, ep, an, mp, rn, tp);
#endif
#ifdef TEST_WHY3
wmpn_powm(rr, ap, an, ep, an, mp, rn, tp);
#endif
#ifdef BENCH
}
gettimeofday(&end, NULL);
elapsed +=
(end.tv_sec - begin.tv_sec) * 1000000.0
+ (end.tv_usec - begin.tv_usec);
#endif
}
elapsed = elapsed / (nb * nb_iter);
#ifdef BENCH
printf ("%ld %ld %g\n", an, rn, elapsed);
if (an==rn)
printf ("\n"); //for gnuplot
#endif
#ifdef COMPARE
if (mpn_cmp (refr, rr, rn))
{
printf ("ERROR, an = %d, rn = %d\n",
(int) an, (int) rn);
printf ("b: "); mpn_dump (ap, an);
printf ("e: "); mpn_dump (ep, an);
printf ("m: "); mpn_dump (mp, rn);
printf ("r: "); mpn_dump (rr, rn);
printf ("refr: "); mpn_dump (refr, rn);
abort();
}
#endif
}
}
#ifdef COMPARE
printf ("powm ok\n");
#endif
skip:
#endif
//TMP_FREE;
//tests_end ();
return 0;
......
......@@ -13,6 +13,17 @@ uint64_t rsa_estimate (uint64_t a) {
end
module powm.Powm
prelude "#include \"binverttab.h\"
uint64_t binvert_limb_table (uint64_t n) {
return (uint64_t)binverttab[n];
}
"
end
module mach.int.UInt64GMP
prelude "
......
......@@ -12,5 +12,6 @@ module Wmpn
use export toom.Toom
use export sqrt.Sqrt1
use export sqrtrem.Sqrt
use export powm.Powm
end
\ No newline at end of file
......@@ -132,7 +132,8 @@ module C = struct
(** [get_last_expr] extracts the expression computed by the given statement.
This is needed when loop conditions are more complex than a simple expression. *)
This is needed when loop conditions are more complex than a
simple expression. *)
let rec get_last_expr = function
| Snop -> raise NotAValue
| Sexpr e -> Snop, e
......@@ -488,6 +489,9 @@ module Print = struct
let local_printer = create_ident_printer c_keywords ~sanitizer
let global_printer = create_ident_printer c_keywords ~sanitizer
let c_inline = create_attribute "extraction:c_inline"
(* prints the c inline keyword *)
let print_local_ident fmt id = fprintf fmt "%s" (id_unique local_printer id)
let print_global_ident fmt id = fprintf fmt "%s" (id_unique global_printer id)
......@@ -669,15 +673,20 @@ module Print = struct
| Sreturn e -> fprintf fmt "return %a;" print_expr_no_paren e
and print_def fmt def =
let print_inline fmt id =
if Sattr.mem c_inline id.id_attrs
then fprintf fmt "inline "
else fprintf fmt "" in
try match def with
| Dfun (id,(rt,args),body) ->
let s = sprintf "@[@[<hv 2>%a %a(@[%a@]) {@\n@[%a@]@]\n}\n@]"
let s = sprintf "@[@[<hv 2>%a%a %a(@[%a@]) {@\n@[%a@]@]\n}\n@]"
print_inline id
(print_ty ~paren:false) rt
print_global_ident id
(print_list comma
(print_pair_delim nothing space_nolinebreak nothing
(print_ty ~paren:false) print_local_ident))
args
(print_pair_delim nothing space_nolinebreak nothing
(print_ty ~paren:false) print_local_ident))
args
print_body body in
(* print into string first to print nothing in case of exception *)
fprintf fmt "%s" s
......@@ -687,21 +696,22 @@ module Print = struct
print_global_ident id
(print_list comma
(print_pair_delim nothing space_nolinebreak nothing
(print_ty ~paren:false) print_local_ident))
args in
(print_ty ~paren:false) print_local_ident))
args in
fprintf fmt "%s" s
| Ddecl (Tarray(ty, e), lie) ->
let s = sprintf "%a @[<hov>%a@];"
(print_ty ~paren:false) ty
(print_list comma (print_id_init ~stars:0 ~size:(Some e)))
(print_ty ~paren:false) ty
(print_list comma (print_id_init ~stars:0 ~size:(Some e)))
lie in
fprintf fmt "%s" s
| Ddecl (ty, lie) ->
let nb, ty = extract_stars ty in
assert (nb=0);
let s = sprintf "%a @[<hov>%a@];"
(print_ty ~paren:false) ty
(print_list comma (print_id_init ~stars:nb ~size:None)) lie in
(print_ty ~paren:false) ty
(print_list comma (print_id_init ~stars:nb ~size:None))
lie in
fprintf fmt "%s" s
| Dstruct (s, lf) ->
let s = sprintf "struct %s@ @[<hov>{@;<1 2>@[<hov>%a@]@\n};@\n@]"
......@@ -719,7 +729,7 @@ module Print = struct
fprintf fmt "#include \"%s.h\"@;" (sanitizer id.id_string)
| Dtypedef (ty,id) ->
let s = sprintf "@[<hov>typedef@ %a@;%a;@]"
(print_ty ~paren:false) ty print_global_ident id in
(print_ty ~paren:false) ty print_global_ident id in
fprintf fmt "%s" s
with Unprinted s ->
Debug.dprintf debug_c_extraction "Missed a def because : %s@." s
......@@ -864,10 +874,11 @@ module MLToC = struct
computes_return_value : bool;
current_function : rsymbol;
ret_regs : Sreg.t;
breaks : Sid.t;
breaks : Sid.t;
returns : Sid.t;
array_sizes : C.expr Mid.t;
boxed : unit Hreg.t; (* is this struct boxed or passed by value? *)
boxed : unit Hreg.t;
(* is this struct boxed or passed by value? *)
}
let is_true e = match e.e_node with
......@@ -966,7 +977,7 @@ module MLToC = struct
match e.e_ity with
| I i when ity_equal i Ity.ity_unit -> false
| _ -> true)
el in
el in
let env_f = { env with computes_return_value = false } in
let args = List.map (fun e -> simplify_expr (expr info env_f e)) args in
let ((sname, sfields) as sd) = struct_of_constructor info rs in
......@@ -988,7 +999,7 @@ module MLToC = struct
match e.e_ity with
| I i when ity_equal i Ity.ity_unit -> false
| _ -> true)
el
el
in (*FIXME still needed with masks? *)
let env_f = { env with computes_return_value = false } in
if is_rs_tuple rs && env.computes_return_value
......@@ -1006,12 +1017,12 @@ module MLToC = struct
| e::t ->
let b = expr info env_f e in
C.Sseq(assign i b, assigns t (i+1)) in
C.([d_struct], Sseq(assigns args 0, Sreturn(e_struct)))
end
C.([d_struct], Sseq(assigns args 0, Sreturn(e_struct)))
end
else
let (prdefs, prstmt), e' =
let (prdefs, prstmt), e' =
let prelude, unboxed_params =
Lists.map_fold_left
Lists.map_fold_left
(fun ((accd, accs) as acc) e ->
let d, s = expr info env_f e in
let pty = ty_of_ty info (ty_of_ity (ity_of_expr e)) in
......@@ -1022,7 +1033,7 @@ module MLToC = struct
let s', e' = get_last_expr s in
(accd@d, Sseq(accs, s')), (e', pty))
([], Snop)
args in
args in
let params =
List.map2
(fun p mle ->
......@@ -1041,27 +1052,27 @@ module MLToC = struct
|| String.contains s '(' in
if complex s
then
let rty = ty_of_ity (match e.e_ity with
let rty = ty_of_ity (match e.e_ity with
| C _ -> assert false
| I i -> i) in
let rtyargs = match rty.ty_node with
| Tyvar _ -> [||]
| Tyapp (_,args) ->
let rtyargs = match rty.ty_node with
| Tyvar _ -> [||]
| Tyapp (_,args) ->
Array.of_list (List.map (ty_of_ty info) args)
in
in
let p = Mid.find rs.rs_name info.prec in
C.Esyntax(s,ty_of_ty info rty, rtyargs, params, p)
C.Esyntax(s,ty_of_ty info rty, rtyargs, params, p)
else
if args=[]
then C.(Esyntax(s, Tnosyntax, [||], [], [])) (*constant*)
else
(*function defined in the prelude *)
let cargs = List.map fst params in
C.(Esyntaxrename (s, cargs))
C.(Esyntaxrename (s, cargs))
| None ->
match rs.rs_field with
| None ->
C.(Ecall(Evar(rs.rs_name), List.map fst params))
C.(Ecall(Evar(rs.rs_name), List.map fst params))
| Some pv ->
assert (List.length el = 1);
begin match unboxed_params, args with
......@@ -1073,7 +1084,7 @@ module MLToC = struct
in
let s =
if env.computes_return_value
then
then
begin match e.e_ity with
| I ity when ity_equal ity Ity.ity_unit ->
Sseq(Sexpr e', Sreturn Enothing)
......@@ -1092,12 +1103,12 @@ module MLToC = struct
| [], C.Sexpr c ->
let c = handle_likely cond.e_attrs c in
if is_false th && is_true el
then C.([], Sexpr(Eunop(Unot, c)))
else [], C.Sif(c,C.Sblock t, C.Sblock e)
then C.([], Sexpr(Eunop(Unot, c)))
else [], C.Sif(c,C.Sblock t, C.Sblock e)
| cdef, cs ->
let cid = id_register (id_fresh "cond") in (* ? *)
C.Ddecl (C.Tsyntax ("int",[]), [cid, C.Enothing])::cdef,
C.Sseq (C.assignify (Evar cid) cs,
let cid = id_register (id_fresh "cond") in (* ? *)
C.Ddecl (C.Tsyntax ("int",[]), [cid, C.Enothing])::cdef,
C.Sseq (C.assignify (Evar cid) cs,
C.Sif ((handle_likely cond.e_attrs (C.Evar cid)),
C.Sblock t, C.Sblock e))
end
......@@ -1111,7 +1122,7 @@ module MLToC = struct
let env' = { env with
computes_return_value = false;
in_unguarded_loop = true;
breaks =
breaks =
if env.in_unguarded_loop
then Sid.empty else env.breaks } in
let b = expr info env' b in
......@@ -1239,10 +1250,10 @@ module MLToC = struct
let params = [ st, ty_of_ty info (ty_of_ity pv.pv_ity) ] in
let rty = ty_of_ity rs.rs_cty.cty_result in
let rtyargs = match rty.ty_node with
| Tyvar _ -> [||]
| Tyapp (_,args) ->
| Tyvar _ -> [||]
| Tyapp (_,args) ->
Array.of_list (List.map (ty_of_ty info) args)
in
in
let p = Mid.find rs.rs_name info.prec in
C.Esyntax(s,ty_of_ty info rty, rtyargs, params,p)
| None -> if boxed
......@@ -1325,7 +1336,8 @@ module MLToC = struct
| _ -> true in
let keep_pv pv =
not pv.pv_ghost &&
not (ity_equal pv.pv_ity Ity.ity_unit && is_dummy pv.pv_vs.vs_name) in
not (ity_equal pv.pv_ity Ity.ity_unit
&& is_dummy pv.pv_vs.vs_name) in
let ngvl = List.filter keep_var vl in
let ngargs = List.filter keep_pv rs.rs_cty.cty_args in
let params =
......@@ -1338,25 +1350,25 @@ module MLToC = struct
Hreg.add boxed r ();
C.Tptr cty, id
| _ -> (cty, id)) ngvl ngargs in
let ret_regs = ity_exp_fold Sreg.add_left Sreg.empty rs.rs_cty.cty_result in
let rity = rs.rs_cty.cty_result in
let is_simple_tuple ity =
let arity_zero = function
| Ityapp(_,a,r) -> a = [] && r = []
| Ityreg { reg_args = a; reg_regs = r } ->
let rity = rs.rs_cty.cty_result in
let ret_regs = ity_exp_fold Sreg.add_left Sreg.empty rity in
let is_simple_tuple ity =
let arity_zero = function
| Ityapp(_,a,r) -> a = [] && r = []
| Ityreg { reg_args = a; reg_regs = r } ->
a = [] && r = []
| Ityvar _ -> true
in
(match ity.ity_node with
| Ityapp ({its_ts = s},_,_)
| Ityvar _ -> true
in
(match ity.ity_node with
| Ityapp ({its_ts = s},_,_)
| Ityreg { reg_its = {its_ts = s}; }
-> is_ts_tuple s
| _ -> false)
&& (ity_fold
-> is_ts_tuple s
| _ -> false)
&& (ity_fold
(fun acc ity ->
acc && arity_zero ity.ity_node) true ity)
in
(* FIXME is it necessary to have arity 0 in regions ?*)
in
(* FIXME is it necessary to have arity 0 in regions ?*)
let rtype = try ty_of_mlty info mlty
with Unsupported _ -> (*FIXME*)
ty_of_ty info (ty_of_ity rity) in
......@@ -1370,21 +1382,21 @@ module MLToC = struct
if header
then sdecls@[C.Dproto (rs.rs_name, (rtype, params))]
else
let env = { computes_return_value = true;
in_unguarded_loop = false;
let env = { computes_return_value = true;
in_unguarded_loop = false;
current_function = rs;
ret_regs = ret_regs;
returns = Sid.empty;
breaks = Sid.empty;