Commit 2585bb44 authored by Andrei Paskevich's avatar Andrei Paskevich
Browse files

add core support for the "if" construction in terms.

It's not more expressive but much nicer than epsilon.
parent 9cf074db
......@@ -187,6 +187,9 @@ and print_tnode opl opr fmt t = match t.t_node with
| Tapp (fs, tl) ->
fprintf fmt (protect_on opl "%a%a:%a")
print_ls fs (print_paren_r print_term) tl print_ty t.t_ty
| Tif (f,t1,t2) ->
fprintf fmt (protect_on opr "if %a@ then %a@ else %a")
print_fmla f print_term t1 print_opl_term t2
| Tlet (t1,tb) ->
let v,t2 = t_open_bound tb in
fprintf fmt (protect_on opr "let %a =@ %a in@ %a")
......@@ -223,6 +226,9 @@ and print_fnode opl opr fmt f = match f.f_node with
print_opr_fmla f1 print_binop b print_opl_fmla f2
| Fnot f ->
fprintf fmt (protect_on opr "not %a") print_opl_fmla f
| Fif (f1,f2,f3) ->
fprintf fmt (protect_on opr "if %a@ then %a@ else %a")
print_fmla f1 print_fmla f2 print_opl_fmla f3
| Flet (t,f) ->
let v,f = f_open_bound f in
fprintf fmt (protect_on opr "let %a =@ %a in@ %a")
......@@ -232,9 +238,6 @@ and print_fnode opl opr fmt f = match f.f_node with
fprintf fmt "match %a with@\n@[<hov>%a@]@\nend"
(print_list comma print_term) tl
(print_list newline print_fbranch) bl
| Fif (f1,f2,f3) ->
fprintf fmt (protect_on opr "if %a@ then %a@ else %a")
print_fmla f1 print_fmla f2 print_opl_fmla f3
and print_tbranch fmt br =
let pl,t = t_open_branch br in
......
......@@ -238,6 +238,7 @@ and term_node =
| Tvar of vsymbol
| Tconst of constant
| Tapp of lsymbol * term list
| Tif of fmla * term * term
| Tlet of term * term_bound
| Tcase of term list * term_branch list
| Teps of fmla_bound
......@@ -293,6 +294,7 @@ module Hsterm = Hashcons.Make (struct
| Tvar v1, Tvar v2 -> v1 == v2
| Tconst c1, Tconst c2 -> c1 = c2
| Tapp (s1, l1), Tapp (s2, l2) -> s1 == s2 && List.for_all2 (==) l1 l2
| Tif (f1, t1, e1), Tif (f2, t2, e2) -> f1 == f2 && t1 == t2 && e2 == e2
| Tlet (t1, b1), Tlet (t2, b2) -> t1 == t2 && t_eq_bound b1 b2
| Tcase (tl1, bl1), Tcase (tl2, bl2) ->
list_all2 (==) tl1 tl2 && list_all2 t_eq_branch bl1 bl2
......@@ -318,6 +320,7 @@ module Hsterm = Hashcons.Make (struct
| Tvar v -> v.vs_name.id_tag
| Tconst c -> Hashtbl.hash c
| Tapp (f, tl) -> Hashcons.combine_list t_hash (f.ls_name.id_tag) tl
| Tif (f, t, e) -> Hashcons.combine2 f.f_tag t.t_tag e.t_tag
| Tlet (t, bt) -> Hashcons.combine t.t_tag (t_hash_bound bt)
| Tcase (tl, bl) -> let ht = Hashcons.combine_list t_hash 17 tl in
Hashcons.combine_list t_hash_branch ht bl
......@@ -433,6 +436,7 @@ let t_bvar n ty = Hsterm.hashcons (mk_term (Tbvar n) ty)
let t_var v = Hsterm.hashcons (mk_term (Tvar v) v.vs_ty)
let t_const c ty = Hsterm.hashcons (mk_term (Tconst c) ty)
let t_app f tl ty = Hsterm.hashcons (mk_term (Tapp (f, tl)) ty)
let t_if f t1 t2 = Hsterm.hashcons (mk_term (Tif (f, t1, t2)) t2.t_ty)
let t_let v t1 t2 = Hsterm.hashcons (mk_term (Tlet (t1, (v, t2))) t2.t_ty)
let t_case tl bl ty = Hsterm.hashcons (mk_term (Tcase (tl, bl)) ty)
let t_eps u f = Hsterm.hashcons (mk_term (Teps (u, f)) u.vs_ty)
......@@ -494,6 +498,7 @@ let t_map_unsafe fnT fnF lvl t = t_label_copy t (match t.t_node with
| Tbvar n when n >= lvl -> raise UnboundIndex
| Tbvar _ | Tvar _ | Tconst _ -> t
| Tapp (f, tl) -> t_app f (List.map (fnT lvl) tl) t.t_ty
| Tif (f, t1, t2) -> t_if (fnF lvl f) (fnT lvl t1) (fnT lvl t2)
| Tlet (t1, (u, t2)) -> t_let u (fnT lvl t1) (fnT (lvl + 1) t2)
| Tcase (tl, bl) ->
t_case (List.map (fnT lvl) tl) (List.map (brlvl fnT lvl) bl) t.t_ty
......@@ -527,6 +532,7 @@ let t_fold_unsafe fnT fnF lvl acc t = match t.t_node with
| Tbvar n when n >= lvl -> raise UnboundIndex
| Tbvar _ | Tvar _ | Tconst _ -> acc
| Tapp (f, tl) -> List.fold_left (fnT lvl) acc tl
| Tif (f, t1, t2) -> fnT lvl (fnT lvl (fnF lvl acc f) t1) t2
| Tlet (t1, (u, t2)) -> fnT (lvl + 1) (fnT lvl acc t1) t2
| Tcase (tl, bl) ->
List.fold_left (brlvl fnT lvl) (List.fold_left (fnT lvl) acc tl) bl
......@@ -620,6 +626,10 @@ let f_case tl bl =
List.iter f_check_branch bl;
f_case tl bl
let t_if f t1 t2 =
if t1.t_ty != t2.t_ty then raise TypeMismatch;
t_if f t1 t2
let t_let v t1 t2 =
if v.vs_ty != t1.t_ty then raise TypeMismatch;
t_let v t1 t2
......@@ -639,6 +649,7 @@ let rec t_s_map fnT fnV fnL t =
| Tvar v -> t_var (fnV v ty)
| Tconst _ -> t
| Tapp (f, tl) -> t_app (fnL f) (List.map fn_t tl) ty
| Tif (f, t1, t2) -> t_if (fn_f f) (fn_t t1) (fn_t t2)
| Tlet (t1, (u, t2)) ->
let t1 = fn_t t1 in t_let (fnV u t1.t_ty) t1 (fn_t t2)
| Tcase (tl, bl) ->
......@@ -704,6 +715,7 @@ let rec t_s_fold fnT fnL acc t =
match t.t_node with
| Tbvar _ | Tvar _ | Tconst _ -> acc
| Tapp (f, tl) -> List.fold_left fn_t (fnL acc f) tl
| Tif (f, t1, t2) -> fn_t (fn_t (fn_f acc f) t1) t2
| Tlet (t1, (_,t2)) -> fn_t (fn_t acc t1) t2
| Tcase (tl, bl) ->
List.fold_left (t_branch fnT fnL) (List.fold_left fn_t acc tl) bl
......@@ -903,6 +915,7 @@ let t_map fnT fnF t = t_label_copy t (match t.t_node with
| Tbvar _ -> raise UnboundIndex
| Tvar _ | Tconst _ -> t
| Tapp (f, tl) -> t_app_unsafe f (List.map fnT tl) t.t_ty
| Tif (f, t1, t2) -> t_if (fnF f) (fnT t1) (fnT t2)
| Tlet (t1, b) -> let u,t2 = t_open_bound b in
let t1' = fnT t1 in let t2' = fnT t2 in
if t1' == t1 && t2' == t2 then t else t_let u t1' t2'
......@@ -947,6 +960,7 @@ let t_fold fnT fnF acc t = match t.t_node with
| Tbvar _ -> raise UnboundIndex
| Tvar _ | Tconst _ -> acc
| Tapp (f, tl) -> List.fold_left fnT acc tl
| Tif (f, t1, t2) -> fnT (fnT (fnF acc f) t1) t2
| Tlet (t1, b) -> let _,t2 = t_open_bound b in fnT (fnT acc t1) t2
| Tcase (tl, bl) ->
List.fold_left (t_branch fnT) (List.fold_left fnT acc tl) bl
......@@ -1041,6 +1055,8 @@ let rec t_equal_alpha t1 t2 =
c1 = c2
| Tapp (s1, l1), Tapp (s2, l2) ->
s1 == s2 && List.for_all2 t_equal_alpha l1 l2
| Tif (f1, t1, e1), Tif (f2, t2, e2) ->
f_equal_alpha f1 f2 && t_equal_alpha t1 t2 && t_equal_alpha e1 e2
| Tlet (t1, tb1), Tlet (t2, tb2) ->
let v1, b1 = tb1 in
let v2, b2 = tb2 in
......@@ -1121,6 +1137,8 @@ let rec t_match s t1 t2 =
then Mvs.add v1 t2 s else raise NoMatch)
| Tapp (s1, l1), Tapp (s2, l2) when s1 == s2 ->
List.fold_left2 t_match s l1 l2
| Tif (f1, t1, e1), Tif (f2, t2, e2) ->
t_match (t_match (f_match s f1 f2) t1 t2) e1 e2
| Tlet (t1, tb1), Tlet (t2, tb2) ->
let v1, b1 = tb1 in
let v2, b2 = tb2 in
......@@ -1325,6 +1343,11 @@ let f_binary_simp op = match op with
| Fimplies -> f_implies_simp
| Fiff -> f_iff_simp
let t_if_simp f t1 t2 = match f.f_node with
| Ftrue -> t1
| Ffalse -> t2
| _ -> t_if f t1 t2
let f_if_simp f1 f2 f3 = match f1.f_node, f2.f_node, f3.f_node with
| Ftrue, _, _ -> f2
| Ffalse, _, _ -> f3
......
......@@ -126,6 +126,7 @@ and term_node = private
| Tvar of vsymbol
| Tconst of constant
| Tapp of lsymbol * term list
| Tif of fmla * term * term
| Tlet of term * term_bound
| Tcase of term list * term_branch list
| Teps of fmla_bound
......@@ -167,6 +168,7 @@ module Sfmla : Set.S with type elt = fmla
val t_var : vsymbol -> term
val t_const : constant -> ty -> term
val t_app : lsymbol -> term list -> ty -> term
val t_if : fmla -> term -> term -> term
val t_let : vsymbol -> term -> term -> term
val t_case : term list -> (pattern list * term) list -> ty -> term
val t_eps : vsymbol -> fmla -> term
......@@ -210,6 +212,7 @@ val f_implies_simp : fmla -> fmla -> fmla
val f_iff_simp : fmla -> fmla -> fmla
val f_binary_simp : binop -> fmla -> fmla -> fmla
val f_not_simp : fmla -> fmla
val t_if_simp : fmla -> term -> term -> term
val f_if_simp : fmla -> fmla -> fmla -> fmla
val f_let_simp : vsymbol -> term -> fmla -> fmla
......
......@@ -87,6 +87,8 @@ let rec print_term drv fmt t = match t.t_node with
fprintf fmt "@[(let %a = %a@ in %a)@]" print_ident v.vs_name
(print_term drv) t1 (print_term drv) t2;
forget_var v
| Tif _ ->
assert false
| Tcase _ ->
assert false
| Teps _ ->
......
......@@ -165,6 +165,9 @@ and print_tnode opl opr drv fmt t = match t.t_node with
| Tconst (ConstReal c) ->
Print_real.print_with_integers
"(%s)%%R" "(%s * %s)%%R" "(%s / %s)%%R" fmt c
| Tif (f,t1,t2) ->
fprintf fmt (protect_on opr "if %a@ then %a@ else %a")
(print_fmla drv) f (print_term drv) t1 (print_opl_term drv) t2
| Tlet (t1,tb) ->
let v,t2 = t_open_bound tb in
fprintf fmt (protect_on opr "let %a :=@ %a in@ %a")
......
......@@ -95,6 +95,8 @@ let rec print_term drv fmt t = match t.t_node with
fprintf fmt "@[(let %a = %a@ in %a)@]" print_ident v.vs_name
(print_term drv) t1 (print_term drv) t2;
forget_var v
| Tif _ ->
assert false
| Tcase _ ->
assert false
| Teps _ ->
......
......@@ -170,6 +170,9 @@ and print_tnode opl opr drv fmt t = match t.t_node with
print_vs fmt v
| Tconst c ->
Pretty.print_const fmt c
| Tif (f,t1,t2) ->
fprintf fmt (protect_on opr "if %a@ then %a@ else %a")
(print_fmla drv) f (print_term drv) t1 (print_opl_term drv) t2
| Tlet (t1,tb) ->
let v,t2 = t_open_bound tb in
fprintf fmt (protect_on opr "let %a =@ %a in@ %a")
......@@ -210,6 +213,9 @@ and print_fnode opl opr drv fmt f = match f.f_node with
(print_opr_fmla drv) f1 print_binop b (print_opl_fmla drv) f2
| Fnot f ->
fprintf fmt (protect_on opr "not %a") (print_opl_fmla drv) f
| Fif (f1,f2,f3) ->
fprintf fmt (protect_on opr "if %a@ then %a@ else %a")
(print_fmla drv) f1 (print_fmla drv) f2 (print_opl_fmla drv) f3
| Flet (t,f) ->
let v,f = f_open_bound f in
fprintf fmt (protect_on opr "let %a =@ %a in@ %a")
......@@ -219,9 +225,6 @@ and print_fnode opl opr drv fmt f = match f.f_node with
fprintf fmt "match %a with@\n@[<hov>%a@]@\nend"
(print_list comma (print_term drv)) tl
(print_list newline (print_fbranch drv)) bl
| Fif (f1,f2,f3) ->
fprintf fmt (protect_on opr "if %a@ then %a@ else %a")
(print_fmla drv) f1 (print_fmla drv) f2 (print_opl_fmla drv) f3
| Fapp (ps, tl) ->
begin match drv ps.ls_name with
| Syntax s -> syntax_arguments s (print_term drv) fmt tl
......
......@@ -190,6 +190,7 @@ let conv_res_app tenv tvar p tl ty =
let rec rewrite_term tenv tvar vsvar t =
let fnT = rewrite_term tenv tvar vsvar in
let fnF = rewrite_fmla tenv tvar vsvar in
match t.t_node with
| Tconst _ -> t
| Tvar x -> Mvs.find x vsvar
......@@ -198,6 +199,8 @@ let rec rewrite_term tenv tvar vsvar t =
let p = Hls.find tenv.trans_lsymbol p in
let tl = List.map2 (conv_arg tenv tvar) tl p.ls_args in
conv_res_app tenv tvar p tl t.t_ty
| Tif (f, t1, t2) ->
t_if (fnF f) (fnT t1) (fnT t2)
| Tlet (t1, b) -> let u,t2 = t_open_bound b in
let t1' = fnT t1 in let t2' = fnT t2 in
if t1' == t1 && t2' == t2 then t else t_let u t1' t2'
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment