Commit 8c916bc6 authored by Raphael Rieu-Helft's avatar Raphael Rieu-Helft

Eliminate internal symbols for let-recs in Compile

parent 16645aa5
...@@ -97,8 +97,6 @@ module Print = struct ...@@ -97,8 +97,6 @@ module Print = struct
let pv_name pv = pv.pv_vs.vs_name let pv_name pv = pv.pv_vs.vs_name
let print_pv info fmt pv = print_lident info fmt (pv_name pv) let print_pv info fmt pv = print_lident info fmt (pv_name pv)
let ht_rs = Hrs.create 7 (* rec_rsym -> rec_sym *)
(* FIXME put these in Compile*) (* FIXME put these in Compile*)
let is_true e = match e.e_node with let is_true e = match e.e_node with
| Eapp (s, []) -> rs_equal s rs_true | Eapp (s, []) -> rs_equal s rs_true
...@@ -222,9 +220,7 @@ module Print = struct ...@@ -222,9 +220,7 @@ module Print = struct
(print_lident info) rs1.rs_name (print_lident info) rs1.rs_name
(print_fun_type_args info) (args, s, res, e); (print_fun_type_args info) (args, s, res, e);
forget_vars args in forget_vars args in
List.iter (fun fd -> Hrs.replace ht_rs fd.rec_rsym fd.rec_sym) rdef;
print_list_next newline print_one fmt rdef; print_list_next newline print_one fmt rdef;
List.iter (fun fd -> Hrs.remove ht_rs fd.rec_rsym) rdef
| Lany (rs, _, _) -> | Lany (rs, _, _) ->
check_val_in_drv info rs check_val_in_drv info rs
...@@ -254,14 +250,14 @@ module Print = struct ...@@ -254,14 +250,14 @@ module Print = struct
| Eapp (rs, []) when rs_equal rs rs_false -> | Eapp (rs, []) when rs_equal rs rs_false ->
fprintf fmt "false" fprintf fmt "false"
| Eapp (rs, []) -> (* avoids parenthesis around values *) | Eapp (rs, []) -> (* avoids parenthesis around values *)
fprintf fmt "%a" (print_apply info (Hrs.find_def ht_rs rs rs)) [] fprintf fmt "%a" (print_apply info rs) []
| Eapp (rs, pvl) -> | Eapp (rs, pvl) ->
begin match query_syntax info.info_convert rs.rs_name, pvl with begin match query_syntax info.info_convert rs.rs_name, pvl with
| Some s, [{e_node = Econst _}] -> | Some s, [{e_node = Econst _}] ->
syntax_arguments s print_constant fmt pvl syntax_arguments s print_constant fmt pvl
| _ -> | _ ->
fprintf fmt (protect_on paren "%a") fprintf fmt (protect_on paren "%a")
(print_apply info (Hrs.find_def ht_rs rs rs)) pvl end (print_apply info rs) pvl end
| Ematch (e1, [p, e2], []) -> | Ematch (e1, [p, e2], []) ->
fprintf fmt (protect_on paren "let %a =@ %a in@ %a") fprintf fmt (protect_on paren "let %a =@ %a in@ %a")
(print_pat info) p (print_expr info) e1 (print_expr info) e2 (print_pat info) p (print_expr info) e1 (print_expr info) e2
......
...@@ -49,6 +49,8 @@ module Translate = struct ...@@ -49,6 +49,8 @@ module Translate = struct
module ML = Mltree module ML = Mltree
let ht_rs = Hrs.create 7 (* rec_rsym -> rec_sym *)
let debug_compile = let debug_compile =
Debug.register_info_flag ~desc:"Compilation" "compile" Debug.register_info_flag ~desc:"Compilation" "compile"
...@@ -188,6 +190,7 @@ module Translate = struct ...@@ -188,6 +190,7 @@ module Translate = struct
let mk_eta_expansion rs pvl ({cty_args = ca; cty_effect = ce} as c) = let mk_eta_expansion rs pvl ({cty_args = ca; cty_effect = ce} as c) =
(* FIXME : effects and types of the expression in this situation *) (* FIXME : effects and types of the expression in this situation *)
let rs = Hrs.find_def ht_rs rs rs in
let mv = MaskVisible in let mv = MaskVisible in
let args_f = let args_f =
let def pv = let def pv =
...@@ -298,6 +301,7 @@ module Translate = struct ...@@ -298,6 +301,7 @@ module Translate = struct
let cmk = cty.cty_mask in let cmk = cty.cty_mask in
let ceff = cty.cty_effect in let ceff = cty.cty_effect in
let pvl = app pvl rs_app.rs_cty.cty_args (fun x -> x) in let pvl = app pvl rs_app.rs_cty.cty_args (fun x -> x) in
let rs_app = Hrs.find_def ht_rs rs_app rs_app in
let eapp = ML.e_app rs_app pvl (ML.C cty) cmk ceff Sattr.empty in let eapp = ML.e_app rs_app pvl (ML.C cty) cmk ceff Sattr.empty in
let res = mlty_of_ity cty.cty_mask cty.cty_result in let res = mlty_of_ity cty.cty_mask cty.cty_result in
let ld = ML.sym_defn rsf res (params cty.cty_args) eapp in let ld = ML.sym_defn rsf res (params cty.cty_args) eapp in
...@@ -305,9 +309,11 @@ module Translate = struct ...@@ -305,9 +309,11 @@ module Translate = struct
ML.e_let ld ein (ML.I e.e_ity) mask eff attrs ML.e_let ld ein (ML.I e.e_ity) mask eff attrs
| Elet (LDrec rdefl, ein) -> | Elet (LDrec rdefl, ein) ->
let rdefl = filter_out_ghost_rdef rdefl in let rdefl = filter_out_ghost_rdef rdefl in
List.iter
(fun { rec_sym = rs1; rec_rsym = rs2; } ->
Hrs.replace ht_rs rs2 rs1) rdefl;
let def = function let def = function
| { rec_sym = rs1; rec_rsym = rs2; | { rec_sym = rs1; rec_fun = {c_node = Cfun ef; c_cty = cty} } ->
rec_fun = {c_node = Cfun ef; c_cty = cty} } ->
let res = mlty_of_ity rs1.rs_cty.cty_mask rs1.rs_cty.cty_result in let res = mlty_of_ity rs1.rs_cty.cty_mask rs1.rs_cty.cty_result in
let args = params cty.cty_args in let args = params cty.cty_args in
let new_svar = let new_svar =
...@@ -316,9 +322,8 @@ module Translate = struct ...@@ -316,9 +322,8 @@ module Translate = struct
add_tvar svar res in add_tvar svar res in
let new_svar = Stv.diff svar new_svar in let new_svar = Stv.diff svar new_svar in
let ef = expr info (Stv.union svar new_svar) ef.e_mask ef in let ef = expr info (Stv.union svar new_svar) ef.e_mask ef in
{ ML.rec_sym = rs1; ML.rec_rsym = rs2; { ML.rec_sym = rs1; ML.rec_args = args; ML.rec_exp = ef;
ML.rec_args = args; ML.rec_exp = ef; ML.rec_res = res; ML.rec_svar = new_svar; }
ML.rec_res = res; ML.rec_svar = new_svar; }
| _ -> assert false in | _ -> assert false in
let rdefl = List.map def rdefl in let rdefl = List.map def rdefl in
if rdefl <> [] then if rdefl <> [] then
...@@ -343,6 +348,7 @@ module Translate = struct ...@@ -343,6 +348,7 @@ module Translate = struct
Debug.dprintf debug_compile "compiling total application of %s@." Debug.dprintf debug_compile "compiling total application of %s@."
rs.rs_name.id_string; rs.rs_name.id_string;
Debug.dprintf debug_compile "cty_args: %d@." (List.length cty.cty_args); Debug.dprintf debug_compile "cty_args: %d@." (List.length cty.cty_args);
let rs = Hrs.find_def ht_rs rs rs in
let add_unit = function [] -> [ML.e_unit] | args -> args in let add_unit = function [] -> [ML.e_unit] | args -> args in
let id_f = fun x -> x in let id_f = fun x -> x in
let f_zero = match rs.rs_logic with RLnone -> let f_zero = match rs.rs_logic with RLnone ->
...@@ -542,7 +548,9 @@ module Translate = struct ...@@ -542,7 +548,9 @@ module Translate = struct
[ML.Dlet (ML.Lsym (rs, res, args, e))] [ML.Dlet (ML.Lsym (rs, res, args, e))]
| PDlet (LDrec rl) -> | PDlet (LDrec rl) ->
let rl = filter_out_ghost_rdef rl in let rl = filter_out_ghost_rdef rl in
let def {rec_fun = e; rec_sym = rs1; rec_rsym = rs2} = List.iter (fun {rec_sym = rs1; rec_rsym = rs2} ->
Hrs.replace ht_rs rs2 rs1) rl;
let def {rec_fun = e; rec_sym = rs1} =
let e = match e.c_node with Cfun e -> e | _ -> assert false in let e = match e.c_node with Cfun e -> e | _ -> assert false in
let args = params rs1.rs_cty.cty_args in let args = params rs1.rs_cty.cty_args in
let res = mlty_of_ity rs1.rs_cty.cty_mask rs1.rs_cty.cty_result in let res = mlty_of_ity rs1.rs_cty.cty_mask rs1.rs_cty.cty_result in
...@@ -551,9 +559,8 @@ module Translate = struct ...@@ -551,9 +559,8 @@ module Translate = struct
let svar = List.fold_left add_tvar Stv.empty args' in let svar = List.fold_left add_tvar Stv.empty args' in
add_tvar svar res in add_tvar svar res in
let e = expr info svar rs1.rs_cty.cty_mask e in let e = expr info svar rs1.rs_cty.cty_mask e in
{ ML.rec_sym = rs1; ML.rec_rsym = rs2; { ML.rec_sym = rs1; ML.rec_args = args; ML.rec_exp = e;
ML.rec_args = args; ML.rec_exp = e; ML.rec_res = res; ML.rec_svar = svar; } in
ML.rec_res = res; ML.rec_svar = svar; } in
if rl = [] then [] else [ML.Dlet (ML.Lrec (List.map def rl))] if rl = [] then [] else [ML.Dlet (ML.Lrec (List.map def rl))]
| PDlet (LDsym _) | PDpure -> | PDlet (LDsym _) | PDpure ->
[] []
......
...@@ -41,7 +41,6 @@ type info = { ...@@ -41,7 +41,6 @@ type info = {
env : Env.env; env : Env.env;
mm : Pmodule.pmodule Mstr.t; mm : Pmodule.pmodule Mstr.t;
vars: value Mid.t; vars: value Mid.t;
recs: rsymbol Mrs.t;
funs: decl Mrs.t; funs: decl Mrs.t;
get_decl: rsymbol -> Mltree.decl; get_decl: rsymbol -> Mltree.decl;
cur_rs: rsymbol; (* current function *) cur_rs: rsymbol; (* current function *)
...@@ -632,7 +631,6 @@ let rec interp_expr info (e:Mltree.expr) : value = ...@@ -632,7 +631,6 @@ let rec interp_expr info (e:Mltree.expr) : value =
v end in v end in
Debug.dprintf debug_interp "eval call@."; Debug.dprintf debug_interp "eval call@.";
let res = try begin let res = try begin
let rs = if Mrs.mem rs info.recs then Mrs.find rs info.recs else rs in
if Hrs.mem builtin_progs rs if Hrs.mem builtin_progs rs
then then
(Debug.dprintf debug_interp "%a is builtin@." Expr.print_rs rs; (Debug.dprintf debug_interp "%a is builtin@." Expr.print_rs rs;
...@@ -646,8 +644,8 @@ let rec interp_expr info (e:Mltree.expr) : value = ...@@ -646,8 +644,8 @@ let rec interp_expr info (e:Mltree.expr) : value =
| Dlet (Lsym (rs, _ty, vl, e)) -> | Dlet (Lsym (rs, _ty, vl, e)) ->
eval_call info vl e rs eval_call info vl e rs
| Dlet(Lrec([{rec_args = vl; rec_exp = e; | Dlet(Lrec([{rec_args = vl; rec_exp = e;
rec_sym = rs; rec_rsym = rrs; rec_res=_ty}])) -> rec_sym = rs; rec_res=_ty}])) ->
eval_call { info with recs = Mrs.add rrs rs info.recs } vl e rs eval_call info vl e rs
| Dlet (Lrec _) -> | Dlet (Lrec _) ->
Debug.dprintf Debug.dprintf
debug_interp "unhandled mutually recursive functions@."; debug_interp "unhandled mutually recursive functions@.";
...@@ -863,7 +861,6 @@ let init_info env mm rs vars = ...@@ -863,7 +861,6 @@ let init_info env mm rs vars =
{ env = env; { env = env;
mm = mm; mm = mm;
funs = Mrs.empty; funs = Mrs.empty;
recs = Mrs.empty;
vars = vars; vars = vars;
get_decl = get_decl env mm; get_decl = get_decl env mm;
cur_rs = rs; cur_rs = rs;
......
...@@ -79,7 +79,7 @@ and let_def = ...@@ -79,7 +79,7 @@ and let_def =
and rdef = { and rdef = {
rec_sym : rsymbol; (* exported *) rec_sym : rsymbol; (* exported *)
rec_rsym : rsymbol; (* internal *) (* rec_rsym : rsymbol;*) (* internal *)
rec_args : var list; rec_args : var list;
rec_exp : expr; rec_exp : expr;
rec_res : ty; rec_res : ty;
......
...@@ -264,8 +264,6 @@ module Print = struct ...@@ -264,8 +264,6 @@ module Print = struct
let pv_name pv = pv.pv_vs.vs_name let pv_name pv = pv.pv_vs.vs_name
let print_pv info fmt pv = print_lident info fmt (pv_name pv) let print_pv info fmt pv = print_lident info fmt (pv_name pv)
let ht_rs = Hrs.create 7 (* rec_rsym -> rec_sym *)
(* FIXME put these in Compile*) (* FIXME put these in Compile*)
let is_true e = match e.e_node with let is_true e = match e.e_node with
| Eapp (s, []) -> rs_equal s rs_true | Eapp (s, []) -> rs_equal s rs_true
...@@ -411,9 +409,7 @@ module Print = struct ...@@ -411,9 +409,7 @@ module Print = struct
(print_fun_type_args info) (args, s, res, e); (print_fun_type_args info) (args, s, res, e);
forget_vars args forget_vars args
in in
List.iter (fun fd -> Hrs.replace ht_rs fd.rec_rsym fd.rec_sym) rdef;
print_list_next newline print_one fmt rdef; print_list_next newline print_one fmt rdef;
List.iter (fun fd -> Hrs.remove ht_rs fd.rec_rsym) rdef
| Lany (rs, res, []) when functor_arg -> | Lany (rs, res, []) when functor_arg ->
fprintf fmt "@[<hov 2>val %a : %a@]" fprintf fmt "@[<hov 2>val %a : %a@]"
(print_lident info) rs.rs_name (print_lident info) rs.rs_name
...@@ -456,14 +452,14 @@ module Print = struct ...@@ -456,14 +452,14 @@ module Print = struct
| Eapp (rs, []) when rs_equal rs rs_false -> | Eapp (rs, []) when rs_equal rs rs_false ->
fprintf fmt "false" fprintf fmt "false"
| Eapp (rs, []) -> (* avoids parenthesis around values *) | Eapp (rs, []) -> (* avoids parenthesis around values *)
fprintf fmt "%a" (print_apply info (Hrs.find_def ht_rs rs rs)) [] fprintf fmt "%a" (print_apply info rs) []
| Eapp (rs, pvl) -> | Eapp (rs, pvl) ->
begin match query_syntax info.info_convert rs.rs_name, pvl with begin match query_syntax info.info_convert rs.rs_name, pvl with
| Some s, [{e_node = Econst _}] -> | Some s, [{e_node = Econst _}] ->
syntax_arguments s print_constant fmt pvl syntax_arguments s print_constant fmt pvl
| _ -> | _ ->
fprintf fmt (protect_on paren "%a") fprintf fmt (protect_on paren "%a")
(print_apply info (Hrs.find_def ht_rs rs rs)) pvl end (print_apply info rs) pvl end
| Ematch (e1, [p, e2], []) -> | Ematch (e1, [p, e2], []) ->
fprintf fmt (protect_on paren "let %a =@ %a in@ %a") fprintf fmt (protect_on paren "let %a =@ %a in@ %a")
(print_pat info) p (print_expr info) e1 (print_expr info) e2 (print_pat info) p (print_expr info) e1 (print_expr info) e2
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment