Commit 25ad8d0f authored by Raphael Rieu-Helft's avatar Raphael Rieu-Helft

Support modular C extraction using headers

parent 4c5a33d9
......@@ -3,6 +3,7 @@ printer "c"
prelude "#include <stdlib.h>"
prelude "#include <stdint.h>"
prelude "#include <stdio.h>"
prelude "#include <assert.h>"
module ref.Ref
......@@ -77,7 +78,7 @@ struct __add32_with_carry_result
uint32_t __field_1;
};
static struct __add32_with_carry_result add32_with_carry(uint32_t x, uint32_t y, uint32_t c)
struct __add32_with_carry_result add32_with_carry(uint32_t x, uint32_t y, uint32_t c)
{
struct __add32_with_carry_result result;
uint64_t r = (uint64_t)x + (uint64_t)y + (uint64_t) c;
......@@ -91,7 +92,7 @@ struct __sub32_with_borrow_result
uint32_t __field_1;
};
static struct __sub32_with_borrow_result sub32_with_borrow(uint32_t x, uint32_t y, uint32_t b)
struct __sub32_with_borrow_result sub32_with_borrow(uint32_t x, uint32_t y, uint32_t b)
{
struct __sub32_with_borrow_result result;
uint64_t r = (uint64_t)x - (uint64_t)y - (uint64_t) b;
......@@ -105,7 +106,7 @@ struct __mul32_double_result
uint32_t __field_1;
};
static struct __mul32_double_result mul32_double(uint32_t x, uint32_t y)
struct __mul32_double_result mul32_double(uint32_t x, uint32_t y)
{
struct __mul32_double_result result;
uint64_t r = (uint64_t)x * (uint64_t)y;
......@@ -119,7 +120,7 @@ struct __add32_3_result
uint32_t __field_1;
};
static struct __add32_3_result add32_3(uint32_t x, uint32_t y, uint32_t z)
struct __add32_3_result add32_3(uint32_t x, uint32_t y, uint32_t z)
{
struct __add32_3_result result;
uint64_t r = (uint64_t)x + (uint64_t)y + (uint64_t) z;
......@@ -133,7 +134,7 @@ struct __lsld32_result
uint32_t __field_1;
};
static struct __lsld32_result lsld32(uint32_t x, uint32_t cnt)
struct __lsld32_result lsld32(uint32_t x, uint32_t cnt)
{
struct __lsld32_result result;
uint64_t r = (uint64_t)x << cnt;
......@@ -150,35 +151,35 @@ struct __add32_with_carry_result
uint32_t __field_1;
};
static struct __add32_with_carry_result add32_with_carry(uint32_t x, uint32_t y, uint32_t c);
struct __add32_with_carry_result add32_with_carry(uint32_t x, uint32_t y, uint32_t c);
struct __sub32_with_borrow_result
{ uint32_t __field_0;
uint32_t __field_1;
};
static struct __sub32_with_borrow_result sub32_with_borrow(uint32_t x, uint32_t y, uint32_t b);
struct __sub32_with_borrow_result sub32_with_borrow(uint32_t x, uint32_t y, uint32_t b);
struct __mul32_double_result
{ uint32_t __field_0;
uint32_t __field_1;
};
static struct __mul32_double_result mul32_double(uint32_t x, uint32_t y);
struct __mul32_double_result mul32_double(uint32_t x, uint32_t y);
struct __add32_3_result
{ uint32_t __field_0;
uint32_t __field_1;
};
static struct __add32_3_result add32_3(uint32_t x, uint32_t y, uint32_t z);
struct __add32_3_result add32_3(uint32_t x, uint32_t y, uint32_t z);
struct __lsld32_result
{ uint32_t __field_0;
uint32_t __field_1;
};
static struct __lsld32_result lsld32(uint32_t x, uint32_t cnt);
struct __lsld32_result lsld32(uint32_t x, uint32_t cnt);
"
syntax converter of_int "%1U"
......@@ -263,7 +264,7 @@ struct __add64_with_carry_result
uint64_t __field_1;
};
static struct __add64_with_carry_result
struct __add64_with_carry_result
add64_with_carry(uint64_t x, uint64_t y, uint64_t c)
{
struct __add64_with_carry_result result;
......@@ -279,7 +280,7 @@ struct __add64_double_result
uint64_t __field_1;
};
static struct __add64_double_result
struct __add64_double_result
add64_double(uint64_t a1, uint64_t a0, uint64_t b1, uint64_t b0)
{
struct __add64_double_result result;
......@@ -292,7 +293,7 @@ struct __sub64_with_borrow_result
uint64_t __field_1;
};
static struct __sub64_with_borrow_result
struct __sub64_with_borrow_result
sub64_with_borrow(uint64_t x, uint64_t y, uint64_t b)
{
struct __sub64_with_borrow_result result;
......@@ -309,7 +310,7 @@ struct __sub64_double_result
uint64_t __field_1;
};
static struct __sub64_double_result
struct __sub64_double_result
sub64_double(uint64_t a1, uint64_t a0, uint64_t b1, uint64_t b0)
{
struct __sub64_double_result result;
......@@ -322,14 +323,14 @@ struct __mul64_double_result
uint64_t __field_1;
};
static struct __mul64_double_result mul64_double(uint64_t x, uint64_t y)
struct __mul64_double_result mul64_double(uint64_t x, uint64_t y)
{
struct __mul64_double_result result;
umul_ppmm(result.__field_1,result.__field_0,x,y);
return result;
}
static uint64_t div64_2by1(uint64_t ul, uint64_t uh, uint64_t d)
uint64_t div64_2by1(uint64_t ul, uint64_t uh, uint64_t d)
{
uint64_t q;
uint64_t _dummy __attribute__((unused));
......@@ -342,7 +343,7 @@ struct __add64_3_result
uint64_t __field_1;
};
static struct __add64_3_result add64_3(uint64_t x, uint64_t y, uint64_t z)
struct __add64_3_result add64_3(uint64_t x, uint64_t y, uint64_t z)
{
struct __add64_3_result result;
uint64_t r, c1, c2;
......@@ -360,7 +361,7 @@ struct __lsld64_result
uint64_t __field_1;
};
static struct __lsld64_result lsld64(uint64_t x, uint64_t cnt)
struct __lsld64_result lsld64(uint64_t x, uint64_t cnt)
{
struct __lsld64_result result;
result.__field_1 = x >> (64 - cnt);
......@@ -383,7 +384,7 @@ struct __add64_with_carry_result
uint64_t __field_1;
};
static struct __add64_with_carry_result
struct __add64_with_carry_result
add64_with_carry(uint64_t x, uint64_t y, uint64_t c);
struct __add64_double_result
......@@ -391,7 +392,7 @@ struct __add64_double_result
uint64_t __field_1;
};
static struct __add64_double_result
struct __add64_double_result
add64_double(uint64_t a1, uint64_t a0, uint64_t b1, uint64_t b0);
struct __sub64_with_borrow_result
......@@ -399,7 +400,7 @@ struct __sub64_with_borrow_result
uint64_t __field_1;
};
static struct __sub64_with_borrow_result
struct __sub64_with_borrow_result
sub64_with_borrow(uint64_t x, uint64_t y, uint64_t b);
struct __sub64_double_result
......@@ -407,7 +408,7 @@ struct __sub64_double_result
uint64_t __field_1;
};
static struct __sub64_double_result
struct __sub64_double_result
sub64_double(uint64_t a1, uint64_t a0, uint64_t b1, uint64_t b0);
struct __mul64_double_result
......@@ -415,23 +416,23 @@ struct __mul64_double_result
uint64_t __field_1;
};
static struct __mul64_double_result mul64_double(uint64_t x, uint64_t y);
struct __mul64_double_result mul64_double(uint64_t x, uint64_t y);
static uint64_t div64_2by1(uint64_t ul, uint64_t uh, uint64_t d);
uint64_t div64_2by1(uint64_t ul, uint64_t uh, uint64_t d);
struct __add64_3_result
{ uint64_t __field_0;
uint64_t __field_1;
};
static struct __add64_3_result add64_3(uint64_t x, uint64_t y, uint64_t z);
struct __add64_3_result add64_3(uint64_t x, uint64_t y, uint64_t z);
struct __lsld64_result
{ uint64_t __field_0;
uint64_t __field_1;
};
static struct __lsld64_result lsld64(uint64_t x, uint64_t cnt);
struct __lsld64_result lsld64(uint64_t x, uint64_t cnt);
"
syntax converter of_int "%1ULL"
......
......@@ -79,9 +79,10 @@ module C = struct
| Sbreak
| Sreturn of expr
and include_kind = Sys | Proj (* include <...> vs. include "..." *)
and definition =
| Dfun of ident * proto * body
| Dinclude of ident
| Dinclude of ident * include_kind
| Dproto of ident * proto
| Ddecl of names
| Dstruct of struct_def
......@@ -206,7 +207,7 @@ module C = struct
| Ddecl (ty,l) ->
let l,b = aux l in
Ddecl (ty, l), b
| Dinclude i -> Dinclude i, true
| Dinclude (i,k) -> Dinclude (i,k), true
| Dstruct _ -> raise (Unsupported "struct declaration inside function")
| Dfun _ -> raise (Unsupported "nested function")
| Dtypedef _ -> raise (Unsupported "typedef inside function")
......@@ -387,6 +388,7 @@ module Print = struct
let () = assert (List.length c_keywords = 32)
let sanitizer = sanitizer char_to_lalpha char_to_alnumus
let sanitizer s = String.lowercase_ascii (sanitizer s)
let printer = create_ident_printer c_keywords ~sanitizer
let print_ident fmt id = fprintf fmt "%s" (id_unique printer id)
......@@ -527,7 +529,8 @@ module Print = struct
(print_ty ~paren:false) print_ident))
args
print_body body in
fprintf fmt "%s" s (* print into string first to print nothing in case of exception *)
(* print into string first to print nothing in case of exception *)
fprintf fmt "%s" s
| Dproto (id, (rt, args)) ->
let s = sprintf "%a %a(@[%a@]);@;"
(print_ty ~paren:false) rt
......@@ -552,13 +555,16 @@ module Print = struct
(print_ty ~paren:false) ty s))
lf in
fprintf fmt "%s" s
| Dinclude id ->
fprintf fmt "#include<%a.h>@;" print_ident id
| Dinclude (id, Sys) ->
fprintf fmt "#include <%s.h>@;" (sanitizer id.id_string)
| Dinclude (id, Proj) ->
fprintf fmt "#include \"%s.h\"@;" (sanitizer id.id_string)
| Dtypedef (ty,id) ->
let s = sprintf "@[<hov>typedef@ %a@;%a;@]"
(print_ty ~paren:false) ty print_ident id in
fprintf fmt "%s" s
with Unprinted s -> Format.printf "Missed a def because : %s@." s
with Unprinted s ->
Debug.dprintf debug_c_extraction "Missed a def because : %s@." s
and print_body fmt (def, s) =
if def = []
......@@ -569,6 +575,33 @@ module Print = struct
(print_stmt ~braces:true)
fmt (def,s)
let print_header_def fmt def =
try match def with
| Dfun (id,(rt,args),_) | Dproto (id, (rt, args)) ->
let s = sprintf "%a %a(@[%a@]);@;"
(print_ty ~paren:false) rt
print_ident id
(print_list comma
(print_pair_delim nothing space nothing
(print_ty ~paren:false) print_ident))
args in
fprintf fmt "%s" s
| Dstruct (s, lf) ->
let s = sprintf "struct %s@ @[<hov>{@;<1 2>@[<hov>%a@]@\n};@\n@]"
s
(print_list newline
(fun fmt (s,ty) -> fprintf fmt "%a %s;"
(print_ty ~paren:false) ty s))
lf in
fprintf fmt "%s" s
| Dinclude _ | Ddecl _ -> ()
| Dtypedef (ty,id) ->
let s = sprintf "@[<hov>typedef@ %a@;%a;@]"
(print_ty ~paren:false) ty print_ident id in
fprintf fmt "%s" s
with Unprinted s ->
Debug.dprintf debug_c_extraction "Missed a def because : %s@." s
let print_file fmt info ast =
Mid.iter (fun _ sl -> List.iter (fprintf fmt "%s\n") sl) info.thprelude;
newline fmt ();
......@@ -576,7 +609,8 @@ module Print = struct
end
(*TODO simplifications : propagate constants, collapse blocks with only a statement and no def*)
(*TODO simplifications : propagate constants, collapse blocks with
only a statement and no def*)
module MLToC = struct
......@@ -1080,19 +1114,21 @@ module MLToC = struct
Debug.dprintf debug_c_extraction "PDtype %s@." id.id_string;
begin
match idef with
| Some (Dalias ty) -> [C.Dtypedef (ty_of_mlty info ty, id)]
| Some (Dalias _ty) -> [] (*[C.Dtypedef (ty_of_mlty info ty, id)] *)
| Some _ -> raise (Unsupported "Ddata/Drecord@.")
| None ->
begin match query_syntax info.syntax id with
| Some _ -> []
| None ->
raise (Unsupported ("type declaration without syntax or alias: "^id.id_string))
raise (Unsupported
("type declaration without syntax or alias: "
^id.id_string))
end
end
| _ -> [] (*TODO exn ? *) end
with Unsupported s ->
Format.printf "Unsupported : %s@." s; []
Debug.dprintf debug_c_extraction "Unsupported : %s@." s; []
let translate_decl (info:info) (d:Mltree.decl) : C.definition list
=
......@@ -1113,6 +1149,7 @@ end
let name_gen suffix ?fname m =
let n = m.Pmodule.mod_theory.Theory.th_name.Ident.id_string in
let n = Print.sanitizer n in
let r = match fname with
| None -> n ^ suffix
| Some f -> f ^ "__" ^ n ^ suffix in
......@@ -1126,24 +1163,27 @@ let print_header_decl args ?old ?fname ~flat m fmt d =
ignore fname;
ignore flat;
ignore m;
ignore args;
ignore fmt;
ignore d;
() (* TODO *)
let cds = MLToC.translate_decl args d in
List.iter (Format.fprintf fmt "%a@." Print.print_header_def) cds
let print_prelude args ?old ?fname ~flat fmt m =
let print_prelude args ?old ?fname ~flat deps fmt pm =
ignore old;
ignore fname;
ignore flat;
ignore pm;
ignore args;
ignore fmt;
ignore m;
() (* TODO *)
let add_include id =
Format.fprintf fmt "%a@." Print.print_def C.(Dinclude (id,Proj)) in
List.iter
(fun m ->
let id = m.Pmodule.mod_theory.Theory.th_name in
add_include id)
(List.rev deps)
let print_decl args ?old ?fname ~flat m fmt d =
ignore old;
ignore fname;
ignore flat; (*FIXME*)
ignore flat;
ignore m;
let cds = MLToC.translate_decl args d in
List.iter (Format.fprintf fmt "%a@." Print.print_def) cds
......
......@@ -245,9 +245,9 @@ type interf_printer =
type prelude_printer =
printer_args -> ?old:in_channel -> ?fname:string -> flat:bool
-> Pmodule.pmodule Pp.pp
-> Pmodule.pmodule list -> Pmodule.pmodule Pp.pp
let print_empty_prelude _ ?old:_ ?fname:_ ~flat:_ _ _ = ()
let print_empty_prelude _ ?old:_ ?fname:_ ~flat:_ _ _ _ = ()
type decl_printer =
printer_args -> ?old:in_channel -> ?fname:string -> flat:bool ->
......
......@@ -52,7 +52,7 @@ type interf_printer =
Only used in modular extraction. *)
type prelude_printer =
printer_args -> ?old:in_channel -> ?fname:string -> flat:bool
-> Pmodule.pmodule Pp.pp
-> Pmodule.pmodule list -> Pmodule.pmodule Pp.pp
val print_empty_prelude: prelude_printer
......
......@@ -146,7 +146,7 @@ let print_preludes =
let l = List.fold_left add [] th_pm in
Printer.print_prelude fmt l
let print_mdecls ?fname m mdecls =
let print_mdecls ?fname m mdecls deps =
let pargs, printer = Pdriver.lookup_printer opt_driver in
let fg = printer.Pdriver.file_gen in
let pr = printer.Pdriver.decl_printer in
......@@ -155,8 +155,12 @@ let print_mdecls ?fname m mdecls =
let test_id_not_driver id =
Printer.query_syntax pargs.Pdriver.syntax id = None in
List.exists test_id_not_driver decl_name in
if List.exists test_decl_not_driver mdecls then begin
let prelude_exists =
Ident.Mid.mem m.mod_theory.Theory.th_name pargs.Pdriver.thprelude in
if List.exists test_decl_not_driver mdecls || prelude_exists
then begin
let flat = opt_modu_flat = Flat in
let thname = m.mod_theory.Theory.th_name in
(* print interface file *)
if !opt_interface then begin
match printer.Pdriver.interf_gen, printer.Pdriver.interf_printer with
......@@ -166,21 +170,28 @@ let print_mdecls ?fname m mdecls =
| Some ig, Some ipr ->
let iout, old = get_cout_old ig m ?fname in
let ifmt = formatter_of_out_channel iout in
Printer.print_prelude ifmt pargs.Pdriver.prelude;
let inter_p = Ident.Mid.find_def [] thname pargs.Pdriver.thinterface in
Printer.print_interface ifmt inter_p;
(* printer.Pdriver.prelude_printer pargs ?old ?fname ~flat deps ifmt m;*)
let pr_idecl fmt d =
fprintf fmt "%a" (ipr pargs ?old ?fname ~flat m) d in
Pp.print_list Pp.nothing pr_idecl ifmt mdecls;
if iout <> stdout then close_out iout end;
let cout, old = get_cout_old fg m ?fname in
let fmt = formatter_of_out_channel cout in
(* print module prelude *)
printer.Pdriver.prelude_printer pargs ?old ?fname ~flat deps fmt m;
(* print driver prelude *)
Printer.print_prelude fmt pargs.Pdriver.prelude;
let pm = pargs.Pdriver.thprelude in
print_preludes m.mod_theory.Theory.th_name fmt pm;
(* print module prelude *)
printer.Pdriver.prelude_printer pargs ?old ?fname ~flat fmt m;
print_preludes thname fmt pm;
(* print decls *)
let pr_decl fmt d = fprintf fmt "%a" (pr pargs ?old ?fname ~flat m) d in
Pp.print_list Pp.nothing pr_decl fmt mdecls;
if cout <> stdout then close_out cout end
if cout <> stdout then close_out cout;
true end
else false
let find_module_path mm path m = match path with
| [] -> Mstr.find m mm
......@@ -206,24 +217,27 @@ let is_not_extractable_theory =
let h = Hstr.create 16 in
List.iter (fun s -> Hstr.add h s ()) not_extractable_theories;
Hstr.mem h
let extract_to =
let memo = Ident.Hid.create 16 in
fun ?fname ?decl m ->
fun ?fname ?decl m deps ->
match m.mod_theory.Theory.th_path with
| t::_ when is_not_extractable_theory t -> ()
| t::_ when is_not_extractable_theory t -> false
| _ -> let name = m.mod_theory.Theory.th_name in
if not (Ident.Hid.mem memo name) then begin
Ident.Hid.add memo name ();
let mdecls = match decl with
| None -> (translate_module m).Mltree.mod_decl
| Some d -> Translate.pdecl_m m d in
print_mdecls ?fname m mdecls
let file_exists = print_mdecls ?fname m mdecls deps in
Ident.Hid.add memo name file_exists;
file_exists
end
else Ident.Hid.find memo name
let rec use_iter f l =
List.iter
(function Uuse t -> f t | Uscope (_,l) -> use_iter f l | _ -> ()) l
let rec use_fold f l =
List.fold_left
(fun acc -> function | Uuse t -> if f t then t::acc else acc
| Uscope (_,l) -> (use_fold f l)@acc
| _ -> acc) [] l
let rec do_extract_module ?fname m =
let extract_use m' =
......@@ -231,11 +245,12 @@ let rec do_extract_module ?fname m =
if m'.mod_theory.Theory.th_path = [] then fname else None in
do_extract_module ?fname m'
in
begin match opt_rec_single with
| Recursive -> use_iter extract_use m.mod_units
| Single -> ()
end;
extract_to ?fname m
let deps =
match opt_rec_single with
| Recursive -> use_fold extract_use m.mod_units
| Single -> []
in
extract_to ?fname m deps
let do_extract_module_from fname mm m =
try
......@@ -285,16 +300,16 @@ let do_modular target =
match target with
| File fname ->
let mm = read_mlw_file ?format env fname in
let do_m _ m = do_extract_module ~fname m in
let do_m _ m = ignore (do_extract_module ~fname m) in
Mstr.iter do_m mm
| Module (path, m) ->
let mm = Mstr.empty in
let m = find_module_path mm path m in
do_extract_module m
ignore (do_extract_module m)
| Symbol (path, m, s) ->
let mm = Mstr.empty in
let m = find_module_path mm path m in
do_extract_symbol_from m s
ignore (do_extract_symbol_from m s []) (* FIXME empty deps ? *)
type extract_info = {
info_rec : bool;
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment