From b7e23f63552f2f76252570826c333796fc2972f6 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?M=C3=A1rio=20Pereira?= <mpereira@lri.fr>
Date: Fri, 23 Mar 2018 15:29:05 +0100
Subject: [PATCH] Extraction: fix extraction of zero-argument functions and
 partial application

---
 src/mlw/compile.ml            | 154 +++++++++++++++++++---------------
 src/mlw/ocaml_printer.ml      |   6 +-
 tests/test-extraction/main.ml |  10 +--
 tests/test_extraction.mlw     |  11 ++-
 4 files changed, 104 insertions(+), 77 deletions(-)

diff --git a/src/mlw/compile.ml b/src/mlw/compile.ml
index 3c62667e02..2965ea7c0f 100644
--- a/src/mlw/compile.ml
+++ b/src/mlw/compile.ml
@@ -177,12 +177,23 @@ module ML = struct
     | Dexn (_, Some ty) -> iter_deps_ty f ty
     | Dmodule (_, dl) -> List.iter (iter_deps f) dl
 
-  let mk_expr e_node e_ity e_effect e_label =
-    { e_node; e_ity; e_effect; e_label; }
+  let ity_unit = I Ity.ity_unit
 
-  let tunit = Ttuple []
+  let ity_of_mask ity mask =
+    let mk_ty acc ty = function MaskGhost -> acc | _ -> ty :: acc in
+    match ity, mask with
+    | _, MaskGhost   -> ity_unit
+    | _, MaskVisible -> ity
+    | I ({ity_node = Ityapp ({its_ts = s}, tl, _)}), MaskTuple m
+      when is_ts_tuple s && List.length tl = List.length m ->
+        let tl = List.fold_left2 mk_ty [] tl m in
+        I (ity_tuple tl)
+    | _ -> ity (* FIXME ? *)
 
-  let ity_unit = I Ity.ity_unit
+  let mk_expr e_node e_ity mask e_effect e_label =
+    { e_node; e_ity = ity_of_mask e_ity mask; e_effect; e_label; }
+
+  let tunit = Ttuple []
 
   let is_unit = function
     | I i -> ity_equal i Ity.ity_unit
@@ -191,19 +202,18 @@ module ML = struct
   let enope = Eblock []
 
   let mk_hole =
-    mk_expr Ehole (I Ity.ity_unit) Ity.eff_empty Slab.empty
+    mk_expr Ehole (I Ity.ity_unit) MaskVisible Ity.eff_empty Slab.empty
 
   let mk_var id ty ghost = (id, ty, ghost)
 
   let mk_var_unit () = id_register (id_fresh "_"), tunit, true
 
-  let mk_its_defn id args private_ def =
-    { its_name    = id      ; its_args = args;
-      its_private = private_; its_def  = def; }
+  let mk_its_defn its_name its_args its_private its_def =
+    { its_name; its_args; its_private; its_def; }
 
   (* smart constructors *)
   let e_unit =
-    mk_expr enope (I Ity.ity_unit) Ity.eff_empty Slab.empty
+    mk_expr enope (I Ity.ity_unit) MaskVisible Ity.eff_empty Slab.empty
 
   let var_defn pv e =
     Lvar (pv, e)
@@ -220,7 +230,7 @@ module ML = struct
 
   let e_ignore e =
     if is_unit e.e_ity then e
-    else mk_expr (Eignore e) ity_unit e.e_effect e.e_label
+    else mk_expr (Eignore e) ity_unit MaskVisible e.e_effect e.e_label
 
   let e_if e1 e2 e3 =
     mk_expr (Mltree.Eif (e1, e2, e3)) e2.e_ity
@@ -241,8 +251,8 @@ module ML = struct
     mk_expr (Mltree.Etry (e, true, xl))
 *)
 
-  let e_assign al ity eff lbl =
-    if al = [] then e_unit else mk_expr (Mltree.Eassign al) ity eff lbl
+  let e_assign al ity mask eff lbl =
+    if al = [] then e_unit else mk_expr (Mltree.Eassign al) ity mask eff lbl
 
   let e_absurd =
     mk_expr Eabsurd
@@ -257,8 +267,8 @@ module ML = struct
     mk_expr e
 
   let var_list_of_pv_list pvl =
-    let mk_var pv =
-      mk_expr (Mltree.Evar pv) (Mltree.I pv.pv_ity) eff_empty Slab.empty in
+    let mk_var pv = mk_expr (Mltree.Evar pv) (Mltree.I pv.pv_ity)
+      MaskVisible eff_empty Slab.empty in
     List.map mk_var pvl
 
 end
@@ -323,7 +333,7 @@ module Translate = struct
     | Pas (p, vs) ->
         Mltree.Pas (pat m p, vs)
     | Papp (ls, pl) when is_fs_tuple ls ->
-        let pl = visible_of_mask m pl in
+        let pl = List.rev (visible_of_mask m pl) in
         begin match pl with
           | [] -> Mltree.Pwild
           | [p] -> pat m p
@@ -363,7 +373,7 @@ module Translate = struct
     | To -> Mltree.To
     | DownTo -> Mltree.DownTo
 
-  let isconstructor info rs =
+  let isconstructor info rs = (* TODO *)
     match Mid.find_opt rs.rs_name info.Mltree.from_km with
     | Some {pd_node = PDtype its} ->
         let is_constructor its =
@@ -406,17 +416,19 @@ module Translate = struct
 
   let mk_eta_expansion rs pvl ({cty_args = ca; cty_effect = ce} as c) =
     (* FIXME : effects and types of the expression in this situation *)
+    let mv = MaskVisible in
     let args_f =
       let def pv =
-        (pv_name pv, mlty_of_ity (mask_of_pv pv) pv.pv_ity, pv.pv_ghost) in
+        pv_name pv, mlty_of_ity (mask_of_pv pv) pv.pv_ity, pv.pv_ghost in
       filter_ghost_params pv_not_ghost def ca in
     let args =
-      let def pv =
-        ML.mk_expr (Mltree.Evar pv) (Mltree.I pv.pv_ity) eff_empty Slab.empty in
+      let def pv = ML.mk_expr (Mltree.Evar pv) (Mltree.I pv.pv_ity) mv
+        eff_empty Slab.empty in
       let args = filter_ghost_params pv_not_ghost def pvl in
       let extra_args = List.map def ca in args @ extra_args in
-    let eapp = ML.mk_expr (Mltree.Eapp (rs, args)) (Mltree.C c) ce Slab.empty in
-    ML.mk_expr (Mltree.Efun (args_f, eapp)) (Mltree.C c) ce Slab.empty
+    let eapp = ML.mk_expr (Mltree.Eapp (rs, args)) (Mltree.C c) mv
+      ce Slab.empty in
+    ML.mk_expr (Mltree.Efun (args_f, eapp)) (Mltree.C c) mv ce Slab.empty
 
   (* function arguments *)
   let filter_params args =
@@ -424,10 +436,9 @@ module Translate = struct
     let p (_, _, is_ghost) = not is_ghost in
     List.filter p args
 
-  let params = function
-    | []   -> []
-    | args -> let args = filter_params args in
-        if args = [] then [ML.mk_var_unit ()] else args
+  let params args =
+    let args = filter_params args in
+    if args = [] then [ML.mk_var_unit ()] else args
 
   let filter_params_cty p def pvl cty_args =
     let rec loop = function
@@ -437,10 +448,11 @@ module Translate = struct
       | _ -> assert false
     in loop (pvl, cty_args)
 
-  let app pvl cty_args =
-    let def pv =
-      ML.mk_expr (Mltree.Evar pv) (Mltree.I pv.pv_ity) eff_empty Slab.empty in
-    filter_params_cty pv_not_ghost def pvl cty_args
+  let app pvl cty_args f_zero =
+    let def pv = ML.mk_expr (Mltree.Evar pv) (Mltree.I pv.pv_ity) MaskVisible
+      eff_empty Slab.empty in
+    let args = filter_params_cty pv_not_ghost def pvl cty_args in
+    f_zero args
 
   (* build the set of type variables from functions arguments *)
   let rec add_tvar acc = function
@@ -461,32 +473,34 @@ module Translate = struct
           let pvl = List.fold_left2 mk_pv_of_mask [] pvl m in
           List.rev pvl in
     match e.e_node with
-    | Econst  _ | Evar _ | Eexec ({c_node = Cfun _}, _) | Eassign _
-    | Ewhile  _ | Efor _ | Eraise _ | Eexn _ | Eabsurd when mask = MaskGhost ->
+    | Econst _ | Evar _ | Eexec ({c_node = Cfun _}, _) (* FIXME *)
+      when mask = MaskGhost ->
         ML.e_unit
     | Econst c ->
         let c = match c with Number.ConstInt c -> c | _ -> assert false in
-        ML.mk_expr (Mltree.Econst c) (Mltree.I e.e_ity) eff lbl
-    | Evar pv -> ML.mk_expr (Mltree.Evar pv) (Mltree.I e.e_ity) eff lbl
-    | Elet (LDvar (_, e1), e2) when e_ghost e1 -> expr info svar mask e2
+        ML.mk_expr (Mltree.Econst c) (Mltree.I e.e_ity) mask eff lbl
+    | Evar pv ->
+        ML.mk_expr (Mltree.Evar pv) (Mltree.I e.e_ity) mask eff lbl
+    | Elet (LDvar (_, e1), e2) when e_ghost e1 ->
+        expr info svar mask e2
     | Elet (LDvar (_, e1), e2) when e_ghost e2 ->
         (* sequences are transformed into [let o = e1 in e2] by A-normal form *)
-        expr info svar e1.e_mask e1
+        expr info svar MaskGhost e1
     | Elet (LDvar (pv, e1), e2)
       when pv.pv_ghost || not (Mpv.mem pv e2.e_effect.eff_reads) ->
         if eff_pure e1.e_effect then expr info svar mask e2
-        else let e1 = ML.e_ignore (expr info svar e1.e_mask e1) in
-          ML.e_seq e1 (expr info svar mask e2) (Mltree.I e.e_ity) eff lbl
+        else let e1 = expr info svar MaskGhost e1 in
+          ML.e_seq e1 (expr info svar mask e2) (Mltree.I e.e_ity) mask eff lbl
     | Elet (LDvar (pv, e1), e2) ->
-        let ld = ML.var_defn pv (expr info svar e1.e_mask e1) in
-        ML.e_let ld (expr info svar mask e2) (Mltree.I e.e_ity) eff lbl
+        let ld = ML.var_defn pv (expr info svar MaskVisible e1) in
+        ML.e_let ld (expr info svar mask e2) (Mltree.I e.e_ity) mask eff lbl
     | Elet (LDsym (rs, _), ein) when rs_ghost rs ->
         expr info svar mask ein
     | Elet (LDsym (rs, {c_node = Cfun ef; c_cty = cty}), ein) ->
         let args = params cty.cty_args in
         let res = mlty_of_ity cty.cty_mask cty.cty_result in
-        let ld = ML.sym_defn rs res args (expr info svar ef.e_mask ef) in
-        ML.e_let ld (expr info svar mask ein) (Mltree.I e.e_ity) eff lbl
+        let ld = ML.sym_defn rs res args (expr info svar cty.cty_mask ef) in
+        ML.e_let ld (expr info svar mask ein) (Mltree.I e.e_ity) mask eff lbl
     | Elet (LDsym (rs, {c_node = Capp (rs_app, pvl); c_cty = cty}), ein)
       when isconstructor info rs_app -> (* partial application of constructor *)
         let eta_app = mk_eta_expansion rs_app pvl cty in
@@ -494,15 +508,18 @@ module Translate = struct
         let func = List.fold_right mk_func cty.cty_args cty.cty_result in
         let res = mlty_of_ity cty.cty_mask func in
         let ld = ML.sym_defn rs res [] eta_app in
-        ML.e_let ld (expr info svar mask ein) (Mltree.I e.e_ity) e.e_effect lbl
+        let ein = expr info svar mask ein in
+        ML.e_let ld ein (Mltree.I e.e_ity) mask eff lbl
     | Elet (LDsym (rsf, {c_node = Capp (rs_app, pvl); c_cty = cty}), ein) ->
-        (* partial application *)
-        let pvl = app pvl rs_app.rs_cty.cty_args in
-        let eff = cty.cty_effect in
-        let eapp = ML.e_app rs_app pvl (Mltree.C cty) eff Slab.empty in
+        (* partial application *) (* FIXME -> zero arguments functions *)
+        let cmk = cty.cty_mask in
+        let ceff = cty.cty_effect in
+        let pvl = app pvl rs_app.rs_cty.cty_args (fun x -> x) in
+        let eapp = ML.e_app rs_app pvl (Mltree.C cty) cmk ceff Slab.empty in
         let res = mlty_of_ity cty.cty_mask cty.cty_result in
         let ld = ML.sym_defn rsf res (params cty.cty_args) eapp in
-        ML.e_let ld (expr info svar mask ein) (Mltree.I e.e_ity) e.e_effect lbl
+        let ein = expr info svar mask ein in
+        ML.e_let ld ein (Mltree.I e.e_ity) mask eff lbl
     | Elet (LDrec rdefl, ein) ->
         let rdefl = filter_out_ghost_rdef rdefl in
         let def = function
@@ -524,7 +541,7 @@ module Translate = struct
         if rdefl <> [] then
           let ein = expr info svar mask ein in
           let ml_letrec = Mltree.Elet (Mltree.Lrec rdefl, ein) in
-          ML.mk_expr ml_letrec (Mltree.I e.e_ity) e.e_effect lbl
+          ML.mk_expr ml_letrec (Mltree.I e.e_ity) mask e.e_effect lbl
         else expr info svar mask ein
     | Eexec ({c_node = Capp (rs, [])}, _)  when is_rs_tuple rs ->
         ML.e_unit
@@ -532,29 +549,30 @@ module Translate = struct
         let pvl = pv_list_of_mask pvl mask in
         let res_ity = ity_tuple (List.map (fun v -> v.pv_ity) pvl) in
         let pvl = ML.var_list_of_pv_list pvl in
-        ML.e_app rs pvl (Mltree.I res_ity) eff lbl
-    | Eexec ({c_node = Capp (rs, _)}, _)
-      when is_empty_record info rs || rs_ghost rs ->
+        ML.e_app rs pvl (Mltree.I res_ity) mask eff lbl
+    | Eexec ({c_node = Capp (rs, _)}, _) when is_empty_record info rs ->
         ML.e_unit
     | Eexec ({c_node = Capp (rs, pvl); c_cty = cty}, _)
       when isconstructor info rs && cty.cty_args <> [] ->
         (* partial application of constructors *)
         mk_eta_expansion rs pvl cty
     | Eexec ({c_node = Capp (rs, pvl); _}, _) ->
-        let pvl = app pvl rs.rs_cty.cty_args in
+        let add_unit = function [] -> [ML.e_unit] | args -> args in
+        let f_zero = if isconstructor info rs then fun x -> x else add_unit in
+        let pvl = app pvl rs.rs_cty.cty_args f_zero in
         begin match pvl with
           | [pv_expr] when is_optimizable_record_rs info rs -> pv_expr
-          | _ -> ML.e_app rs pvl (Mltree.I e.e_ity) eff lbl end
+          | _ -> ML.e_app rs pvl (Mltree.I e.e_ity) mask eff lbl end
     | Eexec ({c_node = Cfun e; c_cty = {cty_args = []}}, _) ->
         (* abstract block *)
         expr info svar e.e_mask e
-    | Eexec ({c_node = Cfun e; c_cty = cty}, _) ->
-        ML.e_fun (params cty.cty_args) (expr info svar e.e_mask e)
-          (Mltree.I e.e_ity) eff lbl
+    | Eexec ({c_node = Cfun ef; c_cty = cty}, _) ->
+        let ef = expr info svar e.e_mask ef in
+        ML.e_fun (params cty.cty_args) ef (Mltree.I e.e_ity) mask eff lbl
     | Eexec ({c_node = Cany}, _) ->
         ML.mk_hole
     | Eabsurd ->
-        ML.e_absurd (Mltree.I e.e_ity) eff lbl
+        ML.e_absurd (Mltree.I e.e_ity) mask eff lbl
     | Eassert _ ->
         ML.e_unit
     | Eif (e1, e2, e3) when e_ghost e1 ->
@@ -565,32 +583,30 @@ module Translate = struct
     | Eif (e1, e2, e3) when e_ghost e3 ->
         let e1 = expr info svar e1.e_mask e1 in
         let e2 = expr info svar mask e2 in
-        ML.e_if e1 e2  ML.e_unit eff lbl
+        ML.e_if e1 e2 ML.e_unit mask eff lbl
     | Eif (e1, e2, e3) when e_ghost e2 ->
         let e1 = expr info svar e1.e_mask e1 in
         let e3 = expr info svar mask e3 in
-        ML.e_if e1 ML.e_unit e3 eff lbl
+        ML.e_if e1 ML.e_unit e3 mask eff lbl
     | Eif (e1, e2, e3) ->
         let e1 = expr info svar e1.e_mask e1 in
         let e2 = expr info svar mask e2 in
         let e3 = expr info svar mask e3 in
-        ML.e_if e1 e2 e3 eff lbl
+        ML.e_if e1 e2 e3 mask eff lbl
     | Ewhile (e1, _, _, e2) ->
-        assert (mask = MaskVisible);
         let e1 = expr info svar e1.e_mask e1 in
         let e2 = expr info svar e2.e_mask e2 in
-        ML.e_while e1 e2 eff lbl
+        ML.e_while e1 e2 mask eff lbl
     | Efor (pv1, (pv2, dir, pv3), _, _, efor) ->
-        assert (mask = MaskVisible);
         let dir = for_direction dir in
         let efor = expr info svar efor.e_mask efor in
-        ML.e_for pv1 pv2 dir pv3 efor eff lbl
+        ML.e_for pv1 pv2 dir pv3 efor mask eff lbl
     | Eghost _ | Epure _ ->
         assert false
     | Eassign al ->
         let rm_ghost (_, rs, _) = not (rs_ghost rs) in
         let al = List.filter rm_ghost al in
-        ML.e_assign al (Mltree.I e.e_ity) eff lbl
+        ML.e_assign al (Mltree.I e.e_ity) mask eff lbl
     | Ematch (e1, [], xl) when Mxs.is_empty xl ->
         expr info svar e1.e_mask e1
     | Ematch (e1, bl, xl) when e_ghost e1 ->
@@ -605,18 +621,18 @@ module Translate = struct
         (* NOTE: why no pv_list_of_mask here? *)
         let mk_xl (xs, (pvl, e)) = xs, pvl, expr info svar mask e in
         let xl = List.map mk_xl (Mxs.bindings xl) in
-        ML.e_match e1 bl xl (Mltree.I e.e_ity) eff lbl
+        ML.e_match e1 bl xl (Mltree.I e.e_ity) mask eff lbl
     | Eraise (xs, ex) -> let ex = match expr info svar xs.xs_mask ex with
         | {Mltree.e_node = Mltree.Eblock []} -> None
         | e -> Some e in
-        ML.mk_expr (Mltree.Eraise (xs, ex)) (Mltree.I e.e_ity) eff lbl
+        ML.mk_expr (Mltree.Eraise (xs, ex)) (Mltree.I e.e_ity) mask eff lbl
     | Eexn (xs, e1) ->
         if mask_ghost e1.e_mask then ML.mk_expr
-          (Mltree.Eexn (xs, None, ML.e_unit)) (Mltree.I e.e_ity) eff lbl
+          (Mltree.Eexn (xs, None, ML.e_unit)) (Mltree.I e.e_ity) mask eff lbl
         else let e1 = expr info svar xs.xs_mask e1 in
           let ty = if ity_equal xs.xs_ity ity_unit then None
             else Some (mlty_of_ity xs.xs_mask xs.xs_ity) in
-        ML.mk_expr (Mltree.Eexn (xs, ty, e1)) (Mltree.I e.e_ity) eff lbl
+        ML.mk_expr (Mltree.Eexn (xs, ty, e1)) (Mltree.I e.e_ity) mask eff lbl
     | Elet (LDsym (_, {c_node=(Cany|Cpur (_, _)); _ }), _)
     | Eexec ({c_node=Cpur (_, _); _ }, _) -> ML.mk_hole
 
diff --git a/src/mlw/ocaml_printer.ml b/src/mlw/ocaml_printer.ml
index b6e0e56f99..8d81151c6f 100644
--- a/src/mlw/ocaml_printer.ml
+++ b/src/mlw/ocaml_printer.ml
@@ -315,8 +315,10 @@ module Print = struct
         else fprintf fmt "%a" (print_expr ~paren:true info) expr;
         if exprl <> [] then fprintf fmt "@ ";
         print_apply_args info fmt (exprl, pvl)
+    | expr :: exprl, [] ->
+        fprintf fmt "%a" (print_expr ~paren:true info) expr;
+        print_apply_args info fmt (exprl, [])
     | [], _ -> ()
-    | _, [] -> assert false
 
   and print_apply info rs fmt pvl =
     let isfield =
@@ -360,7 +362,7 @@ module Print = struct
         end
     | _, None, [] ->
         (print_lident info) fmt rs.rs_name
-    | _, _, tl -> (* FIXME? when is in driver but is not a local id *)
+    | _, _, tl ->
         fprintf fmt "@[<hov 2>%a %a@]"
           (print_lident info) rs.rs_name
           (print_apply_args info) (tl, rs.rs_cty.cty_args)
diff --git a/tests/test-extraction/main.ml b/tests/test-extraction/main.ml
index 6b8f3aaac1..bbcc97e166 100644
--- a/tests/test-extraction/main.ml
+++ b/tests/test-extraction/main.ml
@@ -8,11 +8,11 @@ let () = assert (test_array   () = 42)
 let (=) = Z.equal
 
 let b42 = Z.of_int 42
-let () = assert (test_int     () = b42)
-let () = assert (test_int63   () = b42)
-let () = assert (test_ref     () = b42)
-let () = assert (test_array63 () = b42)
-
+let () = assert (test_int      () = b42)
+let () = assert (test_int63    () = b42)
+let () = assert (test_ref      () = b42)
+let () = assert (test_array63  () = b42)
+let () = assert (test_partial2 ()   = b42)
 let () = main ()
 
 let () = Format.printf "OCaml extraction test successful@."
diff --git a/tests/test_extraction.mlw b/tests/test_extraction.mlw
index 859910d96e..5e9c002c33 100644
--- a/tests/test_extraction.mlw
+++ b/tests/test_extraction.mlw
@@ -172,6 +172,15 @@ module TestExtraction
     let partial = test_filter_ghost_args 3 in
     42
 
+  let constant test_partial2 : int =
+    let r = ref 0 in
+    let f (x: int) (ghost y) = r := !r + 21 in
+    let g = f 17 in
+    g (0:int); g (1:int); !r
+
+  let test_zero_args () : int =
+    test_partial2 + 0
+
   let test_filter_ghost_args2 (x: int) (ghost y: int) (z: int) : int =
     x + z
 
@@ -199,7 +208,7 @@ module TestExtraction
     let res = yxz - xzy in
     res
 
-  let test_partial2 (x: int) : int =
+  let test_partial3 (x: int) : int =
     let sum : int -> int -> int = fun x y -> x + y in
     let incr_a (a: int) = sum a in
     incr_a x x
-- 
GitLab