Commit d15599f7 authored by Andrei Paskevich's avatar Andrei Paskevich
Browse files

add term/formula matching, export safe maps and folds

parent e5b7d169
......@@ -89,14 +89,19 @@ module Ty = struct
exception TypeMismatch
let rec matching s ty1 ty2 = match ty1.ty_node, ty2.ty_node with
| Tyvar n1, _ ->
(try if Name.M.find n1 s == ty2 then s else raise TypeMismatch
with Not_found -> Name.M.add n1 ty2 s)
| Tyapp (f1, l1), Tyapp (f2, l2) when f1 == f2 ->
List.fold_left2 matching s l1 l2
| _ ->
raise TypeMismatch
let rec matching s ty1 ty2 =
if ty1 == ty2 then s
else match ty1.ty_node, ty2.ty_node with
| Tyvar n1, _ ->
(try if Name.M.find n1 s == ty2 then s else raise TypeMismatch
with Not_found -> Name.M.add n1 ty2 s)
| Tyapp (f1, l1), Tyapp (f2, l2) when f1 == f2 ->
List.fold_left2 matching s l1 l2
| _ ->
raise TypeMismatch
let ty_match ty1 ty2 s =
try Some (matching s ty1 ty2) with TypeMismatch -> None
end
......@@ -124,8 +129,6 @@ module Hvs = Hashcons.Make(Vsym)
module Mvs = Map.Make(Vsym)
module Svs = Set.Make(Vsym)
type vsymbol_set = Svs.t
let mk_vs name ty = { vs_name = name; vs_ty = ty; vs_tag = -1 }
let create_vsymbol name ty = Hvs.hashcons (mk_vs name ty)
......@@ -362,8 +365,12 @@ module T = struct
let tag n t = { t with t_tag = n }
let compare t1 t2 = Pervasives.compare t1.t_tag t2.t_tag
end
module Hterm = Hashcons.Make(T)
module Mterm = Map.Make(T)
module Sterm = Set.Make(T)
module F = struct
......@@ -376,8 +383,8 @@ module F = struct
let equal_fmla_node f1 f2 = match f1, f2 with
| Fapp (s1, l1), Fapp (s2, l2) ->
s1 == s2 && List.for_all2 (==) l1 l2
| Fquant (q1, bf1), Fquant (q2, bf2) ->
q1 == q2 && eq_bind_fmla bf1 bf2
| Fquant (q1, b1), Fquant (q2, b2) ->
q1 == q2 && eq_bind_fmla b1 b2
| Fbinop (op1, f1, g1), Fbinop (op2, f2, g2) ->
op1 == op2 && f1 == f2 && g1 == g2
| Fnot f1, Fnot f2 ->
......@@ -387,11 +394,11 @@ module F = struct
true
| Fif (f1, g1, h1), Fif (f2, g2, h2) ->
f1 == f2 && g1 == g2 && h1 == h2
| Flet (t1, bf1), Flet (t2, bf2) ->
t1 == t2 && eq_bind_fmla bf1 bf2
| Fcase (t1, bl1), Fcase (t2, bl2) ->
| Flet (t1, b1), Flet (t2, b2) ->
t1 == t2 && eq_bind_fmla b1 b2
| Fcase (t1, l1), Fcase (t2, l2) ->
t1 == t2 &&
(try List.for_all2 eq_fbranch bl1 bl2
(try List.for_all2 eq_fbranch l1 l2
with Invalid_argument _ -> false)
| _ ->
false
......@@ -426,8 +433,12 @@ module F = struct
let tag n f = { f with f_tag = n }
let compare f1 f2 = Pervasives.compare f1.f_tag f2.f_tag
end
module Hfmla = Hashcons.Make(F)
module Mfmla = Map.Make(F)
module Sfmla = Set.Make(F)
(* hash-consing constructors for terms *)
......@@ -513,6 +524,8 @@ exception FoldSkip
let forall_fnT prT lvl _ t = prT lvl t || raise FoldSkip
let forall_fnF prF lvl _ f = prF lvl f || raise FoldSkip
let exists_fnT prT lvl _ t = prT lvl t && raise FoldSkip
let exists_fnF prF lvl _ f = prF lvl f && raise FoldSkip
let forall_term_unsafe prT prF lvl t =
try fold_term_unsafe (forall_fnT prT) (forall_fnF prF) lvl true t
......@@ -522,9 +535,6 @@ let forall_fmla_unsafe prT prF lvl f =
try fold_fmla_unsafe (forall_fnT prT) (forall_fnF prF) lvl true f
with FoldSkip -> false
let exists_fnT prT lvl _ t = prT lvl t && raise FoldSkip
let exists_fnF prF lvl _ f = prF lvl f && raise FoldSkip
let exists_term_unsafe prT prF lvl t =
try fold_term_unsafe (exists_fnT prT) (exists_fnF prF) lvl false t
with FoldSkip -> true
......@@ -607,10 +617,13 @@ and freevars_fmla lvl acc t =
let freevars_term t = freevars_term 0 Svs.empty t
let freevars_fmla f = freevars_fmla 0 Svs.empty f
(* USE PHYSICAL EQUALITY *)
(*
(* equality *)
let t_equal = (==)
let f_equal = (==)
*)
(* alpha-equivalence *)
......@@ -623,6 +636,7 @@ let rec t_alpha_equal t1 t2 =
| Tvar v1, Tvar v2 ->
v1 == v2
| Tapp (s1, l1), Tapp (s2, l2) ->
(* assert (List.length l1 == List.length l2); *)
s1 == s2 && List.for_all2 t_alpha_equal l1 l2
| Tlet (t1, (v1, b1)), Tlet (t2, (v2, b2)) ->
(* assert (v1.vs_ty == t1.t_ty && v2.vs_ty == t2.t_ty); *)
......@@ -639,9 +653,9 @@ let rec t_alpha_equal t1 t2 =
and f_alpha_equal f1 f2 =
f1 == f2 ||
match f1.f_node, f2.f_node with
| Fapp (s1, tl1), Fapp (s2, tl2) ->
(* assert (List.length tl1 == List.length tl2); *)
s1 == s2 && List.for_all2 t_alpha_equal tl1 tl2
| Fapp (s1, l1), Fapp (s2, l2) ->
(* assert (List.length l1 == List.length l2); *)
s1 == s2 && List.for_all2 t_alpha_equal l1 l2
| Fquant (q1, (v1, f1)), Fquant (q2, (v2, f2)) ->
q1 == q2 && v1.vs_ty == v2.vs_ty && f_alpha_equal f1 f2
| Fbinop (op1, f1, g1), Fbinop (op2, f2, g2) ->
......@@ -656,12 +670,11 @@ and f_alpha_equal f1 f2 =
| Flet (t1, (v1, f1)), Flet (t2, (v2, f2)) ->
(* assert (v1.vs_ty == t1.t_ty && v2.vs_ty == t2.t_ty); *)
t_alpha_equal t1 t2 && f_alpha_equal f1 f2
| Fcase (t1, bl1), Fcase (t2, bl2) ->
| Fcase (t1, l1), Fcase (t2, l2) ->
t_alpha_equal t1 t2 &&
(try List.for_all2 fbranch_alpha_equal bl1 bl2
(try List.for_all2 fbranch_alpha_equal l1 l2
with Invalid_argument _ -> false)
| _ ->
false
| _ -> false
and tbranch_alpha_equal (pat1, _, t1) (pat2, _, t2) =
pat_alpha_equal pat1 pat2 && t_alpha_equal t1 t2
......@@ -669,6 +682,94 @@ and tbranch_alpha_equal (pat1, _, t1) (pat2, _, t2) =
and fbranch_alpha_equal (pat1, _, f1) (pat2, _, f2) =
pat_alpha_equal pat1 pat2 && f_alpha_equal f1 f2
(* calculate the greatest free de Bruijn index *)
let ix_empty = (Mterm.empty, Mfmla.empty)
let max_ix_term mT lvl acc t = max acc (Mterm.find t mT - lvl)
let max_ix_fmla mF lvl acc f = max acc (Mfmla.find f mF - lvl)
let rec build_max_term lvl acc t = match t.t_node with
| Tbvar ix -> let mT,mF = acc in (Mterm.add t ix mT, mF)
| _ ->
let mT,mF = fold_term_unsafe build_max_term build_max_fmla lvl acc t in
let ix = fold_term_unsafe (max_ix_term mT) (max_ix_fmla mF) 0 (-1) t in
(Mterm.add t ix mT, mF)
and build_max_fmla lvl acc f =
let mT,mF = fold_fmla_unsafe build_max_term build_max_fmla lvl acc f in
let ix = fold_fmla_unsafe (max_ix_term mT) (max_ix_fmla mF) 0 (-1) f in
(mT, Mfmla.add f ix mF)
(* matching modulo alpha in the pattern *)
exception NoMatch
let rec t_match m s t1 t2 =
if t1 == t2 then s else
if t1.t_ty != t2.t_ty then raise NoMatch else
match t1.t_node, t2.t_node with
| Tbvar x1, Tbvar x2 when x1 == x2 ->
s
| Tvar v1, _ ->
if Mterm.find t2 m < 0 then
try if Mvs.find v1 s == t2 then s else raise NoMatch
with Not_found -> Mvs.add v1 t2 s
else raise NoMatch
| Tapp (s1, l1), Tapp (s2, l2) when s1 == s2 ->
(* assert (List.length l1 == List.length l2); *)
List.fold_left2 (t_match m) s l1 l2
| Tlet (t1, (v1, b1)), Tlet (t2, (v2, b2)) ->
(* assert (v1.vs_ty == t1.t_ty && v2.vs_ty == t2.t_ty); *)
t_match m (t_match m s t1 t2) b1 b2
| Tcase (t1, l1), Tcase (t2, l2) ->
(try List.fold_left2 (tbranch_match m) (t_match m s t1 t2) l1 l2
with Invalid_argument _ -> raise NoMatch)
| Teps (v1, f1), Teps (v2, f2) ->
(* assert (v1.vs_ty == t1.t_ty && v2.vs_ty == t2.t_ty); *)
f_match m s f1 f2
| _ -> raise NoMatch
and f_match m s f1 f2 =
if f1 == f2 then s else
match f1.f_node, f2.f_node with
| Fapp (s1, l1), Fapp (s2, l2) when s1 == s2 ->
(* assert (List.length l1 == List.length l2); *)
List.fold_left2 (t_match m) s l1 l2
| Fquant (q1, (v1, f1)), Fquant (q2, (v2, f2))
when q1 == q2 && v1.vs_ty == v2.vs_ty ->
f_match m s f1 f2
| Fbinop (op1, f1, g1), Fbinop (op2, f2, g2) when op1 == op2 ->
f_match m (f_match m s f1 f2) g1 g2
| Fnot f1, Fnot f2 ->
f_match m s f1 f2
| Ftrue, Ftrue
| Ffalse, Ffalse ->
s
| Fif (f1, g1, h1), Fif (f2, g2, h2) ->
f_match m (f_match m (f_match m s f1 f2) g1 g2) h1 h2
| Flet (t1, (v1, f1)), Flet (t2, (v2, f2)) ->
(* assert (v1.vs_ty == t1.t_ty && v2.vs_ty == t2.t_ty); *)
f_match m (t_match m s t1 t2) f1 f2
| Fcase (t1, l1), Fcase (t2, l2) ->
(try List.fold_left2 (fbranch_match m) (t_match m s t1 t2) l1 l2
with Invalid_argument _ -> raise NoMatch)
| _ -> raise NoMatch
and tbranch_match m s (pat1, _, t1) (pat2, _, t2) =
if pat_alpha_equal pat1 pat2 then t_match m s t1 t2 else raise NoMatch
and fbranch_match m s (pat1, _, f1) (pat2, _, f2) =
if pat_alpha_equal pat1 pat2 then f_match m s f1 f2 else raise NoMatch
let t_match t1 t2 s =
let m,_ = build_max_term 0 ix_empty t2 in
try Some (t_match m s t1 t2) with NoMatch -> None
let f_match f1 f2 s =
let m,_ = build_max_fmla 0 ix_empty f2 in
try Some (f_match m s f1 f2) with NoMatch -> None
(* occurrence check *)
let rec t_occurs_term r lvl t = r == t ||
......@@ -767,6 +868,82 @@ and f_alpha_subst_fmla f1 f2 lvl f =
let f_alpha_subst_term f1 f2 t = f_alpha_subst_term f1 f2 0 t
let f_alpha_subst_fmla f1 f2 f = f_alpha_subst_fmla f1 f2 0 f
(* safe transparent map *)
let rec map_skip_term fnT fnF mT mF lvl t =
if Mterm.find t mT < 0 then fnT t
else map_term_unsafe (map_skip_term fnT fnF mT mF)
(map_skip_fmla fnT fnF mT mF) lvl t
and map_skip_fmla fnT fnF mT mF lvl f =
if Mfmla.find f mF < 0 then fnF f
else map_fmla_unsafe (map_skip_term fnT fnF mT mF)
(map_skip_fmla fnT fnF mT mF) lvl f
let map_skip_term fnT fnF lvl t =
if lvl == 0 then fnT t
else let mT,mF = build_max_term lvl ix_empty t in
map_skip_term fnT fnF mT mF lvl t
let map_skip_fmla fnT fnF lvl f =
if lvl == 0 then fnF f
else let mT,mF = build_max_fmla lvl ix_empty f in
map_skip_fmla fnT fnF mT mF lvl f
let map_trans_term fnT fnF t =
map_term_unsafe (map_skip_term fnT fnF) (map_skip_fmla fnT fnF) 0 t
let map_trans_fmla fnT fnF f =
map_fmla_unsafe (map_skip_term fnT fnF) (map_skip_fmla fnT fnF) 0 f
(* safe transparent fold *)
let rec fold_skip_term fnT fnF mT mF lvl acc t =
if Mterm.find t mT < 0 then fnT acc t
else fold_term_unsafe (fold_skip_term fnT fnF mT mF)
(fold_skip_fmla fnT fnF mT mF) lvl acc t
and fold_skip_fmla fnT fnF mT mF lvl acc f =
if Mfmla.find f mF < 0 then fnF acc f
else fold_fmla_unsafe (fold_skip_term fnT fnF mT mF)
(fold_skip_fmla fnT fnF mT mF) lvl acc f
let fold_skip_term fnT fnF lvl acc t =
if lvl == 0 then fnT acc t
else let mT,mF = build_max_term lvl ix_empty t in
fold_skip_term fnT fnF mT mF lvl acc t
let fold_skip_fmla fnT fnF lvl acc f =
if lvl == 0 then fnF acc f
else let mT,mF = build_max_fmla lvl ix_empty f in
fold_skip_fmla fnT fnF mT mF lvl acc f
let fold_trans_term fnT fnF acc t =
fold_term_unsafe (fold_skip_term fnT fnF) (fold_skip_fmla fnT fnF) 0 acc t
let fold_trans_fmla fnT fnF acc f =
fold_fmla_unsafe (fold_skip_term fnT fnF) (fold_skip_fmla fnT fnF) 0 acc f
let forall_fnT prT _ t = prT t || raise FoldSkip
let forall_fnF prF _ f = prF f || raise FoldSkip
let exists_fnT prT _ t = prT t && raise FoldSkip
let exists_fnF prF _ f = prF f && raise FoldSkip
let forall_trans_term prT prF t =
try fold_trans_term (forall_fnT prT) (forall_fnF prF) true t
with FoldSkip -> false
let forall_trans_fmla prT prF f =
try fold_trans_fmla (forall_fnT prT) (forall_fnF prF) true f
with FoldSkip -> false
let exists_trans_term prT prF t =
try fold_trans_term (exists_fnT prT) (exists_fnF prF) false t
with FoldSkip -> true
let exists_trans_fmla prT prF f =
try fold_trans_fmla (exists_fnT prT) (exists_fnF prF) false f
with FoldSkip -> true
(* smart constructors *)
......@@ -870,4 +1047,64 @@ let open_fbranch (pat, _, f) =
let vars, s, ns = substs_for_pattern pat in
(rename_pat ns pat, vars, inst_fmla s 0 f)
(* safe opening map *)
let tbranch fn b = let pat,_,t = open_tbranch b in (pat, fn t)
let map_open_term fnT fnF t = match t.t_node with
| Tbvar _ | Tvar _ -> t
| Tapp (f, tl) -> t_app f (List.map fnT tl) t.t_ty
| Tlet (t1, b) -> let u,t2 = open_bind_term b in t_let u (fnT t1) (fnT t2)
| Tcase (t1, bl) -> t_case (fnT t1) (List.map (tbranch fnT) bl) t.t_ty
| Teps b -> let u,f = open_bind_fmla b in t_eps u (fnF f)
let fbranch fn b = let pat,_,f = open_fbranch b in (pat, fn f)
let map_open_fmla fnT fnF f = match f.f_node with
| Fapp (p, tl) -> f_app p (List.map fnT tl)
| Fquant (q, b) -> let u,f1 = open_bind_fmla b in f_quant q u (fnF f1)
| Fbinop (op, f1, f2) -> f_binary op (fnF f1) (fnF f2)
| Fnot f1 -> f_not (fnF f1)
| Ftrue | Ffalse -> f
| Fif (f1, f2, f3) -> f_if (fnF f1) (fnF f2) (fnF f3)
| Flet (t, b) -> let u,f1 = open_bind_fmla b in f_let u (fnT t) (fnF f1)
| Fcase (t, bl) -> f_case (fnT t) (List.map (fbranch fnF) bl)
(* safe opening fold *)
let tbranch fn acc b = let _,_,t = open_tbranch b in fn acc t
let fbranch fn acc b = let _,_,f = open_fbranch b in fn acc f
let fold_open_term fnT fnF acc t = match t.t_node with
| Tbvar _ | Tvar _ -> acc
| Tapp (f, tl) -> List.fold_left fnT acc tl
| Tlet (t1, b) -> let _,t2 = open_bind_term b in fnT (fnT acc t1) t2
| Tcase (t1, bl) -> List.fold_left (tbranch fnT) (fnT acc t1) bl
| Teps b -> let _,f = open_bind_fmla b in fnF acc f
let fold_open_fmla fnT fnF acc f = match f.f_node with
| Fapp (p, tl) -> List.fold_left fnT acc tl
| Fquant (q, b) -> let _,f1 = open_bind_fmla b in fnF acc f1
| Fbinop (op, f1, f2) -> fnF (fnF acc f1) f2
| Fnot f1 -> fnF acc f1
| Ftrue | Ffalse -> acc
| Fif (f1, f2, f3) -> fnF (fnF (fnF acc f1) f2) f3
| Flet (t, b) -> let _,f1 = open_bind_fmla b in fnF (fnT acc t) f1
| Fcase (t, bl) -> List.fold_left (fbranch fnF) (fnT acc t) bl
let forall_open_term prT prF t =
try fold_open_term (forall_fnT prT) (forall_fnF prF) true t
with FoldSkip -> false
let forall_open_fmla prT prF f =
try fold_open_fmla (forall_fnT prT) (forall_fnF prF) true f
with FoldSkip -> false
let exists_open_term prT prF t =
try fold_open_term (exists_fnT prT) (exists_fnF prF) false t
with FoldSkip -> true
let exists_open_fmla prT prF f =
try fold_open_fmla (exists_fnT prT) (exists_fnF prF) false f
with FoldSkip -> true
......@@ -45,6 +45,8 @@ module Ty : sig
val ty_var : tvsymbol -> ty
val ty_app : tysymbol -> ty list -> ty
val ty_match : ty -> ty -> ty Name.M.t -> ty Name.M.t option
end
type tvsymbol = Ty.tvsymbol
......@@ -62,8 +64,6 @@ type vsymbol = private {
module Mvs : Map.S with type key = vsymbol
module Svs : Set.S with type elt = vsymbol
type vsymbol_set = Svs.t
val create_vsymbol : Name.t -> ty -> vsymbol
(** Function symbols *)
......@@ -162,6 +162,11 @@ and tbranch
and fbranch
module Mterm : Map.S with type key = term
module Sterm : Set.S with type elt = term
module Mfmla : Map.S with type key = fmla
module Sfmla : Set.S with type elt = fmla
(* smart constructors for term *)
val t_var : vsymbol -> term
......@@ -195,10 +200,42 @@ val f_label_add : label -> fmla -> fmla
(* bindings *)
val open_bind_term : bind_term -> vsymbol * term
val open_tbranch : tbranch -> pattern * vsymbol_set * term
val open_tbranch : tbranch -> pattern * Svs.t * term
val open_bind_fmla : bind_fmla -> vsymbol * fmla
val open_fbranch : fbranch -> pattern * vsymbol_set * fmla
val open_fbranch : fbranch -> pattern * Svs.t * fmla
(* safe opening map/fold *)
val map_open_term : (term -> term) -> (fmla -> fmla) -> term -> term
val map_open_fmla : (term -> term) -> (fmla -> fmla) -> fmla -> fmla
val fold_open_term : ('a -> term -> 'a) -> ('a -> fmla -> 'a)
-> 'a -> term -> 'a
val fold_open_fmla : ('a -> term -> 'a) -> ('a -> fmla -> 'a)
-> 'a -> fmla -> 'a
val forall_open_term : (term -> bool) -> (fmla -> bool) -> term -> bool
val forall_open_fmla : (term -> bool) -> (fmla -> bool) -> fmla -> bool
val exists_open_term : (term -> bool) -> (fmla -> bool) -> term -> bool
val exists_open_fmla : (term -> bool) -> (fmla -> bool) -> fmla -> bool
(* safe transparent map/fold *)
val map_trans_term : (term -> term) -> (fmla -> fmla) -> term -> term
val map_trans_fmla : (term -> term) -> (fmla -> fmla) -> fmla -> fmla
val fold_trans_term : ('a -> term -> 'a) -> ('a -> fmla -> 'a)
-> 'a -> term -> 'a
val fold_trans_fmla : ('a -> term -> 'a) -> ('a -> fmla -> 'a)
-> 'a -> fmla -> 'a
val forall_trans_term : (term -> bool) -> (fmla -> bool) -> term -> bool
val forall_trans_fmla : (term -> bool) -> (fmla -> bool) -> fmla -> bool
val exists_trans_term : (term -> bool) -> (fmla -> bool) -> term -> bool
val exists_trans_fmla : (term -> bool) -> (fmla -> bool) -> fmla -> bool
(* variable occurrence check *)
......@@ -221,10 +258,13 @@ val subst_fmla_single : term -> vsymbol -> fmla -> fmla
val freevars_term : term -> Svs.t
val freevars_fmla : fmla -> Svs.t
(* USE PHYSICAL EQUALITY *)
(*
(* equality *)
val t_equal : term -> term -> bool
val f_equal : fmla -> fmla -> bool
*)
(* alpha-equivalence *)
......@@ -255,3 +295,8 @@ val t_alpha_subst_fmla : term -> term -> fmla -> fmla
val f_alpha_subst_term : fmla -> fmla -> term -> term
val f_alpha_subst_fmla : fmla -> fmla -> fmla -> fmla
(* term/fmla matching modulo alpha in the pattern *)
val t_match : term -> term -> term Mvs.t -> term Mvs.t option
val f_match : fmla -> fmla -> term Mvs.t -> term Mvs.t option
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment