reduction_engine.ml 30 KB
Newer Older
MARCHE Claude's avatar
headers  
MARCHE Claude committed
1 2 3 4 5 6 7 8 9 10 11
(********************************************************************)
(*                                                                  *)
(*  The Why3 Verification Platform   /   The Why3 Development Team  *)
(*  Copyright 2010-2014   --   INRIA - CNRS - Paris-Sud University  *)
(*                                                                  *)
(*  This software is distributed under the terms of the GNU Lesser  *)
(*  General Public License version 2.1, with the special exception  *)
(*  on linking described in file LICENSE.                           *)
(*                                                                  *)
(********************************************************************)

12 13
open Term

14
(* {2 Values} *)
15

16 17 18 19
type value =
| Term of term    (* invariant: is in normal form *)
| Int of BigInt.t

20 21 22 23 24
let v_label_copy orig v =
  match v with
  | Int _ -> v
  | Term t -> Term (t_label_copy orig t)

25 26
let const_of_positive n =
    t_const (Number.ConstInt (Number.int_const_dec (BigInt.to_string n)))
27 28 29

let ls_minus = ref ps_equ (* temporary *)

30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52
let const_of_big_int n =
  if BigInt.ge n BigInt.zero then const_of_positive n else
    let t = const_of_positive (BigInt.minus n) in
    t_app_infer !ls_minus [t]

let term_of_value v =
  match v with
  | Term t -> t
  | Int n -> const_of_big_int n

exception NotNum

let big_int_of_const c =
  match c with
    | Number.ConstInt i -> Number.compute_int i
    | _ -> raise NotNum

let big_int_of_value v =
  match v with
  | Int n -> n
  | Term {t_node = Tconst c } -> big_int_of_const c
  | _ -> raise NotNum

53

54 55 56 57 58 59
(* {2 Builtin symbols} *)

let builtins = Hls.create 17

(* all builtin functions *)

60
exception Undetermined
61

62 63
let to_bool b = if b then t_true else t_false

64
let _t_app_value ls l ty =
65
  Term (t_app ls (List.map term_of_value l) ty)
66

67 68 69 70 71 72 73 74 75
let is_zero v =
  try BigInt.eq (big_int_of_value v) BigInt.zero
  with NotNum -> false

let is_one v =
  try BigInt.eq (big_int_of_value v) BigInt.one
  with NotNum -> false

let eval_int_op op simpl ls l ty =
76
  match l with
77
  | [t1 ; t2] ->
78
    begin
79 80 81 82
      try
        let n1 = big_int_of_value t1 in
        let n2 = big_int_of_value t2 in
        Int (op n1 n2)
83
      with NotNum | Division_by_zero ->
84
        simpl ls t1 t2 ty
85
    end
86
  | _ -> assert false (* t_app_value ls l ty *)
87

88
(* unused anymore, for the moment
89 90
let simpl_none ls t1 t2 ty =
  t_app_value ls [t1;t2] ty
91
*)
92

93
let simpl_add _ls t1 t2 _ty =
94 95
  if is_zero t1 then t2 else
  if is_zero t2 then t1 else
96 97
  raise Undetermined
(*
98
  t_app_value ls [t1;t2] ty
99
*)
100

101
let simpl_sub _ls t1 t2 _ty =
102
  if is_zero t2 then t1 else
103 104
  raise Undetermined
(*
105
  t_app_value ls [t1;t2] ty
106
*)
107

108
let simpl_mul _ls t1 t2 _ty =
109 110 111 112
  if is_zero t1 then t1 else
  if is_zero t2 then t2 else
  if is_one t1 then t2 else
  if is_one t2 then t1 else
113 114
  raise Undetermined
(*
115
  t_app_value ls [t1;t2] ty
116
*)
117

118
let simpl_divmod _ls t1 t2 _ty =
119 120
  if is_zero t1 then t1 else
  if is_one t2 then t1 else
121 122
  raise Undetermined
(*
123
  t_app_value ls [t1;t2] ty
124
*)
125

126
let simpl_minmax _ls v1 v2 _ty =
127 128
  match v1,v2 with
  | Term t1, Term t2 ->
129 130 131 132 133 134 135 136 137 138 139 140
    if t_equal t1 t2 then v1 else 
      raise Undetermined
  (*
    t_app_value ls [v1;v2] ty
  *)
  | _ -> 
    raise Undetermined
(*
  t_app_value ls [v1;v2] ty
*)
      
let eval_int_rel op _ls l _ty =
MARCHE Claude's avatar
MARCHE Claude committed
141
  match l with
142
  | [t1 ; t2] ->
MARCHE Claude's avatar
MARCHE Claude committed
143
    begin
144 145 146 147
      try
        let n1 = big_int_of_value t1 in
        let n2 = big_int_of_value t2 in
        Term (to_bool (op n1 n2))
MARCHE Claude's avatar
MARCHE Claude committed
148
      with NotNum | Division_by_zero ->
149 150
        raise Undetermined
    (*        t_app_value ls l ty *)
MARCHE Claude's avatar
MARCHE Claude committed
151
    end
152 153
  | _ -> assert false
    (* t_app_value ls l ty *)
MARCHE Claude's avatar
MARCHE Claude committed
154

155
let eval_int_uop op _ls l _ty =
156
  match l with
157
  | [t1] ->
158
    begin
159 160
      try
        let n1 = big_int_of_value t1 in Int (op n1)
161
      with NotNum | Division_by_zero ->
162 163
        raise Undetermined
    (*       t_app_value ls l ty *)
164
    end
165
  | _ -> assert false 
166

167 168

let built_in_theories =
MARCHE Claude's avatar
MARCHE Claude committed
169 170 171
  [
(*
 ["bool"],"Bool", [],
172 173 174
    [ "True", None, eval_true ;
      "False", None, eval_false ;
    ] ;
MARCHE Claude's avatar
MARCHE Claude committed
175
*)
176
    ["int"],"Int", [],
177 178 179
    [ "infix +", None, eval_int_op BigInt.add simpl_add;
      "infix -", None, eval_int_op BigInt.sub simpl_sub;
      "infix *", None, eval_int_op BigInt.mul simpl_mul;
180 181 182 183 184 185 186
      "prefix -", Some ls_minus, eval_int_uop BigInt.minus;
      "infix <", None, eval_int_rel BigInt.lt;
      "infix <=", None, eval_int_rel BigInt.le;
      "infix >", None, eval_int_rel BigInt.gt;
      "infix >=", None, eval_int_rel BigInt.ge;
    ] ;
    ["int"],"MinMax", [],
187 188
    [ "min", None, eval_int_op BigInt.min simpl_minmax;
      "max", None, eval_int_op BigInt.max simpl_minmax;
189 190
    ] ;
    ["int"],"ComputerDivision", [],
191 192
    [ "div", None, eval_int_op BigInt.computer_div simpl_divmod;
      "mod", None, eval_int_op BigInt.computer_mod simpl_divmod;
193 194
    ] ;
    ["int"],"EuclideanDivision", [],
195 196
    [ "div", None, eval_int_op BigInt.euclidean_div simpl_divmod;
      "mod", None, eval_int_op BigInt.euclidean_mod simpl_divmod;
197
    ] ;
MARCHE Claude's avatar
MARCHE Claude committed
198
(*
199 200 201 202 203 204 205 206 207
    ["map"],"Map", ["map", builtin_map_type],
    [ "const", Some ls_map_const, eval_map_const;
      "get", Some ls_map_get, eval_map_get;
      "set", Some ls_map_set, eval_map_set;
    ] ;
*)
  ]

let add_builtin_th env (l,n,t,d) =
208 209 210 211 212 213 214 215 216 217 218 219 220 221
  let th = Env.read_theory env l n in
  List.iter
    (fun (id,r) ->
      let ts = Theory.ns_find_ts th.Theory.th_export [id] in
      r ts)
    t;
  List.iter
    (fun (id,r,f) ->
      let ls = Theory.ns_find_ls th.Theory.th_export [id] in
      Hls.add builtins ls f;
      match r with
        | None -> ()
        | Some r -> r := ls)
    d
222 223 224 225 226 227 228 229 230 231

let get_builtins env =
  Hls.clear builtins;
  List.iter (add_builtin_th env) built_in_theories



(* {2 the reduction machine} *)


232 233
type rule = Svs.t * term list * term

234 235 236 237 238 239
type params =
  { compute_defs : bool;
    compute_builtin : bool;
    compute_def_set : Term.Sls.t;
  }

240 241 242
type engine =
  { known_map : Decl.decl Ident.Mid.t;
    rules : rule list Mls.t;
243
    params : params;
244
  }
245 246


MARCHE Claude's avatar
MARCHE Claude committed
247
(* OBSOLETE COMMENT
248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294

  A configuration is a pair (t,s) where t is a stack of terms and s is a
  stack of function symbols.

  A configuration ([t1;..;tn],[f1;..;fk]) represents a whole term, its
  model, as defined recursively by

    model([t],[]) = t

    model(t1::..::tn::t,f::s) = model(f(t1,..,tn)::t,s)
      where f as arity n

  A given term can be "exploded" into a configuration by reversing the
  rules above

  During reduction, the terms in the first stack are kept in normal
  form. The normalization process can be defined as the repeated
  application of the following rules.

  ([t],[]) --> t  // t is in normal form

  if f(t1,..,tn) is not a redex then
  (t1::..::tn::t,f::s) --> (f(t1,..,tn)::t,s)

  if f(t1,..,tn) is a redex l sigma for a rule l -> r then
  (t1::..::tn::t,f::s) --> (subst(sigma) @ t,explode(r) @ s)




*)

type substitution = term Mvs.t

type cont =
| Kapp of lsymbol * Ty.ty option
| Kif of term * term * substitution
| Klet of vsymbol * term * substitution
| Kcase of term_branch list * substitution
| Keps of vsymbol
| Kquant of quant * vsymbol list * trigger
| Kbinop of binop
| Knot
| Keval of term * substitution

type config = {
  value_stack : value list;
295 296
  cont_stack : (cont * term) list;
  (* second term is the original term, for label and loc copy *)
297 298 299 300 301 302
}


exception NoMatch

let first_order_matching (vars : Svs.t) (largs : term list)
303 304
    (args : term list) : Ty.ty Ty.Mtv.t * substitution =
  let rec loop ((mt,mv) as sigma) largs args =
305 306 307 308 309 310 311 312 313 314 315
    match largs,args with
      | [],[] -> sigma
      | t1::r1, t2::r2 ->
        begin
(*
          Format.eprintf "matching terms %a and %a...@."
            Pretty.print_term t1 Pretty.print_term t2;
*)
          match t1.t_node with
            | Tvar vs when Svs.mem vs vars ->
              begin
316
                try let t = Mvs.find vs mv in
317 318 319 320 321
                    if t_equal t t2 then
                      loop sigma r1 r2
                    else
                      raise NoMatch
                with Not_found ->
322
                  try
323 324 325
                    let ts = Ty.ty_match mt vs.vs_ty (t_type t2) in
                    loop (ts,Mvs.add vs t2 mv) r1 r2
                  with Ty.TypeMismatch _ -> raise NoMatch
326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355
              end
            | Tapp(ls1,args1) ->
              begin
                match t2.t_node with
                  | Tapp(ls2,args2) when ls_equal ls1 ls2 ->
                    loop sigma (List.rev_append args1 r1)
                      (List.rev_append args2 r2)
                  | _ -> raise NoMatch
              end
            | _ ->
(*
              Format.eprintf "are these terms equal ?...";
*)
              if t_equal t1 t2 then
                begin
(*
                  Format.eprintf " yes!@.";
*)
                  loop sigma r1 r2
                end
              else
                begin
(*
                  Format.eprintf " no@.";
*)
                  raise NoMatch
                end
        end
      | _ -> raise NoMatch
  in
356
  loop (Ty.Mtv.empty, Mvs.empty) largs args
357

358
exception Irreducible
359

360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376
let one_step_reduce engine ls args =
  try
    let rules = Mls.find ls engine.rules in
    let rec loop rules =
      match rules with
        | [] -> raise Irreducible
        | (vars,largs,rhs)::rem ->
          begin
            try
              let sigma = first_order_matching vars largs args in
              sigma,rhs
            with NoMatch ->
              loop rem
          end
    in loop rules
  with Not_found ->
    raise Irreducible
377

378
let rec matching ((mt,mv) as sigma) t p =
379 380
  match p.pat_node with
  | Pwild -> sigma
381
  | Pvar v -> (mt,Mvs.add v t mv)
382 383 384 385 386
  | Por(p1,p2) ->
    begin
      try matching sigma t p1
      with NoMatch -> matching sigma t p2
    end
387
  | Pas(p,v) -> matching (mt,Mvs.add v t mv) t p
388 389 390 391 392 393 394 395 396 397 398
  | Papp(ls1,pl) ->
    match t.t_node with
      | Tapp(ls2,tl) ->
        if ls_equal ls1 ls2 then
          List.fold_left2 matching sigma tl pl
        else
          if ls2.ls_constr > 0 then raise NoMatch
          else raise Undetermined
      | _ -> raise Undetermined


399 400 401 402 403 404
let rec extract_first n acc l =
  if n = 0 then acc,l else
    match l with
    | x :: r ->
      extract_first (n-1) (x::acc) r
    | [] -> assert false
405 406 407 408 409


let rec reduce engine c =
  match c.value_stack, c.cont_stack with
  | _, [] -> assert false
410 411 412
  | st, (Keval (t,sigma),orig) :: rem -> reduce_eval st t ~orig sigma rem
  | [], (Kif _, _) :: _ -> assert false
  | v :: st, (Kif(t2,t3,sigma), orig) :: rem ->
413 414 415 416
    begin
      match v with
      | Term { t_node = Ttrue } ->
        { value_stack = st ;
417
          cont_stack = (Keval(t2,sigma),t_label_copy orig t2)  :: rem }
418 419
      | Term { t_node = Tfalse } ->
        { value_stack = st ;
420
          cont_stack = (Keval(t3,sigma),t_label_copy orig t3) :: rem }
421 422
      | Term t1 ->
        { value_stack =
423 424 425
            Term
              (t_label_copy orig
                 (t_if t1 (t_subst sigma t2) (t_subst sigma t3))) :: st;
426 427 428 429
          cont_stack = rem ;
        }
      | Int _ -> assert false (* would be ill-typed *)
    end
430 431
  | [], (Klet _, _) :: _ -> assert false
  | t1 :: st, (Klet(v,t2,sigma), orig) :: rem ->
432
    let t1 = term_of_value t1 in
433
    { value_stack = st;
434 435
      cont_stack =
        (Keval(t2, Mvs.add v t1 sigma), t_label_copy orig t2) :: rem;
436
    }
437 438 439 440 441 442 443 444
  | [], (Kcase _, _) :: _ -> assert false
  | Int _ :: _, (Kcase _, _) :: _ -> assert false
  | (Term t1) :: st, (Kcase(tbl,sigma), orig) :: rem ->
    reduce_match st t1 ~orig tbl sigma rem
  | ([] | [_] | Int _ :: _ | Term _ :: Int _ :: _),
    (Kbinop _, _) :: _ -> assert false
  | (Term t1) :: (Term t2) :: st, (Kbinop op, orig) :: rem ->
    { value_stack = Term (t_label_copy orig (t_binary_simp op t2 t1)) :: st;
445 446
      cont_stack = rem;
    }
447 448 449 450
  | [], (Knot,_) :: _ -> assert false
  | Int _ :: _ , (Knot,_) :: _ -> assert false
  | (Term t) :: st, (Knot, orig) :: rem ->
    { value_stack = Term (t_label_copy orig (t_not t)) :: st;
451 452
      cont_stack = rem;
    }
453 454 455 456 457 458
  | st, (Kapp(ls,ty), orig) :: rem ->
    reduce_app engine st ~orig ls ty rem
  | [], (Keps _, _) :: _ -> assert false
  | Int _ :: _ , (Keps _, _) :: _ -> assert false
  | Term t :: st, (Keps v, orig) :: rem ->
    { value_stack = Term (t_label_copy orig (t_eps_close v t)) :: st;
459 460
      cont_stack = rem;
    }
461 462 463
  | [], (Kquant _, _) :: _ -> assert false
  | Int _ :: _, (Kquant _, _) :: _ -> assert false
  | Term t :: st, (Kquant(q,vl,tr), orig) :: rem ->
464
    { value_stack = Term (t_label_copy orig (t_quant_close_simp q vl tr t)) :: st;
465 466 467
      cont_stack = rem;
    }

468
and reduce_match st u ~orig tbl sigma cont =
469 470 471 472 473 474
  let rec iter tbl =
    match tbl with
    | [] -> assert false (* pattern matching not exhaustive *)
    | b::rem ->
      let p,t = t_open_branch b in
      try
475
        let (mt',mv') = matching (Ty.Mtv.empty,sigma) u p in
476
(*
477 478 479 480 481 482 483 484 485 486 487 488 489
        Format.eprintf "Pattern-matching succeeded:@\nmt' = @[";
        Ty.Mtv.iter
          (fun tv ty -> Format.eprintf "%a -> %a,"
            Pretty.print_tv tv Pretty.print_ty ty)
          mt';
        Format.eprintf "@]@\n";
        Format.eprintf "mv' = @[";
        Mvs.iter
          (fun v t -> Format.eprintf "%a -> %a,"
            Pretty.print_vs v Pretty.print_term t)
          mv';
        Format.eprintf "@]@.";
        Format.eprintf "branch before inst: %a@." Pretty.print_term t;
490
*)
491
        let mv'',t = t_subst_types mt' mv' t in
492
(*
493 494 495 496 497 498 499
        Format.eprintf "branch after types inst: %a@." Pretty.print_term t;
        Format.eprintf "mv'' = @[";
        Mvs.iter
          (fun v t -> Format.eprintf "%a -> %a,"
            Pretty.print_vs v Pretty.print_term t)
          mv'';
        Format.eprintf "@]@.";
500
*)
501
        { value_stack = st;
502
          cont_stack = (Keval(t,mv''), t_label_copy orig t) :: cont;
503 504 505 506
        }
      with NoMatch -> iter rem
  in
  try iter tbl with Undetermined ->
507 508
    { value_stack =
        Term (t_label_copy orig (t_subst sigma (t_case u tbl))) :: st;
509 510 511 512
      cont_stack = cont;
    }


513
and reduce_eval st t ~orig sigma rem =
514 515 516 517 518
  match t.t_node with
  | Tvar v ->
    begin
      try
        let t = Mvs.find v sigma in
519
        { value_stack = Term (t_label_copy orig t) :: st ;
520 521
          cont_stack = rem;
        }
522
      with Not_found ->
MARCHE Claude's avatar
MARCHE Claude committed
523 524 525 526 527
        (* this may happen, e.g when computing below a quantified formula *)
        (*
          Format.eprintf "Tvar not found: %a@." Pretty.print_vs v;
          assert false
        *)
528
        { value_stack = Term (t_label_copy orig t) :: st ;
MARCHE Claude's avatar
MARCHE Claude committed
529 530
          cont_stack = rem;
        }
531 532 533
    end
  | Tif(t1,t2,t3) ->
    { value_stack = st;
534
      cont_stack = (Keval(t1,sigma),t1) :: (Kif(t2,t3,sigma),t) :: rem;
535 536 537 538
    }
  | Tlet(t1,tb) ->
    let v,t2 = t_open_bound tb in
    { value_stack = st ;
539
      cont_stack = (Keval(t1,sigma),t1) :: (Klet(v,t2,sigma),orig) :: rem }
540 541
  | Tcase(t1,tbl) ->
    { value_stack = st;
542
      cont_stack = (Keval(t1,sigma),t1) :: (Kcase(tbl,sigma),orig) :: rem }
543 544
  | Tbinop(op,t1,t2) ->
    { value_stack = st;
545 546 547
      cont_stack =
        (Keval(t1,sigma),t1) ::
          (Keval(t2,sigma),t2) :: (Kbinop op, orig) :: rem;
548 549 550
    }
  | Tnot t1 ->
    { value_stack = st;
551
      cont_stack = (Keval(t1,sigma),t1) :: (Knot,orig) :: rem;
552 553 554 555
    }
  | Teps tb ->
    let v,t1 = t_open_bound tb in
    { value_stack = st ;
556
      cont_stack = (Keval(t1,sigma),t1) :: (Keps v,orig) :: rem;
557 558 559 560
    }
  | Tquant(q,tq) ->
    let vl,tr,t1 = t_open_quant tq in
    { value_stack = st;
561
      cont_stack = (Keval(t1,sigma),t1) :: (Kquant(q,vl,tr),orig) :: rem;
562 563
    }
  | Tapp(ls,tl) ->
564
    let args = List.rev_map (fun t -> (Keval(t,sigma),t)) tl in
565
    { value_stack = st;
566
      cont_stack = List.rev_append args ((Kapp(ls,t.t_ty),orig) :: rem);
567 568
    }
  | Ttrue | Tfalse | Tconst _ ->
569
    { value_stack = Term (t_label_copy orig t) :: st;
570 571 572
      cont_stack = rem;
    }

573
and reduce_app engine st ls ~orig ty rem_cont =
574 575
  if ls_equal ls ps_equ then
    match st with
576
    | t2 :: t1 :: rem_st ->
577 578
      begin
        try
579
          reduce_equ ~orig rem_st t1 t2 rem_cont
580
        with Undetermined ->
581
          reduce_app_no_equ engine st ls ~orig ty rem_cont
582
      end
583
    | _ -> assert false
584 585 586
  else
    if ls_equal ls fs_func_app then
      match st with
587 588 589 590 591 592 593
      | t2 :: t1 :: rem_st -> 
        begin
          try
            reduce_func_app ~orig ty rem_st t1 t2 rem_cont
          with Undetermined ->
            reduce_app_no_equ engine st ls ~orig ty rem_cont
        end
594 595 596 597
      | _ ->  assert false
    else
      reduce_app_no_equ engine st ls ~orig ty rem_cont

598
and reduce_func_app ~orig _ty rem_st t1 t2 rem_cont =
599 600 601 602 603 604 605 606
    (* attempt to decompile t1 under the form
       (epsilon fc. forall x. fc @ x = body)
       that is equivalent to \x.body *)
    match t1 with
    | Term { t_node = Teps tb } ->
      let fc,t = Term.t_open_bound tb in
      begin match t.t_node with
      | Tquant(Tforall,tq) ->
607
        let vl,trig,t = t_open_quant tq in
608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623
        let process lhs body equ elim =
          let rvl = List.rev vl in
          let rec remove_var lhs rvh rvt = match lhs.t_node with
            | Tapp (ls2,[lhs1;{t_node = Tvar v1} as arg])
              when ls_equal ls2 fs_func_app && vs_equal v1 rvh ->
              begin
              match rvt , lhs1 with
              | rvh::rvt , _ ->
                let lhs1 , fc2 = remove_var lhs1 rvh rvt in
                let lhs2 = t_app ls2 [lhs1;arg] lhs.t_ty in
                t_label_copy lhs lhs2 , fc2
              | [] , { t_node = Tvar fc1 } when vs_equal fc1 fc ->
                let fcn = fc.vs_name in
                let fc2 = Ident.id_derive fcn.Ident.id_string fcn in
                let fc2 = create_vsymbol fc2 (t_type lhs) in
                t_label_copy lhs (t_var fc2) , fc2
624
              | _ -> raise Undetermined
625 626 627 628 629 630 631 632 633
              end
            | _ -> raise Undetermined
          in
          begin
          match rvl with
          | rvh :: rvt -> let lhs , fc2 = remove_var lhs rvh rvt in
            let (vh,vl) = match vl with
            | [] -> assert false
            | vh::vl -> (vh,vl)
634
            in
635
            let t2 = term_of_value t2 in
636
            begin
637 638 639 640 641 642
            match vl with
            | [] -> elim body vh t2
            | _ ->
              let eq = equ lhs body in
              let tq = t_quant Tforall (t_close_quant vl trig eq) in
              let body = t_label_copy t (t_eps_close fc2 tq) in
643 644 645 646 647
              { value_stack = rem_st;
                cont_stack =
                  (Keval(body,Mvs.add vh t2 Mvs.empty),
                   t_label_copy orig body) :: rem_cont;
              }
648 649
            end
          | _ -> raise Undetermined
650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685
          end
        in
        begin
        match t.t_node with
          | Tapp (ls1,[lhs;body]) when ls_equal ls1 ps_equ ->
            let equ lhs body = t_label_copy t (t_app ps_equ [lhs;body] None) in
            let elim body vh t2 = {
              value_stack = rem_st;
              cont_stack =
                (Keval(body,Mvs.add vh t2 Mvs.empty),
                 t_label_copy orig body) :: rem_cont;
            } in
            process lhs body equ elim
          | Tbinop (Tiff,
            ({t_node=Tapp (ls1,[lhs;tr])} as teq),
            body)
            when ls_equal ls1 ps_equ && t_equal tr t_bool_true ->
            let equ lhs body =
              let lhs = t_label_copy teq (t_app ps_equ [lhs;tr] None) in
              t_label_copy t (t_binary Tiff lhs body) in
            let elim body vh t2 =
              match rem_cont with
              | (Keval (tr,_),_) :: (Kapp (ls,_),_) :: rem_cont
                when t_equal tr t_bool_true && ls_equal ls ps_equ ->
                { value_stack = rem_st;
                  cont_stack =
                    (Keval(body,Mvs.add vh t2 Mvs.empty),
                     t_label_copy orig body) :: rem_cont }
              | _ ->
                let body = t_if body t_bool_true t_bool_false in
                { value_stack = rem_st;
                  cont_stack =
                  (Keval(body,Mvs.add vh t2 Mvs.empty),
                  t_label_copy orig body) :: rem_cont } in
            process lhs body equ elim
          | _ -> raise Undetermined
686 687 688 689
        end
      | _ -> raise Undetermined
      end
    | _ -> raise Undetermined
690

691
and reduce_app_no_equ engine st ls ~orig ty rem_cont =
692 693 694 695
  let arity = List.length ls.ls_args in
  let args,rem_st = extract_first arity [] st in
  try
    let f = Hls.find builtins ls in
696 697
    let v = f ls args ty in
    { value_stack = (v_label_copy orig v) :: rem_st;
698 699
      cont_stack = rem_cont;
    }
700
  with Not_found | Undetermined ->
701
    let args = List.map term_of_value args in
702
    try
703
      let d = Ident.Mid.find ls.ls_name engine.known_map in
704 705
      let rewrite () =
      (* try a rewrite rule *)
706 707
        begin
          try
708
(*
709
            Format.eprintf "try a rewrite rule on %a@." Pretty.print_ls ls;
710 711 712
*)
            let (mt,mv),rhs = one_step_reduce engine ls args in
(*
713 714 715 716 717
            Format.eprintf "rhs = %a@." Pretty.print_term rhs;
            Format.eprintf "sigma = ";
            Mvs.iter
              (fun v t -> Format.eprintf "%a -> %a,"
                Pretty.print_vs v Pretty.print_term t)
718
              (snd sigma);
719
            Format.eprintf "@.";
720 721
            Format.eprintf "try a type match: %a and %a@."
              (Pp.print_option Pretty.print_ty) ty
722
              (Pp.print_option Pretty.print_ty) rhs.t_ty;
723 724
*)
(*
725
            let type_subst = Ty.oty_match Ty.Mtv.empty rhs.t_ty ty in
726 727 728 729 730 731 732 733 734 735
            Format.eprintf "subst of rhs: ";
            Ty.Mtv.iter
              (fun tv ty -> Format.eprintf "%a -> %a,"
                Pretty.print_tv tv Pretty.print_ty ty)
              type_subst;
            Format.eprintf "@.";
            let rhs = t_ty_subst type_subst Mvs.empty rhs in
            let sigma =
              Mvs.map (t_ty_subst type_subst Mvs.empty) sigma
            in
736 737 738 739 740 741 742
            Format.eprintf "rhs = %a@." Pretty.print_term rhs;
            Format.eprintf "sigma = ";
            Mvs.iter
              (fun v t -> Format.eprintf "%a -> %a,"
                Pretty.print_vs v Pretty.print_term t)
              sigma;
            Format.eprintf "@.";
743 744
*)
            let mv,rhs = t_subst_types mt mv rhs in
745
            { value_stack = rem_st;
746
              cont_stack = (Keval(rhs,mv),orig) :: rem_cont;
747 748 749
            }
          with Irreducible ->
            raise Not_found
750 751 752 753 754 755 756 757 758 759 760 761 762 763 764 765 766
        end in
      match d.Decl.d_node with
      | Decl.Dtype _ | Decl.Dprop _ -> assert false
      | Decl.Dlogic dl ->
        (* regular definition *)
        let d = List.assq ls dl in
        if engine.params.compute_defs ||
           Term.Sls.mem ls engine.params.compute_def_set
        then begin
          let vl,e = Decl.open_ls_defn d in
          let add (mt,mv) x y =
            Ty.ty_match mt x.vs_ty (t_type y), Mvs.add x y mv
          in
          let (mt,mv) = List.fold_left2 add (Ty.Mtv.empty, Mvs.empty) vl args in
          let mt = Ty.oty_match mt e.t_ty ty in
          let mv,e = t_subst_types mt mv e in
          { value_stack = rem_st;
767
            cont_stack = (Keval(e,mv),orig) :: rem_cont;
768 769 770 771
          }
        end else rewrite ()
      | Decl.Dparam _ | Decl.Dind _ ->
        rewrite ()
772 773
      | Decl.Ddata dl ->
        (* constructor or projection *)
774
        match args with
775 776 777 778 779
        | [ { t_node = Tapp(ls1,tl1) } ] ->
          (* if ls is a projection and ls1 is a constructor,
             we should compute that projection *)
          let rec iter dl =
            match dl with
780
            | [] -> raise Not_found
781 782 783 784 785 786 787 788 789 790 791 792
            | (_,csl) :: rem ->
              let rec iter2 csl =
                match csl with
                | [] -> iter rem
                | (cs,prs) :: rem2 ->
                  if ls_equal cs ls1
                  then
                    (* we found the right constructor *)
                    let rec iter3 prs tl1 =
                      match prs,tl1 with
                      | (Some pr)::prs, t::tl1 ->
                        if ls_equal ls pr
793
                        then (* projection found! *)
794 795
                          { value_stack =
                              (Term (t_label_copy orig t)) :: rem_st;
796 797
                            cont_stack = rem_cont;
                          }
798 799 800 801
                        else
                          iter3 prs tl1
                      | None::prs, _::tl1 ->
                        iter3 prs tl1
802
                      | _ -> raise Not_found
803 804 805 806
                    in iter3 prs tl1
                  else iter2 rem2
              in iter2 csl
          in iter dl
807
        | _ -> raise Not_found
808
    with Not_found ->
809
      { value_stack = Term (t_label_copy orig (t_app ls args ty)) :: rem_st;
810
        cont_stack = rem_cont;
811 812 813
      }


814
and reduce_equ (* engine *) ~orig st v1 v2 cont =
815
(*
816
  try
817
*)
818 819 820
    match v1,v2 with
    | Int n1, Int n2 ->
      let b = to_bool (BigInt.eq n1 n2) in
821
      { value_stack = Term (t_label_copy orig b) :: st;
822 823 824 825 826 827 828
        cont_stack = cont;
      }
    | Int n, Term {t_node = Tconst c} | Term {t_node = Tconst c}, Int n ->
      begin
        try
          let n' = big_int_of_const c in
          let b = to_bool (BigInt.eq n n') in
829
          { value_stack = Term (t_label_copy orig b) :: st;
830 831 832 833 834
            cont_stack = cont;
          }
        with NotNum -> raise Undetermined
      end
    | Int _,  Term _ | Term _,  Int _ -> raise Undetermined
835
    | Term t1, Term t2 -> reduce_term_equ ~orig st t1 t2 cont
836
(*
837 838 839 840
  with Undetermined ->
    { value_stack = Term (t_equ (term_of_value v1) (term_of_value v2)) :: st;
      cont_stack = cont;
    }
841
*)
842

843
and reduce_term_equ ~orig st t1 t2 cont =
844
  if t_equal t1 t2 then
845
    { value_stack = Term (t_label_copy orig t_true) :: st;
846 847 848
      cont_stack = cont;
    }
  else
849 850 851 852 853 854
  match (t1.t_node,t2.t_node) with
  | Tconst c1, Tconst c2 ->
    begin
      match c1,c2 with
      | Number.ConstInt i1, Number.ConstInt i2 ->
        let b = BigInt.eq (Number.compute_int i1) (Number.compute_int i2) in
855
        { value_stack = Term (t_label_copy orig (to_bool b)) :: st;
856 857 858 859 860 861 862 863 864 865 866 867 868 869 870 871 872 873 874 875 876 877
          cont_stack = cont;
        }
      | _ -> raise Undetermined
    end
  | Tapp(ls1,tl1), Tapp(ls2,tl2) when ls1.ls_constr > 0 && ls2.ls_constr > 0 ->
    if ls_equal ls1 ls2 then
      let rec aux sigma t tyl l1 l2 =
        match tyl,l1,l2 with
        | [],[],[] -> sigma,t
        | ty::tyl, t1::tl1, t2::tl2 ->
          let v1 = create_vsymbol (Ident.id_fresh "") ty in
          let v2 = create_vsymbol (Ident.id_fresh "") ty in
          aux
            (Mvs.add v1 t1 (Mvs.add v2 t2 sigma))
            (t_and_simp (t_equ (t_var v1) (t_var v2)) t)
            tyl tl1 tl2
        | _ ->  assert false
      in
      let sigma,t =
        aux Mvs.empty t_true ls1.ls_args tl1 tl2
      in
      { value_stack = st;
878
        cont_stack = (Keval(t,sigma),orig) :: cont;
879 880
      }
    else
881
      { value_stack = Term (t_label_copy orig t_false) :: st;
882 883 884 885 886
        cont_stack = cont;
      }
  | _ -> raise Undetermined


887

888
let rec reconstruct c =
889 890 891
  match c.value_stack, c.cont_stack with
  | [Term t], [] -> t
  | _, [] -> assert false
892
  | _, (k,orig) :: rem ->
893
    let t, st =
894 895 896 897 898 899 900 901 902 903 904 905 906 907 908 909 910 911 912 913 914 915 916 917
      match c.value_stack, k with
      | st, Keval (t,sigma) -> (t_subst sigma t), st
      | [], Kif _ -> assert false
      | v :: st, Kif(t2,t3,sigma) ->
        (t_if (term_of_value v) (t_subst sigma t2) (t_subst sigma t3)), st
      | [], Klet _ -> assert false
      | t1 :: st, Klet(v,t2,sigma) ->
        (t_let_close v (term_of_value t1) (t_subst sigma t2)), st
      | [], Kcase _ -> assert false
      | v :: st, Kcase(tbl,sigma) ->
        (t_subst sigma (t_case (term_of_value v) tbl)), st
      | ([] | [_]), Kbinop _ -> assert false
      | t1 :: t2 :: st, Kbinop op ->
        (t_binary_simp op (term_of_value t2) (term_of_value t1)), st
      | [], Knot -> assert false
      | t :: st, Knot -> (t_not (term_of_value t)), st
      | st, Kapp(ls,ty) ->
        let args,rem_st = extract_first (List.length ls.ls_args) [] st in
        let args = List.map term_of_value args in
        (t_app ls args ty), rem_st
      | [], Keps _ -> assert false
      | t :: st, Keps v -> (t_eps_close v (term_of_value t)), st
      | [], Kquant _ -> assert false
      | t :: st, Kquant(q,vl,tr) ->
918
        (t_quant_close_simp q vl tr (term_of_value t)), st
919
    in
920
    reconstruct {
921
      value_stack = (Term (t_label_copy orig t)) :: st;
922 923 924
      cont_stack = rem;
    }

925

926 927
(** iterated reductions *)

928
let normalize ~limit engine t0 =
929 930 931 932 933 934 935 936 937 938 939 940 941 942 943
  let rec many_steps c n =
    match c.value_stack, c.cont_stack with
    | [Term t], [] -> t
    | _, [] -> assert false
    | _ ->
      if n = limit then
        begin
          Warning.emit "reduction of term %a takes more than %d steps, aborted.@."
            Pretty.print_term t0 limit;
          reconstruct c
        end
      else
        let c = reduce engine c in
        many_steps c (n+1)
  in
944
  let c = { value_stack = [];
945
            cont_stack = [Keval(t0,Mvs.empty),t0] ;
946 947
          }
  in
948
  many_steps c 0
949

950 951 952 953 954 955 956





(* the rewrite engine *)

957 958 959 960
let create p env km =
  if p.compute_builtin
  then get_builtins env
  else Hls.clear builtins;
961 962
  { known_map = km ;
    rules = Mls.empty;
963
    params = p;
964
  }
965 966 967

exception NotARewriteRule of string

968 969
let extract_rule _km t =
(*
970 971 972 973 974 975 976 977 978 979 980 981 982
  let check_ls ls =
    try let _ = Hls.find builtins ls in
        raise (NotARewriteRule "root of lhs of rule must not be a built-in symbol")
    with Not_found ->
      let d = Ident.Mid.find ls.ls_name km in
      match d.Decl.d_node with
      | Decl.Dtype _ | Decl.Dprop _ -> assert false
      | Decl.Dlogic _ ->
        raise (NotARewriteRule "root of lhs of rule must not be defined symbol")
      | Decl.Ddata _ ->
        raise (NotARewriteRule "root of lhs of rule must not be a constructor nor a projection")
      | Decl.Dparam _ | Decl.Dind _ -> ()
  in
983
*)
984 985 986 987
  let rec aux acc t =
    match t.t_node with
      | Tquant(Tforall,q) ->
        let vs,_,t = t_open_quant q in
988
        aux (List.fold_left (fun acc v -> Svs.add v acc) acc vs) t
989 990 991
      | Tbinop(Tiff,t1,t2) ->
        begin
          match t1.t_node with
992
            | Tapp(ls,args) -> (* check_ls ls; *) acc,ls,args,t2
993 994 995 996 997 998
            | _ -> raise
              (NotARewriteRule "lhs of <-> should be a predicate symbol")
        end
      | Tapp(ls,[t1;t2]) when ls == ps_equ ->
        begin
          match t1.t_node with
999
            | Tapp(ls,args) -> (* check_ls ls; *) acc,ls,args,t2
1000 1001 1002 1003 1004 1005
            | _ -> raise
              (NotARewriteRule "lhs of = should be a function symbol")
        end
      | _ -> raise
        (NotARewriteRule "rule should be of the form forall ... t1 = t2 or f1 <-> f2")
  in
1006
  aux Svs.empty t
1007 1008 1009


let add_rule t e =
1010
  let vars,ls,args,r = extract_rule e.known_map t in
1011 1012 1013 1014
  let rules =
    try Mls.find ls e.rules
    with Not_found -> []
  in
1015
  {e with rules =
1016
      Mls.add ls ((vars,args,r)::rules) e.rules}