Commit a1cc4c0a authored by Raphael Rieu-Helft's avatar Raphael Rieu-Helft

Add support for mutually recursive definitions in C extractions

Toom multiplication can now be extracted
parent 8c916bc6
......@@ -27,17 +27,17 @@ module mach.int.Int32
syntax type int32 "int32_t"
syntax val (+) "%1 + %2"
syntax val (-) "%1 - %2"
syntax val (+) "(%1) + (%2)"
syntax val (-) "(%1) - (%2)"
syntax val (-_) "-(%1)"
syntax val (*) "%1 * %2"
syntax val (/) "%1 / %2"
syntax val (%) "%1 % %2"
syntax val (=) "%1 == %2"
syntax val (<=) "%1 <= %2"
syntax val (<) "%1 < %2"
syntax val (>=) "%1 >= %2"
syntax val (>) "%1 > %2"
syntax val (*) "(%1) * (%2)"
syntax val (/) "(%1) / (%2)"
syntax val (%) "(%1) % (%2)"
syntax val (=) "(%1) == (%2)"
syntax val (<=) "(%1) <= (%2)"
syntax val (<) "(%1) < (%2)"
syntax val (>=) "(%1) >= (%2)"
syntax val (>) "(%1) > (%2)"
end
module mach.int.UInt32Gen
......@@ -53,16 +53,17 @@ module mach.int.UInt32
syntax converter of_int "%1U"
syntax val (+) "%1 + %2"
syntax val (-) "%1 - %2"
syntax val (*) "%1 * %2"
syntax val (/) "%1 / %2"
syntax val (%) "%1 % %2"
syntax val (=) "%1 == %2"
syntax val (<=) "%1 <= %2"
syntax val (<) "%1 < %2"
syntax val (>=) "%1 >= %2"
syntax val (>) "%1 > %2"
syntax val (+) "(%1) + (%2)"
syntax val (-) "(%1) - (%2)"
syntax val (*) "(%1) * (%2)"
syntax val (/) "(%1) / (%2)"
syntax val (%) "(%1) % (%2)"
syntax val (=) "(%1) == (%2)"
syntax val (<=) "(%1) <= (%2)"
syntax val (<) "(%1) < (%2)"
syntax val (>=) "(%1) >= (%2)"
syntax val (>) "(%1) > (%2)"
end
......@@ -184,16 +185,16 @@ struct __lsld32_result lsld32(uint32_t x, uint32_t cnt);
syntax converter of_int "%1U"
syntax val (+) "%1 + %2"
syntax val (-) "%1 - %2"
syntax val (*) "%1 * %2"
syntax val (/) "%1 / %2"
syntax val (%) "%1 % %2"
syntax val (=) "%1 == %2"
syntax val (<=) "%1 <= %2"
syntax val (<) "%1 < %2"
syntax val (>=) "%1 >= %2"
syntax val (>) "%1 > %2"
syntax val (+) "(%1) + (%2)"
syntax val (-) "(%1) - (%2)"
syntax val (*) "(%1) * (%2)"
syntax val (/) "(%1) / (%2)"
syntax val (%) "(%1) % (%2)"
syntax val (=) "(%1) == (%2)"
syntax val (<=) "(%1) <= (%2)"
syntax val (<) "(%1) < (%2)"
syntax val (>=) "(%1) >= (%2)"
syntax val (>) "(%1) > (%2)"
syntax val add_with_carry "add32_with_carry"
syntax val sub_with_borrow "sub32_with_borrow"
......@@ -201,23 +202,42 @@ struct __lsld32_result lsld32(uint32_t x, uint32_t cnt);
syntax val add3 "add32_3"
syntax val lsld "lsld32"
syntax val add_mod "%1 + %2"
syntax val sub_mod "%1 - %2"
syntax val mul_mod "%1 * %2"
syntax val add_mod "(%1) + (%2)"
syntax val sub_mod "(%1) - (%2)"
syntax val mul_mod "(%1) * (%2)"
syntax val div2by1
"(uint32_t)(((uint64_t)%1 | ((uint64_t)%2 << 32))/(uint64_t)%3)"
"(uint32_t)((((uint64_t)%1) | (((uint64_t)%2) << 32))/(uint64_t)(%3))"
syntax val lsl "%1 << %2"
syntax val lsr "%1 >> %2"
syntax val lsl "(%1) << (%2)"
syntax val lsr "(%1) >> (%2)"
syntax val is_msb_set "%1 & 0x80000000U"
syntax val is_msb_set "(%1) & 0x80000000U"
syntax val count_leading_zeros "__builtin_clz(%1)"
syntax val of_int32 "(uint32_t)%1"
syntax val of_int32 "(uint32_t)(%1)"
end
module mach.int.Int64
syntax val of_int "%1"
syntax converter of_int "%1"
syntax type int64 "int64_t"
syntax val (+) "(%1) + (%2)"
syntax val (-) "(%1) - (%2)"
syntax val (-_) "-(%1)"
syntax val (*) "(%1) * (%2)"
syntax val (/) "(%1) / (%2)"
syntax val (%) "(%1) % (%2)"
syntax val (=) "(%1) == (%2)"
syntax val (<=) "(%1) <= (%2)"
syntax val (<) "(%1) < (%2)"
syntax val (>=) "(%1) >= (%2)"
syntax val (>) "(%1) > (%2)"
end
module mach.int.UInt64Gen
......@@ -233,17 +253,17 @@ module mach.int.UInt64
syntax converter of_int "%1ULL"
syntax val (+) "%1 + %2"
syntax val (-) "%1 - %2"
syntax val (*) "%1 * %2"
syntax val (/) "%1 / %2"
syntax val (%) "%1 % %2"
syntax val (=) "%1 == %2"
syntax val (<=) "%1 <= %2"
syntax val (<) "%1 < %2"
syntax val (>=) "%1 >= %2"
syntax val (>) "%1 > %2"
syntax val (+) "(%1) + (%2)"
syntax val (-) "(%1) - (%2)"
syntax val (-_) "-(%1)"
syntax val (*) "(%1) * (%2)"
syntax val (/) "(%1) / (%2)"
syntax val (%) "(%1) % (%2)"
syntax val (=) "(%1) == (%2)"
syntax val (<=) "(%1) <= (%2)"
syntax val (<) "(%1) < (%2)"
syntax val (>=) "(%1) >= (%2)"
syntax val (>) "(%1) > (%2)"
end
......@@ -434,16 +454,16 @@ struct __lsld64_result lsld64(uint64_t x, uint64_t cnt);
"
syntax converter of_int "%1ULL"
syntax val (+) "%1 + %2"
syntax val (-) "%1 - %2"
syntax val (*) "%1 * %2"
syntax val (/) "%1 / %2"
syntax val (%) "%1 % %2"
syntax val (=) "%1 == %2"
syntax val (<=) "%1 <= %2"
syntax val (<) "%1 < %2"
syntax val (>=) "%1 >= %2"
syntax val (>) "%1 > %2"
syntax val (+) "(%1) + (%2)"
syntax val (-) "(%1) - (%2)"
syntax val (*) "(%1) * (%2)"
syntax val (/) "(%1) / (%2)"
syntax val (%) "(%1) % (%2)"
syntax val (=) "(%1) == (%2)"
syntax val (<=) "(%1) <= (%2)"
syntax val (<) "(%1) < (%2)"
syntax val (>=) "(%1) >= (%2)"
syntax val (>) "(%1) > (%2)"
syntax val add_with_carry "add64_with_carry"
syntax val add_double "add64_double"
......@@ -458,20 +478,20 @@ struct __lsld64_result lsld64(uint64_t x, uint64_t cnt);
syntax val add3 "add64_3"
syntax val lsld "lsld64"
syntax val add_mod "%1 + %2"
syntax val sub_mod "%1 - %2"
syntax val mul_mod "%1 * %2"
syntax val add_mod "(%1) + (%2)"
syntax val sub_mod "(%1) - (%2)"
syntax val mul_mod "(%1) * (%2)"
syntax val lsl "%1 << %2"
syntax val lsr "%1 >> %2"
syntax val lsl "(%1) << (%2)"
syntax val lsr "(%1) >> (%2)"
syntax val is_msb_set "%1 & 0x8000000000000000ULL"
syntax val is_msb_set "(%1) & 0x8000000000000000ULL"
syntax val count_leading_zeros "__builtin_clzll(%1)"
syntax val of_int32 "(uint64_t)%1"
syntax val to_int64 "(int64_t)%1"
syntax val of_int64 "(uint64_t)%1"
syntax val of_int32 "(uint64_t)(%1)"
syntax val to_int64 "(int64_t)(%1)"
syntax val of_int64 "(uint64_t)(%1)"
end
......@@ -480,6 +500,8 @@ end
module mach.c.C
prelude "#include <assert.h>"
prelude "#define IGNORE2(x,y) do {} while (0)"
interface "#define IGNORE2(x,y) do {} while (0)"
syntax type ptr "%1 *"
syntax type bool "int" (* ? *)
......@@ -501,7 +523,9 @@ module mach.c.C
syntax val set_ofs "*(%1+(%2)) = %3"
syntax val incr_split "%1+(%2)"
syntax val join "(void)0"
syntax val decr_split "%1-(%2)"
syntax val join "IGNORE2"
syntax val join_r "IGNORE2"
syntax val c_assert "assert ( %1 )"
......
......@@ -380,9 +380,9 @@ module Sub
writes { x.data.elts }
=
let ghost ox = { x } in
let ghost b : ref limb = ref 1 in
let lx : ref limb = ref 0 in
let i : ref int32 = ref 0 in
let ghost b = ref (Limb.of_int 1) in
let lx = ref (Limb.of_int 0) in
let i = ref (Int32.of_int 0) in
while (Limb.(=) !lx 0) do
invariant { 0 <= !i <= sz }
invariant { !i = sz -> !lx <> 0 }
......
This diff is collapsed.
......@@ -327,9 +327,12 @@ module C = struct
| Sexpr _ -> true
| _ -> false
let simplify_expr (d,s) : expr =
let rec simplify_expr (d,s) : expr =
match (d,s) with
| [], Sblock([],s) -> simplify_expr ([],s)
| [], Sexpr e -> e
| [], Sif(c,t,e) ->
Equestion (c, simplify_expr([],t), simplify_expr([],e))
| _ -> raise (Invalid_argument "simplify_expr")
let rec simplify_cond (cd, cs) =
......@@ -459,6 +462,9 @@ module Print = struct
| Ecast(ty, e) ->
fprintf fmt (protect_on paren "(%a)%a")
(print_ty ~paren:false) ty (print_expr ~paren:true) e
| Ecall (Esyntax (s, _, _, [], _), l) -> (* function defined in the prelude *)
fprintf fmt (protect_on paren "%s(%a)")
s (print_list comma (print_expr ~paren:false)) l
| Ecall (e,l) -> fprintf fmt (protect_on paren "%a(%a)")
(print_expr ~paren:true) e (print_list comma (print_expr ~paren:false)) l
| Econst c -> print_const fmt c
......@@ -754,6 +760,7 @@ module MLToC = struct
let e = C.(Econst (Cint (BigInt.to_string n))) in
([], return_or_expr env e)
| Eapp (rs, el) ->
Debug.dprintf debug_c_extraction "call to %s@." rs.rs_name.id_string;
if is_rs_tuple rs && env.computes_return_value
then begin
let args =
......@@ -1063,9 +1070,8 @@ module MLToC = struct
let translate_decl (info:info) (d:decl) : C.definition list
=
try
begin match d with
| Dlet (Lsym(rs, _, vl, e)) ->
let translate_fun rs vl e =
Debug.dprintf debug_c_extraction "print %s@." rs.rs_name.id_string;
if rs_ghost rs
then begin Debug.dprintf debug_c_extraction "is ghost@."; [] end
else
......@@ -1109,7 +1115,10 @@ module MLToC = struct
let d = C.group_defs_by_type d in
let s = C.elim_nop s in
let s = C.elim_empty_blocks s in
sdecls@[C.Dfun (rs.rs_name, (rtype,params), (d,s))]
sdecls@[C.Dfun (rs.rs_name, (rtype,params), (d,s))] in
try
begin match d with
| Dlet (Lsym(rs, _, vl, e)) -> translate_fun rs vl e
| Dtype [{its_name=id; its_def=idef}] ->
Debug.dprintf debug_c_extraction "PDtype %s@." id.id_string;
begin
......@@ -1125,7 +1134,15 @@ module MLToC = struct
^id.id_string))
end
end
| Dlet (Lrec rl) ->
let translate_rdef rd =
translate_fun rd.rec_sym rd.rec_args rd.rec_exp in
let defs = List.flatten (List.map translate_rdef rl) in
let proto_of_fun = function
| C.Dfun (id, pr, _) -> [C.Dproto (id, pr)]
| _ -> [] in
let protos = List.flatten (List.map proto_of_fun defs) in
protos@defs
| _ -> [] (*TODO exn ? *) end
with Unsupported s ->
Debug.dprintf debug_c_extraction "Unsupported : %s@." s; []
......@@ -1133,16 +1150,9 @@ module MLToC = struct
let translate_decl (info:info) (d:Mltree.decl) : C.definition list
=
let decide_print id = query_syntax info.syntax id = None in
match Mltree.get_decl_name d with
| [id] when decide_print id ->
Debug.dprintf debug_c_extraction "print %s@." id.id_string;
translate_decl info d
| [_] | [] -> []
| l -> Debug.dprintf debug_c_extraction "%d defs: %a@."
(List.length l)
(Pp.print_list Pp.space Pretty.print_id_attrs) l;
[]
match List.filter decide_print (Mltree.get_decl_name d) with
| [] -> []
| _ -> translate_decl info d
end
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment