Commit 919d60bf authored by Andrei Paskevich's avatar Andrei Paskevich

whyml: accept infix relation chains

In a chain "e1 op1 e2 op2 e3 op3 e4", each relation symbol is either:

- an infix symbol "=" or "<>", or

- a binary symbol whose value type is Bool.bool or Prop (for lsymbols)
  and whose arguments' types are not Bool.bool.

In other words, we interpret a chain as a conjunction only if there
is no possibility(*) to interpret it as a superposition. The exception
is only made for "=" and "<>", which are _always_ considered as
chainable, even if they are redefined with some bogus type signatures.
Notice that redefining "<>" has no effect whatsoever, since "<>" is
always treated as negated "=".

As for the evaluation order, the chain above would be equivalent to:
    let x2 = e2 in
    (e1 op1 x2) &&
        let x3 = e3 in
        (x2 op2 x3) &&
            (x3 op3 e4)
This is due to the fact that lazy conjunctions are evaluated from
left to right, function arguments are evaluated from right to left,
and no expression should be evaluated twice.

[*] well, not really, since we consider symbols ('a -> 'b -> bool)
as chainable, even though such chains could be interpreted as
superpositions(**). We could treat such symbols as unchainable,
but that would make equality and disequality doubly special cases,
and I don't like it. We'll see if the current conditions are not
enough.

[**] what also bothers me is dynamic types of locally defined
infix symbols, which can be type variables or Bool.bool depending
on the order of operations in Mlw_typing. Currently, I can't come
up with any example of bad behaviour -- we are somewhat saved by
not being able to write "let (==) = ... in ...").
parent ce73b83f
......@@ -67,7 +67,7 @@ back +-+-+-+-------------------+
let test (a: sparse_array 'a) i
requires { 0 <= i < length a }
ensures { result=True <-> is_elt a i }
= 0 <= a.index[i] && a.index[i] < a.card && a.back[a.index[i]] = i
= 0 <= a.index[i] < a.card && a.back[a.index[i]] = i
let get (a: sparse_array 'a) i
requires { 0 <= i < length a }
......
......@@ -110,7 +110,7 @@ end
let mk_infix e1 op e2 =
let id = mk_id (infix op) (floc_i 2) in
mk_expr (mk_apply_id id [e1; e2])
mk_expr (Einfix (e1, id, e2))
let mk_prefix op e1 =
let id = mk_id (prefix op) (floc_i 1) in
......@@ -1072,11 +1072,13 @@ fun_expr:
expr:
| expr_arg
{ $1 }
{ match $1.expr_desc with (* break the infix relation chain *)
| Einfix (l,o,r) -> { $1 with expr_desc = Einnfix (l,o,r) }
| _ -> $1 }
| expr EQUAL expr
{ mk_infix $1 "=" $3 }
| expr LTGT expr
{ mk_expr (Enot (mk_infix $1 "=" $3)) }
{ mk_infix $1 "<>" $3 }
| expr LARROW expr
{ match $1.expr_desc with
| Eapply (e11, e12) -> begin match e11.expr_desc with
......@@ -1131,9 +1133,9 @@ expr:
| PPptuple [] -> mk_expr (Elet (id_anonymous (), true,
{ $5 with expr_desc = Ecast ($5, PPTtuple []) }, $7))
| _ -> Loc.errorm ~loc:(floc_i 3) "`ghost' cannot come before a pattern" }
| LET lident labels fun_defn IN expr
| LET lident_rich labels fun_defn IN expr
{ mk_expr (Elet (add_lab $2 $3, false, $4, $6)) }
| LET GHOST lident labels fun_defn IN expr
| LET GHOST lident_rich labels fun_defn IN expr
{ mk_expr (Elet (add_lab $3 $4, true, $5, $7)) }
| LET REC list1_rec_defn IN expr
{ mk_expr (Eletrec ($3, $5)) }
......
......@@ -211,6 +211,8 @@ and expr_desc =
| Econstant of constant
| Eident of qualid
| Eapply of expr * expr
| Einfix of expr * ident * expr
| Einnfix of expr * ident * expr
| Efun of binder list * triple
| Elet of ident * ghost * expr * expr
| Eletrec of letrec list * expr
......
......@@ -195,6 +195,15 @@ let unify d1 d2 = unify ~weak:false d1 d2
type dvty = dity list * dity (* A -> B -> C == ([A;B],C) *)
let is_chainable dvty =
let rec is_bool = function
| Dvar { contents = Dval dty } -> is_bool dty
| Dts (ts,_) -> ts_equal ts ts_bool
| _ -> false in
match dvty with
| [t1;t2],t -> is_bool t && not (is_bool t1) && not (is_bool t2)
| _ -> false
let vty_of_dvty (argl,res) =
let vty = VTvalue (ity_of_dity res) in
let conv a = create_pvsymbol (id_fresh "x") (ity_of_dity a) in
......
......@@ -36,6 +36,8 @@ val ts_app: tysymbol -> dity list -> dity
val dity_refresh: dity -> dity (* refresh regions *)
val is_chainable: dvty -> bool (* non-bool * non-bool -> bool *)
exception DTypeMismatch of dity * dity
val unify: dity -> dity -> unit
......
......@@ -254,12 +254,12 @@ let hidden_ls ~loc ls =
(* helper functions for let-expansion *)
let test_var e = match e.de_desc with
| DElocal _ | DEglobal_pv _ -> true
| DElocal _ | DEglobal_pv _ | DEconstant _ -> true
| _ -> false
let mk_var e =
let mk_var name e =
if test_var e then e else
{ de_desc = DElocal "q";
{ de_desc = DElocal name;
de_type = e.de_type;
de_loc = e.de_loc;
de_lab = Slab.empty }
......@@ -270,11 +270,11 @@ let mk_id s loc =
let mk_dexpr desc dvty loc labs =
{ de_desc = desc; de_type = dvty; de_loc = loc; de_lab = labs }
let mk_let ~loc ~uloc e (desc,dvty) =
let mk_let name ~loc ~uloc e (desc,dvty) =
if test_var e then desc, dvty else
let loc = Opt.get_def loc uloc in
let e1 = mk_dexpr desc dvty loc Slab.empty in
DElet (mk_id "q" e.de_loc, false, e, e1), dvty
DElet (mk_id name e.de_loc, false, e, e1), dvty
(* patterns *)
......@@ -295,6 +295,25 @@ let specialize_qualid uc p = match uc_find_ps uc p with
| LS ls -> DEglobal_ls ls, Loc.try1 (qloc p) specialize_lsymbol ls
| XS xs -> errorm ~loc:(qloc p) "unexpected exception symbol %a" print_xs xs
let chainable_qualid uc p = match uc_find_ps uc p with
| PS { ps_aty = { aty_args = [pv1;pv2]; aty_result = VTvalue ity }}
| PS { ps_aty = { aty_args = [pv1]; aty_result =
VTarrow { aty_args = [pv2]; aty_result = VTvalue ity }}} ->
ity_equal ity ity_bool
&& not (ity_equal pv1.pv_ity ity_bool)
&& not (ity_equal pv2.pv_ity ity_bool)
| LS { ls_args = [ty1;ty2]; ls_value = ty } ->
Opt.fold (fun _ ty -> ty_equal ty ty_bool) true ty
&& not (ty_equal ty1 ty_bool)
&& not (ty_equal ty2 ty_bool)
| PS _ | LS _ | PL _ | PV _ | XS _ -> false
let chainable_op denv op =
op.id = "infix =" || op.id = "infix <>" ||
match Mstr.find_opt op.id denv.locals with
| Some (_, dvty) -> is_chainable dvty
| None -> chainable_qualid denv.uc (Qident op)
let find_xsymbol uc p = match uc_find_ps uc p with
| XS xs -> xs
| _ -> errorm ~loc:(qloc p) "exception symbol expected"
......@@ -468,6 +487,37 @@ and de_desc denv loc = function
let e, el = decompose_app [e2] e1 in
let el = List.map (dexpr denv) el in
de_app loc (dexpr denv e) el
| Ptree.Einfix (e12, op2, e3)
| Ptree.Einnfix (e12, op2, e3) ->
let mk_bool (d,ty) =
let de = mk_dexpr d ty (Opt.get_def loc denv.uloc) Slab.empty in
expected_type de dity_bool; de in
let make_app de1 op de2 =
let id = Ptree.Eident (Qident op) in
let e0 = { expr_desc = id; expr_loc = op.id_loc } in
de_app loc (dexpr denv e0) [de1; de2] in
let make_app de1 op de2 =
if op.id <> "infix <>" then make_app de1 op de2 else
let de12 = mk_bool (make_app de1 { op with id = "infix =" } de2) in
DEnot de12, de12.de_type in
let rec make_chain n1 n2 de1 = function
| [op,de2] ->
make_app de1 op de2
| (op,de2) :: ch ->
let v = mk_var n1 de2 in
let de12 = mk_bool (make_app de1 op v) in
let de23 = mk_bool (make_chain n2 n1 v ch) in
let d = DElazy (LazyAnd, de12, de23) in
mk_let n1 ~loc ~uloc:denv.uloc de2 (d, de12.de_type)
| [] -> assert false in
let rec get_chain e12 acc = match e12.expr_desc with
| Ptree.Einfix (e1, op1, e2) when chainable_op denv op1 ->
get_chain e1 ((op1, dexpr denv e2) :: acc)
| _ -> e12, acc in
let e1, ch = if chainable_op denv op2
then get_chain e12 [op2, dexpr denv e3]
else e12, [op2, dexpr denv e3] in
make_chain "q1 " "q2 " (dexpr denv e1) ch
| Ptree.Elet (id, gh, e1, e2) ->
let e1 = dexpr denv e1 in
let denv = match e1.de_desc with
......@@ -526,7 +576,7 @@ and de_desc denv loc = function
de_app loc (hidden_pl ~loc cs) (List.map get_val pjl)
| Ptree.Eupdate (e1, fl) when is_pure_record denv.uc fl ->
let e1 = dexpr denv e1 in
let e0 = mk_var e1 in
let e0 = mk_var "q " e1 in
let kn = Theory.get_known (get_theory denv.uc) in
let fl = List.map (find_pure_field denv.uc) fl in
let cs,pjl,flm = Loc.try2 loc Decl.parse_record kn fl in
......@@ -537,10 +587,10 @@ and de_desc denv loc = function
let d, dvty = de_app loc (hidden_ls ~loc pj) [e0] in
mk_dexpr d dvty loc Slab.empty in
let res = de_app loc (hidden_ls ~loc cs) (List.map get_val pjl) in
mk_let ~loc ~uloc:denv.uloc e1 res
mk_let "q " ~loc ~uloc:denv.uloc e1 res
| Ptree.Eupdate (e1, fl) ->
let e1 = dexpr denv e1 in
let e0 = mk_var e1 in
let e0 = mk_var "q " e1 in
let fl = List.map (find_prog_field denv.uc) fl in
let cs,pjl,flm = Loc.try2 loc parse_record denv.uc fl in
let get_val pj = match Mls.find_opt pj.pl_ls flm with
......@@ -550,7 +600,7 @@ and de_desc denv loc = function
let d, dvty = de_app loc (hidden_pl ~loc pj) [e0] in
mk_dexpr d dvty loc Slab.empty in
let res = de_app loc (hidden_pl ~loc cs) (List.map get_val pjl) in
mk_let ~loc ~uloc:denv.uloc e1 res
mk_let "q " ~loc ~uloc:denv.uloc e1 res
| Ptree.Eassign (e1, q, e2) ->
let fl = dexpr denv { expr_desc = Eident q; expr_loc = qloc q } in
let pl = match fl.de_desc with
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment