Commit 2585bb44 authored by Andrei Paskevich's avatar Andrei Paskevich

add core support for the "if" construction in terms.

It's not more expressive but much nicer than epsilon.
parent 9cf074db
......@@ -187,6 +187,9 @@ and print_tnode opl opr fmt t = match t.t_node with
| Tapp (fs, tl) ->
fprintf fmt (protect_on opl "%a%a:%a")
print_ls fs (print_paren_r print_term) tl print_ty t.t_ty
| Tif (f,t1,t2) ->
fprintf fmt (protect_on opr "if %a@ then %a@ else %a")
print_fmla f print_term t1 print_opl_term t2
| Tlet (t1,tb) ->
let v,t2 = t_open_bound tb in
fprintf fmt (protect_on opr "let %a =@ %a in@ %a")
......@@ -223,6 +226,9 @@ and print_fnode opl opr fmt f = match f.f_node with
print_opr_fmla f1 print_binop b print_opl_fmla f2
| Fnot f ->
fprintf fmt (protect_on opr "not %a") print_opl_fmla f
| Fif (f1,f2,f3) ->
fprintf fmt (protect_on opr "if %a@ then %a@ else %a")
print_fmla f1 print_fmla f2 print_opl_fmla f3
| Flet (t,f) ->
let v,f = f_open_bound f in
fprintf fmt (protect_on opr "let %a =@ %a in@ %a")
......@@ -232,9 +238,6 @@ and print_fnode opl opr fmt f = match f.f_node with
fprintf fmt "match %a with@\n@[<hov>%a@]@\nend"
(print_list comma print_term) tl
(print_list newline print_fbranch) bl
| Fif (f1,f2,f3) ->
fprintf fmt (protect_on opr "if %a@ then %a@ else %a")
print_fmla f1 print_fmla f2 print_opl_fmla f3
and print_tbranch fmt br =
let pl,t = t_open_branch br in
......
......@@ -238,6 +238,7 @@ and term_node =
| Tvar of vsymbol
| Tconst of constant
| Tapp of lsymbol * term list
| Tif of fmla * term * term
| Tlet of term * term_bound
| Tcase of term list * term_branch list
| Teps of fmla_bound
......@@ -293,6 +294,7 @@ module Hsterm = Hashcons.Make (struct
| Tvar v1, Tvar v2 -> v1 == v2
| Tconst c1, Tconst c2 -> c1 = c2
| Tapp (s1, l1), Tapp (s2, l2) -> s1 == s2 && List.for_all2 (==) l1 l2
| Tif (f1, t1, e1), Tif (f2, t2, e2) -> f1 == f2 && t1 == t2 && e2 == e2
| Tlet (t1, b1), Tlet (t2, b2) -> t1 == t2 && t_eq_bound b1 b2
| Tcase (tl1, bl1), Tcase (tl2, bl2) ->
list_all2 (==) tl1 tl2 && list_all2 t_eq_branch bl1 bl2
......@@ -318,6 +320,7 @@ module Hsterm = Hashcons.Make (struct
| Tvar v -> v.vs_name.id_tag
| Tconst c -> Hashtbl.hash c
| Tapp (f, tl) -> Hashcons.combine_list t_hash (f.ls_name.id_tag) tl
| Tif (f, t, e) -> Hashcons.combine2 f.f_tag t.t_tag e.t_tag
| Tlet (t, bt) -> Hashcons.combine t.t_tag (t_hash_bound bt)
| Tcase (tl, bl) -> let ht = Hashcons.combine_list t_hash 17 tl in
Hashcons.combine_list t_hash_branch ht bl
......@@ -433,6 +436,7 @@ let t_bvar n ty = Hsterm.hashcons (mk_term (Tbvar n) ty)
let t_var v = Hsterm.hashcons (mk_term (Tvar v) v.vs_ty)
let t_const c ty = Hsterm.hashcons (mk_term (Tconst c) ty)
let t_app f tl ty = Hsterm.hashcons (mk_term (Tapp (f, tl)) ty)
let t_if f t1 t2 = Hsterm.hashcons (mk_term (Tif (f, t1, t2)) t2.t_ty)
let t_let v t1 t2 = Hsterm.hashcons (mk_term (Tlet (t1, (v, t2))) t2.t_ty)
let t_case tl bl ty = Hsterm.hashcons (mk_term (Tcase (tl, bl)) ty)
let t_eps u f = Hsterm.hashcons (mk_term (Teps (u, f)) u.vs_ty)
......@@ -494,6 +498,7 @@ let t_map_unsafe fnT fnF lvl t = t_label_copy t (match t.t_node with
| Tbvar n when n >= lvl -> raise UnboundIndex
| Tbvar _ | Tvar _ | Tconst _ -> t
| Tapp (f, tl) -> t_app f (List.map (fnT lvl) tl) t.t_ty
| Tif (f, t1, t2) -> t_if (fnF lvl f) (fnT lvl t1) (fnT lvl t2)
| Tlet (t1, (u, t2)) -> t_let u (fnT lvl t1) (fnT (lvl + 1) t2)
| Tcase (tl, bl) ->
t_case (List.map (fnT lvl) tl) (List.map (brlvl fnT lvl) bl) t.t_ty
......@@ -527,6 +532,7 @@ let t_fold_unsafe fnT fnF lvl acc t = match t.t_node with
| Tbvar n when n >= lvl -> raise UnboundIndex
| Tbvar _ | Tvar _ | Tconst _ -> acc
| Tapp (f, tl) -> List.fold_left (fnT lvl) acc tl
| Tif (f, t1, t2) -> fnT lvl (fnT lvl (fnF lvl acc f) t1) t2
| Tlet (t1, (u, t2)) -> fnT (lvl + 1) (fnT lvl acc t1) t2
| Tcase (tl, bl) ->
List.fold_left (brlvl fnT lvl) (List.fold_left (fnT lvl) acc tl) bl
......@@ -620,6 +626,10 @@ let f_case tl bl =
List.iter f_check_branch bl;
f_case tl bl
let t_if f t1 t2 =
if t1.t_ty != t2.t_ty then raise TypeMismatch;
t_if f t1 t2
let t_let v t1 t2 =
if v.vs_ty != t1.t_ty then raise TypeMismatch;
t_let v t1 t2
......@@ -639,6 +649,7 @@ let rec t_s_map fnT fnV fnL t =
| Tvar v -> t_var (fnV v ty)
| Tconst _ -> t
| Tapp (f, tl) -> t_app (fnL f) (List.map fn_t tl) ty
| Tif (f, t1, t2) -> t_if (fn_f f) (fn_t t1) (fn_t t2)
| Tlet (t1, (u, t2)) ->
let t1 = fn_t t1 in t_let (fnV u t1.t_ty) t1 (fn_t t2)
| Tcase (tl, bl) ->
......@@ -704,6 +715,7 @@ let rec t_s_fold fnT fnL acc t =
match t.t_node with
| Tbvar _ | Tvar _ | Tconst _ -> acc
| Tapp (f, tl) -> List.fold_left fn_t (fnL acc f) tl
| Tif (f, t1, t2) -> fn_t (fn_t (fn_f acc f) t1) t2
| Tlet (t1, (_,t2)) -> fn_t (fn_t acc t1) t2
| Tcase (tl, bl) ->
List.fold_left (t_branch fnT fnL) (List.fold_left fn_t acc tl) bl
......@@ -903,6 +915,7 @@ let t_map fnT fnF t = t_label_copy t (match t.t_node with
| Tbvar _ -> raise UnboundIndex
| Tvar _ | Tconst _ -> t
| Tapp (f, tl) -> t_app_unsafe f (List.map fnT tl) t.t_ty
| Tif (f, t1, t2) -> t_if (fnF f) (fnT t1) (fnT t2)
| Tlet (t1, b) -> let u,t2 = t_open_bound b in
let t1' = fnT t1 in let t2' = fnT t2 in
if t1' == t1 && t2' == t2 then t else t_let u t1' t2'
......@@ -947,6 +960,7 @@ let t_fold fnT fnF acc t = match t.t_node with
| Tbvar _ -> raise UnboundIndex
| Tvar _ | Tconst _ -> acc
| Tapp (f, tl) -> List.fold_left fnT acc tl
| Tif (f, t1, t2) -> fnT (fnT (fnF acc f) t1) t2
| Tlet (t1, b) -> let _,t2 = t_open_bound b in fnT (fnT acc t1) t2
| Tcase (tl, bl) ->
List.fold_left (t_branch fnT) (List.fold_left fnT acc tl) bl
......@@ -1041,6 +1055,8 @@ let rec t_equal_alpha t1 t2 =
c1 = c2
| Tapp (s1, l1), Tapp (s2, l2) ->
s1 == s2 && List.for_all2 t_equal_alpha l1 l2
| Tif (f1, t1, e1), Tif (f2, t2, e2) ->
f_equal_alpha f1 f2 && t_equal_alpha t1 t2 && t_equal_alpha e1 e2
| Tlet (t1, tb1), Tlet (t2, tb2) ->
let v1, b1 = tb1 in
let v2, b2 = tb2 in
......@@ -1121,6 +1137,8 @@ let rec t_match s t1 t2 =
then Mvs.add v1 t2 s else raise NoMatch)
| Tapp (s1, l1), Tapp (s2, l2) when s1 == s2 ->
List.fold_left2 t_match s l1 l2
| Tif (f1, t1, e1), Tif (f2, t2, e2) ->
t_match (t_match (f_match s f1 f2) t1 t2) e1 e2
| Tlet (t1, tb1), Tlet (t2, tb2) ->
let v1, b1 = tb1 in
let v2, b2 = tb2 in
......@@ -1325,6 +1343,11 @@ let f_binary_simp op = match op with
| Fimplies -> f_implies_simp
| Fiff -> f_iff_simp
let t_if_simp f t1 t2 = match f.f_node with
| Ftrue -> t1
| Ffalse -> t2
| _ -> t_if f t1 t2
let f_if_simp f1 f2 f3 = match f1.f_node, f2.f_node, f3.f_node with
| Ftrue, _, _ -> f2
| Ffalse, _, _ -> f3
......
......@@ -126,6 +126,7 @@ and term_node = private
| Tvar of vsymbol
| Tconst of constant
| Tapp of lsymbol * term list
| Tif of fmla * term * term
| Tlet of term * term_bound
| Tcase of term list * term_branch list
| Teps of fmla_bound
......@@ -167,6 +168,7 @@ module Sfmla : Set.S with type elt = fmla
val t_var : vsymbol -> term
val t_const : constant -> ty -> term
val t_app : lsymbol -> term list -> ty -> term
val t_if : fmla -> term -> term -> term
val t_let : vsymbol -> term -> term -> term
val t_case : term list -> (pattern list * term) list -> ty -> term
val t_eps : vsymbol -> fmla -> term
......@@ -210,6 +212,7 @@ val f_implies_simp : fmla -> fmla -> fmla
val f_iff_simp : fmla -> fmla -> fmla
val f_binary_simp : binop -> fmla -> fmla -> fmla
val f_not_simp : fmla -> fmla
val t_if_simp : fmla -> term -> term -> term
val f_if_simp : fmla -> fmla -> fmla -> fmla
val f_let_simp : vsymbol -> term -> fmla -> fmla
......
......@@ -87,6 +87,8 @@ let rec print_term drv fmt t = match t.t_node with
fprintf fmt "@[(let %a = %a@ in %a)@]" print_ident v.vs_name
(print_term drv) t1 (print_term drv) t2;
forget_var v
| Tif _ ->
assert false
| Tcase _ ->
assert false
| Teps _ ->
......
......@@ -165,6 +165,9 @@ and print_tnode opl opr drv fmt t = match t.t_node with
| Tconst (ConstReal c) ->
Print_real.print_with_integers
"(%s)%%R" "(%s * %s)%%R" "(%s / %s)%%R" fmt c
| Tif (f,t1,t2) ->
fprintf fmt (protect_on opr "if %a@ then %a@ else %a")
(print_fmla drv) f (print_term drv) t1 (print_opl_term drv) t2
| Tlet (t1,tb) ->
let v,t2 = t_open_bound tb in
fprintf fmt (protect_on opr "let %a :=@ %a in@ %a")
......
......@@ -95,6 +95,8 @@ let rec print_term drv fmt t = match t.t_node with
fprintf fmt "@[(let %a = %a@ in %a)@]" print_ident v.vs_name
(print_term drv) t1 (print_term drv) t2;
forget_var v
| Tif _ ->
assert false
| Tcase _ ->
assert false
| Teps _ ->
......
......@@ -170,6 +170,9 @@ and print_tnode opl opr drv fmt t = match t.t_node with
print_vs fmt v
| Tconst c ->
Pretty.print_const fmt c
| Tif (f,t1,t2) ->
fprintf fmt (protect_on opr "if %a@ then %a@ else %a")
(print_fmla drv) f (print_term drv) t1 (print_opl_term drv) t2
| Tlet (t1,tb) ->
let v,t2 = t_open_bound tb in
fprintf fmt (protect_on opr "let %a =@ %a in@ %a")
......@@ -210,6 +213,9 @@ and print_fnode opl opr drv fmt f = match f.f_node with
(print_opr_fmla drv) f1 print_binop b (print_opl_fmla drv) f2
| Fnot f ->
fprintf fmt (protect_on opr "not %a") (print_opl_fmla drv) f
| Fif (f1,f2,f3) ->
fprintf fmt (protect_on opr "if %a@ then %a@ else %a")
(print_fmla drv) f1 (print_fmla drv) f2 (print_opl_fmla drv) f3
| Flet (t,f) ->
let v,f = f_open_bound f in
fprintf fmt (protect_on opr "let %a =@ %a in@ %a")
......@@ -219,9 +225,6 @@ and print_fnode opl opr drv fmt f = match f.f_node with
fprintf fmt "match %a with@\n@[<hov>%a@]@\nend"
(print_list comma (print_term drv)) tl
(print_list newline (print_fbranch drv)) bl
| Fif (f1,f2,f3) ->
fprintf fmt (protect_on opr "if %a@ then %a@ else %a")
(print_fmla drv) f1 (print_fmla drv) f2 (print_opl_fmla drv) f3
| Fapp (ps, tl) ->
begin match drv ps.ls_name with
| Syntax s -> syntax_arguments s (print_term drv) fmt tl
......
......@@ -190,6 +190,7 @@ let conv_res_app tenv tvar p tl ty =
let rec rewrite_term tenv tvar vsvar t =
let fnT = rewrite_term tenv tvar vsvar in
let fnF = rewrite_fmla tenv tvar vsvar in
match t.t_node with
| Tconst _ -> t
| Tvar x -> Mvs.find x vsvar
......@@ -198,6 +199,8 @@ let rec rewrite_term tenv tvar vsvar t =
let p = Hls.find tenv.trans_lsymbol p in
let tl = List.map2 (conv_arg tenv tvar) tl p.ls_args in
conv_res_app tenv tvar p tl t.t_ty
| Tif (f, t1, t2) ->
t_if (fnF f) (fnT t1) (fnT t2)
| Tlet (t1, b) -> let u,t2 = t_open_bound b in
let t1' = fnT t1 in let t2' = fnT t2 in
if t1' == t1 && t2' == t2 then t else t_let u t1' t2'
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment