Commit 2e7810ad authored by Raphael Rieu-Helft's avatar Raphael Rieu-Helft

C extraction: use driver symbols precedence

parent 6895b227
...@@ -7,13 +7,13 @@ prelude "#include <assert.h>" ...@@ -7,13 +7,13 @@ prelude "#include <assert.h>"
module Ref module Ref
syntax type ref "%1" syntax type ref "%1"
syntax val ref "%1" syntax val ref "%1" prec 0
syntax val contents "%1" syntax val contents "%1" prec 0
end end
module ref.Ref module ref.Ref
syntax val (!_) "%1" syntax val (!_) "%1" prec 0
syntax val (:=) "%1 = %2" syntax val (:=) "%1 = %2" prec 14
end end
module mach.int.Unsigned module mach.int.Unsigned
...@@ -27,25 +27,25 @@ module mach.int.Int32 ...@@ -27,25 +27,25 @@ module mach.int.Int32
syntax type int32 "int32_t" syntax type int32 "int32_t"
syntax literal int32 "%d" syntax literal int32 "%d"
syntax val (+) "%1 + %2" syntax val (+) "%1 + %2" prec 4
syntax val (-) "%1 - %2" syntax val (-) "%1 - %2" prec 4
syntax val (-_) "-%1" syntax val (-_) "-%1" prec 2
syntax val ( * ) "%1 * %2" syntax val ( * ) "%1 * %2" prec 3
syntax val (/) "%1 / %2" syntax val (/) "%1 / %2" prec 3
syntax val (%) "%1 % %2" syntax val (%) "%1 % %2" prec 3
syntax val (=) "%1 == %2" syntax val (=) "%1 == %2" prec 7
syntax val (<=) "%1 <= %2" syntax val (<=) "%1 <= %2" prec 6
syntax val (<) "%1 < %2" syntax val (<) "%1 < %2" prec 6
syntax val (>=) "%1 >= %2" syntax val (>=) "%1 >= %2" prec 6
syntax val (>) "%1 > %2" syntax val (>) "%1 > %2" prec 6
end end
module mach.int.UInt32Gen module mach.int.UInt32Gen
syntax type uint32 "uint32_t" syntax type uint32 "uint32_t"
syntax val max_uint32 "0xffffffff" syntax val max_uint32 "0xffffffff" prec 0
syntax val length "32" syntax val length "32" prec 0
end end
...@@ -53,16 +53,17 @@ module mach.int.UInt32 ...@@ -53,16 +53,17 @@ module mach.int.UInt32
syntax literal uint32 "0x%8xU" syntax literal uint32 "0x%8xU"
syntax val (+) "%1 + %2" syntax val (+) "%1 + %2" prec 4
syntax val (-) "%1 - %2" syntax val (-) "%1 - %2" prec 4
syntax val ( * ) "%1 * %2" syntax val (-_) "-%1" prec 2
syntax val (/) "%1 / %2" syntax val ( * ) "%1 * %2" prec 3
syntax val (%) "%1 % %2" syntax val (/) "%1 / %2" prec 3
syntax val (=) "%1 == %2" syntax val (%) "%1 % %2" prec 3
syntax val (<=) "%1 <= %2" syntax val (=) "%1 == %2" prec 7
syntax val (<) "%1 < %2" syntax val (<=) "%1 <= %2" prec 6
syntax val (>=) "%1 >= %2" syntax val (<) "%1 < %2" prec 6
syntax val (>) "%1 > %2" syntax val (>=) "%1 >= %2" prec 6
syntax val (>) "%1 > %2" prec 6
end end
...@@ -184,16 +185,17 @@ struct __lsld32_result lsld32(uint32_t x, uint32_t cnt); ...@@ -184,16 +185,17 @@ struct __lsld32_result lsld32(uint32_t x, uint32_t cnt);
syntax literal uint32 "0x%8xU" syntax literal uint32 "0x%8xU"
syntax val (+) "%1 + %2" syntax val (+) "%1 + %2" prec 4
syntax val (-) "%1 - %2" syntax val (-) "%1 - %2" prec 4
syntax val ( * ) "%1 * %2" syntax val (-_) "-%1" prec 2
syntax val (/) "%1 / %2" syntax val ( * ) "%1 * %2" prec 3
syntax val (%) "%1 % %2" syntax val (/) "%1 / %2" prec 3
syntax val (=) "%1 == %2" syntax val (%) "%1 % %2" prec 3
syntax val (<=) "%1 <= %2" syntax val (=) "%1 == %2" prec 7
syntax val (<) "%1 < %2" syntax val (<=) "%1 <= %2" prec 6
syntax val (>=) "%1 >= %2" syntax val (<) "%1 < %2" prec 6
syntax val (>) "%1 > %2" syntax val (>=) "%1 >= %2" prec 6
syntax val (>) "%1 > %2" prec 6
syntax val add_with_carry "add32_with_carry" syntax val add_with_carry "add32_with_carry"
syntax val sub_with_borrow "sub32_with_borrow" syntax val sub_with_borrow "sub32_with_borrow"
...@@ -201,21 +203,22 @@ struct __lsld32_result lsld32(uint32_t x, uint32_t cnt); ...@@ -201,21 +203,22 @@ struct __lsld32_result lsld32(uint32_t x, uint32_t cnt);
syntax val add3 "add32_3" syntax val add3 "add32_3"
syntax val lsld "lsld32" syntax val lsld "lsld32"
syntax val add_mod "%1 + %2" syntax val add_mod "%1 + %2" prec 4
syntax val sub_mod "%1 - %2" syntax val sub_mod "%1 - %2" prec 4
syntax val mul_mod "%1 * %2" syntax val mul_mod "%1 * %2" prec 3
syntax val div2by1 syntax val div2by1
"(uint32_t)((((uint64_t)%1) | (((uint64_t)%2) << 32))/(uint64_t)(%3))" "(uint32_t)((((uint64_t)%1) | (((uint64_t)%2) << 32))/(uint64_t)(%3))"
prec 2
syntax val lsl "%1 << %2" syntax val lsl "%1 << %2" prec 5
syntax val lsr "%1 >> %2" syntax val lsr "%1 >> %2" prec 5
syntax val is_msb_set "%1 & 0x80000000U" syntax val is_msb_set "%1 & 0x80000000U" prec 8
syntax val count_leading_zeros "__builtin_clz(%1)" syntax val count_leading_zeros "__builtin_clz(%1)" prec 1
syntax val of_int32 "(uint32_t)(%1)" syntax val of_int32 "(uint32_t)(%1)" prec 2
end end
...@@ -224,25 +227,26 @@ module mach.int.Int64 ...@@ -224,25 +227,26 @@ module mach.int.Int64
syntax type int64 "int64_t" syntax type int64 "int64_t"
syntax literal int64 "%dLL" syntax literal int64 "%dLL"
syntax val (+) "%1 + %2" syntax val (+) "%1 + %2" prec 4
syntax val (-) "%1 - %2" syntax val (-) "%1 - %2" prec 4
syntax val (-_) "-%1" syntax val (-_) "-%1" prec 2
syntax val ( * ) "%1 * %2" syntax val ( * ) "%1 * %2" prec 3
syntax val (/) "%1 / %2" syntax val (/) "%1 / %2" prec 3
syntax val (%) "%1 % %2" syntax val (%) "%1 % %2" prec 3
syntax val (=) "%1 == %2" syntax val (=) "%1 == %2" prec 7
syntax val (<=) "%1 <= %2" syntax val (<=) "%1 <= %2" prec 6
syntax val (<) "%1 < %2" syntax val (<) "%1 < %2" prec 6
syntax val (>=) "%1 >= %2" syntax val (>=) "%1 >= %2" prec 6
syntax val (>) "%1 > %2" syntax val (>) "%1 > %2" prec 6
end end
module mach.int.UInt64Gen module mach.int.UInt64Gen
syntax type uint64 "uint64_t" syntax type uint64 "uint64_t"
syntax val max_uint64 "0xffffffffffffffff" syntax val max_uint64 "0xffffffffffffffff" prec 0
syntax val length "64" syntax val length "64" prec 0
end end
...@@ -250,17 +254,17 @@ module mach.int.UInt64 ...@@ -250,17 +254,17 @@ module mach.int.UInt64
syntax literal uint64 "0x%16xULL" syntax literal uint64 "0x%16xULL"
syntax val (+) "%1 + %2" syntax val (+) "%1 + %2" prec 4
syntax val (-) "%1 - %2" syntax val (-) "%1 - %2" prec 4
syntax val (-_) "-%1" syntax val (-_) "-%1" prec 2
syntax val ( * ) "%1 * %2" syntax val ( * ) "%1 * %2" prec 3
syntax val (/) "%1 / %2" syntax val (/) "%1 / %2" prec 3
syntax val (%) "%1 % %2" syntax val (%) "%1 % %2" prec 3
syntax val (=) "%1 == %2" syntax val (=) "%1 == %2" prec 7
syntax val (<=) "%1 <= %2" syntax val (<=) "%1 <= %2" prec 6
syntax val (<) "%1 < %2" syntax val (<) "%1 < %2" prec 6
syntax val (>=) "%1 >= %2" syntax val (>=) "%1 >= %2" prec 6
syntax val (>) "%1 > %2" syntax val (>) "%1 > %2" prec 6
end end
...@@ -505,18 +509,19 @@ static struct __lsld64_result lsld64(uint64_t x, uint64_t cnt) ...@@ -505,18 +509,19 @@ static struct __lsld64_result lsld64(uint64_t x, uint64_t cnt)
" "
syntax literal uint64 "0x%16xULL" syntax literal uint64 "0x%16xULL"
syntax val uint64_max "0xffffffffffffffffULL" syntax val uint64_max "0xffffffffffffffffULL" prec 0
syntax val (+) "%1 + %2" syntax val (+) "%1 + %2" prec 4
syntax val (-) "%1 - %2" syntax val (-) "%1 - %2" prec 4
syntax val ( * ) "%1 * %2" syntax val (-_) "-%1" prec 2
syntax val (/) "%1 / %2" syntax val ( * ) "%1 * %2" prec 3
syntax val (%) "%1 % %2" syntax val (/) "%1 / %2" prec 3
syntax val (=) "%1 == %2" syntax val (%) "%1 % %2" prec 3
syntax val (<=) "%1 <= %2" syntax val (=) "%1 == %2" prec 7
syntax val (<) "%1 < %2" syntax val (<=) "%1 <= %2" prec 6
syntax val (>=) "%1 >= %2" syntax val (<) "%1 < %2" prec 6
syntax val (>) "%1 > %2" syntax val (>=) "%1 >= %2" prec 6
syntax val (>) "%1 > %2" prec 6
syntax val add_with_carry "add64_with_carry" syntax val add_with_carry "add64_with_carry"
syntax val add_double "add64_double" syntax val add_double "add64_double"
...@@ -531,20 +536,20 @@ static struct __lsld64_result lsld64(uint64_t x, uint64_t cnt) ...@@ -531,20 +536,20 @@ static struct __lsld64_result lsld64(uint64_t x, uint64_t cnt)
syntax val add3 "add64_3" syntax val add3 "add64_3"
syntax val lsld "lsld64" syntax val lsld "lsld64"
syntax val add_mod "%1 + %2" syntax val add_mod "%1 + %2" prec 4
syntax val sub_mod "%1 - %2" syntax val sub_mod "%1 - %2" prec 4
syntax val mul_mod "%1 * %2" syntax val mul_mod "%1 * %2" prec 3
syntax val lsl "%1 << %2" syntax val lsl "%1 << %2" prec 5
syntax val lsr "%1 >> %2" syntax val lsr "%1 >> %2" prec 5
syntax val is_msb_set "%1 & 0x8000000000000000ULL" syntax val is_msb_set "%1 & 0x8000000000000000ULL" prec 8
syntax val count_leading_zeros "__builtin_clzll(%1)" syntax val count_leading_zeros "__builtin_clzll(%1)" prec 1
syntax val of_int32 "(uint64_t)(%1)" syntax val of_int32 "(uint64_t)(%1)" prec 2
syntax val to_int64 "(int64_t)(%1)" syntax val to_int64 "(int64_t)(%1)" prec 2
syntax val of_int64 "(uint64_t)(%1)" syntax val of_int64 "(uint64_t)(%1)" prec 2
end end
...@@ -563,34 +568,34 @@ module mach.c.C ...@@ -563,34 +568,34 @@ module mach.c.C
syntax type ptr "%1 *" syntax type ptr "%1 *"
syntax type bool "int" (* ? *) syntax type bool "int" (* ? *)
syntax val malloc "malloc((%1) * sizeof(%v0))" syntax val malloc "malloc((%1) * sizeof(%v0))" prec 1
syntax val free "free(%1)" syntax val free "free(%1)" prec 1
syntax val realloc "realloc(%1, (%2) * sizeof(%v0))" syntax val realloc "realloc(%1, (%2) * sizeof(%v0))" prec 1
syntax val salloc "alloca((%1) * sizeof(%v0))" syntax val salloc "alloca((%1) * sizeof(%v0))" prec 1
syntax val sfree "(void)(%1)" syntax val sfree "(void)(%1)" prec 2
(* syntax val is_null "(%1) == NULL" *) (* syntax val is_null "(%1) == NULL" *)
syntax val is_not_null "%1" syntax val is_not_null "%1" prec 0
syntax val null "NULL" syntax val null "NULL" prec 0
syntax val incr "%1 + %2" syntax val incr "%1 + %2" prec 4
syntax val get "*%1" syntax val get "*%1" prec 2
syntax val get_ofs "%1[%2]" syntax val get_ofs "%1[%2]" prec 1
syntax val set "*%1 = %2" syntax val set "*(%1) = %2" prec 14
syntax val set_ofs "%1[%2] = %3" syntax val set_ofs "%1[%2] = %3" prec 14
syntax val incr_split "%1 + %2" syntax val incr_split "%1 + %2" prec 4
syntax val decr_split "%1 - %2" syntax val decr_split "%1 - %2" prec 4
syntax val join "IGNORE2" syntax val join "IGNORE2"
syntax val join_r "IGNORE2" syntax val join_r "IGNORE2"
syntax val c_assert "assert ( %1 )" syntax val c_assert "assert ( %1 )" prec 1
syntax val print_space "printf(\" \")" syntax val print_space "printf(\" \")" prec 1
syntax val print_newline "printf(\"\\n\")" syntax val print_newline "printf(\"\\n\")" prec 1
syntax val print_uint32 "printf(\"%#010x\",%1)" syntax val print_uint32 "printf(\"%#010x\",%1)" prec 1
end end
...@@ -59,8 +59,8 @@ module C = struct ...@@ -59,8 +59,8 @@ module C = struct
| Eindex of expr * expr (* Array access *) | Eindex of expr * expr (* Array access *)
| Edot of expr * string (* Field access with dot *) | Edot of expr * string (* Field access with dot *)
| Earrow of expr * string (* Pointer access with arrow *) | Earrow of expr * string (* Pointer access with arrow *)
| Esyntax of string * ty * (ty array) * (expr*ty) list | Esyntax of string * ty * (ty array) * (expr*ty) list * int option
(* template, type and type arguments of result, typed arguments *) (* template, type and type arguments of result, typed arguments, precedence level *)
and constant = and constant =
| Cint of string | Cint of string
...@@ -171,8 +171,8 @@ module C = struct ...@@ -171,8 +171,8 @@ module C = struct
propagate_in_expr id v e2) propagate_in_expr id v e2)
| Edot (e,i) -> Edot (propagate_in_expr id v e, i) | Edot (e,i) -> Edot (propagate_in_expr id v e, i)
| Earrow (e,i) -> Earrow (propagate_in_expr id v e, i) | Earrow (e,i) -> Earrow (propagate_in_expr id v e, i)
| Esyntax (s,t,ta,l) -> | Esyntax (s,t,ta,l,p) ->
Esyntax (s,t,ta,List.map (fun (e,t) -> (propagate_in_expr id v e),t) l) Esyntax (s,t,ta,List.map (fun (e,t) -> (propagate_in_expr id v e),t) l,p)
| Enothing -> Enothing | Enothing -> Enothing
| Econst c -> Econst c | Econst c -> Econst c
| Elikely e -> Elikely (propagate_in_expr id v e) | Elikely e -> Elikely (propagate_in_expr id v e)
...@@ -399,7 +399,7 @@ module C = struct ...@@ -399,7 +399,7 @@ module C = struct
| Esize_expr _ -> false | Esize_expr _ -> false
| Esize_type _ -> true | Esize_type _ -> true
| Eindex (_,_) | Edot (_,_) | Earrow (_,_) -> false | Eindex (_,_) | Edot (_,_) | Earrow (_,_) -> false
| Esyntax (_,_,_,_) -> false | Esyntax (_,_,_,_,_) -> false
let rec get_const_expr (d,s) = let rec get_const_expr (d,s) =
let fail () = raise (Unsupported "non-constant array size") in let fail () = raise (Unsupported "non-constant array size") in
...@@ -442,6 +442,8 @@ type info = { ...@@ -442,6 +442,8 @@ type info = {
syntax : Printer.syntax_map; syntax : Printer.syntax_map;
literal : Printer.syntax_map; (*TODO handle literals*) literal : Printer.syntax_map; (*TODO handle literals*)
kn : Pdecl.known_map; kn : Pdecl.known_map;
prec : (int option) Mid.t;
assoc : (Driver_ast.assoc_dir option) Mid.t;
} }
let debug_c_extraction = Debug.register_info_flag let debug_c_extraction = Debug.register_info_flag
...@@ -528,6 +530,7 @@ module Print = struct ...@@ -528,6 +530,7 @@ module Print = struct
| Bge -> fprintf fmt ">=" | Bge -> fprintf fmt ">="
and print_expr ~prec fmt = function and print_expr ~prec fmt = function
(* invariant: 0 <= prec <= 15 *)
| Enothing -> () | Enothing -> ()
| Eunop(u,e) -> | Eunop(u,e) ->
let p = prec_unop u in let p = prec_unop u in
...@@ -548,7 +551,7 @@ module Print = struct ...@@ -548,7 +551,7 @@ module Print = struct
| Ecast(ty, e) -> | Ecast(ty, e) ->
fprintf fmt (protect_on (prec <= 2) "(%a)%a") fprintf fmt (protect_on (prec <= 2) "(%a)%a")
(print_ty ~paren:false) ty (print_expr ~prec:2) e (print_ty ~paren:false) ty (print_expr ~prec:2) e
| Ecall (Esyntax (s, _, _, []), l) -> | Ecall (Esyntax (s, _, _, [],_), l) ->
(* function defined in the prelude *) (* function defined in the prelude *)
fprintf fmt (protect_on (prec <= 1) "%s(%a)") fprintf fmt (protect_on (prec <= 1) "%s(%a)")
s (print_list comma (print_expr ~prec:15)) l s (print_list comma (print_expr ~prec:15)) l
...@@ -578,11 +581,16 @@ module Print = struct ...@@ -578,11 +581,16 @@ module Print = struct
| Earrow (e,s) -> | Earrow (e,s) ->
fprintf fmt (protect_on (prec <= 1) "%a->%s") fprintf fmt (protect_on (prec <= 1) "%a->%s")
(print_expr ~prec:1) e s (print_expr ~prec:1) e s
| Esyntax (s, t, args, lte) -> | Esyntax (s, t, args, lte, p) ->
(* no way to know precedence, so full parentheses*) if s = "%1" (*identity*)
then begin
assert (List.length lte = 1);
print_expr ~prec fmt (fst (List.hd lte)) end
else
let p = match p with Some n -> n | None -> 15 in
gen_syntax_arguments_typed snd (fun _ -> args) gen_syntax_arguments_typed snd (fun _ -> args)
(if prec <= 13 then ("("^s^")") else s) (if prec <= p then ("("^s^")") else s)
(fun fmt (e,_t) -> print_expr ~prec:1 fmt e) (fun fmt (e,_t) -> print_expr ~prec:p fmt e)
(print_ty ~paren:false) (C.Enothing,t) fmt lte (print_ty ~paren:false) (C.Enothing,t) fmt lte
and print_const fmt = function and print_const fmt = function
...@@ -983,14 +991,15 @@ module MLToC = struct ...@@ -983,14 +991,15 @@ module MLToC = struct
| Tyapp (_,args) -> | Tyapp (_,args) ->
Array.of_list (List.map (ty_of_ty info) args) Array.of_list (List.map (ty_of_ty info) args)
in in
C.Esyntax(s,ty_of_ty info rty, rtyargs, params) let p = Mid.find rs.rs_name info.prec in
C.Esyntax(s,ty_of_ty info rty, rtyargs, params, p)
with Not_found -> with Not_found ->
if args=[] if args=[]
then C.(Esyntax(s, Tnosyntax, [||], [])) (*constant*) then C.(Esyntax(s, Tnosyntax, [||], [], None)) (*constant*)
else else
(*function defined in the prelude *) (*function defined in the prelude *)
let cargs = List.map fst params in let cargs = List.map fst params in
C.(Ecall(Esyntax(s, Tnosyntax, [||], []), cargs)) C.(Ecall(Esyntax(s, Tnosyntax, [||], [], None), cargs))
end end
| None -> | None ->
match rs.rs_field with match rs.rs_field with
...@@ -1174,7 +1183,8 @@ module MLToC = struct ...@@ -1174,7 +1183,8 @@ module MLToC = struct
| Tyapp (_,args) -> | Tyapp (_,args) ->
Array.of_list (List.map (ty_of_ty info) args) Array.of_list (List.map (ty_of_ty info) args)
in in
C.Esyntax(s,ty_of_ty info rty, rtyargs, params) let p = Mid.find rs.rs_name info.prec in
C.Esyntax(s,ty_of_ty info rty, rtyargs, params,p)
| None -> if boxed | None -> if boxed
then C.(Earrow(Evar id, rs.rs_name.id_string)) then C.(Earrow(Evar id, rs.rs_name.id_string))
else C.(Edot(Evar id, rs.rs_name.id_string)) in else C.(Edot(Evar id, rs.rs_name.id_string)) in
...@@ -1388,6 +1398,8 @@ let mk_info (args:Pdriver.printer_args) m = { ...@@ -1388,6 +1398,8 @@ let mk_info (args:Pdriver.printer_args) m = {
blacklist = args.Pdriver.blacklist; blacklist = args.Pdriver.blacklist;
syntax = args.Pdriver.syntax; syntax = args.Pdriver.syntax;
literal = args.Pdriver.literal; literal = args.Pdriver.literal;
prec = args.Pdriver.prec;
assoc = args.Pdriver.assoc;
kn = m.Pmodule.mod_known } kn = m.Pmodule.mod_known }
let print_header_decl = let print_header_decl =
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment