Commit 6b94c96b authored by Raphael Rieu-Helft's avatar Raphael Rieu-Helft

C extraction: print much fewer parentheses

parent 0f7b3e3d
...@@ -27,17 +27,17 @@ module mach.int.Int32 ...@@ -27,17 +27,17 @@ module mach.int.Int32
syntax type int32 "int32_t" syntax type int32 "int32_t"
syntax val (+) "(%1) + (%2)" syntax val (+) "%1 + %2"
syntax val (-) "(%1) - (%2)" syntax val (-) "%1 - %2"
syntax val (-_) "-(%1)" syntax val (-_) "-%1"
syntax val ( * ) "(%1) * (%2)" syntax val ( * ) "%1 * %2"
syntax val (/) "(%1) / (%2)" syntax val (/) "%1 / %2"
syntax val (%) "(%1) % (%2)" syntax val (%) "%1 % %2"
syntax val (=) "(%1) == (%2)" syntax val (=) "%1 == %2"
syntax val (<=) "(%1) <= (%2)" syntax val (<=) "%1 <= %2"
syntax val (<) "(%1) < (%2)" syntax val (<) "%1 < %2"
syntax val (>=) "(%1) >= (%2)" syntax val (>=) "%1 >= %2"
syntax val (>) "(%1) > (%2)" syntax val (>) "%1 > %2"
end end
module mach.int.UInt32Gen module mach.int.UInt32Gen
...@@ -54,16 +54,16 @@ module mach.int.UInt32 ...@@ -54,16 +54,16 @@ module mach.int.UInt32
syntax converter of_int "%1U" syntax converter of_int "%1U"
syntax val (+) "(%1) + (%2)" syntax val (+) "%1 + %2"
syntax val (-) "(%1) - (%2)" syntax val (-) "%1 - %2"
syntax val ( * ) "(%1) * (%2)" syntax val ( * ) "%1 * %2"
syntax val (/) "(%1) / (%2)" syntax val (/) "%1 / %2"
syntax val (%) "(%1) % (%2)" syntax val (%) "%1 % %2"
syntax val (=) "(%1) == (%2)" syntax val (=) "%1 == %2"
syntax val (<=) "(%1) <= (%2)" syntax val (<=) "%1 <= %2"
syntax val (<) "(%1) < (%2)" syntax val (<) "%1 < %2"
syntax val (>=) "(%1) >= (%2)" syntax val (>=) "%1 >= %2"
syntax val (>) "(%1) > (%2)" syntax val (>) "%1 > %2"
end end
...@@ -185,16 +185,16 @@ struct __lsld32_result lsld32(uint32_t x, uint32_t cnt); ...@@ -185,16 +185,16 @@ struct __lsld32_result lsld32(uint32_t x, uint32_t cnt);
syntax converter of_int "%1U" syntax converter of_int "%1U"
syntax val (+) "(%1) + (%2)" syntax val (+) "%1 + %2"
syntax val (-) "(%1) - (%2)" syntax val (-) "%1 - %2"
syntax val ( * ) "(%1) * (%2)" syntax val ( * ) "%1 * %2"
syntax val (/) "(%1) / (%2)" syntax val (/) "%1 / %2"
syntax val (%) "(%1) % (%2)" syntax val (%) "%1 % %2"
syntax val (=) "(%1) == (%2)" syntax val (=) "%1 == %2"
syntax val (<=) "(%1) <= (%2)" syntax val (<=) "%1 <= %2"
syntax val (<) "(%1) < (%2)" syntax val (<) "%1 < %2"
syntax val (>=) "(%1) >= (%2)" syntax val (>=) "%1 >= %2"
syntax val (>) "(%1) > (%2)" syntax val (>) "%1 > %2"
syntax val add_with_carry "add32_with_carry" syntax val add_with_carry "add32_with_carry"
syntax val sub_with_borrow "sub32_with_borrow" syntax val sub_with_borrow "sub32_with_borrow"
...@@ -202,17 +202,17 @@ struct __lsld32_result lsld32(uint32_t x, uint32_t cnt); ...@@ -202,17 +202,17 @@ struct __lsld32_result lsld32(uint32_t x, uint32_t cnt);
syntax val add3 "add32_3" syntax val add3 "add32_3"
syntax val lsld "lsld32" syntax val lsld "lsld32"
syntax val add_mod "(%1) + (%2)" syntax val add_mod "%1 + %2"
syntax val sub_mod "(%1) - (%2)" syntax val sub_mod "%1 - %2"
syntax val mul_mod "(%1) * (%2)" syntax val mul_mod "%1 * %2"
syntax val div2by1 syntax val div2by1
"(uint32_t)((((uint64_t)%1) | (((uint64_t)%2) << 32))/(uint64_t)(%3))" "(uint32_t)((((uint64_t)%1) | (((uint64_t)%2) << 32))/(uint64_t)(%3))"
syntax val lsl "(%1) << (%2)" syntax val lsl "%1 << %2"
syntax val lsr "(%1) >> (%2)" syntax val lsr "%1 >> %2"
syntax val is_msb_set "(%1) & 0x80000000U" syntax val is_msb_set "%1 & 0x80000000U"
syntax val count_leading_zeros "__builtin_clz(%1)" syntax val count_leading_zeros "__builtin_clz(%1)"
...@@ -227,17 +227,17 @@ module mach.int.Int64 ...@@ -227,17 +227,17 @@ module mach.int.Int64
syntax type int64 "int64_t" syntax type int64 "int64_t"
syntax val (+) "(%1) + (%2)" syntax val (+) "%1 + %2"
syntax val (-) "(%1) - (%2)" syntax val (-) "%1 - %2"
syntax val (-_) "-(%1)" syntax val (-_) "-%1"
syntax val ( * ) "(%1) * (%2)" syntax val ( * ) "%1 * %2"
syntax val (/) "(%1) / (%2)" syntax val (/) "%1 / %2"
syntax val (%) "(%1) % (%2)" syntax val (%) "%1 % %2"
syntax val (=) "(%1) == (%2)" syntax val (=) "%1 == %2"
syntax val (<=) "(%1) <= (%2)" syntax val (<=) "%1 <= %2"
syntax val (<) "(%1) < (%2)" syntax val (<) "%1 < %2"
syntax val (>=) "(%1) >= (%2)" syntax val (>=) "%1 >= %2"
syntax val (>) "(%1) > (%2)" syntax val (>) "%1 > %2"
end end
module mach.int.UInt64Gen module mach.int.UInt64Gen
...@@ -253,17 +253,17 @@ module mach.int.UInt64 ...@@ -253,17 +253,17 @@ module mach.int.UInt64
syntax converter of_int "%1ULL" syntax converter of_int "%1ULL"
syntax val (+) "(%1) + (%2)" syntax val (+) "%1 + %2"
syntax val (-) "(%1) - (%2)" syntax val (-) "%1 - %2"
syntax val (-_) "-(%1)" syntax val (-_) "-%1"
syntax val ( * ) "(%1) * (%2)" syntax val ( * ) "%1 * %2"
syntax val (/) "(%1) / (%2)" syntax val (/) "%1 / %2"
syntax val (%) "(%1) % (%2)" syntax val (%) "%1 % %2"
syntax val (=) "(%1) == (%2)" syntax val (=) "%1 == %2"
syntax val (<=) "(%1) <= (%2)" syntax val (<=) "%1 <= %2"
syntax val (<) "(%1) < (%2)" syntax val (<) "%1 < %2"
syntax val (>=) "(%1) >= (%2)" syntax val (>=) "%1 >= %2"
syntax val (>) "(%1) > (%2)" syntax val (>) "%1 > %2"
end end
...@@ -508,16 +508,16 @@ static struct __lsld64_result lsld64(uint64_t x, uint64_t cnt) ...@@ -508,16 +508,16 @@ static struct __lsld64_result lsld64(uint64_t x, uint64_t cnt)
" "
syntax converter of_int "%1ULL" syntax converter of_int "%1ULL"
syntax val (+) "(%1) + (%2)" syntax val (+) "%1 + %2"
syntax val (-) "(%1) - (%2)" syntax val (-) "%1 - %2"
syntax val ( * ) "(%1) * (%2)" syntax val ( * ) "%1 * %2"
syntax val (/) "(%1) / (%2)" syntax val (/) "%1 / %2"
syntax val (%) "(%1) % (%2)" syntax val (%) "%1 % %2"
syntax val (=) "(%1) == (%2)" syntax val (=) "%1 == %2"
syntax val (<=) "(%1) <= (%2)" syntax val (<=) "%1 <= %2"
syntax val (<) "(%1) < (%2)" syntax val (<) "%1 < %2"
syntax val (>=) "(%1) >= (%2)" syntax val (>=) "%1 >= %2"
syntax val (>) "(%1) > (%2)" syntax val (>) "%1 > %2"
syntax val add_with_carry "add64_with_carry" syntax val add_with_carry "add64_with_carry"
syntax val add_double "add64_double" syntax val add_double "add64_double"
...@@ -532,14 +532,14 @@ static struct __lsld64_result lsld64(uint64_t x, uint64_t cnt) ...@@ -532,14 +532,14 @@ static struct __lsld64_result lsld64(uint64_t x, uint64_t cnt)
syntax val add3 "add64_3" syntax val add3 "add64_3"
syntax val lsld "lsld64" syntax val lsld "lsld64"
syntax val add_mod "(%1) + (%2)" syntax val add_mod "%1 + %2"
syntax val sub_mod "(%1) - (%2)" syntax val sub_mod "%1 - %2"
syntax val mul_mod "(%1) * (%2)" syntax val mul_mod "%1 * %2"
syntax val lsl "(%1) << (%2)" syntax val lsl "%1 << %2"
syntax val lsr "(%1) >> (%2)" syntax val lsr "%1 >> %2"
syntax val is_msb_set "(%1) & 0x8000000000000000ULL" syntax val is_msb_set "%1 & 0x8000000000000000ULL"
syntax val count_leading_zeros "__builtin_clzll(%1)" syntax val count_leading_zeros "__builtin_clzll(%1)"
...@@ -570,16 +570,16 @@ module mach.c.C ...@@ -570,16 +570,16 @@ module mach.c.C
syntax val is_not_null "(%1) != NULL" syntax val is_not_null "(%1) != NULL"
syntax val null "NULL" syntax val null "NULL"
syntax val incr "%1+(%2)" syntax val incr "%1 + %2"
syntax val get "*(%1)" syntax val get "*%1"
syntax val get_ofs "*(%1+(%2))" syntax val get_ofs "%1[%2]"
syntax val set "*(%1) = %2" syntax val set "*%1 = %2"
syntax val set_ofs "*(%1+(%2)) = %3" syntax val set_ofs "%1[%2] = %3"
syntax val incr_split "%1+(%2)" syntax val incr_split "%1 + %2"
syntax val decr_split "%1-(%2)" syntax val decr_split "%1 - %2"
syntax val join "IGNORE2" syntax val join "IGNORE2"
syntax val join_r "IGNORE2" syntax val join_r "IGNORE2"
......
...@@ -360,6 +360,20 @@ module C = struct ...@@ -360,6 +360,20 @@ module C = struct
else b else b
| d,s -> d, s | d,s -> d, s
(* Operator precedence, needed to compute which parentheses can be removed *)
let prec_unop = function
| Unot | Ustar | Uaddr | Upreincr | Upredecr -> 2
| Upostincr | Upostdecr -> 1
let prec_binop = function
| Band -> 11
| Bor -> 11 (* really 12, but this avoids Wparentheses *)
| Beq | Bne -> 7
| Bassign -> 14
| Blt | Ble | Bgt | Bge -> 6
end end
type info = Pdriver.printer_args = private { type info = Pdriver.printer_args = private {
...@@ -419,7 +433,7 @@ module Print = struct ...@@ -419,7 +433,7 @@ module Print = struct
(* should be handled in extract_stars *) (* should be handled in extract_stars *)
| Tarray (ty, expr) -> | Tarray (ty, expr) ->
fprintf fmt (protect_on paren "%a[%a]") fprintf fmt (protect_on paren "%a[%a]")
(print_ty ~paren:true) ty (print_expr ~paren:false) expr (print_ty ~paren:true) ty (print_expr ~prec:1) expr
| Tstruct (s,_) -> fprintf fmt "struct %s" s | Tstruct (s,_) -> fprintf fmt "struct %s" s
| Tunion _ -> raise (Unprinted "unions") | Tunion _ -> raise (Unprinted "unions")
| Tnamed id -> print_global_ident fmt id | Tnamed id -> print_global_ident fmt id
...@@ -447,47 +461,57 @@ module Print = struct ...@@ -447,47 +461,57 @@ module Print = struct
| Bgt -> fprintf fmt ">" | Bgt -> fprintf fmt ">"
| Bge -> fprintf fmt ">=" | Bge -> fprintf fmt ">="
and print_expr ~paren fmt = function and print_expr ~prec fmt = function
| Enothing -> Debug.dprintf debug_c_extraction "enothing"; () | Enothing -> Debug.dprintf debug_c_extraction "enothing"; ()
| Eunop(u,e) -> | Eunop(u,e) ->
let p = prec_unop u in
if unop_postfix u if unop_postfix u
then fprintf fmt (protect_on paren "%a%a") then fprintf fmt (protect_on (prec <= p) "%a%a")
(print_expr ~paren:true) e print_unop u (print_expr ~prec:p) e print_unop u
else fprintf fmt (protect_on paren "%a%a") else fprintf fmt (protect_on (prec <= p) "%a%a")
print_unop u (print_expr ~paren:true) e print_unop u (print_expr ~prec:p) e
| Ebinop(b,e1,e2) -> | Ebinop(b,e1,e2) ->
fprintf fmt (protect_on paren "%a %a %a") let p = prec_binop b in
(print_expr ~paren:true) e1 print_binop b (print_expr ~paren:true) e2 fprintf fmt (protect_on (prec <= p) "%a %a %a")
(print_expr ~prec:p) e1 print_binop b (print_expr ~prec:p) e2
| Equestion(c,t,e) -> | Equestion(c,t,e) ->
fprintf fmt (protect_on paren "%a ? %a : %a") fprintf fmt (protect_on (prec <= 13) "%a ? %a : %a")
(print_expr ~paren:true) c (print_expr ~prec:13) c
(print_expr ~paren:true) t (print_expr ~prec:13) t
(print_expr ~paren:true) e (print_expr ~prec:13) e
| Ecast(ty, e) -> | Ecast(ty, e) ->
fprintf fmt (protect_on paren "(%a)%a") fprintf fmt (protect_on (prec <= 2) "(%a)%a")
(print_ty ~paren:false) ty (print_expr ~paren:true) e (print_ty ~paren:false) ty (print_expr ~prec:2) e
| Ecall (Esyntax (s, _, _, [], _), l) -> (* function defined in the prelude *) | Ecall (Esyntax (s, _, _, [], _), l) ->
fprintf fmt (protect_on paren "%s(%a)") (* function defined in the prelude *)
s (print_list comma (print_expr ~paren:false)) l fprintf fmt (protect_on (prec <= 1) "%s(%a)")
| Ecall (e,l) -> fprintf fmt (protect_on paren "%a(%a)") s (print_list comma (print_expr ~prec:15)) l
(print_expr ~paren:true) e (print_list comma (print_expr ~paren:false)) l | Ecall (e,l) ->
fprintf fmt (protect_on (prec <= 1) "%a(%a)")
(print_expr ~prec:1) e (print_list comma (print_expr ~prec:15)) l
| Econst c -> print_const fmt c | Econst c -> print_const fmt c
| Evar id -> print_local_ident fmt id | Evar id -> print_local_ident fmt id
| Elikely e -> fprintf fmt (protect_on paren "__builtin_expect(%a,1)") | Elikely e -> fprintf fmt
(print_expr ~paren:true) e (protect_on (prec <= 1) "__builtin_expect(%a,1)")
| Eunlikely e -> fprintf fmt (protect_on paren "__builtin_expect(%a,0)") (print_expr ~prec:15) e
(print_expr ~paren:true) e | Eunlikely e -> fprintf fmt
(protect_on (prec <= 1) "__builtin_expect(%a,0)")
(print_expr ~prec:15) e
| Esize_expr e -> | Esize_expr e ->
fprintf fmt (protect_on paren "sizeof(%a)") (print_expr ~paren:false) e fprintf fmt (protect_on (prec <= 2) "sizeof(%a)") (print_expr ~prec:15) e
| Esize_type ty -> | Esize_type ty ->
fprintf fmt (protect_on paren "sizeof(%a)") (print_ty ~paren:false) ty fprintf fmt (protect_on (prec <= 2) "sizeof(%a)")
| Edot (e,s) -> fprintf fmt "%a.%s" (print_expr ~paren:true) e s (print_ty ~paren:false) ty
| Eindex _ | Earrow _ -> raise (Unprinted "struct/array access") | Edot (e,s) ->
| Esyntax (s, t, args, lte,_) -> fprintf fmt (protect_on (prec <= 1) "%a.%s")
gen_syntax_arguments_typed snd (fun _ -> args) (print_expr ~prec:1) e s
(if paren then ("("^s^")") else s) | Eindex _ | Earrow _ -> raise (Unprinted "struct/union access")
(fun fmt (e,_t) -> print_expr ~paren:false fmt e) | Esyntax (s, t, args, lte, c) ->
(print_ty ~paren:false) (C.Enothing,t) fmt lte (* no way to know precedence, so full parentheses*)
gen_syntax_arguments_typed snd (fun _ -> args)
(if prec <= 13 && not c then ("("^s^")") else s)
(fun fmt (e,_t) -> print_expr ~prec:1 fmt e)
(print_ty ~paren:false) (C.Enothing,t) fmt lte
and print_const fmt = function and print_const fmt = function
| Cint s | Cfloat s | Cchar s | Cstring s -> fprintf fmt "%s" s | Cint s | Cfloat s | Cchar s | Cstring s -> fprintf fmt "%s" s
...@@ -499,11 +523,13 @@ module Print = struct ...@@ -499,11 +523,13 @@ module Print = struct
match ie with match ie with
| id, Enothing -> print_local_ident fmt id | id, Enothing -> print_local_ident fmt id
| id,e -> fprintf fmt "%a = %a" | id,e -> fprintf fmt "%a = %a"
print_local_ident id (print_expr ~paren:false) e print_local_ident id (print_expr ~prec:(prec_binop Bassign)) e
let print_expr_no_paren fmt expr = print_expr ~prec:max_int fmt expr
let rec print_stmt ~braces fmt = function let rec print_stmt ~braces fmt = function
| Snop -> Debug.dprintf debug_c_extraction "snop"; () | Snop -> Debug.dprintf debug_c_extraction "snop"; ()
| Sexpr e -> fprintf fmt "%a;" (print_expr ~paren:false) e; | Sexpr e -> fprintf fmt "%a;" print_expr_no_paren e;
| Sblock ([] ,s) when not braces -> | Sblock ([] ,s) when not braces ->
(print_stmt ~braces:false) fmt s (print_stmt ~braces:false) fmt s
| Sblock b -> fprintf fmt "@[<hov>{@\n @[<hov>%a@]@\n}@]" print_body b | Sblock b -> fprintf fmt "@[<hov>{@\n @[<hov>%a@]@\n}@]" print_body b
...@@ -511,23 +537,23 @@ module Print = struct ...@@ -511,23 +537,23 @@ module Print = struct
(print_stmt ~braces:false) s1 (print_stmt ~braces:false) s1
(print_stmt ~braces:false) s2 (print_stmt ~braces:false) s2
| Sif(c,t,e) when is_nop e -> | Sif(c,t,e) when is_nop e ->
fprintf fmt "if(%a)@\n%a" (print_expr ~paren:false) c fprintf fmt "if(%a)@\n%a" print_expr_no_paren c
(print_stmt ~braces:true) (Sblock([],t)) (print_stmt ~braces:true) (Sblock([],t))
| Sif (c,t,e) -> fprintf fmt "if(%a)@\n%a@\nelse@\n%a" | Sif (c,t,e) -> fprintf fmt "if(%a)@\n%a@\nelse@\n%a"
(print_expr ~paren:false) c print_expr_no_paren c
(print_stmt ~braces:true) (Sblock([],t)) (print_stmt ~braces:true) (Sblock([],t))
(print_stmt ~braces:true) (Sblock([],e)) (print_stmt ~braces:true) (Sblock([],e))
| Swhile (e,b) -> fprintf fmt "while (%a)@;<1 2>%a" | Swhile (e,b) -> fprintf fmt "while (%a)@;<1 2>%a"
(print_expr ~paren:false) e (print_stmt ~braces:true) (Sblock([],b)) print_expr_no_paren e (print_stmt ~braces:true) (Sblock([],b))
| Sfor (einit, etest, eincr, s) -> | Sfor (einit, etest, eincr, s) ->
fprintf fmt "for (%a; %a; %a)@;<1 2>%a" fprintf fmt "for (%a; %a; %a)@;<1 2>%a"
(print_expr ~paren:false) einit print_expr_no_paren einit
(print_expr ~paren:false) etest print_expr_no_paren etest
(print_expr ~paren:false) eincr print_expr_no_paren eincr
(print_stmt ~braces:true) (Sblock([],s)) (print_stmt ~braces:true) (Sblock([],s))
| Sbreak -> fprintf fmt "break;" | Sbreak -> fprintf fmt "break;"
| Sreturn Enothing -> fprintf fmt "return;" | Sreturn Enothing -> fprintf fmt "return;"
| Sreturn e -> fprintf fmt "return %a;" (print_expr ~paren:true) e | Sreturn e -> fprintf fmt "return %a;" print_expr_no_paren e
and print_def fmt def = and print_def fmt def =
try match def with try match def with
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment