reflection.ml 30.8 KB
Newer Older
Guillaume Melquiond's avatar
Guillaume Melquiond committed
1 2 3
(********************************************************************)
(*                                                                  *)
(*  The Why3 Verification Platform   /   The Why3 Development Team  *)
Guillaume Melquiond's avatar
Guillaume Melquiond committed
4
(*  Copyright 2010-2019   --   Inria - CNRS - Paris-Sud University  *)
Guillaume Melquiond's avatar
Guillaume Melquiond committed
5 6 7 8 9 10 11
(*                                                                  *)
(*  This software is distributed under the terms of the GNU Lesser  *)
(*  General Public License version 2.1, with the special exception  *)
(*  on linking described in file LICENSE.                           *)
(*                                                                  *)
(********************************************************************)

12 13 14
open Term
open Ty
open Decl
15
open Ident
16
open Task
17
open Args_wrapper
18
open Generic_arg_trans_utils
19

20
exception NoReification
21

22 23 24 25 26 27
let debug_reification = Debug.register_info_flag
                          ~desc:"Reification"
                          "reification"
let debug_refl = Debug.register_info_flag
                     ~desc:"Reflection transformations"
                     "reflection"
28

Andrei Paskevich's avatar
Andrei Paskevich committed
29
let expl_reification_check = Ident.create_attribute "expl:reification check"
30

31
type reify_env = { kn: known_map;
32
                   store: (vsymbol * int) Mterm.t;
33 34
                   fr: int;
                   subst: term Mvs.t;
Raphael Rieu-Helft's avatar
Raphael Rieu-Helft committed
35
                   lv: vsymbol list;
36
                   var_maps: ty Mvs.t; (* type of values pointed by each map*)
Raphael Rieu-Helft's avatar
Raphael Rieu-Helft committed
37
                   crc_map: Coercion.t;
38
                   ty_to_map: vsymbol Mty.t;
39
                   env: Env.env;
40
                   interps: Sls.t; (* functions that were inverted*)
41
                   task: Task.task;
42 43
                   bound_vars: Svs.t; (* bound variables, do not map them in a var map*)
                   bound_fr: int; (* separate, negative index for bound vars*)
44
                 }
45

Raphael Rieu-Helft's avatar
Raphael Rieu-Helft committed
46
let init_renv kn crc lv env task =
47 48 49 50 51 52
  { kn=kn;
    store = Mterm.empty;
    fr = 0;
    subst = Mvs.empty;
    lv = lv;
    var_maps = Mvs.empty;
Raphael Rieu-Helft's avatar
Raphael Rieu-Helft committed
53
    crc_map = crc;
54 55
    ty_to_map = Mty.empty;
    env = env;
56
    interps = Sls.empty;
57
    task = task;
58 59
    bound_vars = Svs.empty;
    bound_fr = -1;
60
  }
61 62

let rec reify_term renv t rt =
Raphael Rieu-Helft's avatar
Raphael Rieu-Helft committed
63
  let is_pvar p = match p.pat_node with Pvar _ -> true | _ -> false in
64 65 66 67 68 69 70 71 72 73 74 75 76
  let rec use_interp t =
    let r = match t.t_node with
      | Tconst _ -> true
      | Tvar _ -> false
      | Tapp (ls, []) ->
         begin match find_logic_definition renv.kn ls with
         | None -> false
         | Some ld ->
            let _,t = open_ls_defn ld in
            use_interp t
         end
      | Tapp (_, _) -> true
      | _ -> false in
77
    Debug.dprintf debug_reification "use_interp %a: %b@." Pretty.print_term t r;
78
    r in
79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110
  let add_to_maps renv vyl =
     let var_maps, ty_to_map =
       List.fold_left
         (fun (var_maps, ty_to_map) vy ->
           if Mty.mem vy.vs_ty ty_to_map
           then (Mvs.add vy vy.vs_ty var_maps, ty_to_map)
           else (Mvs.add vy vy.vs_ty var_maps,
                 Mty.add vy.vs_ty vy ty_to_map))
         (renv.var_maps, renv.ty_to_map)
         (List.map
            (fun t -> match t.t_node with Tvar vy -> vy | _ -> assert false)
            vyl)
     in
     { renv with var_maps = var_maps; ty_to_map = ty_to_map }
  in
  let open Theory in
  let th_list = Env.read_theory renv.env ["list"] "List" in
  let ty_list = ns_find_ts th_list.th_export ["list"] in
  let compat_h t rt =
    match t.t_node, rt.t_node with
    | Tapp (ls1,_), Tapp(ls2, _) -> ls_equal ls1 ls2
    | Tquant (Tforall, _), Tquant (Tforall, _)
      | Tquant (Texists, _), Tquant (Texists, _)-> true
    | _ -> false in
  let is_eq_true t = match t.t_node with
    | Tapp (eq, [_; tr])
       when ls_equal eq ps_equ && t_equal tr t_bool_true -> true
    | _ -> false in
  let lhs_eq_true t = match t.t_node with
    | Tapp (eq, [t; tr])
         when ls_equal eq ps_equ && t_equal tr t_bool_true -> t
    | _ -> assert false in
Raphael Rieu-Helft's avatar
Raphael Rieu-Helft committed
111
  let rec invert_nonvar_pat vl (renv:reify_env) (p,f) t =
112 113 114
    Debug.dprintf debug_reification
      "invert_nonvar_pat p %a f %a t %a@."
      Pretty.print_pat p Pretty.print_term f Pretty.print_term t;
115 116
    if is_eq_true f && not (is_eq_true t)
    then invert_nonvar_pat vl renv (p, lhs_eq_true f) t else
117
    match p.pat_node, f.t_node, t.t_node with
Andrei Paskevich's avatar
Andrei Paskevich committed
118
    | Pwild , _, _ | Pvar _,_,_ when t_equal_nt_na f t ->
119
       Debug.dprintf debug_reification "case equal@.";
120
       renv, t
121 122
    | Papp (cs, pl), _,_
         when compat_h f t
Raphael Rieu-Helft's avatar
Raphael Rieu-Helft committed
123 124
              && Svs.for_all (fun v -> t_v_occurs v f = 1) p.pat_vars
              && List.for_all is_pvar pl
125
                              (* could remove this with a bit more work in term reconstruction *)
Raphael Rieu-Helft's avatar
Raphael Rieu-Helft committed
126
      ->
127
       Debug.dprintf debug_reification "case app@.";
Raphael Rieu-Helft's avatar
Raphael Rieu-Helft committed
128 129
       let rec rt_of_var svs f t v (renv, acc) =
         assert (not (Mvs.mem v acc));
130
         Debug.dprintf debug_reification "rt_of_var %a %a@."
Raphael Rieu-Helft's avatar
Raphael Rieu-Helft committed
131 132 133 134 135 136 137 138 139 140 141 142 143 144 145
                                     Pretty.print_vs v Pretty.print_term f;
         if t_v_occurs v f = 1
            && Svs.for_all (fun v' -> vs_equal v v' || t_v_occurs v' f = 0) svs
         then let renv, rt = invert_pat vl renv (pat_var v, f) t in
              renv, Mvs.add v rt acc
         else
           match f.t_node, t.t_node with
           | Tapp(ls1, la1), Tapp(ls2, la2) when ls_equal ls1 ls2 ->
              let rec aux la1 la2 =
                match la1, la2 with
                | f'::l1, t'::l2 ->
                   if t_v_occurs v f' = 1 then rt_of_var svs f' t' v (renv, acc)
                   else aux l1 l2
                | _ -> assert false in
              aux la1 la2
146 147 148 149 150 151 152
           | Tquant (Tforall, tq1), Tquant (Tforall, tq2)
             | Tquant (Texists, tq1), Tquant (Texists, tq2) ->
              let _, _, t1 = t_open_quant tq1 in
              let vl, _, t2 = t_open_quant tq2 in
              let bv = List.fold_left Svs.add_left renv.bound_vars vl in
              let renv = { renv with bound_vars = bv } in
              rt_of_var svs t1 t2 v (renv, acc)
Raphael Rieu-Helft's avatar
Raphael Rieu-Helft committed
153 154
           | _ -> raise NoReification
       in
155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174
       let rec check_nonvar f t =
         match f.t_node, t.t_node with
         | Tapp (ls1, la1), Tapp (ls2, la2) ->
            if Svs.for_all (fun v -> t_v_occurs v f = 0) p.pat_vars
            then (if not (ls_equal ls1 ls2)
                  then raise NoReification);
            if ls_equal ls1 ls2 then List.iter2 check_nonvar la1 la2;
         | Tapp (ls,_), Tconst _ ->
            (* reject constants that do not match the
               definitions of logic constants*)
            if Svs.for_all (fun v -> t_v_occurs v f = 0) p.pat_vars
            then
              match find_logic_definition renv.kn ls with
              | None -> raise NoReification
              | Some ld -> let v,f = open_ls_defn ld in
                           assert (v = []);
                           check_nonvar f t
            else ()
         | Tconst (Number.ConstInt c1), Tconst (Number.ConstInt c2) ->
            let open Number in
175
            if not (BigInt.eq c1.il_int c2.il_int)
176 177 178 179 180 181
            then raise NoReification
         | _ -> () (* FIXME add more failure cases if needed *)
       in
       check_nonvar f t;
       let renv, mvs = Svs.fold (rt_of_var p.pat_vars f t) p.pat_vars
                                (renv, Mvs.empty) in
Raphael Rieu-Helft's avatar
Raphael Rieu-Helft committed
182 183 184
       let lrt = List.map (function | {pat_node = Pvar v} -> Mvs.find v mvs
                                    | _ -> assert false) pl in
       renv, t_app cs lrt (Some p.pat_ty)
185 186
    | Pvar v, Tapp (ls1, la1), Tapp(ls2, la2)
         when ls_equal ls1 ls2 && t_v_occurs v f = 1
187
      -> Debug.dprintf debug_reification "case app_var@.";
188
         let renv, rt =
Raphael Rieu-Helft's avatar
Raphael Rieu-Helft committed
189 190
           List.fold_left2
             (fun (renv, acc) f t ->
191 192 193 194 195 196 197
               if acc = None
               then if t_v_occurs v f > 0
                    then let renv, rt = (invert_pat vl renv (p, f) t) in
                         renv, Some rt
                    else renv, acc
               else (assert (t_v_occurs v f = 0);
                     renv, acc))
Raphael Rieu-Helft's avatar
Raphael Rieu-Helft committed
198
             (renv,None) la1 la2 in
199
         renv, Opt.get rt
200 201 202
    | Pvar v, Tquant(Tforall, tq1), Tquant(Tforall, tq2)
      | Pvar v, Tquant(Texists, tq1), Tquant(Texists, tq2)
         when t_v_occurs v f = 1 ->
203
       Debug.dprintf debug_reification "case quant_var@.";
204 205 206 207 208
       let _,_,t1 = t_open_quant tq1 in
       let vl,_,t2 = t_open_quant tq2 in
       let bv = List.fold_left Svs.add_left renv.bound_vars vl in
       let renv = { renv with bound_vars = bv } in
       invert_nonvar_pat vl renv (p, t1) t2
209
    | Por (p1, p2), _, _ ->
210
       Debug.dprintf debug_reification "case or@.";
211 212
       begin try invert_pat vl renv (p1, f) t
             with NoReification -> invert_pat vl renv (p2, f) t
213
       end
214
    | Pvar _, Tvar _, Tvar _ | Pvar _, Tvar _, Tapp (_, [])
215
      | Pvar _, Tvar _, Tconst _
216
      -> Debug.dprintf debug_reification "case vars@.";
217
         (renv, t)
218
    | Pvar _, Tapp (ls, _hd::_tl), _
219
      -> Debug.dprintf debug_reification "case interp@.";
220
         invert_interp renv ls t
221 222
    | Papp (cs, [{pat_node = Pvar v}]), Tvar v', Tconst _
         when vs_equal v v'
223
      -> Debug.dprintf debug_reification "case var_const@.";
224
         renv, t_app cs [t] (Some p.pat_ty)
225 226
    | Papp (cs, [{pat_node = Pvar _}]), Tapp(ls, _hd::_tl), _
         when use_interp t (*FIXME*)
227
      -> Debug.dprintf debug_reification "case interp_var@.";
228 229
         let renv, rt = invert_interp renv ls t in
         renv, (t_app cs [rt] (Some p.pat_ty))
230
    | Papp _, Tapp (ls1, _), Tapp(ls2, _) ->
231
       Debug.dprintf debug_reification "head symbol mismatch %a %a@."
232 233
                                   Pretty.print_ls ls1 Pretty.print_ls ls2;
       raise NoReification
234
    | _ -> raise NoReification
235
  and invert_var_pat vl (renv:reify_env) (p,f) t =
236 237 238
    Debug.dprintf debug_reification
      "invert_var_pat p %a f %a t %a@."
      Pretty.print_pat p Pretty.print_term f Pretty.print_term t;
239 240 241 242
    match p.pat_node, f.t_node with
    | Papp (_, [{pat_node = Pvar v1}]),
      Tapp (ffa,[{t_node = Tvar vy}; {t_node = Tvar v2}])
      | Pvar v1, Tapp (ffa,[{t_node = Tvar vy}; {t_node = Tvar v2}])
243 244 245 246 247 248
         when ty_equal v1.vs_ty ty_int
              && Svs.mem v1 p.pat_vars
              && vs_equal v1 v2
              && ls_equal ffa fs_func_app
              && List.exists (fun vs -> vs_equal vs vy) vl (*FIXME*)
      ->
249
       Debug.dprintf debug_reification "case var@.";
250 251 252 253 254
       let rty = (Some p.pat_ty) in
       let app_pat trv = match p.pat_node with
         | Papp (cs, _) -> t_app cs [trv] rty
         | Pvar _ -> trv
         | _ -> assert false in
255 256 257 258 259 260 261 262 263 264
       let rec rm t =
         let t = match t.t_node with
           | Tapp (f,tl) -> t_app f (List.map rm tl) t.t_ty
           | Tvar _ | Tconst _ -> t
           | Tif (f,t1,t2) -> t_if (rm f) (rm t1) (rm t2)
           | Tbinop (op,f1,f2) -> t_binary op (rm f1) (rm f2)
           | Tnot f1 -> t_not (rm f1)
           | Ttrue | Tfalse -> t
           | _ -> t (* FIXME some cases missing *)
         in
Andrei Paskevich's avatar
Andrei Paskevich committed
265
         t_attr_set ?loc:t.t_loc Sattr.empty t
266
       in
267
       let t = rm t in
Andrei Paskevich's avatar
Andrei Paskevich committed
268
       (* remove attributes to identify terms modulo attributes *)
269 270 271
       if Mterm.mem t renv.store
       then
         begin
272
           Debug.dprintf debug_reification "%a exists@." Pretty.print_term t;
273
           (renv, app_pat (t_nat_const (snd (Mterm.find t renv.store))))
274 275 276
         end
       else
         begin
277
           Debug.dprintf debug_reification "%a is new@." Pretty.print_term t;
278 279 280 281 282 283 284
           let bound = match t.t_node with
             | Tvar v -> Svs.mem v renv.bound_vars
             | _ -> false in
           let renv, i=
             if bound
             then let i = renv.bound_fr in
                  { renv with bound_fr = i-1 }, i
285
             else
286 287 288 289
               let vy = Mty.find vy.vs_ty renv.ty_to_map in
               let fr = renv.fr in
               let store = Mterm.add t (vy, fr) renv.store in
               { renv with store = store; fr = fr + 1 }, fr in
290
           let const = Number.int_const_of_int i in
291
           (renv, app_pat (t_const const Ty.ty_int))
292 293
         end
    | _ -> raise NoReification
294
  and invert_pat vl renv (p,f) t =
Raphael Rieu-Helft's avatar
Raphael Rieu-Helft committed
295 296 297 298 299 300 301 302 303 304 305 306 307
    if (oty_equal f.t_ty t.t_ty)
    then
      try invert_nonvar_pat vl renv (p,f) t
      with NoReification -> invert_var_pat vl renv (p,f) t
    else begin
        try
          let crc = Coercion.find renv.crc_map
                                  (Opt.get t.t_ty) (Opt.get f.t_ty) in
          let apply_crc t ls = t_app_infer ls [t] in
          let crc_t = List.fold_left apply_crc t crc in
          assert (oty_equal f.t_ty crc_t.t_ty);
          invert_pat vl renv (p,f) crc_t
        with Not_found ->
308 309 310 311
          Debug.dprintf debug_reification "type mismatch between %a and %a@."
            Pretty.print_ty (Opt.get f.t_ty)
            Pretty.print_ty (Opt.get t.t_ty);
          raise NoReification
Raphael Rieu-Helft's avatar
Raphael Rieu-Helft committed
312
      end
313
  and invert_interp renv ls (t:term) =
314
    let ld = try Opt.get (find_logic_definition renv.kn ls)
315
             with Invalid_argument _ ->
316 317 318 319
               Debug.dprintf debug_reification
                 "did not find def of %a@."
                 Pretty.print_ls ls;
               raise NoReification
320
    in
321
    let vl, f = open_ls_defn ld in
322
    Debug.dprintf debug_reification "invert_interp ls %a t %a@."
323
                                Pretty.print_ls ls Pretty.print_term t;
324
    invert_body { renv with interps = Sls.add ls renv.interps } ls vl f t
325 326 327 328 329 330 331
  and invert_body renv ls vl f t =
    match f.t_node with
    | Tvar v when vs_equal v (List.hd vl) -> renv, t
    | Tif (f, th, el) when t_equal th t_bool_true && t_equal el t_bool_false ->
       invert_body renv ls vl f t
    | Tcase (x, bl)
      ->
332 333
       (match x.t_node with
        | Tvar v when vs_equal v (List.hd vl) -> ()
334
        | _ -> Debug.dprintf debug_reification "not matching on first param@.";
335
               raise NoReification);
336
       Debug.dprintf debug_reification "case match@.";
337
       let rec aux invert = function
338
         | [] -> raise NoReification
339
         | tb::l ->
340
            try invert vl renv (t_open_branch tb) t
341
            with NoReification ->
342
                 Debug.dprintf debug_reification "match failed@."; aux invert l in
343
       (try aux invert_nonvar_pat bl with NoReification -> aux invert_var_pat bl)
344
    | Tapp (ls', _) ->
345
       Debug.dprintf debug_reification "case app@.";
346
       invert_interp renv ls' t
347 348
    | _ -> Debug.dprintf debug_reification "function body not handled@.";
           Debug.dprintf debug_reification "f: %a@." Pretty.print_term f;
349 350 351
           raise NoReification
  and invert_ctx_interp renv ls t l g =
    let ld = try Opt.get (find_logic_definition renv.kn ls)
352
             with Invalid_argument _ ->
353 354 355
               Debug.dprintf debug_reification "did not find def of %a@."
                 Pretty.print_ls ls;
               raise NoReification
356 357
    in
    let vl, f = open_ls_defn ld in
358
    Debug.dprintf debug_reification "invert_ctx_interp ls %a @."
359
                                Pretty.print_ls ls;
360
    let renv = { renv with interps = Sls.add ls renv.interps } in
361 362
    invert_ctx_body renv ls vl f t l g
  and invert_ctx_body renv ls vl f t l g =
363
    match f.t_node with
364
    | Tcase ({t_node = Tvar v}, [tbn; tbc] ) when vs_equal v (List.hd vl) ->
365 366 367
       let ty_g = g.vs_ty in
       let ty_list_g = ty_app ty_list [ty_g] in
       if (not (ty_equal ty_list_g l.vs_ty))
368 369
       then (Debug.dprintf debug_reification
               "bad type for context interp function@.";
370 371 372 373 374 375
             raise NoReification);
       let nil = ns_find_ls th_list.th_export ["Nil"] in
       let cons = ns_find_ls th_list.th_export ["Cons"] in
       let (pn, fn) = t_open_branch tbn in
       let (pc, fc) = t_open_branch tbc in
       begin match pn.pat_node, fn.t_node, pc.pat_node, fc.t_node with
376 377
       | Papp(n, []),
         Tapp(eq'', [{t_node=Tapp(leq,{t_node = Tvar g'}::_)};btr'']),
378
         Papp (c, [{pat_node = Pvar hdl};{pat_node = Pvar tll}]),
379 380 381 382 383 384
         Tbinop(Timplies,
                {t_node=(Tapp(eq, [({t_node = Tapp(leq', _)} as thd); btr]))},
                {t_node = (Tapp(eq',
                    [({t_node =
                         Tapp(ls', {t_node = Tvar tll'}::{t_node=Tvar g''}::_)}
                         as ttl); btr']))})
385
            when ls_equal n nil && ls_equal c cons && ls_equal ls ls'
386
                 && vs_equal tll tll'
387 388 389 390
                 && vs_equal g' g'' && ls_equal leq leq'
                 && List.mem g' vl
                 && not (Mvs.mem tll (t_vars thd))
                 && not (Mvs.mem hdl (t_vars ttl))
391 392 393
                 && ls_equal eq ps_equ && ls_equal eq' ps_equ
                 && ls_equal eq'' ps_equ && t_equal btr t_bool_true
                 && t_equal btr' t_bool_true && t_equal btr'' t_bool_true
394
         ->
395
          Debug.dprintf debug_reification "reifying goal@.";
396 397
          let (renv, rg) = invert_interp renv leq t in
          let renv = { renv with subst = Mvs.add g rg renv.subst } in
398
          Debug.dprintf debug_reification "filling context@.";
399 400 401
          let rec add_to_ctx (renv, ctx) e =
            try
              match e.t_node with
402
              | Teps _ -> (renv, ctx)
403 404 405 406 407 408 409 410
              | Tbinop (Tand,e1,e2) ->
                 add_to_ctx (add_to_ctx (renv, ctx) e1) e2
              | _ ->
                 let (renv,req) = invert_interp renv leq e in
                 (renv,(t_app cons [req; ctx] (Some ty_list_g)))
            with
            | NoReification -> renv,ctx
          in
411 412 413 414 415
          let renv, ctx =
              task_fold
                (fun (renv,ctx) td ->
                  match td.td_node with
                  | Decl {d_node = Dprop (Paxiom, _, e)}
416
                    -> add_to_ctx (renv, ctx) e
417 418 419
                  | Decl {d_node = Dlogic [ls, ld]} when ls.ls_args = []
                    ->
                     add_to_ctx (renv, ctx) (ls_defn_axiom ld)
420 421 422
                  | _-> renv,ctx)
                             (renv, (t_app nil [] (Some ty_list_g))) renv.task in
          { renv with subst = Mvs.add l ctx renv.subst }
423
       | _ -> Debug.dprintf debug_reification "unhandled interp structure@.";
424 425
              raise NoReification
       end
426 427
    | Tif (c, th, el) when t_equal th t_bool_true && t_equal el t_bool_false ->
       invert_ctx_body renv ls vl c t l g
428
    | _ -> Debug.dprintf debug_reification "not a match on list@.";
429 430
           raise NoReification
  in
431 432
  Debug.dprintf debug_reification "reify_term t %a rt %a@."
    Pretty.print_term t Pretty.print_term rt;
433
  if not (oty_equal t.t_ty rt.t_ty)
434 435 436
  then (Debug.dprintf debug_reification "reification type mismatch %a %a@."
          Pretty.print_ty (Opt.get t.t_ty)
          Pretty.print_ty (Opt.get rt.t_ty);
437 438
        raise NoReification);
  match t.t_node, rt.t_node with
439 440 441 442 443 444 445
  | _, Tapp(interp, {t_node = Tvar vx}::vyl)
       when List.mem vx renv.lv
            && List.for_all
                 (fun t -> match t.t_node with
                           | Tvar vy -> List.mem vy renv.lv
                           | _ -> false)
                 vyl  ->
446
     Debug.dprintf debug_reification "case interp@.";
447
     let renv = add_to_maps renv vyl in
448 449
     let renv, x = invert_interp renv interp t in
     { renv with subst = Mvs.add vx x renv.subst }
450
  | Tapp(eq, [t1; t2]), Tapp (eq', [rt1; rt2])
451 452 453
       when ls_equal eq ps_equ && ls_equal eq' ps_equ
            && oty_equal t1.t_ty rt1.t_ty && oty_equal t2.t_ty rt2.t_ty
    ->
454
     Debug.dprintf debug_reification "case eq@.";
455
     reify_term (reify_term renv t1 rt1) t2 rt2
456 457
  | _, Tapp(eq,[{t_node=Tapp(interp, {t_node = Tvar l}::{t_node = Tvar g}::vyl)}; tr])
       when ls_equal eq ps_equ && t_equal tr t_bool_true
458
            && ty_equal (ty_app ty_list [g.vs_ty]) l.vs_ty
459 460 461 462 463 464 465 466
            && List.mem l renv.lv
            && List.mem g renv.lv
            && List.for_all
                 (fun t -> match t.t_node with
                           | Tvar vy -> List.mem vy renv.lv
                           | _ -> false)
                 vyl
    ->
467
     Debug.dprintf debug_reification "case context@.";
468 469
     let renv = add_to_maps renv vyl in
     invert_ctx_interp renv interp t l g
470 471 472
  | Tbinop(Tiff,t,{t_node=Ttrue}), Tapp(eq,[{t_node=Tapp(interp, {t_node = Tvar f}::vyl)}; tr])
       when ls_equal eq ps_equ && t_equal tr t_bool_true
            && t.t_ty=None ->
473 474
     Debug.dprintf debug_reification "case interp_fmla@.";
     Debug.dprintf debug_reification "t %a rt %a@." Pretty.print_term t Pretty.print_term rt;
475 476 477
     let renv = add_to_maps renv vyl in
     let renv, rf = invert_interp renv interp t in
     { renv with subst = Mvs.add f rf renv.subst }
478 479
  | _ -> Debug.dprintf debug_reification "no reify_term match@.";
         Debug.dprintf debug_reification "lv = [%a]@."
480 481 482
                                     (Pp.print_list Pp.space Pretty.print_vs)
                                     renv.lv;
         raise NoReification
483 484

let build_vars_map renv prev =
485
  Debug.dprintf debug_reification "building vars map@.";
486
  let subst, prev = Mvs.fold
487
                (fun vy ty_vars (subst, prev) ->
488 489
                  Debug.dprintf debug_reification "creating var map %a@."
                    Pretty.print_vs vy;
490
                  let ly = create_fsymbol (Ident.id_fresh vy.vs_name.id_string)
491
                             [] ty_vars in
492 493 494 495 496
                  let y = t_app ly [] (Some ty_vars) in
                  let d = create_param_decl ly in
                  let prev = Task.add_decl prev d in
                  Mvs.add vy y subst, prev)
                renv.var_maps (renv.subst, prev) in
497
  let prev, mapdecls =
498 499
    Mvs.fold
      (fun vy _ (prev,prs) ->
500
        Debug.dprintf debug_reification "checking %a@." Pretty.print_vs vy;
501 502 503
        let vs = Mty.find vy.vs_ty renv.ty_to_map in
        if vs_equal vy vs then prev,prs
        else begin
504 505
            Debug.dprintf debug_reification "aliasing %a and %a@."
              Pretty.print_vs vy Pretty.print_vs vs;
506 507 508 509 510 511 512
            let y = Mvs.find vy subst in
            let z = Mvs.find vs subst in
            let et = t_equ y z in
            let pr = create_prsymbol (Ident.id_fresh "map_alias") in
            let d = create_prop_decl Paxiom pr et in
            Task.add_decl prev d, pr::prs end)
      renv.var_maps (prev, []) in
Raphael Rieu-Helft's avatar
Raphael Rieu-Helft committed
513
  if not (List.for_all (fun v -> Mvs.mem v subst) renv.lv)
514 515 516
  then (Debug.dprintf debug_reification "vars not matched: %a@."
          (Pp.print_list Pp.space Pretty.print_vs)
          (List.filter (fun v -> not (Mvs.mem v subst)) renv.lv);
517
        raise (Arg_error "vars not matched"));
518
  Debug.dprintf debug_reification "all vars matched@.";
519
  let prev, defdecls =
520 521 522 523 524 525 526
    Mterm.fold
      (fun t (vy,i) (prev,prs) ->
        let y = Mvs.find vy subst in
        let et = t_equ
                   (t_app fs_func_app [y; t_nat_const i]
                          t.t_ty)
                   t in
527
        Debug.dprintf debug_reification "%a %d = %a@."
528
                                    Pretty.print_vs vy i Pretty.print_term t;
529 530
        let s = Format.sprintf "y_val%d" i in
        let pr = create_prsymbol (Ident.id_fresh s) in
531 532
        let d = create_prop_decl Paxiom pr et in
        Task.add_decl prev d, pr::prs)
533 534
      renv.store (prev,[]) in
  subst, prev, mapdecls, defdecls
535

536
let build_goals do_trans renv prev mapdecls defdecls subst env lp g rt =
537
  Debug.dprintf debug_refl "building goals@.";
538
  let inst_rt = t_subst subst rt in
539
  Debug.dprintf debug_refl "reified goal instantiated@.";
540
  let inst_lp = List.map (t_subst subst) lp in
541
  Debug.dprintf debug_refl "premises instantiated@.";
542 543
  let hr = create_prsymbol (id_fresh "HR") in
  let d_r = create_prop_decl Paxiom hr inst_rt in
544 545
  let pr = create_prsymbol
             (id_fresh "GR"
Andrei Paskevich's avatar
Andrei Paskevich committed
546
                       ~attrs:(Sattr.singleton expl_reification_check)) in
547 548
  let d = create_prop_decl Pgoal pr g in
  let task_r = Task.add_decl (Task.add_decl prev d_r) d in
549 550
  Debug.dprintf debug_refl "building cut indication rt %a g %a@."
    Pretty.print_term rt Pretty.print_term g;
551
  let compute_hyp_few pr = Compute.normalize_hyp_few None (Some pr) env in
552
  let compute_in_goal = Compute.normalize_goal_transf_all env in
553 554 555 556 557 558
  let ltask_r =
    try let ci =
          match (rt.t_node, g.t_node) with
          | (Tapp(eq, rh::rl),
             Tapp(eq', h::l))
               when ls_equal eq eq' ->
559 560 561 562 563
             List.fold_left2
               (fun ci st rst ->
                 t_and ci (t_equ (t_subst subst rst) st))
               (t_equ (t_subst subst rh) h)
               l rl
564
          | _,_ when g.t_ty <> None -> t_equ (t_subst subst rt) g
565
          | _ -> raise Not_found in
566
        Debug.dprintf debug_refl "cut ok@.";
567
        Trans.apply (Cut.cut ci (Some "interp")) task_r
568
    with Arg_trans _ | TypeMismatch _ | Not_found ->
569
         Debug.dprintf debug_refl "no cut found@.";
570 571
         if do_trans
         then
572 573 574 575 576 577 578
           let g, prev = task_separate_goal task_r in
           let prev = Sls.fold
                     (fun ls t ->
                       Task.add_meta t Compute.meta_rewrite_def [Theory.MAls ls])
                     renv.interps prev in
           let t = Task.add_tdecl prev g in
           let t = Trans.apply (compute_hyp_few hr) t in
579 580
           match t with
           | [t] ->
581
              let rewrite = Apply.rewrite_list false true
582 583
                              (mapdecls@defdecls) (Some hr) in
              Trans.apply rewrite t
584 585
           | [] -> []
           | _ -> assert false
586
         else [task_r] in
587 588 589
  let lt = List.map (fun ng -> Task.add_decl prev
                       (create_prop_decl Pgoal (create_prsymbol (id_fresh "G")) ng))
                    inst_lp in
590 591 592
  let lt = if do_trans
           then Lists.apply (Trans.apply compute_in_goal) lt
           else lt in
593
  Debug.dprintf debug_refl "done@.";
594 595
  ltask_r@lt

596
let reflection_by_lemma pr env : Task.task Trans.tlist = Trans.store (fun task ->
597
  let kn = task_known task in
598 599
  let g, prev = Task.task_separate_goal task in
  let g = Apply.term_decl g in
600
  Debug.dprintf debug_refl "start@.";
601 602 603 604
  let l =
    let kn' = task_known prev in (* TODO Do we want kn here ? *)
    match find_prop_decl kn' pr with
    | (_, t) -> t
605
    | exception Not_found -> raise (Arg_error "lemma not found")
606
  in
607
  let (lp, lv, llet, rt) = Apply.intros l in
608 609 610 611 612 613
  if llet <> []
  then begin
      (* TODO handle lets *)
      Debug.dprintf debug_refl "let in procedure postcondition@.";
      raise NoReification
    end;
614 615 616
  let nt = Args_wrapper.build_naming_tables task in
  let crc = nt.Trans.coercion in
  let renv = reify_term (init_renv kn crc lv env prev) g rt in
617
  let subst, prev, mds, dds = build_vars_map renv prev in
618
  build_goals true renv prev mds dds subst env lp g rt)
619 620 621

open Expr
open Ity
Guillaume Melquiond's avatar
Guillaume Melquiond committed
622
open Wstdlib
623
open Mlinterp
624

625 626 627
exception ReductionFail of reify_env

let reflection_by_function do_trans s env = Trans.store (fun task ->
628
  Debug.dprintf debug_refl "reflection_f start@.";
629
  let kn = task_known task in
Raphael Rieu-Helft's avatar
Raphael Rieu-Helft committed
630 631
  let nt = Args_wrapper.build_naming_tables task in
  let crc = nt.Trans.coercion in
632 633
  let g, prev = Task.task_separate_goal task in
  let g = Apply.term_decl g in
634 635 636 637 638 639 640 641
  let ths = Task.used_theories task in
  let o =
    Mid.fold
      (fun _ th o ->
        try
          let pmod = Pmodule.restore_module th in
          let rs = Pmodule.ns_find_rs pmod.Pmodule.mod_export [s] in
          if o = None then Some (pmod, rs)
642
          else (let es = Format.sprintf "module or function %s found twice" s in
643
                raise (Arg_error es))
644 645
        with Not_found -> o)
      ths None in
646
  let (_pmod, rs) = if o = None
647 648
                    then (let es = Format.sprintf "Symbol %s not found@." s in
                          raise (Arg_error es))
649
                   else Opt.get o in
650 651
  let lpost = List.map open_post rs.rs_cty.cty_post in
  if List.exists (fun pv -> pv.pv_ghost) rs.rs_cty.cty_args
652
  then (Debug.dprintf debug_refl "ghost parameter@.";
653
        raise (Arg_error "function has ghost parameters"));
654
  Debug.dprintf debug_refl "building module map@.";
655 656 657 658 659 660 661
  let mm = Mid.fold
             (fun id th acc ->
               try
                 let pm = Pmodule.restore_module th in
                 Mstr.add id.id_string pm acc
               with Not_found -> acc)
             ths Mstr.empty in
662
  Debug.dprintf debug_refl "module map built@.";
663 664
  let args = List.map (fun pv -> pv.pv_vs) rs.rs_cty.cty_args in
  let rec reify_post = function
665
    | [] -> Debug.dprintf debug_refl "no postcondition reifies@.";
666
            raise NoReification
667 668
    | (vres, p)::t -> begin
        try
669 670 671
          Debug.dprintf debug_refl "new post@.";
          Debug.dprintf debug_refl "post: %a, %a@."
            Pretty.print_vs vres Pretty.print_term p;
672
          let (lp, lv, llet, rt) = Apply.intros p in
673 674 675 676 677 678
          if llet <> []
          then begin
              (* TODO handle lets *)
              Debug.dprintf debug_refl "let in procedure postcondition@.";
              raise NoReification
            end;
679
          let lv = lv @ args in
Raphael Rieu-Helft's avatar
Raphael Rieu-Helft committed
680
          let renv = reify_term (init_renv kn crc lv env prev) g rt in
681
          Debug.dprintf debug_refl "computing args@.";
682 683 684 685 686
          let vars =
            List.fold_left
              (fun vars (vs, t) ->
                if List.mem vs args
                then begin
687
                    Debug.dprintf debug_refl "value of term %a for arg %a@."
688 689
                                                Pretty.print_term t
                                                Pretty.print_vs vs;
690
                    Mid.add vs.vs_name (value_of_term kn t) vars end
691 692 693
                else vars)
              Mid.empty
              (Mvs.bindings renv.subst) in
694
          Debug.dprintf debug_refl "evaluating@.";
695
          let res =
696
            try term_of_value (Mlinterp.interp env mm rs vars)
697 698 699
            with Raised (xs,_,cs) ->
              Format.eprintf "Raised %s %a@." (xs.xs_name.id_string)
                (Pp.print_list Pp.semi Expr.print_rs) cs;
700
              raise (ReductionFail renv) in
701
          Debug.dprintf debug_refl "res %a@." Pretty.print_term res;
702 703 704 705 706
          let rinfo = {renv with subst = Mvs.add vres res renv.subst} in
          rinfo, lp, lv, rt
        with NoReification -> reify_post t
      end
  in
707 708 709
  try
    let rinfo, lp, _lv, rt = reify_post lpost in
    let lp = (rs.rs_cty.cty_pre)@lp in
710
    let subst, prev, mds, dds = build_vars_map rinfo prev in
711
    build_goals do_trans rinfo prev mds dds subst env lp g rt
712 713 714
  with
    ReductionFail renv ->
    (* proof failed, show reification context for debugging *)
715
    let _, prev, _, _ = build_vars_map renv prev in
716 717 718
    let fg = create_prsymbol (id_fresh "Failure") in
    let df = create_prop_decl Pgoal fg t_false in
    [Task.add_decl prev df] )
719

720
let () = wrap_and_register
721 722 723
           ~desc:"reflection_l <prop>@ \
            attempts@ to@ prove@ the@ goal@ by@ reflection@ \
            using@ the@ lemma@ <prop>."
724
           "reflection_l"
725
           (Tprsymbol Tenvtrans_l) reflection_by_lemma
726

727
let () = wrap_and_register
728 729 730
           ~desc:"reflection_f <f>@ \
            attempts@ to@ prove@ the@ goal@ by@ reflection@ \
            using@ the@ contract@ of@ the@ program@ function@ <f>."
731
           "reflection_f"
732
           (Tstring Tenvtrans_l) (reflection_by_function true)
733

734
let () = wrap_and_register
735 736 737 738 739
           ~desc:"reflection_f <f>@ \
            attempts@ to@ prove@ the@ goal@ by@ reflection@ \
            using@ the@ contract@ of@ the@ program@ function@ <f>.@ \
            Does@ not@ automatically@ perform@ transformations@ \
            afterwards.@ Use@ for@ debugging."
740 741
           "reflection_f_nt"
           (Tstring Tenvtrans_l) (reflection_by_function false)