Commit b48d1573 authored by Andrei Paskevich's avatar Andrei Paskevich

optimise case statement when the matched term is a constructor

parent 602cd848
......@@ -22,7 +22,7 @@ theory Test1
logic zip (l1 : list 'a) (l2 : list 'b) : list ('a,'b) =
match l1, l2 with
| Cons x1 r1, Cons x2 r2 -> Cons (x1,x2) (zip r1 r2)
| (Cons _ _|Nil), _ -> Nil (* to make it total *)
| _, _ -> Nil (* to make it total *)
end
logic foo (l1 : list 'a) (l2 : list 'b) =
......
......@@ -86,6 +86,28 @@ module Compile (X : Action) = struct
in
let cases,wilds = List.fold_right dispatch rl (Mls.empty,[]) in
(* assemble the primitive case statement *)
let pat_cont cs vl pl =
let rec cont acc vl pl = match vl,pl with
| (_::vl), (p::pl) -> cont (p::acc) vl pl
| [], pl -> pat_app cs acc ty :: pl
| _, _ -> assert false
in
cont [] vl pl
in
match t.t_node with
| Tapp (cs,al) when Sls.mem cs css ->
if Mls.mem cs cases then
let tl = List.rev_append al tl in
try compile constructors tl (Mls.find cs cases)
with NonExhaustive pl -> raise (NonExhaustive (pat_cont cs al pl))
else begin
try compile constructors tl wilds
with NonExhaustive pl ->
let al = List.map pat_wild cs.ls_args in
let pat = pat_app cs al (of_option cs.ls_value) in
raise (NonExhaustive (pat::pl))
end
| _ -> begin
let pw = pat_wild ty in
let nopat =
if Sls.is_empty css then Some pw else
......@@ -107,17 +129,13 @@ module Compile (X : Action) = struct
let vl = List.map (fun q -> create_vsymbol id q.pat_ty) ql in
let tl = List.fold_left (fun tl v -> t_var v :: tl) tl vl in
let pat = pat_app fs (List.map pat_var vl) ty in
let rec pat_cont acc vl pl = match vl,pl with
| (_::vl), (p::pl) -> pat_cont (p::acc) vl pl
| [], pl -> pat_app fs acc ty :: pl
| _, _ -> assert false
in
try (pat, compile constructors tl (Mls.find fs cases)) :: acc
with NonExhaustive pl -> raise (NonExhaustive (pat_cont [] vl pl))
with NonExhaustive pl -> raise (NonExhaustive (pat_cont fs vl pl))
in
match Mls.fold add types base with
| [{ pat_node = Pwild }, a] -> a
| bl -> mk_case t bl
end
end
......
......@@ -152,11 +152,12 @@ let pat_var v = mk_pattern (Pvar v) (Svs.singleton v) v.vs_ty
let pat_as p v = mk_pattern (Pas (p,v)) (add_no_dup v p.pat_vars) v.vs_ty
let pat_or p q =
(if not (Svs.equal p.pat_vars q.pat_vars) then
if Svs.equal p.pat_vars q.pat_vars then
mk_pattern (Por (p,q)) p.pat_vars p.pat_ty
else
let s1, s2 = p.pat_vars, q.pat_vars in
Svs.iter (fun vs -> raise (UncoveredVar vs))
(Svs.union (Svs.diff s1 s2) (Svs.diff s2 s1)));
mk_pattern (Por (p,q)) p.pat_vars p.pat_ty
let vs = Svs.choose (Svs.union (Svs.diff s1 s2) (Svs.diff s2 s1)) in
raise (UncoveredVar vs)
let pat_app f pl ty =
let merge s p = Svs.fold add_no_dup s p.pat_vars in
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment