Commit 132bd08d authored by Andrei Paskevich's avatar Andrei Paskevich

rework eval_match

- in particular, we simplify function definitions before testing
them for linearity (this allows to inline array updates which are
defined via a record update {| a with map = a.map[i <- v] |}).
We really need to add some memoization here, otherwise we might
pay too much.
parent a6a418c3
...@@ -1034,7 +1034,7 @@ expr: ...@@ -1034,7 +1034,7 @@ expr:
| Eident x -> | Eident x ->
mk_expr (Eassign (e12, x, $3)) mk_expr (Eassign (e12, x, $3))
| Eapply ({ expr_desc = Eident (Qident x) }, e11) | Eapply ({ expr_desc = Eident (Qident x) }, e11)
when x.id = mixfix "[]" -> when x.id = mixfix "[]" ->
mk_mixfix3 "[]<-" e11 e12 $3 mk_mixfix3 "[]<-" e11 e12 $3
| _ -> | _ ->
raise Parsing.Parse_error raise Parsing.Parse_error
......
...@@ -37,23 +37,16 @@ let is_constructor kn ls = match Mid.find ls.ls_name kn with ...@@ -37,23 +37,16 @@ let is_constructor kn ls = match Mid.find ls.ls_name kn with
| { d_node = Dtype _ } -> true | { d_node = Dtype _ } -> true
| _ -> false | _ -> false
(* checks that all branches ``start'' with constructors *) let rec dive env t = match t.t_node with
let rec update kn t = match t.t_node with | Tvar x ->
| Tapp (ls, _) -> is_constructor kn ls (try dive env (Mvs.find x env) with Not_found -> t)
| Tlet (_, t) -> let _, t = t_open_bound t in update kn t
| _ -> false
let rec dive fn env t = match t.t_node with
| Tvar x when Mvs.mem x env ->
dive fn env (Mvs.find x env)
| Tlet (t1, tb) -> | Tlet (t1, tb) ->
let x, t2, close = t_open_bound_cb tb in let x, t2 = t_open_bound tb in
let t2 = dive fn (Mvs.add x t1 env) t2 in dive (Mvs.add x t1 env) t2
t_label_copy t (t_let_simp t1 (close x t2)) | _ -> t_map (dive env) t
| _ -> fn t
let make_case kn fn t bl = let make_flat_case kn t bl =
let mk_b b = let p,t = t_open_branch b in [p], fn t in let mk_b b = let p,t = t_open_branch b in [p],t in
Pattern.CompileTerm.compile (find_constructors kn) [t] (List.map mk_b bl) Pattern.CompileTerm.compile (find_constructors kn) [t] (List.map mk_b bl)
let rec add_quant kn (vl,tl,f) v = let rec add_quant kn (vl,tl,f) v =
...@@ -88,21 +81,31 @@ let eval_match ~inline kn t = ...@@ -88,21 +81,31 @@ let eval_match ~inline kn t =
let x, t2, close = t_open_bound_cb tb2 in let x, t2, close = t_open_bound_cb tb2 in
let t2 = eval (Mvs.add x t1 env) t2 in let t2 = eval (Mvs.add x t1 env) t2 in
t_label_copy t (t_let_simp t1 (close x t2)) t_label_copy t (t_let_simp t1 (close x t2))
| Tcase (t1, bl) -> | Tcase (t1, bl1) ->
let t1 = eval env t1 in let t1 = eval env t1 in
let process t1 = let t1flat = dive env t1 in
let r = make_case kn (eval env) t1 bl in let r = try match t1flat.t_node with
match r.t_node with | Tapp (ls,_) when is_constructor kn ls ->
| Tcase ({ t_node = Tcase (t1, bl1) }, bl2) -> eval env (make_flat_case kn t1flat bl1)
let branch b = | Tcase (t2, bl2) ->
let p,t,close = t_open_branch_cb b in let mk_b b =
if not (update kn t) then raise Exit; let p,t = t_open_branch b in
close p (make_case kn (fun x -> x) t bl2) match t.t_node with
in | Tapp (ls,_) when is_constructor kn ls ->
(try t_case t1 (List.map branch bl1) with Exit -> r) t_close_branch p (eval env (make_flat_case kn t bl1))
| _ -> r | _ -> raise Exit
in
t_case t2 (List.map mk_b bl2)
| _ -> raise Exit
with
| Exit ->
let mk_b b =
let p,t,close = t_open_branch_cb b in
close p (eval env t)
in
t_case t1 (List.map mk_b bl1)
in in
t_label_copy t (dive process env t1) t_label_copy t r
| Tquant (q, qf) -> | Tquant (q, qf) ->
let vl,tl,f,close = t_open_quant_cb qf in let vl,tl,f,close = t_open_quant_cb qf in
let vl,tl,f = List.fold_left (add_quant kn) ([],tl,f) vl in let vl,tl,f = List.fold_left (add_quant kn) ([],tl,f) vl in
...@@ -127,7 +130,7 @@ let is_algebraic_type kn ty = match ty.ty_node with ...@@ -127,7 +130,7 @@ let is_algebraic_type kn ty = match ty.ty_node with
| Tyapp (ts, _) -> find_constructors kn ts <> [] | Tyapp (ts, _) -> find_constructors kn ts <> []
| Tyvar _ -> false | Tyvar _ -> false
let inline_nonrec_linear kn ls tyl ty = let rec inline_nonrec_linear kn ls tyl ty =
let d = Mid.find ls.ls_name kn in let d = Mid.find ls.ls_name kn in
(* at least one actual parameter (or the result) has an algebraic type *) (* at least one actual parameter (or the result) has an algebraic type *)
List.exists (is_algebraic_type kn) (oty_cons tyl ty) && List.exists (is_algebraic_type kn) (oty_cons tyl ty) &&
...@@ -139,8 +142,9 @@ let inline_nonrec_linear kn ls tyl ty = ...@@ -139,8 +142,9 @@ let inline_nonrec_linear kn ls tyl ty =
true true
| Some def -> | Some def ->
let _, t = open_ls_defn def in let _, t = open_ls_defn def in
let eval = eval_match ~inline:inline_nonrec_linear in
not (t_s_any Util.ffalse (ls_equal ls) t) && not (t_s_any Util.ffalse (ls_equal ls) t) &&
(not (ls_equal ls ls') || linear t) (not (ls_equal ls ls') || linear (eval kn t))
in in
List.for_all no_occ dl List.for_all no_occ dl
| _ -> | _ ->
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment