Commit 80e80b7d authored by Andrei Paskevich's avatar Andrei Paskevich

Expr: termination check for let-functions/predicates/lemmas

use variant inference from Decl to prove termination
for pure recursive functions without variant.
parent b8c5222c
...@@ -106,69 +106,94 @@ type descent = ...@@ -106,69 +106,94 @@ type descent =
| Equal of int | Equal of int
| Unknown | Unknown
type call_set = (ident * descent array) Hid.t
type vs_graph = descent Mvs.t list
let create_call_set () = Hid.create 5
let create_vs_graph vl =
let i = ref (-1) in
let add vm v = incr i; Mvs.add v (Equal !i) vm in
[List.fold_left add Mvs.empty vl]
(* TODO: can we handle projections somehow? *)
let register_call cgr caller vsg callee tl =
let call vm =
let describe t = match t.t_node with
| Tvar v -> Mvs.find_def Unknown v vm
| _ -> Unknown in
let dl = List.map describe tl in
Hid.add cgr callee (caller, Array.of_list dl) in
List.iter call vsg
let vs_graph_drop vsg u = List.rev_map (Mvs.remove u) vsg
(* TODO: can we handle projections somehow? *)
let vs_graph_let vsg t u = match t.t_node with
| Tvar v ->
let add vm = try Mvs.add u (Mvs.find v vm) vm
with Not_found -> Mvs.remove u vm in
List.rev_map add vsg
| _ ->
vs_graph_drop vsg u
let rec match_var link acc p = match p.pat_node with let rec match_var link acc p = match p.pat_node with
| Pwild -> acc | Pwild -> acc
| Pvar u -> List.rev_map (Mvs.add u link) acc | Pvar u -> List.rev_map (Mvs.add u link) acc
| Pas (p,u) -> List.rev_map (Mvs.add u link) (match_var link acc p) | Pas (p,u) -> List.rev_map (Mvs.add u link) (match_var link acc p)
| Por (p1,p2) -> | Por (p1,p2) ->
let acc1 = match_var link acc p1 in List.rev_append (match_var link acc p1) (match_var link acc p2)
let acc2 = match_var link acc p2 in
List.rev_append acc1 acc2
| Papp _ -> | Papp _ ->
let link = match link with let link = match link with
| Unknown -> Unknown | Unknown -> Unknown
| Equal i -> Less i | Equal i -> Less i
| Less i -> Less i | Less i -> Less i in
in
let join u = Mvs.add u link in let join u = Mvs.add u link in
List.rev_map (Svs.fold join p.pat_vars) acc List.rev_map (Svs.fold join p.pat_vars) acc
let rec match_term vm t acc p = match t.t_node, p.pat_node with let rec match_term vm t acc p = match t.t_node, p.pat_node with
| _, Pwild -> acc | _, Pwild -> acc
| Tvar v, _ when not (Mvs.mem v vm) -> acc | Tvar v, _ when Mvs.mem v vm ->
| Tvar v, _ -> match_var (Mvs.find v vm) acc p match_var (Mvs.find v vm) acc p
| Tapp _, Pvar _ -> acc | Tapp _, Pvar u ->
| Tapp _, Pas (p,_) -> match_term vm t acc p vs_graph_drop acc u
| Tapp _, Pas (p,u) ->
match_term vm t (vs_graph_drop acc u) p
| Tapp _, Por (p1,p2) -> | Tapp _, Por (p1,p2) ->
let acc1 = match_term vm t acc p1 in List.rev_append (match_term vm t acc p1) (match_term vm t acc p2)
let acc2 = match_term vm t acc p2 in
List.rev_append acc1 acc2
| Tapp (c1,tl), Papp (c2,pl) when ls_equal c1 c2 -> | Tapp (c1,tl), Papp (c2,pl) when ls_equal c1 c2 ->
let down l t p = match_term vm t l p in let down l t p = match_term vm t l p in
List.fold_left2 down acc tl pl List.fold_left2 down acc tl pl
| _,_ -> acc | _,_ ->
List.rev_map (fun vm -> Mvs.set_diff vm p.pat_vars) acc
let build_call_graph cgr syms ls = let vs_graph_pat vsg t p =
let call vm s tl = let add acc vm = List.rev_append (match_term vm t [vm] p) acc in
let desc t = match t.t_node with List.fold_left add [] vsg
| Tvar v -> Mvs.find_def Unknown v vm
| _ -> Unknown let build_call_graph cgr syms ls (vl,e) =
in
Hls.add cgr s (ls, Array.of_list (List.map desc tl))
in
let rec term vm () t = match t.t_node with let rec term vm () t = match t.t_node with
| Tapp (s,tl) when Mls.mem s syms -> | Tapp (s,tl) when Mls.mem s syms ->
t_fold (term vm) () t; call vm s tl t_fold (term vm) () t;
| Tlet ({t_node = Tvar v}, b) when Mvs.mem v vm -> register_call cgr ls.ls_name vm s.ls_name tl
| Tlet (t,b) ->
term vm () t;
let u,e = t_open_bound b in let u,e = t_open_bound b in
term (Mvs.add u (Mvs.find v vm) vm) () e term (vs_graph_let vm t u) () e
| Tcase (e,bl) -> | Tcase (t,bl) ->
term vm () e; List.iter (fun b -> term vm () t;
let p,t = t_open_branch b in List.iter (fun b ->
let vml = match_term vm e [vm] p in let p,e = t_open_branch b in
List.iter (fun vm -> term vm () t) vml) bl term (vs_graph_pat vm t p) () e) bl
| Tquant (_,b) -> | Tquant (_,b) -> (* ignore triggers *)
let _,_,f = t_open_quant b in term vm () f let _,_,f = t_open_quant b in term vm () f
| _ -> t_fold (term vm) () t | _ ->
t_fold (term vm) () t
in in
fun (vl,e) -> term (create_vs_graph vl) () e
let i = ref (-1) in
let add vm v = incr i; Mvs.add v (Equal !i) vm in let build_call_list cgr id =
let vm = List.fold_left add Mvs.empty vl in let htb = Hid.create 5 in
term vm () e
let build_call_list cgr ls =
let htb = Hls.create 5 in
let local v = Array.mapi (fun i -> function let local v = Array.mapi (fun i -> function
| (Less j) as d when i = j -> d | (Less j) as d when i = j -> d
| (Equal j) as d when i = j -> d | (Equal j) as d when i = j -> d
...@@ -186,7 +211,7 @@ let build_call_list cgr ls = ...@@ -186,7 +211,7 @@ let build_call_list cgr ls =
try Array.iteri test v1; true with Not_found -> false try Array.iteri test v1; true with Not_found -> false
in in
let subsumed s c = let subsumed s c =
List.exists (subsumes c) (Hls.find_all htb s) List.exists (subsumes c) (Hid.find_all htb s)
in in
let multiply v1 v2 = let multiply v1 v2 =
let to_less = function let to_less = function
...@@ -200,21 +225,20 @@ let build_call_list cgr ls = ...@@ -200,21 +225,20 @@ let build_call_list cgr ls =
| Less i -> to_less (Array.get v2 i)) v1 | Less i -> to_less (Array.get v2 i)) v1
in in
let resolve s c = let resolve s c =
Hls.add htb s c; Hid.add htb s c;
let mult (s,v) = (s, multiply c v) in let mult (s,v) = (s, multiply c v) in
List.rev_map mult (Hls.find_all cgr s) List.rev_map mult (Hid.find_all cgr s)
in in
let rec add_call lc = function let rec add_call lc = function
| [] -> lc | [] -> lc
| (s,c)::r when ls_equal ls s -> add_call (local c :: lc) r | (s,c)::r when id_equal id s -> add_call (local c :: lc) r
| (s,c)::r when subsumed s c -> add_call lc r | (s,c)::r when subsumed s c -> add_call lc r
| (s,c)::r -> add_call lc (List.rev_append (resolve s c) r) | (s,c)::r -> add_call lc (List.rev_append (resolve s c) r)
in in
add_call [] (Hls.find_all cgr ls) add_call [] (Hid.find_all cgr id)
exception NoTerminationProof of lsymbol
let check_call_list ls cl = let find_variant exn cgr id =
let cl = build_call_list cgr id in
let add d1 d2 = match d1, d2 with let add d1 d2 = match d1, d2 with
| Unknown, _ -> d1 | Unknown, _ -> d1
| _, Unknown -> d2 | _, Unknown -> d2
...@@ -234,7 +258,7 @@ let check_call_list ls cl = ...@@ -234,7 +258,7 @@ let check_call_list ls cl =
let find l = function Less i -> i :: l | _ -> l in let find l = function Less i -> i :: l | _ -> l in
let res = Array.fold_left find [] p in let res = Array.fold_left find [] p in
(* eliminate the decreasing calls *) (* eliminate the decreasing calls *)
if res = [] then raise (NoTerminationProof ls); if res = [] then raise exn;
let test a = let test a =
List.for_all (fun i -> Array.get a i <> Less i) res List.for_all (fun i -> Array.get a i <> Less i) res
in in
...@@ -242,15 +266,15 @@ let check_call_list ls cl = ...@@ -242,15 +266,15 @@ let check_call_list ls cl =
in in
check [] cl check [] cl
exception NoTerminationProof of lsymbol
let check_termination ldl = let check_termination ldl =
let cgr = Hls.create 5 in let cgr = create_call_set () in
let add acc (ls,ld) = Mls.add ls (open_ls_defn ld) acc in let add acc (ls,ld) = Mls.add ls (open_ls_defn ld) acc in
let syms = List.fold_left add Mls.empty ldl in let syms = List.fold_left add Mls.empty ldl in
Mls.iter (build_call_graph cgr syms) syms; Mls.iter (build_call_graph cgr syms) syms;
let check ls _ = let check ls _ =
let cl = build_call_list cgr ls in find_variant (NoTerminationProof ls) cgr ls.ls_name in
check_call_list ls cl
in
let res = Mls.mapi check syms in let res = Mls.mapi check syms in
List.map (fun (ls,(_,f,_)) -> (ls,(ls,f,Mls.find ls res))) ldl List.map (fun (ls,(_,f,_)) -> (ls,(ls,f,Mls.find ls res))) ldl
......
...@@ -51,6 +51,20 @@ val ls_defn_decrease : ls_defn -> int list ...@@ -51,6 +51,20 @@ val ls_defn_decrease : ls_defn -> int list
from a declaration; on the result of [make_ls_defn], from a declaration; on the result of [make_ls_defn],
[ls_defn_decrease] will always return an empty list. *) [ls_defn_decrease] will always return an empty list. *)
(** {2 Structural descent checking} *)
type call_set
type vs_graph
val create_call_set : unit -> call_set
val create_vs_graph : vsymbol list -> vs_graph
val register_call : call_set -> ident ->
vs_graph -> ident -> term list -> unit
val vs_graph_drop : vs_graph -> vsymbol -> vs_graph
val vs_graph_let : vs_graph -> term -> vsymbol -> vs_graph
val vs_graph_pat : vs_graph -> term -> pattern -> vs_graph
val find_variant : exn -> call_set -> ident -> int list
(** {2 Proposition names} *) (** {2 Proposition names} *)
type prsymbol = private { type prsymbol = private {
......
...@@ -807,6 +807,6 @@ let () = Exn_printer.register ...@@ -807,6 +807,6 @@ let () = Exn_printer.register
fprintf fmt "Ident %a is already declared, with a different declaration" fprintf fmt "Ident %a is already declared, with a different declaration"
Ident.print_decoded s Ident.print_decoded s
| Decl.NoTerminationProof ls -> | Decl.NoTerminationProof ls ->
fprintf fmt "Cannot prove the termination of %a" print_ls ls fprintf fmt "Cannot prove termination for %a" print_ls ls
| _ -> raise exn | _ -> raise exn
end end
...@@ -1023,6 +1023,72 @@ let e_assert ak f = ...@@ -1023,6 +1023,72 @@ let e_assert ak f =
let e_absurd ity = mk_expr Eabsurd ity MaskVisible eff_empty let e_absurd ity = mk_expr Eabsurd ity MaskVisible eff_empty
(* structural descent *)
exception NoVariantFound of rsymbol
let term_check rdl =
let cgr = Decl.create_call_set () in
let add acc rd = Mrs.add rd.rec_rsym rd acc in
let syms = List.fold_left add Mrs.empty rdl in
let term_of e =
try pure_of_expr false e with Exit -> t_void in
let rec expr rs vm e = match e.e_node with
| Evar _ | Econst _ | Epure _ | Eassert _ -> ()
| Eghost e | Eexn (_,e) -> expr rs vm e
| Eexec (c, _) ->
cexp rs vm c
| Elet (LDvar (v,d),e) ->
expr rs vm d;
let t = term_of d in
expr rs (Decl.vs_graph_let vm t v.pv_vs) e
| Elet (LDsym (_, c),e) ->
cexp rs vm c; expr rs vm e
| Elet (LDrec rdl,e) ->
List.iter (fun rd -> cexp rs vm rd.rec_fun) rdl;
expr rs vm e;
| Eif (e0,e1,e2) ->
expr rs vm e0; expr rs vm e1; expr rs vm e2
| Ematch (d,bl,xl) when Mxs.is_empty xl ->
expr rs vm d;
let t = term_of d in
let check (p,e) =
expr rs (Decl.vs_graph_pat vm t p.pp_pat) e in
List.iter check bl
| Eassign _ | Ewhile _ | Efor _
| Ematch _ | Eraise _ | Eabsurd ->
raise Exit
and cexp rs vm c = match c.c_node with
| Cfun d ->
if not (eff_pure d.e_effect) then raise Exit;
let drop vm v = Decl.vs_graph_drop vm v.pv_vs in
expr rs (List.fold_left drop vm c.c_cty.cty_args) d
| Capp (s,vl) when Mrs.mem s syms ->
if c.c_cty.cty_args <> [] then raise Exit;
let tl = List.map (fun v -> t_var v.pv_vs) vl in
Decl.register_call cgr rs.rs_name vm s.rs_name tl
| Cany | Cpur _ | Capp _ -> ()
in
let build_call_graph rs rd =
let vl = List.map (fun v -> v.pv_vs) rs.rs_cty.cty_args in
let e = match rd.rec_fun.c_node with
Cfun e -> e | _ -> assert false in
if not (eff_pure e.e_effect) then
raise (NoVariantFound rs);
try expr rs (Decl.create_vs_graph vl) e
with Exit -> raise (NoVariantFound rs)
in
Mrs.iter build_call_graph syms;
let check rs _ =
Decl.find_variant (NoVariantFound rs) cgr rs.rs_name in
ignore (Mrs.mapi check syms)
let term_check rdl = match rdl with
| {rec_varl = []}::_
when List.for_all (fun rd -> rd.rec_sym.rs_logic <> RLnone) rdl ->
term_check rdl
| _ -> ()
(* recursive definitions *) (* recursive definitions *)
let cty_add_variant d varl = let add s (t,_) = t_freepvs s t in let cty_add_variant d varl = let add s (t,_) = t_freepvs s t in
...@@ -1104,13 +1170,9 @@ let let_rec fdl = ...@@ -1104,13 +1170,9 @@ let let_rec fdl =
"All functions in a recursive definition must use the same \ "All functions in a recursive definition must use the same \
well-founded order for the first component of the variant" in well-founded order for the first component of the variant" in
List.iter check_variant (List.tl fdl); List.iter check_variant (List.tl fdl);
(* if we have a top-level total let-function definition and let start_eff = (* pure functions cannot diverge *)
no variants are supplied, then we expect the definition if varl1 = [] && List.exists (fun (_,_,_,k) -> k = RKnone) fdl
to be terminating with respect to Decl.check_termination *) then eff_diverge eff_empty else eff_empty in
let impure (_,d,_,k) =
(k <> RKfunc && k <> RKpred) || d.c_cty.cty_pre <> [] in
let start_eff = if varl1 = [] && List.exists impure fdl then
eff_diverge eff_empty else eff_empty in
(* create the first substitution *) (* create the first substitution *)
let update sm (s,({c_cty = c} as d),varl,_) = let update sm (s,({c_cty = c} as d),varl,_) =
(* check that the type signatures are consistent *) (* check that the type signatures are consistent *)
...@@ -1145,6 +1207,8 @@ let let_rec fdl = ...@@ -1145,6 +1207,8 @@ let let_rec fdl =
let s = create_rsymbol id ~kind ~ghost:(rs_ghost rs) c in let s = create_rsymbol id ~kind ~ghost:(rs_ghost rs) c in
{ rec_sym = s; rec_rsym = rs; rec_fun = d; rec_varl = varl } in { rec_sym = s; rec_rsym = rs; rec_fun = d; rec_varl = varl } in
let rdl = List.map2 merge fdl (rec_fixp (List.map conv dl)) in let rdl = List.map2 merge fdl (rec_fixp (List.map conv dl)) in
(* try to infer the missing variant if termination is assumed *)
term_check rdl;
LDrec rdl, rdl LDrec rdl, rdl
let ls_decr_of_rec_defn = function let ls_decr_of_rec_defn = function
...@@ -1446,4 +1510,6 @@ let () = Exn_printer.register (fun fmt e -> match e with ...@@ -1446,4 +1510,6 @@ let () = Exn_printer.register (fun fmt e -> match e with
"Function %a is not a mutable field" print_rs s "Function %a is not a mutable field" print_rs s
| ExceptionLeak xs -> fprintf fmt | ExceptionLeak xs -> fprintf fmt
"Uncaught local exception %a" print_xs xs "Uncaught local exception %a" print_xs xs
| NoVariantFound rs -> fprintf fmt
"Cannot prove termination for %a" print_rs rs
| _ -> raise e) | _ -> raise e)
...@@ -537,16 +537,9 @@ let create_let_decl ld = ...@@ -537,16 +537,9 @@ let create_let_decl ld =
List.fold_right add_rd rdl ([],[],[]) List.fold_right add_rd rdl ([],[],[])
| LDsym (s,c) -> | LDsym (s,c) ->
add_rs Mrs.empty s c ([],[],[]) in add_rs Mrs.empty s c ([],[],[]) in
let fail_trusted_rec ls =
Loc.error ?loc:ls.ls_name.id_loc (Decl.NoTerminationProof ls) in
let is_trusted_rec = match ld with
| LDrec ({rec_sym = {rs_logic = RLls ls; rs_cty = c}; rec_varl = []}::_)
when total c.cty_effect.eff_oneway -> abst = [] || fail_trusted_rec ls
| _ -> false in
let defn = if defn = [] then [] else let defn = if defn = [] then [] else
let dl = List.map (fun (s,vl,t) -> make_ls_defn s vl t) defn in let dl = List.map (fun (s,vl,t) -> make_ls_defn s vl t) defn in
try [create_logic_decl dl] with Decl.NoTerminationProof ls -> try [create_logic_decl dl] with Decl.NoTerminationProof _ ->
if is_trusted_rec then fail_trusted_rec ls;
let abst = List.map (fun (s,_) -> create_param_decl s) dl in let abst = List.map (fun (s,_) -> create_param_decl s) dl in
let mk_ax ({ls_name = id} as s, vl, t) = let mk_ax ({ls_name = id} as s, vl, t) =
let nm = id.id_string ^ "_def" in let nm = id.id_string ^ "_def" in
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment