Commit d1bb2ed9 authored by Raphaël Rieu-Helft's avatar Raphaël Rieu-Helft
Browse files

Merge branch 'extraction_cosmetics' into 'master'

C extraction cosmetics

See merge request !32
parents 25e4d7ee 6b94c96b
......@@ -27,17 +27,17 @@ module mach.int.Int32
syntax type int32 "int32_t"
syntax val (+) "(%1) + (%2)"
syntax val (-) "(%1) - (%2)"
syntax val (-_) "-(%1)"
syntax val ( * ) "(%1) * (%2)"
syntax val (/) "(%1) / (%2)"
syntax val (%) "(%1) % (%2)"
syntax val (=) "(%1) == (%2)"
syntax val (<=) "(%1) <= (%2)"
syntax val (<) "(%1) < (%2)"
syntax val (>=) "(%1) >= (%2)"
syntax val (>) "(%1) > (%2)"
syntax val (+) "%1 + %2"
syntax val (-) "%1 - %2"
syntax val (-_) "-%1"
syntax val ( * ) "%1 * %2"
syntax val (/) "%1 / %2"
syntax val (%) "%1 % %2"
syntax val (=) "%1 == %2"
syntax val (<=) "%1 <= %2"
syntax val (<) "%1 < %2"
syntax val (>=) "%1 >= %2"
syntax val (>) "%1 > %2"
end
module mach.int.UInt32Gen
......@@ -54,16 +54,16 @@ module mach.int.UInt32
syntax converter of_int "%1U"
syntax val (+) "(%1) + (%2)"
syntax val (-) "(%1) - (%2)"
syntax val ( * ) "(%1) * (%2)"
syntax val (/) "(%1) / (%2)"
syntax val (%) "(%1) % (%2)"
syntax val (=) "(%1) == (%2)"
syntax val (<=) "(%1) <= (%2)"
syntax val (<) "(%1) < (%2)"
syntax val (>=) "(%1) >= (%2)"
syntax val (>) "(%1) > (%2)"
syntax val (+) "%1 + %2"
syntax val (-) "%1 - %2"
syntax val ( * ) "%1 * %2"
syntax val (/) "%1 / %2"
syntax val (%) "%1 % %2"
syntax val (=) "%1 == %2"
syntax val (<=) "%1 <= %2"
syntax val (<) "%1 < %2"
syntax val (>=) "%1 >= %2"
syntax val (>) "%1 > %2"
end
......@@ -185,16 +185,16 @@ struct __lsld32_result lsld32(uint32_t x, uint32_t cnt);
syntax converter of_int "%1U"
syntax val (+) "(%1) + (%2)"
syntax val (-) "(%1) - (%2)"
syntax val ( * ) "(%1) * (%2)"
syntax val (/) "(%1) / (%2)"
syntax val (%) "(%1) % (%2)"
syntax val (=) "(%1) == (%2)"
syntax val (<=) "(%1) <= (%2)"
syntax val (<) "(%1) < (%2)"
syntax val (>=) "(%1) >= (%2)"
syntax val (>) "(%1) > (%2)"
syntax val (+) "%1 + %2"
syntax val (-) "%1 - %2"
syntax val ( * ) "%1 * %2"
syntax val (/) "%1 / %2"
syntax val (%) "%1 % %2"
syntax val (=) "%1 == %2"
syntax val (<=) "%1 <= %2"
syntax val (<) "%1 < %2"
syntax val (>=) "%1 >= %2"
syntax val (>) "%1 > %2"
syntax val add_with_carry "add32_with_carry"
syntax val sub_with_borrow "sub32_with_borrow"
......@@ -202,17 +202,17 @@ struct __lsld32_result lsld32(uint32_t x, uint32_t cnt);
syntax val add3 "add32_3"
syntax val lsld "lsld32"
syntax val add_mod "(%1) + (%2)"
syntax val sub_mod "(%1) - (%2)"
syntax val mul_mod "(%1) * (%2)"
syntax val add_mod "%1 + %2"
syntax val sub_mod "%1 - %2"
syntax val mul_mod "%1 * %2"
syntax val div2by1
"(uint32_t)((((uint64_t)%1) | (((uint64_t)%2) << 32))/(uint64_t)(%3))"
syntax val lsl "(%1) << (%2)"
syntax val lsr "(%1) >> (%2)"
syntax val lsl "%1 << %2"
syntax val lsr "%1 >> %2"
syntax val is_msb_set "(%1) & 0x80000000U"
syntax val is_msb_set "%1 & 0x80000000U"
syntax val count_leading_zeros "__builtin_clz(%1)"
......@@ -227,17 +227,17 @@ module mach.int.Int64
syntax type int64 "int64_t"
syntax val (+) "(%1) + (%2)"
syntax val (-) "(%1) - (%2)"
syntax val (-_) "-(%1)"
syntax val ( * ) "(%1) * (%2)"
syntax val (/) "(%1) / (%2)"
syntax val (%) "(%1) % (%2)"
syntax val (=) "(%1) == (%2)"
syntax val (<=) "(%1) <= (%2)"
syntax val (<) "(%1) < (%2)"
syntax val (>=) "(%1) >= (%2)"
syntax val (>) "(%1) > (%2)"
syntax val (+) "%1 + %2"
syntax val (-) "%1 - %2"
syntax val (-_) "-%1"
syntax val ( * ) "%1 * %2"
syntax val (/) "%1 / %2"
syntax val (%) "%1 % %2"
syntax val (=) "%1 == %2"
syntax val (<=) "%1 <= %2"
syntax val (<) "%1 < %2"
syntax val (>=) "%1 >= %2"
syntax val (>) "%1 > %2"
end
module mach.int.UInt64Gen
......@@ -253,17 +253,17 @@ module mach.int.UInt64
syntax converter of_int "%1ULL"
syntax val (+) "(%1) + (%2)"
syntax val (-) "(%1) - (%2)"
syntax val (-_) "-(%1)"
syntax val ( * ) "(%1) * (%2)"
syntax val (/) "(%1) / (%2)"
syntax val (%) "(%1) % (%2)"
syntax val (=) "(%1) == (%2)"
syntax val (<=) "(%1) <= (%2)"
syntax val (<) "(%1) < (%2)"
syntax val (>=) "(%1) >= (%2)"
syntax val (>) "(%1) > (%2)"
syntax val (+) "%1 + %2"
syntax val (-) "%1 - %2"
syntax val (-_) "-%1"
syntax val ( * ) "%1 * %2"
syntax val (/) "%1 / %2"
syntax val (%) "%1 % %2"
syntax val (=) "%1 == %2"
syntax val (<=) "%1 <= %2"
syntax val (<) "%1 < %2"
syntax val (>=) "%1 >= %2"
syntax val (>) "%1 > %2"
end
......@@ -508,16 +508,16 @@ static struct __lsld64_result lsld64(uint64_t x, uint64_t cnt)
"
syntax converter of_int "%1ULL"
syntax val (+) "(%1) + (%2)"
syntax val (-) "(%1) - (%2)"
syntax val ( * ) "(%1) * (%2)"
syntax val (/) "(%1) / (%2)"
syntax val (%) "(%1) % (%2)"
syntax val (=) "(%1) == (%2)"
syntax val (<=) "(%1) <= (%2)"
syntax val (<) "(%1) < (%2)"
syntax val (>=) "(%1) >= (%2)"
syntax val (>) "(%1) > (%2)"
syntax val (+) "%1 + %2"
syntax val (-) "%1 - %2"
syntax val ( * ) "%1 * %2"
syntax val (/) "%1 / %2"
syntax val (%) "%1 % %2"
syntax val (=) "%1 == %2"
syntax val (<=) "%1 <= %2"
syntax val (<) "%1 < %2"
syntax val (>=) "%1 >= %2"
syntax val (>) "%1 > %2"
syntax val add_with_carry "add64_with_carry"
syntax val add_double "add64_double"
......@@ -532,14 +532,14 @@ static struct __lsld64_result lsld64(uint64_t x, uint64_t cnt)
syntax val add3 "add64_3"
syntax val lsld "lsld64"
syntax val add_mod "(%1) + (%2)"
syntax val sub_mod "(%1) - (%2)"
syntax val mul_mod "(%1) * (%2)"
syntax val add_mod "%1 + %2"
syntax val sub_mod "%1 - %2"
syntax val mul_mod "%1 * %2"
syntax val lsl "(%1) << (%2)"
syntax val lsr "(%1) >> (%2)"
syntax val lsl "%1 << %2"
syntax val lsr "%1 >> %2"
syntax val is_msb_set "(%1) & 0x8000000000000000ULL"
syntax val is_msb_set "%1 & 0x8000000000000000ULL"
syntax val count_leading_zeros "__builtin_clzll(%1)"
......@@ -570,16 +570,16 @@ module mach.c.C
syntax val is_not_null "(%1) != NULL"
syntax val null "NULL"
syntax val incr "%1+(%2)"
syntax val incr "%1 + %2"
syntax val get "*(%1)"
syntax val get_ofs "*(%1+(%2))"
syntax val get "*%1"
syntax val get_ofs "%1[%2]"
syntax val set "*(%1) = %2"
syntax val set_ofs "*(%1+(%2)) = %3"
syntax val set "*%1 = %2"
syntax val set_ofs "%1[%2] = %3"
syntax val incr_split "%1+(%2)"
syntax val decr_split "%1-(%2)"
syntax val incr_split "%1 + %2"
syntax val decr_split "%1 - %2"
syntax val join "IGNORE2"
syntax val join_r "IGNORE2"
......
......@@ -360,6 +360,20 @@ module C = struct
else b
| d,s -> d, s
(* Operator precedence, needed to compute which parentheses can be removed *)
let prec_unop = function
| Unot | Ustar | Uaddr | Upreincr | Upredecr -> 2
| Upostincr | Upostdecr -> 1
let prec_binop = function
| Band -> 11
| Bor -> 11 (* really 12, but this avoids Wparentheses *)
| Beq | Bne -> 7
| Bassign -> 14
| Blt | Ble | Bgt | Bge -> 6
end
type info = Pdriver.printer_args = private {
......@@ -392,9 +406,13 @@ module Print = struct
let sanitizer = sanitizer char_to_lalpha char_to_alnumus
let sanitizer s = Strings.lowercase (sanitizer s)
let printer = create_ident_printer c_keywords ~sanitizer
let local_printer = create_ident_printer c_keywords ~sanitizer
let global_printer = create_ident_printer c_keywords ~sanitizer
let print_ident fmt id = fprintf fmt "%s" (id_unique printer id)
let print_local_ident fmt id = fprintf fmt "%s" (id_unique local_printer id)
let print_global_ident fmt id = fprintf fmt "%s" (id_unique global_printer id)
let clear_local_printer () = Ident.forget_all local_printer
let protect_on x s = if x then "(" ^^ s ^^ ")" else s
......@@ -415,10 +433,10 @@ module Print = struct
(* should be handled in extract_stars *)
| Tarray (ty, expr) ->
fprintf fmt (protect_on paren "%a[%a]")
(print_ty ~paren:true) ty (print_expr ~paren:false) expr
(print_ty ~paren:true) ty (print_expr ~prec:1) expr
| Tstruct (s,_) -> fprintf fmt "struct %s" s
| Tunion _ -> raise (Unprinted "unions")
| Tnamed id -> print_ident fmt id
| Tnamed id -> print_global_ident fmt id
| Tnosyntax -> raise (Unprinted "type without syntax")
and print_unop fmt = function
......@@ -443,47 +461,57 @@ module Print = struct
| Bgt -> fprintf fmt ">"
| Bge -> fprintf fmt ">="
and print_expr ~paren fmt = function
and print_expr ~prec fmt = function
| Enothing -> Debug.dprintf debug_c_extraction "enothing"; ()
| Eunop(u,e) ->
let p = prec_unop u in
if unop_postfix u
then fprintf fmt (protect_on paren "%a%a")
(print_expr ~paren:true) e print_unop u
else fprintf fmt (protect_on paren "%a%a")
print_unop u (print_expr ~paren:true) e
then fprintf fmt (protect_on (prec <= p) "%a%a")
(print_expr ~prec:p) e print_unop u
else fprintf fmt (protect_on (prec <= p) "%a%a")
print_unop u (print_expr ~prec:p) e
| Ebinop(b,e1,e2) ->
fprintf fmt (protect_on paren "%a %a %a")
(print_expr ~paren:true) e1 print_binop b (print_expr ~paren:true) e2
let p = prec_binop b in
fprintf fmt (protect_on (prec <= p) "%a %a %a")
(print_expr ~prec:p) e1 print_binop b (print_expr ~prec:p) e2
| Equestion(c,t,e) ->
fprintf fmt (protect_on paren "%a ? %a : %a")
(print_expr ~paren:true) c
(print_expr ~paren:true) t
(print_expr ~paren:true) e
fprintf fmt (protect_on (prec <= 13) "%a ? %a : %a")
(print_expr ~prec:13) c
(print_expr ~prec:13) t
(print_expr ~prec:13) e
| Ecast(ty, e) ->
fprintf fmt (protect_on paren "(%a)%a")
(print_ty ~paren:false) ty (print_expr ~paren:true) e
| Ecall (Esyntax (s, _, _, [], _), l) -> (* function defined in the prelude *)
fprintf fmt (protect_on paren "%s(%a)")
s (print_list comma (print_expr ~paren:false)) l
| Ecall (e,l) -> fprintf fmt (protect_on paren "%a(%a)")
(print_expr ~paren:true) e (print_list comma (print_expr ~paren:false)) l
fprintf fmt (protect_on (prec <= 2) "(%a)%a")
(print_ty ~paren:false) ty (print_expr ~prec:2) e
| Ecall (Esyntax (s, _, _, [], _), l) ->
(* function defined in the prelude *)
fprintf fmt (protect_on (prec <= 1) "%s(%a)")
s (print_list comma (print_expr ~prec:15)) l
| Ecall (e,l) ->
fprintf fmt (protect_on (prec <= 1) "%a(%a)")
(print_expr ~prec:1) e (print_list comma (print_expr ~prec:15)) l
| Econst c -> print_const fmt c
| Evar id -> print_ident fmt id
| Elikely e -> fprintf fmt (protect_on paren "__builtin_expect(%a,1)")
(print_expr ~paren:true) e
| Eunlikely e -> fprintf fmt (protect_on paren "__builtin_expect(%a,0)")
(print_expr ~paren:true) e
| Evar id -> print_local_ident fmt id
| Elikely e -> fprintf fmt
(protect_on (prec <= 1) "__builtin_expect(%a,1)")
(print_expr ~prec:15) e
| Eunlikely e -> fprintf fmt
(protect_on (prec <= 1) "__builtin_expect(%a,0)")
(print_expr ~prec:15) e
| Esize_expr e ->
fprintf fmt (protect_on paren "sizeof(%a)") (print_expr ~paren:false) e
fprintf fmt (protect_on (prec <= 2) "sizeof(%a)") (print_expr ~prec:15) e
| Esize_type ty ->
fprintf fmt (protect_on paren "sizeof(%a)") (print_ty ~paren:false) ty
| Edot (e,s) -> fprintf fmt "%a.%s" (print_expr ~paren:true) e s
| Eindex _ | Earrow _ -> raise (Unprinted "struct/array access")
| Esyntax (s, t, args, lte,_) ->
gen_syntax_arguments_typed snd (fun _ -> args)
(if paren then ("("^s^")") else s)
(fun fmt (e,_t) -> print_expr ~paren:false fmt e)
(print_ty ~paren:false) (C.Enothing,t) fmt lte
fprintf fmt (protect_on (prec <= 2) "sizeof(%a)")
(print_ty ~paren:false) ty
| Edot (e,s) ->
fprintf fmt (protect_on (prec <= 1) "%a.%s")
(print_expr ~prec:1) e s
| Eindex _ | Earrow _ -> raise (Unprinted "struct/union access")
| Esyntax (s, t, args, lte, c) ->
(* no way to know precedence, so full parentheses*)
gen_syntax_arguments_typed snd (fun _ -> args)
(if prec <= 13 && not c then ("("^s^")") else s)
(fun fmt (e,_t) -> print_expr ~prec:1 fmt e)
(print_ty ~paren:false) (C.Enothing,t) fmt lte
and print_const fmt = function
| Cint s | Cfloat s | Cchar s | Cstring s -> fprintf fmt "%s" s
......@@ -493,12 +521,15 @@ module Print = struct
then fprintf fmt "%s " (String.make stars '*')
else ());
match ie with
| id, Enothing -> print_ident fmt id
| id,e -> fprintf fmt "%a = %a" print_ident id (print_expr ~paren:false) e
| id, Enothing -> print_local_ident fmt id
| id,e -> fprintf fmt "%a = %a"
print_local_ident id (print_expr ~prec:(prec_binop Bassign)) e
let print_expr_no_paren fmt expr = print_expr ~prec:max_int fmt expr
let rec print_stmt ~braces fmt = function
| Snop -> Debug.dprintf debug_c_extraction "snop"; ()
| Sexpr e -> fprintf fmt "%a;" (print_expr ~paren:false) e;
| Sexpr e -> fprintf fmt "%a;" print_expr_no_paren e;
| Sblock ([] ,s) when not braces ->
(print_stmt ~braces:false) fmt s
| Sblock b -> fprintf fmt "@[<hov>{@\n @[<hov>%a@]@\n}@]" print_body b
......@@ -506,33 +537,33 @@ module Print = struct
(print_stmt ~braces:false) s1
(print_stmt ~braces:false) s2
| Sif(c,t,e) when is_nop e ->
fprintf fmt "if(%a)@\n%a" (print_expr ~paren:false) c
fprintf fmt "if(%a)@\n%a" print_expr_no_paren c
(print_stmt ~braces:true) (Sblock([],t))
| Sif (c,t,e) -> fprintf fmt "if(%a)@\n%a@\nelse@\n%a"
(print_expr ~paren:false) c
print_expr_no_paren c
(print_stmt ~braces:true) (Sblock([],t))
(print_stmt ~braces:true) (Sblock([],e))
| Swhile (e,b) -> fprintf fmt "while (%a)@;<1 2>%a"
(print_expr ~paren:false) e (print_stmt ~braces:true) (Sblock([],b))
print_expr_no_paren e (print_stmt ~braces:true) (Sblock([],b))
| Sfor (einit, etest, eincr, s) ->
fprintf fmt "for (%a; %a; %a)@;<1 2>%a"
(print_expr ~paren:false) einit
(print_expr ~paren:false) etest
(print_expr ~paren:false) eincr
print_expr_no_paren einit
print_expr_no_paren etest
print_expr_no_paren eincr
(print_stmt ~braces:true) (Sblock([],s))
| Sbreak -> fprintf fmt "break;"
| Sreturn Enothing -> fprintf fmt "return;"
| Sreturn e -> fprintf fmt "return %a;" (print_expr ~paren:true) e
| Sreturn e -> fprintf fmt "return %a;" print_expr_no_paren e
and print_def fmt def =
try match def with
| Dfun (id,(rt,args),body) ->
let s = sprintf "%a %a(@[%a@])@ @[<hov>{@;<1 2>@[<hov>%a@]@\n}@\n@]"
(print_ty ~paren:false) rt
print_ident id
print_global_ident id
(print_list comma
(print_pair_delim nothing space nothing
(print_ty ~paren:false) print_ident))
(print_ty ~paren:false) print_local_ident))
args
print_body body in
(* print into string first to print nothing in case of exception *)
......@@ -540,10 +571,10 @@ module Print = struct
| Dproto (id, (rt, args)) ->
let s = sprintf "%a %a(@[%a@]);@;"
(print_ty ~paren:false) rt
print_ident id
print_global_ident id
(print_list comma
(print_pair_delim nothing space nothing
(print_ty ~paren:false) print_ident))
(print_ty ~paren:false) print_local_ident))
args in
fprintf fmt "%s" s
| Ddecl (ty, lie) ->
......@@ -569,7 +600,7 @@ module Print = struct
fprintf fmt "#include \"%s.h\"@;" (sanitizer id.id_string)
| Dtypedef (ty,id) ->
let s = sprintf "@[<hov>typedef@ %a@;%a;@]"
(print_ty ~paren:false) ty print_ident id in
(print_ty ~paren:false) ty print_global_ident id in
fprintf fmt "%s" s
with Unprinted s ->
Debug.dprintf debug_c_extraction "Missed a def because : %s@." s
......@@ -583,10 +614,14 @@ module Print = struct
(print_stmt ~braces:true)
fmt (def,s)
let print_global_def fmt def =
clear_local_printer ();
print_def fmt def
let print_file fmt info ast =
Mid.iter (fun _ sl -> List.iter (fprintf fmt "%s\n") sl) info.thprelude;
newline fmt ();
fprintf fmt "@[<v>%a@]@." (print_list newline print_def) ast
fprintf fmt "@[<v>%a@]@." (print_list newline print_global_def) ast
end
......@@ -1145,7 +1180,7 @@ let header_gen = name_gen ".h"
let print_header_decl args fmt d =
let cds = MLToC.translate_decl args d ~header:true in
List.iter (Format.fprintf fmt "%a@." Print.print_def) cds
List.iter (Format.fprintf fmt "%a@." Print.print_global_def) cds
let print_header_decl =
let memo = Hashtbl.create 16 in
......@@ -1166,7 +1201,7 @@ let print_prelude args ?old ?fname ~flat deps fmt pm =
ignore pm;
ignore args;
let add_include id =
Format.fprintf fmt "%a@." Print.print_def C.(Dinclude (id,Proj)) in
Format.fprintf fmt "%a@." Print.print_global_def C.(Dinclude (id,Proj)) in
List.iter
(fun m ->
let id = m.Pmodule.mod_theory.Theory.th_name in
......@@ -1176,7 +1211,7 @@ let print_prelude args ?old ?fname ~flat deps fmt pm =
let print_decl args fmt d =
let cds = MLToC.translate_decl args d ~header:false in
let print_def d =
Format.fprintf fmt "%a@." Print.print_def d in
Format.fprintf fmt "%a@." Print.print_global_def d in
List.iter print_def cds
let print_decl =
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment