Commit 1a78e107 authored by Andrei Paskevich's avatar Andrei Paskevich

export t_map_unsafe/t_fold_unsafe with type checks.

This is less dangerous than I previously thought,
because we still can never create an ill-typed term,
nor we can push a term with an unprotected de Bruijn
index into a context. Salut, François ;)
parent 6fbfa827
......@@ -131,11 +131,18 @@ let pat_as p v = Hp.hashcons (mk_pattern (Pas (p, v)) p.pat_ty)
(* generic traversal functions *)
let pat_map_unsafe fn pat = match pat.pat_node with
let pat_map fn pat = match pat.pat_node with
| Pwild | Pvar _ -> pat
| Papp (s, pl) -> pat_app s (List.map fn pl) pat.pat_ty
| Pas (p, v) -> pat_as (fn p) v
let protect fn pat =
let res = fn pat in
if res.pat_ty != pat.pat_ty then raise TypeMismatch;
res
let pat_map fn = pat_map (protect fn)
let pat_fold fn acc pat = match pat.pat_node with
| Pwild | Pvar _ -> acc
| Papp (_, pl) -> List.fold_left fn acc pl
......@@ -166,14 +173,8 @@ let pat_app fs pl ty =
pat_app fs pl ty
let pat_as p v =
if p.pat_ty == v.vs_ty then pat_as p v else raise TypeMismatch
(* safe map over patterns *)
let pat_map fn pat = match pat.pat_node with
| Pwild | Pvar _ -> pat
| Papp (s, pl) -> pat_app s (List.map fn pl) pat.pat_ty
| Pas (p, v) -> pat_as (fn p) v
if p.pat_ty != v.vs_ty then raise TypeMismatch;
pat_as p v
(* symbol-wise map/fold *)
......@@ -472,9 +473,12 @@ let f_label_try l f = if l == [] then f else f_label l f
(* unsafe map with level *)
exception UnboundIndex
let brlvl fn lvl (pat, nv, t) = (pat, nv, fn (lvl + nv) t)
let t_map_unsafe fnT fnF lvl t = t_label_try t.t_label (match t.t_node with
| Tbvar n when n >= lvl -> raise UnboundIndex
| Tbvar _ | Tvar _ | Tconst _ -> t
| Tapp (f, tl) -> t_app f (List.map (fnT lvl) tl) t.t_ty
| Tlet (t1, (u, t2)) -> t_let u (fnT lvl t1) (fnT (lvl + 1) t2)
......@@ -492,11 +496,20 @@ let f_map_unsafe fnT fnF lvl f = f_label_try f.f_label (match f.f_node with
| Flet (t, (u, f1)) -> f_let u (fnT lvl t) (fnF (lvl + 1) f1)
| Fcase (t, bl) -> f_case (fnT lvl t) (List.map (brlvl fnF lvl) bl))
let protect fn lvl t =
let res = fn lvl t in
if res.t_ty != t.t_ty then raise TypeMismatch;
res
let t_map_unsafe fnT = t_map_unsafe (protect fnT)
let f_map_unsafe fnT = f_map_unsafe (protect fnT)
(* unsafe fold with level *)
let brlvl fn lvl acc (_, nv, t) = fn (lvl + nv) acc t
let t_fold_unsafe fnT fnF lvl acc t = match t.t_node with
| Tbvar n when n >= lvl -> raise UnboundIndex
| Tbvar _ | Tvar _ | Tconst _ -> acc
| Tapp (f, tl) -> List.fold_left (fnT lvl) acc tl
| Tlet (t1, (u, t2)) -> fnT (lvl + 1) (fnT lvl acc t1) t2
......@@ -573,10 +586,12 @@ let f_case t bl =
f_case t bl
let t_let v t1 t2 =
if v.vs_ty == t1.t_ty then t_let v t1 t2 else raise TypeMismatch
if v.vs_ty != t1.t_ty then raise TypeMismatch;
t_let v t1 t2
let f_let v t1 f2 =
if v.vs_ty == t1.t_ty then f_let v t1 f2 else raise TypeMismatch
if v.vs_ty != t1.t_ty then raise TypeMismatch;
f_let v t1 f2
(* map over symbols *)
......@@ -714,7 +729,7 @@ module Im = Map.Make(struct type t = int let compare = Pervasives.compare end)
let rec t_inst m lvl t = match t.t_node with
| Tbvar n when n >= lvl ->
(try Im.find (n - lvl) m with Not_found -> assert false)
(try Im.find (n - lvl) m with Not_found -> raise UnboundIndex)
| _ -> t_map_unsafe (t_inst m) (f_inst m) lvl t
and f_inst m lvl f = f_map_unsafe (t_inst m) (f_inst m) lvl f
......@@ -803,7 +818,7 @@ let rec f_open_exists f = match f.f_node with
let rec pat_rename ns p = match p.pat_node with
| Pvar n -> pat_var (Mvs.find n ns)
| Pas (p, n) -> pat_as (pat_rename ns p) (Mvs.find n ns)
| _ -> pat_map_unsafe (pat_rename ns) p
| _ -> pat_map (pat_rename ns) p
let pat_substs pat =
let m, _ = pat_varmap pat in
......@@ -885,10 +900,7 @@ let f_any prT prF f =
(* map/fold over free variables *)
let rec t_v_map fn lvl t = match t.t_node with
| Tvar u ->
let v = fn u in
if v.t_ty != u.vs_ty then raise TypeMismatch;
v
| Tvar u -> fn u
| _ -> t_map_unsafe (t_v_map fn) (f_v_map fn) lvl t
and f_v_map fn lvl f = f_map_unsafe (t_v_map fn) (f_v_map fn) lvl f
......
......@@ -213,6 +213,32 @@ val f_open_quant : fmla_quant -> vsymbol list * trigger list * fmla
val f_open_forall : fmla -> vsymbol list * fmla
val f_open_exists : fmla -> vsymbol list * fmla
(* unsafe traversal with unprotected de Bruijn indices *)
val t_map_unsafe : (int -> term -> term) ->
(int -> fmla -> fmla) -> int -> term -> term
val f_map_unsafe : (int -> term -> term) ->
(int -> fmla -> fmla) -> int -> fmla -> fmla
val t_fold_unsafe : (int -> 'a -> term -> 'a) ->
(int -> 'a -> fmla -> 'a) -> int -> 'a -> term -> 'a
val f_fold_unsafe : (int -> 'a -> term -> 'a) ->
(int -> 'a -> fmla -> 'a) -> int -> 'a -> fmla -> 'a
val t_all_unsafe : (int -> term -> bool) ->
(int -> fmla -> bool) -> int -> term -> bool
val f_all_unsafe : (int -> term -> bool) ->
(int -> fmla -> bool) -> int -> fmla -> bool
val t_any_unsafe : (int -> term -> bool) ->
(int -> fmla -> bool) -> int -> term -> bool
val f_any_unsafe : (int -> term -> bool) ->
(int -> fmla -> bool) -> int -> fmla -> bool
(* generic term/fmla traversal *)
val t_map : (term -> term) -> (fmla -> fmla) -> term -> term
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment