From cd3db7eb2ea83502edc93f46c8db8ac4069e975d Mon Sep 17 00:00:00 2001
From: Jean-Christophe Filliatre <Jean-Christophe.Filliatre@lri.fr>
Date: Tue, 14 Feb 2012 11:18:52 +0100
Subject: [PATCH] PVS printer: inductive predicates, mutually recursive
 functions

---
 drivers/pvs-common.gen                |  6 ++-
 src/printer/pvs.ml                    | 73 ++++++++++++++++-----------
 src/transform/eliminate_inductive.mli |  2 +
 tests/test-jcf.why                    | 14 ++---
 4 files changed, 54 insertions(+), 41 deletions(-)

diff --git a/drivers/pvs-common.gen b/drivers/pvs-common.gen
index 36f5a2bbe8..ba7a78b344 100644
--- a/drivers/pvs-common.gen
+++ b/drivers/pvs-common.gen
@@ -5,11 +5,15 @@ fail "Syntax error: \\(.*\\)$" "\\1"
 time "why3cpulimit time : %s s"
 
 transformation "inline_trivial"
-
 transformation "eliminate_builtin"
+
+(* PVS does not support mutual recursion *)
 transformation "eliminate_mutual_recursion"
+transformation "simplify_recursive_definition"
+(* though we could do better, we only use recursion on one argument *)
 transformation "eliminate_non_struct_recursion"
 
+(* PVS only has simple patterns *)
 transformation "compile_match"
 
 transformation "simplify_formula"
diff --git a/src/printer/pvs.ml b/src/printer/pvs.ml
index bbea651382..db39e1e0a7 100644
--- a/src/printer/pvs.ml
+++ b/src/printer/pvs.ml
@@ -44,18 +44,16 @@ QUESTIONS FOR THE PVS TEAM
 
   * pattern-matching
 
-    - is there a catch-all pattern in PVS's CASES construct?
+    - By mistake, I used _ as a catch-all in a CASES expressions
+      (as in ML and in Why3) and it made PVS go wild!
 
-      Note: I tried to use _ (as in ML and in Why3) and it made PVS go wild!
-
-  * I intend to use the script "proveit" to replay PVS proofs (when they exist)
+  * I intend to use the script "proveit" to replay PVS proofs (when they exist).
     What is the canonical way to check that all proofs have indeed been
-    successfully replayed?
+    successfully replayed? (exit code, grep on proveit's output, etc.)
 
 TODO
 ----
-  * eliminate mutual recursion in PVS driver
-
+  * use PVS maps
 
 *)
 
@@ -232,19 +230,21 @@ let print_or_list fmt x = print_list (fun fmt () -> fprintf fmt " OR@\n") fmt x
 let comma_newline fmt () = fprintf fmt ",@\n"
 
 let rec print_pat info fmt p = match p.pat_node with
-  | Pwild -> fprintf fmt "_"
-  | Pvar v -> print_vs fmt v
-  | Pas _ | Por _ ->
-      assert false (* compile_match must have taken care of that *)
-  | Papp (cs,pl) when is_fs_tuple cs ->
-      fprintf fmt "%a" (print_paren_r (print_pat info)) pl
-  | Papp (cs,pl) ->
+  | Pvar v ->
+      print_vs fmt v
+  | Papp (cs, _) when is_fs_tuple cs ->
+      assert false (* is handled earlier in print_term/fmla *)
+  | Papp (cs, pl) ->
       begin match query_syntax info.info_syn cs.ls_name with
         | Some s -> syntax_arguments s (print_pat info) fmt pl
         | _ when pl = [] -> (print_ls_real info) fmt cs
         | _ -> fprintf fmt "%a(%a)"
           (print_ls_real info) cs (print_list comma (print_pat info)) pl
       end
+  | Pas _ | Por _ ->
+      assert false (* compile_match must have taken care of that *)
+  | Pwild ->
+      assert false (* is handled in print_branches *)
 
 let print_vsty_nopar info fmt v =
   fprintf fmt "%a:%a" print_vs v (print_ty info) v.vs_ty
@@ -306,7 +306,7 @@ and print_tnode opl opr info fmt t = match t.t_node with
       in
       Print_number.print number_format fmt c
   | Tif (f, t1, t2) ->
-    fprintf fmt "IF %a@ THEN %a@ ELSE %a ENDIF"
+      fprintf fmt "IF %a@ THEN %a@ ELSE %a ENDIF"
         (print_fmla info) f (print_term info) t1 (print_opl_term info) t2
   | Tlet (t1, tb) ->
       let v,t2 = t_open_bound tb in
@@ -323,8 +323,7 @@ and print_tnode opl opr info fmt t = match t.t_node with
       Svs.iter forget_var p.pat_vars
   | Tcase (t, bl) ->
       fprintf fmt "CASES %a OF@\n@[<hov>%a@]@\nENDCASES"
-        (print_term info) t
-        (print_list comma_newline (print_tbranch info)) bl
+        (print_term info) t (print_branches print_term info) bl
   | Teps fb ->
       let v,f = t_open_bound fb in
       fprintf fmt (protect_on opr "epsilon(LAMBDA (%a):@ %a)")
@@ -391,8 +390,7 @@ and print_fnode opl opr info fmt f = match f.t_node with
       Svs.iter forget_var p.pat_vars
   | Tcase (t, bl) ->
       fprintf fmt "CASES %a OF@\n@[<hov>%a@]@\nENDCASES"
-        (print_term info) t
-        (print_list comma_newline (print_fbranch info)) bl
+        (print_term info) t (print_branches print_fmla info) bl
   | Tif (f1, f2, f3) ->
       fprintf fmt (protect_on opr "IF %a@ THEN %a@ ELSE %a ENDIF")
         (print_fmla info) f1 (print_fmla info) f2 (print_opl_fmla info) f3
@@ -425,17 +423,29 @@ and print_tuple_pat info t fmt p =
   in
   print_comma_list print fmt l
 
-and print_tbranch info fmt br =
+and print_branch print info fmt br =
   let p,t = t_open_branch br in
   fprintf fmt "@[<hov 4> %a:@ %a@]"
-    (print_pat info) p (print_term info) t;
+    (print_pat info) p (print info) t;
   Svs.iter forget_var p.pat_vars
 
-and print_fbranch info fmt br =
-  let p,f = t_open_branch br in
-  fprintf fmt "@[<hov 4> %a:@ %a@]"
-    (print_pat info) p (print_fmla info) f;
-  Svs.iter forget_var p.pat_vars
+and print_branches ?(first=true) print info fmt = function
+  | [] ->
+      ()
+  | br :: bl ->
+      let p, t = t_open_branch br in
+      begin match p.pat_node with
+        | Pwild ->
+            assert (bl = []);
+            if not first then fprintf fmt "@\n";
+            fprintf fmt "@[<hov 4>ELSE@ %a@]" (print info) t
+        | _ ->
+            if not first then fprintf fmt ",@\n";
+            fprintf fmt "@[<hov 4> %a:@ %a@]"
+              (print_pat info) p (print info) t;
+            Svs.iter forget_var p.pat_vars;
+            print_branches ~first:false print info fmt bl
+      end
 
 let print_expr info fmt =
   TermTF.t_select (print_term info fmt) (print_fmla info fmt)
@@ -715,12 +725,15 @@ let print_recursive_decl info fmt dl =
 let print_ind info fmt (pr,f) =
   fprintf fmt "@[%% %a:@\n(%a)@]" print_pr pr (print_fmla info) f
 
-let print_ind_decl info fmt (ps,bl) =
+let print_ind_decl info fmt (ps,al) =
   let _ty_vars_args, _ty_vars_value, all_ty_params = ls_ty_vars ps in
+  let vl = List.map (create_vsymbol (id_fresh "z")) ps.ls_args in
+  let tl = List.map t_var vl in
+  let dj = Util.map_join_left (Eliminate_inductive.exi tl) t_or al in
   fprintf fmt "@[<hov 2>%a%a(%a): INDUCTIVE bool =@ @[<hov>%a@]@]@\n"
-     print_ls ps print_implicit_params all_ty_params
-    (print_comma_list (print_ty info)) ps.ls_args
-     (print_or_list (print_ind info)) bl;
+    print_ls ps print_implicit_params all_ty_params
+    (print_comma_list (print_vsty_nopar info)) vl
+    (print_fmla info) dj;
   fprintf fmt "@\n"
 
 let print_ind_decl info fmt d =
diff --git a/src/transform/eliminate_inductive.mli b/src/transform/eliminate_inductive.mli
index e32760f75a..90581b3f65 100644
--- a/src/transform/eliminate_inductive.mli
+++ b/src/transform/eliminate_inductive.mli
@@ -19,3 +19,5 @@
 
 val eliminate_inductive : Task.task Trans.trans
 
+(* exported to be used in the PVS printer *)
+val exi: Term.term list -> 'a * Term.term -> Term.term
diff --git a/tests/test-jcf.why b/tests/test-jcf.why
index f6275ace2a..01c8d2b1c5 100644
--- a/tests/test-jcf.why
+++ b/tests/test-jcf.why
@@ -3,7 +3,6 @@ theory TestPVS
 
   use import int.Int
 
-(***
   function f int : int
 
   axiom f_def: forall x: int. f(x) = x+1
@@ -16,9 +15,8 @@ theory TestPVS
 
   goal G1: match g 1 with
              | A x ((y,z) as p) -> y=1+1 /\ p = (2,3)
-             | B (() as p) -> p=()
-             | C -> false end
-***)
+             | _ -> false
+           end
 
   type elt
   type tree = Null | Node tree elt tree
@@ -43,14 +41,10 @@ theory TestPVS
     | Node l _ r -> size2 l + size2 r + 1
   end
 
-  type u = t
-  with t = A | B u
-
-(*
   inductive even int =
     | even0: even 0
-    | evens: forall n: int. even n -> even (n+2)
-*)
+    | even2: forall n: int. even n -> even (n+2)
+    | even4: forall n: int. even n -> even (n+4)
 
   lemma size_nonneg: forall t: tree. size t >= 0
 
-- 
GitLab