From 8c916bc6bfc89a7c582fd84ea52ea212ff66f8f2 Mon Sep 17 00:00:00 2001
From: Raphael Rieu-Helft <raphael.rieu-helft@lri.fr>
Date: Thu, 16 Aug 2018 15:55:11 +0200
Subject: [PATCH] Eliminate internal symbols for let-recs in Compile

---
 src/mlw/cakeml_printer.ml |  8 ++------
 src/mlw/compile.ml        | 25 ++++++++++++++++---------
 src/mlw/mlinterp.ml       |  7 ++-----
 src/mlw/mltree.ml         |  2 +-
 src/mlw/ocaml_printer.ml  |  8 ++------
 5 files changed, 23 insertions(+), 27 deletions(-)

diff --git a/src/mlw/cakeml_printer.ml b/src/mlw/cakeml_printer.ml
index b5d83079ff..3cbc7762e2 100644
--- a/src/mlw/cakeml_printer.ml
+++ b/src/mlw/cakeml_printer.ml
@@ -97,8 +97,6 @@ module Print = struct
   let pv_name pv = pv.pv_vs.vs_name
   let print_pv info fmt pv = print_lident info fmt (pv_name pv)
 
-  let ht_rs = Hrs.create 7 (* rec_rsym -> rec_sym *)
-
   (* FIXME put these in Compile*)
   let is_true e = match e.e_node with
     | Eapp (s, []) -> rs_equal s rs_true
@@ -222,9 +220,7 @@ module Print = struct
                 (print_lident info) rs1.rs_name
                 (print_fun_type_args info) (args, s, res, e);
               forget_vars args in
-        List.iter (fun fd -> Hrs.replace ht_rs fd.rec_rsym fd.rec_sym) rdef;
         print_list_next newline print_one fmt rdef;
-        List.iter (fun fd -> Hrs.remove ht_rs fd.rec_rsym) rdef
     | Lany (rs, _, _) ->
         check_val_in_drv info rs
 
@@ -254,14 +250,14 @@ module Print = struct
     | Eapp (rs, []) when rs_equal rs rs_false ->
         fprintf fmt "false"
     | Eapp (rs, [])  -> (* avoids parenthesis around values *)
-        fprintf fmt "%a" (print_apply info (Hrs.find_def ht_rs rs rs)) []
+        fprintf fmt "%a" (print_apply info rs) []
     | Eapp (rs, pvl) ->
         begin match query_syntax info.info_convert rs.rs_name, pvl with
           | Some s, [{e_node = Econst _}] ->
               syntax_arguments s print_constant fmt pvl
           | _ ->
               fprintf fmt (protect_on paren "%a")
-                (print_apply info (Hrs.find_def ht_rs rs rs)) pvl end
+                (print_apply info rs) pvl end
     | Ematch (e1, [p, e2], []) ->
         fprintf fmt (protect_on paren "let %a =@ %a in@ %a")
           (print_pat info) p (print_expr info) e1 (print_expr info) e2
diff --git a/src/mlw/compile.ml b/src/mlw/compile.ml
index f64966d61a..d6d3068040 100644
--- a/src/mlw/compile.ml
+++ b/src/mlw/compile.ml
@@ -49,6 +49,8 @@ module Translate = struct
 
   module ML = Mltree
 
+  let ht_rs = Hrs.create 7 (* rec_rsym -> rec_sym *)
+
   let debug_compile =
     Debug.register_info_flag ~desc:"Compilation" "compile"
 
@@ -188,6 +190,7 @@ module Translate = struct
 
   let mk_eta_expansion rs pvl ({cty_args = ca; cty_effect = ce} as c) =
     (* FIXME : effects and types of the expression in this situation *)
+    let rs = Hrs.find_def ht_rs rs rs in
     let mv = MaskVisible in
     let args_f =
       let def pv =
@@ -298,6 +301,7 @@ module Translate = struct
         let cmk = cty.cty_mask in
         let ceff = cty.cty_effect in
         let pvl = app pvl rs_app.rs_cty.cty_args (fun x -> x) in
+        let rs_app = Hrs.find_def ht_rs rs_app rs_app in
         let eapp = ML.e_app rs_app pvl (ML.C cty) cmk ceff Sattr.empty in
         let res = mlty_of_ity cty.cty_mask cty.cty_result in
         let ld = ML.sym_defn rsf res (params cty.cty_args) eapp in
@@ -305,9 +309,11 @@ module Translate = struct
         ML.e_let ld ein (ML.I e.e_ity) mask eff attrs
     | Elet (LDrec rdefl, ein) ->
         let rdefl = filter_out_ghost_rdef rdefl in
+        List.iter
+          (fun { rec_sym = rs1; rec_rsym = rs2; } ->
+            Hrs.replace ht_rs rs2 rs1) rdefl;
         let def = function
-          | { rec_sym = rs1; rec_rsym = rs2;
-              rec_fun = {c_node = Cfun ef; c_cty = cty} } ->
+          | { rec_sym = rs1; rec_fun = {c_node = Cfun ef; c_cty = cty} } ->
               let res = mlty_of_ity rs1.rs_cty.cty_mask rs1.rs_cty.cty_result in
               let args = params cty.cty_args in
               let new_svar =
@@ -316,9 +322,8 @@ module Translate = struct
                 add_tvar svar res in
               let new_svar = Stv.diff svar new_svar in
               let ef = expr info (Stv.union svar new_svar) ef.e_mask ef in
-              { ML.rec_sym  = rs1;  ML.rec_rsym = rs2;
-                ML.rec_args = args; ML.rec_exp  = ef;
-                ML.rec_res  = res;  ML.rec_svar = new_svar; }
+              { ML.rec_sym  = rs1; ML.rec_args = args; ML.rec_exp  = ef;
+                ML.rec_res  = res; ML.rec_svar = new_svar; }
           | _ -> assert false in
         let rdefl = List.map def rdefl in
         if rdefl <> [] then
@@ -343,6 +348,7 @@ module Translate = struct
         Debug.dprintf debug_compile "compiling total application of %s@."
           rs.rs_name.id_string;
         Debug.dprintf debug_compile "cty_args: %d@." (List.length cty.cty_args);
+        let rs = Hrs.find_def ht_rs rs rs in
         let add_unit = function [] -> [ML.e_unit] | args -> args in
         let id_f = fun x -> x in
         let f_zero = match rs.rs_logic with RLnone ->
@@ -542,7 +548,9 @@ module Translate = struct
         [ML.Dlet (ML.Lsym (rs, res, args, e))]
     | PDlet (LDrec rl) ->
         let rl = filter_out_ghost_rdef rl in
-        let def {rec_fun = e; rec_sym = rs1; rec_rsym = rs2} =
+        List.iter (fun {rec_sym = rs1; rec_rsym = rs2} ->
+            Hrs.replace ht_rs rs2 rs1) rl;
+        let def {rec_fun = e; rec_sym = rs1} =
           let e = match e.c_node with Cfun e -> e | _ -> assert false in
           let args = params rs1.rs_cty.cty_args in
           let res  = mlty_of_ity rs1.rs_cty.cty_mask rs1.rs_cty.cty_result in
@@ -551,9 +559,8 @@ module Translate = struct
             let svar  = List.fold_left add_tvar Stv.empty args' in
             add_tvar svar res in
           let e = expr info svar rs1.rs_cty.cty_mask e in
-          { ML.rec_sym  = rs1;  ML.rec_rsym = rs2;
-            ML.rec_args = args; ML.rec_exp  = e;
-            ML.rec_res  = res;  ML.rec_svar = svar; } in
+          { ML.rec_sym  = rs1; ML.rec_args = args; ML.rec_exp  = e;
+            ML.rec_res  = res; ML.rec_svar = svar; } in
         if rl = [] then [] else [ML.Dlet (ML.Lrec (List.map def rl))]
     | PDlet (LDsym _) | PDpure ->
         []
diff --git a/src/mlw/mlinterp.ml b/src/mlw/mlinterp.ml
index 14749b985b..9ec87b1d28 100644
--- a/src/mlw/mlinterp.ml
+++ b/src/mlw/mlinterp.ml
@@ -41,7 +41,6 @@ type info = {
     env : Env.env;
     mm  : Pmodule.pmodule Mstr.t;
     vars: value Mid.t;
-    recs: rsymbol Mrs.t;
     funs: decl Mrs.t;
     get_decl: rsymbol -> Mltree.decl;
     cur_rs: rsymbol; (* current function *)
@@ -632,7 +631,6 @@ let rec interp_expr info (e:Mltree.expr) : value =
           v end in
       Debug.dprintf debug_interp "eval call@.";
       let res = try begin
-        let rs = if Mrs.mem rs info.recs then Mrs.find rs info.recs else rs in
         if Hrs.mem builtin_progs rs
         then
           (Debug.dprintf debug_interp "%a is builtin@." Expr.print_rs rs;
@@ -646,8 +644,8 @@ let rec interp_expr info (e:Mltree.expr) : value =
           | Dlet (Lsym (rs, _ty, vl, e)) ->
              eval_call info vl e rs
           | Dlet(Lrec([{rec_args = vl; rec_exp = e;
-                        rec_sym = rs; rec_rsym = rrs; rec_res=_ty}])) ->
-             eval_call { info with recs = Mrs.add rrs rs info.recs } vl e rs
+                        rec_sym = rs; rec_res=_ty}])) ->
+             eval_call info vl e rs
           | Dlet (Lrec _) ->
              Debug.dprintf
                debug_interp "unhandled mutually recursive functions@.";
@@ -863,7 +861,6 @@ let init_info env mm rs vars =
   { env = env;
     mm = mm;
     funs = Mrs.empty;
-    recs = Mrs.empty;
     vars = vars;
     get_decl = get_decl env mm;
     cur_rs = rs;
diff --git a/src/mlw/mltree.ml b/src/mlw/mltree.ml
index 7e7c8b63d5..15fb20f20e 100644
--- a/src/mlw/mltree.ml
+++ b/src/mlw/mltree.ml
@@ -79,7 +79,7 @@ and let_def =
 
 and rdef = {
   rec_sym  : rsymbol; (* exported *)
-  rec_rsym : rsymbol; (* internal *)
+  (* rec_rsym : rsymbol;*) (* internal *)
   rec_args : var list;
   rec_exp  : expr;
   rec_res  : ty;
diff --git a/src/mlw/ocaml_printer.ml b/src/mlw/ocaml_printer.ml
index c6d1559b66..b281c27b5b 100644
--- a/src/mlw/ocaml_printer.ml
+++ b/src/mlw/ocaml_printer.ml
@@ -264,8 +264,6 @@ module Print = struct
   let pv_name pv = pv.pv_vs.vs_name
   let print_pv info fmt pv = print_lident info fmt (pv_name pv)
 
-  let ht_rs = Hrs.create 7 (* rec_rsym -> rec_sym *)
-
   (* FIXME put these in Compile*)
   let is_true e = match e.e_node with
     | Eapp (s, []) -> rs_equal s rs_true
@@ -411,9 +409,7 @@ module Print = struct
                 (print_fun_type_args info) (args, s, res, e);
               forget_vars args
         in
-        List.iter (fun fd -> Hrs.replace ht_rs fd.rec_rsym fd.rec_sym) rdef;
         print_list_next newline print_one fmt rdef;
-        List.iter (fun fd -> Hrs.remove ht_rs fd.rec_rsym) rdef
     | Lany (rs, res, []) when functor_arg ->
         fprintf fmt "@[<hov 2>val %a : %a@]"
           (print_lident info) rs.rs_name
@@ -456,14 +452,14 @@ module Print = struct
     | Eapp (rs, []) when rs_equal rs rs_false ->
         fprintf fmt "false"
     | Eapp (rs, [])  -> (* avoids parenthesis around values *)
-        fprintf fmt "%a" (print_apply info (Hrs.find_def ht_rs rs rs)) []
+        fprintf fmt "%a" (print_apply info rs) []
     | Eapp (rs, pvl) ->
         begin match query_syntax info.info_convert rs.rs_name, pvl with
           | Some s, [{e_node = Econst _}] ->
               syntax_arguments s print_constant fmt pvl
           | _ ->
               fprintf fmt (protect_on paren "%a")
-                (print_apply info (Hrs.find_def ht_rs rs rs)) pvl end
+                (print_apply info rs) pvl end
     | Ematch (e1, [p, e2], []) ->
         fprintf fmt (protect_on paren "let %a =@ %a in@ %a")
           (print_pat info) p (print_expr info) e1 (print_expr info) e2
-- 
GitLab