From 8c916bc6bfc89a7c582fd84ea52ea212ff66f8f2 Mon Sep 17 00:00:00 2001 From: Raphael Rieu-Helft <raphael.rieu-helft@lri.fr> Date: Thu, 16 Aug 2018 15:55:11 +0200 Subject: [PATCH] Eliminate internal symbols for let-recs in Compile --- src/mlw/cakeml_printer.ml | 8 ++------ src/mlw/compile.ml | 25 ++++++++++++++++--------- src/mlw/mlinterp.ml | 7 ++----- src/mlw/mltree.ml | 2 +- src/mlw/ocaml_printer.ml | 8 ++------ 5 files changed, 23 insertions(+), 27 deletions(-) diff --git a/src/mlw/cakeml_printer.ml b/src/mlw/cakeml_printer.ml index b5d83079ff..3cbc7762e2 100644 --- a/src/mlw/cakeml_printer.ml +++ b/src/mlw/cakeml_printer.ml @@ -97,8 +97,6 @@ module Print = struct let pv_name pv = pv.pv_vs.vs_name let print_pv info fmt pv = print_lident info fmt (pv_name pv) - let ht_rs = Hrs.create 7 (* rec_rsym -> rec_sym *) - (* FIXME put these in Compile*) let is_true e = match e.e_node with | Eapp (s, []) -> rs_equal s rs_true @@ -222,9 +220,7 @@ module Print = struct (print_lident info) rs1.rs_name (print_fun_type_args info) (args, s, res, e); forget_vars args in - List.iter (fun fd -> Hrs.replace ht_rs fd.rec_rsym fd.rec_sym) rdef; print_list_next newline print_one fmt rdef; - List.iter (fun fd -> Hrs.remove ht_rs fd.rec_rsym) rdef | Lany (rs, _, _) -> check_val_in_drv info rs @@ -254,14 +250,14 @@ module Print = struct | Eapp (rs, []) when rs_equal rs rs_false -> fprintf fmt "false" | Eapp (rs, []) -> (* avoids parenthesis around values *) - fprintf fmt "%a" (print_apply info (Hrs.find_def ht_rs rs rs)) [] + fprintf fmt "%a" (print_apply info rs) [] | Eapp (rs, pvl) -> begin match query_syntax info.info_convert rs.rs_name, pvl with | Some s, [{e_node = Econst _}] -> syntax_arguments s print_constant fmt pvl | _ -> fprintf fmt (protect_on paren "%a") - (print_apply info (Hrs.find_def ht_rs rs rs)) pvl end + (print_apply info rs) pvl end | Ematch (e1, [p, e2], []) -> fprintf fmt (protect_on paren "let %a =@ %a in@ %a") (print_pat info) p (print_expr info) e1 (print_expr info) e2 diff --git a/src/mlw/compile.ml b/src/mlw/compile.ml index f64966d61a..d6d3068040 100644 --- a/src/mlw/compile.ml +++ b/src/mlw/compile.ml @@ -49,6 +49,8 @@ module Translate = struct module ML = Mltree + let ht_rs = Hrs.create 7 (* rec_rsym -> rec_sym *) + let debug_compile = Debug.register_info_flag ~desc:"Compilation" "compile" @@ -188,6 +190,7 @@ module Translate = struct let mk_eta_expansion rs pvl ({cty_args = ca; cty_effect = ce} as c) = (* FIXME : effects and types of the expression in this situation *) + let rs = Hrs.find_def ht_rs rs rs in let mv = MaskVisible in let args_f = let def pv = @@ -298,6 +301,7 @@ module Translate = struct let cmk = cty.cty_mask in let ceff = cty.cty_effect in let pvl = app pvl rs_app.rs_cty.cty_args (fun x -> x) in + let rs_app = Hrs.find_def ht_rs rs_app rs_app in let eapp = ML.e_app rs_app pvl (ML.C cty) cmk ceff Sattr.empty in let res = mlty_of_ity cty.cty_mask cty.cty_result in let ld = ML.sym_defn rsf res (params cty.cty_args) eapp in @@ -305,9 +309,11 @@ module Translate = struct ML.e_let ld ein (ML.I e.e_ity) mask eff attrs | Elet (LDrec rdefl, ein) -> let rdefl = filter_out_ghost_rdef rdefl in + List.iter + (fun { rec_sym = rs1; rec_rsym = rs2; } -> + Hrs.replace ht_rs rs2 rs1) rdefl; let def = function - | { rec_sym = rs1; rec_rsym = rs2; - rec_fun = {c_node = Cfun ef; c_cty = cty} } -> + | { rec_sym = rs1; rec_fun = {c_node = Cfun ef; c_cty = cty} } -> let res = mlty_of_ity rs1.rs_cty.cty_mask rs1.rs_cty.cty_result in let args = params cty.cty_args in let new_svar = @@ -316,9 +322,8 @@ module Translate = struct add_tvar svar res in let new_svar = Stv.diff svar new_svar in let ef = expr info (Stv.union svar new_svar) ef.e_mask ef in - { ML.rec_sym = rs1; ML.rec_rsym = rs2; - ML.rec_args = args; ML.rec_exp = ef; - ML.rec_res = res; ML.rec_svar = new_svar; } + { ML.rec_sym = rs1; ML.rec_args = args; ML.rec_exp = ef; + ML.rec_res = res; ML.rec_svar = new_svar; } | _ -> assert false in let rdefl = List.map def rdefl in if rdefl <> [] then @@ -343,6 +348,7 @@ module Translate = struct Debug.dprintf debug_compile "compiling total application of %s@." rs.rs_name.id_string; Debug.dprintf debug_compile "cty_args: %d@." (List.length cty.cty_args); + let rs = Hrs.find_def ht_rs rs rs in let add_unit = function [] -> [ML.e_unit] | args -> args in let id_f = fun x -> x in let f_zero = match rs.rs_logic with RLnone -> @@ -542,7 +548,9 @@ module Translate = struct [ML.Dlet (ML.Lsym (rs, res, args, e))] | PDlet (LDrec rl) -> let rl = filter_out_ghost_rdef rl in - let def {rec_fun = e; rec_sym = rs1; rec_rsym = rs2} = + List.iter (fun {rec_sym = rs1; rec_rsym = rs2} -> + Hrs.replace ht_rs rs2 rs1) rl; + let def {rec_fun = e; rec_sym = rs1} = let e = match e.c_node with Cfun e -> e | _ -> assert false in let args = params rs1.rs_cty.cty_args in let res = mlty_of_ity rs1.rs_cty.cty_mask rs1.rs_cty.cty_result in @@ -551,9 +559,8 @@ module Translate = struct let svar = List.fold_left add_tvar Stv.empty args' in add_tvar svar res in let e = expr info svar rs1.rs_cty.cty_mask e in - { ML.rec_sym = rs1; ML.rec_rsym = rs2; - ML.rec_args = args; ML.rec_exp = e; - ML.rec_res = res; ML.rec_svar = svar; } in + { ML.rec_sym = rs1; ML.rec_args = args; ML.rec_exp = e; + ML.rec_res = res; ML.rec_svar = svar; } in if rl = [] then [] else [ML.Dlet (ML.Lrec (List.map def rl))] | PDlet (LDsym _) | PDpure -> [] diff --git a/src/mlw/mlinterp.ml b/src/mlw/mlinterp.ml index 14749b985b..9ec87b1d28 100644 --- a/src/mlw/mlinterp.ml +++ b/src/mlw/mlinterp.ml @@ -41,7 +41,6 @@ type info = { env : Env.env; mm : Pmodule.pmodule Mstr.t; vars: value Mid.t; - recs: rsymbol Mrs.t; funs: decl Mrs.t; get_decl: rsymbol -> Mltree.decl; cur_rs: rsymbol; (* current function *) @@ -632,7 +631,6 @@ let rec interp_expr info (e:Mltree.expr) : value = v end in Debug.dprintf debug_interp "eval call@."; let res = try begin - let rs = if Mrs.mem rs info.recs then Mrs.find rs info.recs else rs in if Hrs.mem builtin_progs rs then (Debug.dprintf debug_interp "%a is builtin@." Expr.print_rs rs; @@ -646,8 +644,8 @@ let rec interp_expr info (e:Mltree.expr) : value = | Dlet (Lsym (rs, _ty, vl, e)) -> eval_call info vl e rs | Dlet(Lrec([{rec_args = vl; rec_exp = e; - rec_sym = rs; rec_rsym = rrs; rec_res=_ty}])) -> - eval_call { info with recs = Mrs.add rrs rs info.recs } vl e rs + rec_sym = rs; rec_res=_ty}])) -> + eval_call info vl e rs | Dlet (Lrec _) -> Debug.dprintf debug_interp "unhandled mutually recursive functions@."; @@ -863,7 +861,6 @@ let init_info env mm rs vars = { env = env; mm = mm; funs = Mrs.empty; - recs = Mrs.empty; vars = vars; get_decl = get_decl env mm; cur_rs = rs; diff --git a/src/mlw/mltree.ml b/src/mlw/mltree.ml index 7e7c8b63d5..15fb20f20e 100644 --- a/src/mlw/mltree.ml +++ b/src/mlw/mltree.ml @@ -79,7 +79,7 @@ and let_def = and rdef = { rec_sym : rsymbol; (* exported *) - rec_rsym : rsymbol; (* internal *) + (* rec_rsym : rsymbol;*) (* internal *) rec_args : var list; rec_exp : expr; rec_res : ty; diff --git a/src/mlw/ocaml_printer.ml b/src/mlw/ocaml_printer.ml index c6d1559b66..b281c27b5b 100644 --- a/src/mlw/ocaml_printer.ml +++ b/src/mlw/ocaml_printer.ml @@ -264,8 +264,6 @@ module Print = struct let pv_name pv = pv.pv_vs.vs_name let print_pv info fmt pv = print_lident info fmt (pv_name pv) - let ht_rs = Hrs.create 7 (* rec_rsym -> rec_sym *) - (* FIXME put these in Compile*) let is_true e = match e.e_node with | Eapp (s, []) -> rs_equal s rs_true @@ -411,9 +409,7 @@ module Print = struct (print_fun_type_args info) (args, s, res, e); forget_vars args in - List.iter (fun fd -> Hrs.replace ht_rs fd.rec_rsym fd.rec_sym) rdef; print_list_next newline print_one fmt rdef; - List.iter (fun fd -> Hrs.remove ht_rs fd.rec_rsym) rdef | Lany (rs, res, []) when functor_arg -> fprintf fmt "@[<hov 2>val %a : %a@]" (print_lident info) rs.rs_name @@ -456,14 +452,14 @@ module Print = struct | Eapp (rs, []) when rs_equal rs rs_false -> fprintf fmt "false" | Eapp (rs, []) -> (* avoids parenthesis around values *) - fprintf fmt "%a" (print_apply info (Hrs.find_def ht_rs rs rs)) [] + fprintf fmt "%a" (print_apply info rs) [] | Eapp (rs, pvl) -> begin match query_syntax info.info_convert rs.rs_name, pvl with | Some s, [{e_node = Econst _}] -> syntax_arguments s print_constant fmt pvl | _ -> fprintf fmt (protect_on paren "%a") - (print_apply info (Hrs.find_def ht_rs rs rs)) pvl end + (print_apply info rs) pvl end | Ematch (e1, [p, e2], []) -> fprintf fmt (protect_on paren "let %a =@ %a in@ %a") (print_pat info) p (print_expr info) e1 (print_expr info) e2 -- GitLab