Commit b7e23f63 authored by Mário Pereira's avatar Mário Pereira

Extraction: fix extraction of zero-argument functions and partial application

parent 44f44412
......@@ -177,12 +177,23 @@ module ML = struct
| Dexn (_, Some ty) -> iter_deps_ty f ty
| Dmodule (_, dl) -> List.iter (iter_deps f) dl
let mk_expr e_node e_ity e_effect e_label =
{ e_node; e_ity; e_effect; e_label; }
let ity_unit = I Ity.ity_unit
let tunit = Ttuple []
let ity_of_mask ity mask =
let mk_ty acc ty = function MaskGhost -> acc | _ -> ty :: acc in
match ity, mask with
| _, MaskGhost -> ity_unit
| _, MaskVisible -> ity
| I ({ity_node = Ityapp ({its_ts = s}, tl, _)}), MaskTuple m
when is_ts_tuple s && List.length tl = List.length m ->
let tl = List.fold_left2 mk_ty [] tl m in
I (ity_tuple tl)
| _ -> ity (* FIXME ? *)
let ity_unit = I Ity.ity_unit
let mk_expr e_node e_ity mask e_effect e_label =
{ e_node; e_ity = ity_of_mask e_ity mask; e_effect; e_label; }
let tunit = Ttuple []
let is_unit = function
| I i -> ity_equal i Ity.ity_unit
......@@ -191,19 +202,18 @@ module ML = struct
let enope = Eblock []
let mk_hole =
mk_expr Ehole (I Ity.ity_unit) Ity.eff_empty Slab.empty
mk_expr Ehole (I Ity.ity_unit) MaskVisible Ity.eff_empty Slab.empty
let mk_var id ty ghost = (id, ty, ghost)
let mk_var_unit () = id_register (id_fresh "_"), tunit, true
let mk_its_defn id args private_ def =
{ its_name = id ; its_args = args;
its_private = private_; its_def = def; }
let mk_its_defn its_name its_args its_private its_def =
{ its_name; its_args; its_private; its_def; }
(* smart constructors *)
let e_unit =
mk_expr enope (I Ity.ity_unit) Ity.eff_empty Slab.empty
mk_expr enope (I Ity.ity_unit) MaskVisible Ity.eff_empty Slab.empty
let var_defn pv e =
Lvar (pv, e)
......@@ -220,7 +230,7 @@ module ML = struct
let e_ignore e =
if is_unit e.e_ity then e
else mk_expr (Eignore e) ity_unit e.e_effect e.e_label
else mk_expr (Eignore e) ity_unit MaskVisible e.e_effect e.e_label
let e_if e1 e2 e3 =
mk_expr (Mltree.Eif (e1, e2, e3)) e2.e_ity
......@@ -241,8 +251,8 @@ module ML = struct
mk_expr (Mltree.Etry (e, true, xl))
*)
let e_assign al ity eff lbl =
if al = [] then e_unit else mk_expr (Mltree.Eassign al) ity eff lbl
let e_assign al ity mask eff lbl =
if al = [] then e_unit else mk_expr (Mltree.Eassign al) ity mask eff lbl
let e_absurd =
mk_expr Eabsurd
......@@ -257,8 +267,8 @@ module ML = struct
mk_expr e
let var_list_of_pv_list pvl =
let mk_var pv =
mk_expr (Mltree.Evar pv) (Mltree.I pv.pv_ity) eff_empty Slab.empty in
let mk_var pv = mk_expr (Mltree.Evar pv) (Mltree.I pv.pv_ity)
MaskVisible eff_empty Slab.empty in
List.map mk_var pvl
end
......@@ -323,7 +333,7 @@ module Translate = struct
| Pas (p, vs) ->
Mltree.Pas (pat m p, vs)
| Papp (ls, pl) when is_fs_tuple ls ->
let pl = visible_of_mask m pl in
let pl = List.rev (visible_of_mask m pl) in
begin match pl with
| [] -> Mltree.Pwild
| [p] -> pat m p
......@@ -363,7 +373,7 @@ module Translate = struct
| To -> Mltree.To
| DownTo -> Mltree.DownTo
let isconstructor info rs =
let isconstructor info rs = (* TODO *)
match Mid.find_opt rs.rs_name info.Mltree.from_km with
| Some {pd_node = PDtype its} ->
let is_constructor its =
......@@ -406,17 +416,19 @@ module Translate = struct
let mk_eta_expansion rs pvl ({cty_args = ca; cty_effect = ce} as c) =
(* FIXME : effects and types of the expression in this situation *)
let mv = MaskVisible in
let args_f =
let def pv =
(pv_name pv, mlty_of_ity (mask_of_pv pv) pv.pv_ity, pv.pv_ghost) in
pv_name pv, mlty_of_ity (mask_of_pv pv) pv.pv_ity, pv.pv_ghost in
filter_ghost_params pv_not_ghost def ca in
let args =
let def pv =
ML.mk_expr (Mltree.Evar pv) (Mltree.I pv.pv_ity) eff_empty Slab.empty in
let def pv = ML.mk_expr (Mltree.Evar pv) (Mltree.I pv.pv_ity) mv
eff_empty Slab.empty in
let args = filter_ghost_params pv_not_ghost def pvl in
let extra_args = List.map def ca in args @ extra_args in
let eapp = ML.mk_expr (Mltree.Eapp (rs, args)) (Mltree.C c) ce Slab.empty in
ML.mk_expr (Mltree.Efun (args_f, eapp)) (Mltree.C c) ce Slab.empty
let eapp = ML.mk_expr (Mltree.Eapp (rs, args)) (Mltree.C c) mv
ce Slab.empty in
ML.mk_expr (Mltree.Efun (args_f, eapp)) (Mltree.C c) mv ce Slab.empty
(* function arguments *)
let filter_params args =
......@@ -424,10 +436,9 @@ module Translate = struct
let p (_, _, is_ghost) = not is_ghost in
List.filter p args
let params = function
| [] -> []
| args -> let args = filter_params args in
if args = [] then [ML.mk_var_unit ()] else args
let params args =
let args = filter_params args in
if args = [] then [ML.mk_var_unit ()] else args
let filter_params_cty p def pvl cty_args =
let rec loop = function
......@@ -437,10 +448,11 @@ module Translate = struct
| _ -> assert false
in loop (pvl, cty_args)
let app pvl cty_args =
let def pv =
ML.mk_expr (Mltree.Evar pv) (Mltree.I pv.pv_ity) eff_empty Slab.empty in
filter_params_cty pv_not_ghost def pvl cty_args
let app pvl cty_args f_zero =
let def pv = ML.mk_expr (Mltree.Evar pv) (Mltree.I pv.pv_ity) MaskVisible
eff_empty Slab.empty in
let args = filter_params_cty pv_not_ghost def pvl cty_args in
f_zero args
(* build the set of type variables from functions arguments *)
let rec add_tvar acc = function
......@@ -461,32 +473,34 @@ module Translate = struct
let pvl = List.fold_left2 mk_pv_of_mask [] pvl m in
List.rev pvl in
match e.e_node with
| Econst _ | Evar _ | Eexec ({c_node = Cfun _}, _) | Eassign _
| Ewhile _ | Efor _ | Eraise _ | Eexn _ | Eabsurd when mask = MaskGhost ->
| Econst _ | Evar _ | Eexec ({c_node = Cfun _}, _) (* FIXME *)
when mask = MaskGhost ->
ML.e_unit
| Econst c ->
let c = match c with Number.ConstInt c -> c | _ -> assert false in
ML.mk_expr (Mltree.Econst c) (Mltree.I e.e_ity) eff lbl
| Evar pv -> ML.mk_expr (Mltree.Evar pv) (Mltree.I e.e_ity) eff lbl
| Elet (LDvar (_, e1), e2) when e_ghost e1 -> expr info svar mask e2
ML.mk_expr (Mltree.Econst c) (Mltree.I e.e_ity) mask eff lbl
| Evar pv ->
ML.mk_expr (Mltree.Evar pv) (Mltree.I e.e_ity) mask eff lbl
| Elet (LDvar (_, e1), e2) when e_ghost e1 ->
expr info svar mask e2
| Elet (LDvar (_, e1), e2) when e_ghost e2 ->
(* sequences are transformed into [let o = e1 in e2] by A-normal form *)
expr info svar e1.e_mask e1
expr info svar MaskGhost e1
| Elet (LDvar (pv, e1), e2)
when pv.pv_ghost || not (Mpv.mem pv e2.e_effect.eff_reads) ->
if eff_pure e1.e_effect then expr info svar mask e2
else let e1 = ML.e_ignore (expr info svar e1.e_mask e1) in
ML.e_seq e1 (expr info svar mask e2) (Mltree.I e.e_ity) eff lbl
else let e1 = expr info svar MaskGhost e1 in
ML.e_seq e1 (expr info svar mask e2) (Mltree.I e.e_ity) mask eff lbl
| Elet (LDvar (pv, e1), e2) ->
let ld = ML.var_defn pv (expr info svar e1.e_mask e1) in
ML.e_let ld (expr info svar mask e2) (Mltree.I e.e_ity) eff lbl
let ld = ML.var_defn pv (expr info svar MaskVisible e1) in
ML.e_let ld (expr info svar mask e2) (Mltree.I e.e_ity) mask eff lbl
| Elet (LDsym (rs, _), ein) when rs_ghost rs ->
expr info svar mask ein
| Elet (LDsym (rs, {c_node = Cfun ef; c_cty = cty}), ein) ->
let args = params cty.cty_args in
let res = mlty_of_ity cty.cty_mask cty.cty_result in
let ld = ML.sym_defn rs res args (expr info svar ef.e_mask ef) in
ML.e_let ld (expr info svar mask ein) (Mltree.I e.e_ity) eff lbl
let ld = ML.sym_defn rs res args (expr info svar cty.cty_mask ef) in
ML.e_let ld (expr info svar mask ein) (Mltree.I e.e_ity) mask eff lbl
| Elet (LDsym (rs, {c_node = Capp (rs_app, pvl); c_cty = cty}), ein)
when isconstructor info rs_app -> (* partial application of constructor *)
let eta_app = mk_eta_expansion rs_app pvl cty in
......@@ -494,15 +508,18 @@ module Translate = struct
let func = List.fold_right mk_func cty.cty_args cty.cty_result in
let res = mlty_of_ity cty.cty_mask func in
let ld = ML.sym_defn rs res [] eta_app in
ML.e_let ld (expr info svar mask ein) (Mltree.I e.e_ity) e.e_effect lbl
let ein = expr info svar mask ein in
ML.e_let ld ein (Mltree.I e.e_ity) mask eff lbl
| Elet (LDsym (rsf, {c_node = Capp (rs_app, pvl); c_cty = cty}), ein) ->
(* partial application *)
let pvl = app pvl rs_app.rs_cty.cty_args in
let eff = cty.cty_effect in
let eapp = ML.e_app rs_app pvl (Mltree.C cty) eff Slab.empty in
(* partial application *) (* FIXME -> zero arguments functions *)
let cmk = cty.cty_mask in
let ceff = cty.cty_effect in
let pvl = app pvl rs_app.rs_cty.cty_args (fun x -> x) in
let eapp = ML.e_app rs_app pvl (Mltree.C cty) cmk ceff Slab.empty in
let res = mlty_of_ity cty.cty_mask cty.cty_result in
let ld = ML.sym_defn rsf res (params cty.cty_args) eapp in
ML.e_let ld (expr info svar mask ein) (Mltree.I e.e_ity) e.e_effect lbl
let ein = expr info svar mask ein in
ML.e_let ld ein (Mltree.I e.e_ity) mask eff lbl
| Elet (LDrec rdefl, ein) ->
let rdefl = filter_out_ghost_rdef rdefl in
let def = function
......@@ -524,7 +541,7 @@ module Translate = struct
if rdefl <> [] then
let ein = expr info svar mask ein in
let ml_letrec = Mltree.Elet (Mltree.Lrec rdefl, ein) in
ML.mk_expr ml_letrec (Mltree.I e.e_ity) e.e_effect lbl
ML.mk_expr ml_letrec (Mltree.I e.e_ity) mask e.e_effect lbl
else expr info svar mask ein
| Eexec ({c_node = Capp (rs, [])}, _) when is_rs_tuple rs ->
ML.e_unit
......@@ -532,29 +549,30 @@ module Translate = struct
let pvl = pv_list_of_mask pvl mask in
let res_ity = ity_tuple (List.map (fun v -> v.pv_ity) pvl) in
let pvl = ML.var_list_of_pv_list pvl in
ML.e_app rs pvl (Mltree.I res_ity) eff lbl
| Eexec ({c_node = Capp (rs, _)}, _)
when is_empty_record info rs || rs_ghost rs ->
ML.e_app rs pvl (Mltree.I res_ity) mask eff lbl
| Eexec ({c_node = Capp (rs, _)}, _) when is_empty_record info rs ->
ML.e_unit
| Eexec ({c_node = Capp (rs, pvl); c_cty = cty}, _)
when isconstructor info rs && cty.cty_args <> [] ->
(* partial application of constructors *)
mk_eta_expansion rs pvl cty
| Eexec ({c_node = Capp (rs, pvl); _}, _) ->
let pvl = app pvl rs.rs_cty.cty_args in
let add_unit = function [] -> [ML.e_unit] | args -> args in
let f_zero = if isconstructor info rs then fun x -> x else add_unit in
let pvl = app pvl rs.rs_cty.cty_args f_zero in
begin match pvl with
| [pv_expr] when is_optimizable_record_rs info rs -> pv_expr
| _ -> ML.e_app rs pvl (Mltree.I e.e_ity) eff lbl end
| _ -> ML.e_app rs pvl (Mltree.I e.e_ity) mask eff lbl end
| Eexec ({c_node = Cfun e; c_cty = {cty_args = []}}, _) ->
(* abstract block *)
expr info svar e.e_mask e
| Eexec ({c_node = Cfun e; c_cty = cty}, _) ->
ML.e_fun (params cty.cty_args) (expr info svar e.e_mask e)
(Mltree.I e.e_ity) eff lbl
| Eexec ({c_node = Cfun ef; c_cty = cty}, _) ->
let ef = expr info svar e.e_mask ef in
ML.e_fun (params cty.cty_args) ef (Mltree.I e.e_ity) mask eff lbl
| Eexec ({c_node = Cany}, _) ->
ML.mk_hole
| Eabsurd ->
ML.e_absurd (Mltree.I e.e_ity) eff lbl
ML.e_absurd (Mltree.I e.e_ity) mask eff lbl
| Eassert _ ->
ML.e_unit
| Eif (e1, e2, e3) when e_ghost e1 ->
......@@ -565,32 +583,30 @@ module Translate = struct
| Eif (e1, e2, e3) when e_ghost e3 ->
let e1 = expr info svar e1.e_mask e1 in
let e2 = expr info svar mask e2 in
ML.e_if e1 e2 ML.e_unit eff lbl
ML.e_if e1 e2 ML.e_unit mask eff lbl
| Eif (e1, e2, e3) when e_ghost e2 ->
let e1 = expr info svar e1.e_mask e1 in
let e3 = expr info svar mask e3 in
ML.e_if e1 ML.e_unit e3 eff lbl
ML.e_if e1 ML.e_unit e3 mask eff lbl
| Eif (e1, e2, e3) ->
let e1 = expr info svar e1.e_mask e1 in
let e2 = expr info svar mask e2 in
let e3 = expr info svar mask e3 in
ML.e_if e1 e2 e3 eff lbl
ML.e_if e1 e2 e3 mask eff lbl
| Ewhile (e1, _, _, e2) ->
assert (mask = MaskVisible);
let e1 = expr info svar e1.e_mask e1 in
let e2 = expr info svar e2.e_mask e2 in
ML.e_while e1 e2 eff lbl
ML.e_while e1 e2 mask eff lbl
| Efor (pv1, (pv2, dir, pv3), _, _, efor) ->
assert (mask = MaskVisible);
let dir = for_direction dir in
let efor = expr info svar efor.e_mask efor in
ML.e_for pv1 pv2 dir pv3 efor eff lbl
ML.e_for pv1 pv2 dir pv3 efor mask eff lbl
| Eghost _ | Epure _ ->
assert false
| Eassign al ->
let rm_ghost (_, rs, _) = not (rs_ghost rs) in
let al = List.filter rm_ghost al in
ML.e_assign al (Mltree.I e.e_ity) eff lbl
ML.e_assign al (Mltree.I e.e_ity) mask eff lbl
| Ematch (e1, [], xl) when Mxs.is_empty xl ->
expr info svar e1.e_mask e1
| Ematch (e1, bl, xl) when e_ghost e1 ->
......@@ -605,18 +621,18 @@ module Translate = struct
(* NOTE: why no pv_list_of_mask here? *)
let mk_xl (xs, (pvl, e)) = xs, pvl, expr info svar mask e in
let xl = List.map mk_xl (Mxs.bindings xl) in
ML.e_match e1 bl xl (Mltree.I e.e_ity) eff lbl
ML.e_match e1 bl xl (Mltree.I e.e_ity) mask eff lbl
| Eraise (xs, ex) -> let ex = match expr info svar xs.xs_mask ex with
| {Mltree.e_node = Mltree.Eblock []} -> None
| e -> Some e in
ML.mk_expr (Mltree.Eraise (xs, ex)) (Mltree.I e.e_ity) eff lbl
ML.mk_expr (Mltree.Eraise (xs, ex)) (Mltree.I e.e_ity) mask eff lbl
| Eexn (xs, e1) ->
if mask_ghost e1.e_mask then ML.mk_expr
(Mltree.Eexn (xs, None, ML.e_unit)) (Mltree.I e.e_ity) eff lbl
(Mltree.Eexn (xs, None, ML.e_unit)) (Mltree.I e.e_ity) mask eff lbl
else let e1 = expr info svar xs.xs_mask e1 in
let ty = if ity_equal xs.xs_ity ity_unit then None
else Some (mlty_of_ity xs.xs_mask xs.xs_ity) in
ML.mk_expr (Mltree.Eexn (xs, ty, e1)) (Mltree.I e.e_ity) eff lbl
ML.mk_expr (Mltree.Eexn (xs, ty, e1)) (Mltree.I e.e_ity) mask eff lbl
| Elet (LDsym (_, {c_node=(Cany|Cpur (_, _)); _ }), _)
| Eexec ({c_node=Cpur (_, _); _ }, _) -> ML.mk_hole
......
......@@ -315,8 +315,10 @@ module Print = struct
else fprintf fmt "%a" (print_expr ~paren:true info) expr;
if exprl <> [] then fprintf fmt "@ ";
print_apply_args info fmt (exprl, pvl)
| expr :: exprl, [] ->
fprintf fmt "%a" (print_expr ~paren:true info) expr;
print_apply_args info fmt (exprl, [])
| [], _ -> ()
| _, [] -> assert false
and print_apply info rs fmt pvl =
let isfield =
......@@ -360,7 +362,7 @@ module Print = struct
end
| _, None, [] ->
(print_lident info) fmt rs.rs_name
| _, _, tl -> (* FIXME? when is in driver but is not a local id *)
| _, _, tl ->
fprintf fmt "@[<hov 2>%a %a@]"
(print_lident info) rs.rs_name
(print_apply_args info) (tl, rs.rs_cty.cty_args)
......
......@@ -8,11 +8,11 @@ let () = assert (test_array () = 42)
let (=) = Z.equal
let b42 = Z.of_int 42
let () = assert (test_int () = b42)
let () = assert (test_int63 () = b42)
let () = assert (test_ref () = b42)
let () = assert (test_array63 () = b42)
let () = assert (test_int () = b42)
let () = assert (test_int63 () = b42)
let () = assert (test_ref () = b42)
let () = assert (test_array63 () = b42)
let () = assert (test_partial2 () = b42)
let () = main ()
let () = Format.printf "OCaml extraction test successful@."
......
......@@ -172,6 +172,15 @@ module TestExtraction
let partial = test_filter_ghost_args 3 in
42
let constant test_partial2 : int =
let r = ref 0 in
let f (x: int) (ghost y) = r := !r + 21 in
let g = f 17 in
g (0:int); g (1:int); !r
let test_zero_args () : int =
test_partial2 + 0
let test_filter_ghost_args2 (x: int) (ghost y: int) (z: int) : int =
x + z
......@@ -199,7 +208,7 @@ module TestExtraction
let res = yxz - xzy in
res
let test_partial2 (x: int) : int =
let test_partial3 (x: int) : int =
let sum : int -> int -> int = fun x y -> x + y in
let incr_a (a: int) = sum a in
incr_a x x
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment