From b7e23f63552f2f76252570826c333796fc2972f6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?M=C3=A1rio=20Pereira?= <mpereira@lri.fr> Date: Fri, 23 Mar 2018 15:29:05 +0100 Subject: [PATCH] Extraction: fix extraction of zero-argument functions and partial application --- src/mlw/compile.ml | 154 +++++++++++++++++++--------------- src/mlw/ocaml_printer.ml | 6 +- tests/test-extraction/main.ml | 10 +-- tests/test_extraction.mlw | 11 ++- 4 files changed, 104 insertions(+), 77 deletions(-) diff --git a/src/mlw/compile.ml b/src/mlw/compile.ml index 3c62667e02..2965ea7c0f 100644 --- a/src/mlw/compile.ml +++ b/src/mlw/compile.ml @@ -177,12 +177,23 @@ module ML = struct | Dexn (_, Some ty) -> iter_deps_ty f ty | Dmodule (_, dl) -> List.iter (iter_deps f) dl - let mk_expr e_node e_ity e_effect e_label = - { e_node; e_ity; e_effect; e_label; } + let ity_unit = I Ity.ity_unit - let tunit = Ttuple [] + let ity_of_mask ity mask = + let mk_ty acc ty = function MaskGhost -> acc | _ -> ty :: acc in + match ity, mask with + | _, MaskGhost -> ity_unit + | _, MaskVisible -> ity + | I ({ity_node = Ityapp ({its_ts = s}, tl, _)}), MaskTuple m + when is_ts_tuple s && List.length tl = List.length m -> + let tl = List.fold_left2 mk_ty [] tl m in + I (ity_tuple tl) + | _ -> ity (* FIXME ? *) - let ity_unit = I Ity.ity_unit + let mk_expr e_node e_ity mask e_effect e_label = + { e_node; e_ity = ity_of_mask e_ity mask; e_effect; e_label; } + + let tunit = Ttuple [] let is_unit = function | I i -> ity_equal i Ity.ity_unit @@ -191,19 +202,18 @@ module ML = struct let enope = Eblock [] let mk_hole = - mk_expr Ehole (I Ity.ity_unit) Ity.eff_empty Slab.empty + mk_expr Ehole (I Ity.ity_unit) MaskVisible Ity.eff_empty Slab.empty let mk_var id ty ghost = (id, ty, ghost) let mk_var_unit () = id_register (id_fresh "_"), tunit, true - let mk_its_defn id args private_ def = - { its_name = id ; its_args = args; - its_private = private_; its_def = def; } + let mk_its_defn its_name its_args its_private its_def = + { its_name; its_args; its_private; its_def; } (* smart constructors *) let e_unit = - mk_expr enope (I Ity.ity_unit) Ity.eff_empty Slab.empty + mk_expr enope (I Ity.ity_unit) MaskVisible Ity.eff_empty Slab.empty let var_defn pv e = Lvar (pv, e) @@ -220,7 +230,7 @@ module ML = struct let e_ignore e = if is_unit e.e_ity then e - else mk_expr (Eignore e) ity_unit e.e_effect e.e_label + else mk_expr (Eignore e) ity_unit MaskVisible e.e_effect e.e_label let e_if e1 e2 e3 = mk_expr (Mltree.Eif (e1, e2, e3)) e2.e_ity @@ -241,8 +251,8 @@ module ML = struct mk_expr (Mltree.Etry (e, true, xl)) *) - let e_assign al ity eff lbl = - if al = [] then e_unit else mk_expr (Mltree.Eassign al) ity eff lbl + let e_assign al ity mask eff lbl = + if al = [] then e_unit else mk_expr (Mltree.Eassign al) ity mask eff lbl let e_absurd = mk_expr Eabsurd @@ -257,8 +267,8 @@ module ML = struct mk_expr e let var_list_of_pv_list pvl = - let mk_var pv = - mk_expr (Mltree.Evar pv) (Mltree.I pv.pv_ity) eff_empty Slab.empty in + let mk_var pv = mk_expr (Mltree.Evar pv) (Mltree.I pv.pv_ity) + MaskVisible eff_empty Slab.empty in List.map mk_var pvl end @@ -323,7 +333,7 @@ module Translate = struct | Pas (p, vs) -> Mltree.Pas (pat m p, vs) | Papp (ls, pl) when is_fs_tuple ls -> - let pl = visible_of_mask m pl in + let pl = List.rev (visible_of_mask m pl) in begin match pl with | [] -> Mltree.Pwild | [p] -> pat m p @@ -363,7 +373,7 @@ module Translate = struct | To -> Mltree.To | DownTo -> Mltree.DownTo - let isconstructor info rs = + let isconstructor info rs = (* TODO *) match Mid.find_opt rs.rs_name info.Mltree.from_km with | Some {pd_node = PDtype its} -> let is_constructor its = @@ -406,17 +416,19 @@ module Translate = struct let mk_eta_expansion rs pvl ({cty_args = ca; cty_effect = ce} as c) = (* FIXME : effects and types of the expression in this situation *) + let mv = MaskVisible in let args_f = let def pv = - (pv_name pv, mlty_of_ity (mask_of_pv pv) pv.pv_ity, pv.pv_ghost) in + pv_name pv, mlty_of_ity (mask_of_pv pv) pv.pv_ity, pv.pv_ghost in filter_ghost_params pv_not_ghost def ca in let args = - let def pv = - ML.mk_expr (Mltree.Evar pv) (Mltree.I pv.pv_ity) eff_empty Slab.empty in + let def pv = ML.mk_expr (Mltree.Evar pv) (Mltree.I pv.pv_ity) mv + eff_empty Slab.empty in let args = filter_ghost_params pv_not_ghost def pvl in let extra_args = List.map def ca in args @ extra_args in - let eapp = ML.mk_expr (Mltree.Eapp (rs, args)) (Mltree.C c) ce Slab.empty in - ML.mk_expr (Mltree.Efun (args_f, eapp)) (Mltree.C c) ce Slab.empty + let eapp = ML.mk_expr (Mltree.Eapp (rs, args)) (Mltree.C c) mv + ce Slab.empty in + ML.mk_expr (Mltree.Efun (args_f, eapp)) (Mltree.C c) mv ce Slab.empty (* function arguments *) let filter_params args = @@ -424,10 +436,9 @@ module Translate = struct let p (_, _, is_ghost) = not is_ghost in List.filter p args - let params = function - | [] -> [] - | args -> let args = filter_params args in - if args = [] then [ML.mk_var_unit ()] else args + let params args = + let args = filter_params args in + if args = [] then [ML.mk_var_unit ()] else args let filter_params_cty p def pvl cty_args = let rec loop = function @@ -437,10 +448,11 @@ module Translate = struct | _ -> assert false in loop (pvl, cty_args) - let app pvl cty_args = - let def pv = - ML.mk_expr (Mltree.Evar pv) (Mltree.I pv.pv_ity) eff_empty Slab.empty in - filter_params_cty pv_not_ghost def pvl cty_args + let app pvl cty_args f_zero = + let def pv = ML.mk_expr (Mltree.Evar pv) (Mltree.I pv.pv_ity) MaskVisible + eff_empty Slab.empty in + let args = filter_params_cty pv_not_ghost def pvl cty_args in + f_zero args (* build the set of type variables from functions arguments *) let rec add_tvar acc = function @@ -461,32 +473,34 @@ module Translate = struct let pvl = List.fold_left2 mk_pv_of_mask [] pvl m in List.rev pvl in match e.e_node with - | Econst _ | Evar _ | Eexec ({c_node = Cfun _}, _) | Eassign _ - | Ewhile _ | Efor _ | Eraise _ | Eexn _ | Eabsurd when mask = MaskGhost -> + | Econst _ | Evar _ | Eexec ({c_node = Cfun _}, _) (* FIXME *) + when mask = MaskGhost -> ML.e_unit | Econst c -> let c = match c with Number.ConstInt c -> c | _ -> assert false in - ML.mk_expr (Mltree.Econst c) (Mltree.I e.e_ity) eff lbl - | Evar pv -> ML.mk_expr (Mltree.Evar pv) (Mltree.I e.e_ity) eff lbl - | Elet (LDvar (_, e1), e2) when e_ghost e1 -> expr info svar mask e2 + ML.mk_expr (Mltree.Econst c) (Mltree.I e.e_ity) mask eff lbl + | Evar pv -> + ML.mk_expr (Mltree.Evar pv) (Mltree.I e.e_ity) mask eff lbl + | Elet (LDvar (_, e1), e2) when e_ghost e1 -> + expr info svar mask e2 | Elet (LDvar (_, e1), e2) when e_ghost e2 -> (* sequences are transformed into [let o = e1 in e2] by A-normal form *) - expr info svar e1.e_mask e1 + expr info svar MaskGhost e1 | Elet (LDvar (pv, e1), e2) when pv.pv_ghost || not (Mpv.mem pv e2.e_effect.eff_reads) -> if eff_pure e1.e_effect then expr info svar mask e2 - else let e1 = ML.e_ignore (expr info svar e1.e_mask e1) in - ML.e_seq e1 (expr info svar mask e2) (Mltree.I e.e_ity) eff lbl + else let e1 = expr info svar MaskGhost e1 in + ML.e_seq e1 (expr info svar mask e2) (Mltree.I e.e_ity) mask eff lbl | Elet (LDvar (pv, e1), e2) -> - let ld = ML.var_defn pv (expr info svar e1.e_mask e1) in - ML.e_let ld (expr info svar mask e2) (Mltree.I e.e_ity) eff lbl + let ld = ML.var_defn pv (expr info svar MaskVisible e1) in + ML.e_let ld (expr info svar mask e2) (Mltree.I e.e_ity) mask eff lbl | Elet (LDsym (rs, _), ein) when rs_ghost rs -> expr info svar mask ein | Elet (LDsym (rs, {c_node = Cfun ef; c_cty = cty}), ein) -> let args = params cty.cty_args in let res = mlty_of_ity cty.cty_mask cty.cty_result in - let ld = ML.sym_defn rs res args (expr info svar ef.e_mask ef) in - ML.e_let ld (expr info svar mask ein) (Mltree.I e.e_ity) eff lbl + let ld = ML.sym_defn rs res args (expr info svar cty.cty_mask ef) in + ML.e_let ld (expr info svar mask ein) (Mltree.I e.e_ity) mask eff lbl | Elet (LDsym (rs, {c_node = Capp (rs_app, pvl); c_cty = cty}), ein) when isconstructor info rs_app -> (* partial application of constructor *) let eta_app = mk_eta_expansion rs_app pvl cty in @@ -494,15 +508,18 @@ module Translate = struct let func = List.fold_right mk_func cty.cty_args cty.cty_result in let res = mlty_of_ity cty.cty_mask func in let ld = ML.sym_defn rs res [] eta_app in - ML.e_let ld (expr info svar mask ein) (Mltree.I e.e_ity) e.e_effect lbl + let ein = expr info svar mask ein in + ML.e_let ld ein (Mltree.I e.e_ity) mask eff lbl | Elet (LDsym (rsf, {c_node = Capp (rs_app, pvl); c_cty = cty}), ein) -> - (* partial application *) - let pvl = app pvl rs_app.rs_cty.cty_args in - let eff = cty.cty_effect in - let eapp = ML.e_app rs_app pvl (Mltree.C cty) eff Slab.empty in + (* partial application *) (* FIXME -> zero arguments functions *) + let cmk = cty.cty_mask in + let ceff = cty.cty_effect in + let pvl = app pvl rs_app.rs_cty.cty_args (fun x -> x) in + let eapp = ML.e_app rs_app pvl (Mltree.C cty) cmk ceff Slab.empty in let res = mlty_of_ity cty.cty_mask cty.cty_result in let ld = ML.sym_defn rsf res (params cty.cty_args) eapp in - ML.e_let ld (expr info svar mask ein) (Mltree.I e.e_ity) e.e_effect lbl + let ein = expr info svar mask ein in + ML.e_let ld ein (Mltree.I e.e_ity) mask eff lbl | Elet (LDrec rdefl, ein) -> let rdefl = filter_out_ghost_rdef rdefl in let def = function @@ -524,7 +541,7 @@ module Translate = struct if rdefl <> [] then let ein = expr info svar mask ein in let ml_letrec = Mltree.Elet (Mltree.Lrec rdefl, ein) in - ML.mk_expr ml_letrec (Mltree.I e.e_ity) e.e_effect lbl + ML.mk_expr ml_letrec (Mltree.I e.e_ity) mask e.e_effect lbl else expr info svar mask ein | Eexec ({c_node = Capp (rs, [])}, _) when is_rs_tuple rs -> ML.e_unit @@ -532,29 +549,30 @@ module Translate = struct let pvl = pv_list_of_mask pvl mask in let res_ity = ity_tuple (List.map (fun v -> v.pv_ity) pvl) in let pvl = ML.var_list_of_pv_list pvl in - ML.e_app rs pvl (Mltree.I res_ity) eff lbl - | Eexec ({c_node = Capp (rs, _)}, _) - when is_empty_record info rs || rs_ghost rs -> + ML.e_app rs pvl (Mltree.I res_ity) mask eff lbl + | Eexec ({c_node = Capp (rs, _)}, _) when is_empty_record info rs -> ML.e_unit | Eexec ({c_node = Capp (rs, pvl); c_cty = cty}, _) when isconstructor info rs && cty.cty_args <> [] -> (* partial application of constructors *) mk_eta_expansion rs pvl cty | Eexec ({c_node = Capp (rs, pvl); _}, _) -> - let pvl = app pvl rs.rs_cty.cty_args in + let add_unit = function [] -> [ML.e_unit] | args -> args in + let f_zero = if isconstructor info rs then fun x -> x else add_unit in + let pvl = app pvl rs.rs_cty.cty_args f_zero in begin match pvl with | [pv_expr] when is_optimizable_record_rs info rs -> pv_expr - | _ -> ML.e_app rs pvl (Mltree.I e.e_ity) eff lbl end + | _ -> ML.e_app rs pvl (Mltree.I e.e_ity) mask eff lbl end | Eexec ({c_node = Cfun e; c_cty = {cty_args = []}}, _) -> (* abstract block *) expr info svar e.e_mask e - | Eexec ({c_node = Cfun e; c_cty = cty}, _) -> - ML.e_fun (params cty.cty_args) (expr info svar e.e_mask e) - (Mltree.I e.e_ity) eff lbl + | Eexec ({c_node = Cfun ef; c_cty = cty}, _) -> + let ef = expr info svar e.e_mask ef in + ML.e_fun (params cty.cty_args) ef (Mltree.I e.e_ity) mask eff lbl | Eexec ({c_node = Cany}, _) -> ML.mk_hole | Eabsurd -> - ML.e_absurd (Mltree.I e.e_ity) eff lbl + ML.e_absurd (Mltree.I e.e_ity) mask eff lbl | Eassert _ -> ML.e_unit | Eif (e1, e2, e3) when e_ghost e1 -> @@ -565,32 +583,30 @@ module Translate = struct | Eif (e1, e2, e3) when e_ghost e3 -> let e1 = expr info svar e1.e_mask e1 in let e2 = expr info svar mask e2 in - ML.e_if e1 e2 ML.e_unit eff lbl + ML.e_if e1 e2 ML.e_unit mask eff lbl | Eif (e1, e2, e3) when e_ghost e2 -> let e1 = expr info svar e1.e_mask e1 in let e3 = expr info svar mask e3 in - ML.e_if e1 ML.e_unit e3 eff lbl + ML.e_if e1 ML.e_unit e3 mask eff lbl | Eif (e1, e2, e3) -> let e1 = expr info svar e1.e_mask e1 in let e2 = expr info svar mask e2 in let e3 = expr info svar mask e3 in - ML.e_if e1 e2 e3 eff lbl + ML.e_if e1 e2 e3 mask eff lbl | Ewhile (e1, _, _, e2) -> - assert (mask = MaskVisible); let e1 = expr info svar e1.e_mask e1 in let e2 = expr info svar e2.e_mask e2 in - ML.e_while e1 e2 eff lbl + ML.e_while e1 e2 mask eff lbl | Efor (pv1, (pv2, dir, pv3), _, _, efor) -> - assert (mask = MaskVisible); let dir = for_direction dir in let efor = expr info svar efor.e_mask efor in - ML.e_for pv1 pv2 dir pv3 efor eff lbl + ML.e_for pv1 pv2 dir pv3 efor mask eff lbl | Eghost _ | Epure _ -> assert false | Eassign al -> let rm_ghost (_, rs, _) = not (rs_ghost rs) in let al = List.filter rm_ghost al in - ML.e_assign al (Mltree.I e.e_ity) eff lbl + ML.e_assign al (Mltree.I e.e_ity) mask eff lbl | Ematch (e1, [], xl) when Mxs.is_empty xl -> expr info svar e1.e_mask e1 | Ematch (e1, bl, xl) when e_ghost e1 -> @@ -605,18 +621,18 @@ module Translate = struct (* NOTE: why no pv_list_of_mask here? *) let mk_xl (xs, (pvl, e)) = xs, pvl, expr info svar mask e in let xl = List.map mk_xl (Mxs.bindings xl) in - ML.e_match e1 bl xl (Mltree.I e.e_ity) eff lbl + ML.e_match e1 bl xl (Mltree.I e.e_ity) mask eff lbl | Eraise (xs, ex) -> let ex = match expr info svar xs.xs_mask ex with | {Mltree.e_node = Mltree.Eblock []} -> None | e -> Some e in - ML.mk_expr (Mltree.Eraise (xs, ex)) (Mltree.I e.e_ity) eff lbl + ML.mk_expr (Mltree.Eraise (xs, ex)) (Mltree.I e.e_ity) mask eff lbl | Eexn (xs, e1) -> if mask_ghost e1.e_mask then ML.mk_expr - (Mltree.Eexn (xs, None, ML.e_unit)) (Mltree.I e.e_ity) eff lbl + (Mltree.Eexn (xs, None, ML.e_unit)) (Mltree.I e.e_ity) mask eff lbl else let e1 = expr info svar xs.xs_mask e1 in let ty = if ity_equal xs.xs_ity ity_unit then None else Some (mlty_of_ity xs.xs_mask xs.xs_ity) in - ML.mk_expr (Mltree.Eexn (xs, ty, e1)) (Mltree.I e.e_ity) eff lbl + ML.mk_expr (Mltree.Eexn (xs, ty, e1)) (Mltree.I e.e_ity) mask eff lbl | Elet (LDsym (_, {c_node=(Cany|Cpur (_, _)); _ }), _) | Eexec ({c_node=Cpur (_, _); _ }, _) -> ML.mk_hole diff --git a/src/mlw/ocaml_printer.ml b/src/mlw/ocaml_printer.ml index b6e0e56f99..8d81151c6f 100644 --- a/src/mlw/ocaml_printer.ml +++ b/src/mlw/ocaml_printer.ml @@ -315,8 +315,10 @@ module Print = struct else fprintf fmt "%a" (print_expr ~paren:true info) expr; if exprl <> [] then fprintf fmt "@ "; print_apply_args info fmt (exprl, pvl) + | expr :: exprl, [] -> + fprintf fmt "%a" (print_expr ~paren:true info) expr; + print_apply_args info fmt (exprl, []) | [], _ -> () - | _, [] -> assert false and print_apply info rs fmt pvl = let isfield = @@ -360,7 +362,7 @@ module Print = struct end | _, None, [] -> (print_lident info) fmt rs.rs_name - | _, _, tl -> (* FIXME? when is in driver but is not a local id *) + | _, _, tl -> fprintf fmt "@[<hov 2>%a %a@]" (print_lident info) rs.rs_name (print_apply_args info) (tl, rs.rs_cty.cty_args) diff --git a/tests/test-extraction/main.ml b/tests/test-extraction/main.ml index 6b8f3aaac1..bbcc97e166 100644 --- a/tests/test-extraction/main.ml +++ b/tests/test-extraction/main.ml @@ -8,11 +8,11 @@ let () = assert (test_array () = 42) let (=) = Z.equal let b42 = Z.of_int 42 -let () = assert (test_int () = b42) -let () = assert (test_int63 () = b42) -let () = assert (test_ref () = b42) -let () = assert (test_array63 () = b42) - +let () = assert (test_int () = b42) +let () = assert (test_int63 () = b42) +let () = assert (test_ref () = b42) +let () = assert (test_array63 () = b42) +let () = assert (test_partial2 () = b42) let () = main () let () = Format.printf "OCaml extraction test successful@." diff --git a/tests/test_extraction.mlw b/tests/test_extraction.mlw index 859910d96e..5e9c002c33 100644 --- a/tests/test_extraction.mlw +++ b/tests/test_extraction.mlw @@ -172,6 +172,15 @@ module TestExtraction let partial = test_filter_ghost_args 3 in 42 + let constant test_partial2 : int = + let r = ref 0 in + let f (x: int) (ghost y) = r := !r + 21 in + let g = f 17 in + g (0:int); g (1:int); !r + + let test_zero_args () : int = + test_partial2 + 0 + let test_filter_ghost_args2 (x: int) (ghost y: int) (z: int) : int = x + z @@ -199,7 +208,7 @@ module TestExtraction let res = yxz - xzy in res - let test_partial2 (x: int) : int = + let test_partial3 (x: int) : int = let sum : int -> int -> int = fun x y -> x + y in let incr_a (a: int) = sum a in incr_a x x -- GitLab