diff --git a/src/whyml/mlw_expr.ml b/src/whyml/mlw_expr.ml index 5a6ed00da539e4b6c331023ff08f54196663ed2a..c049df5b1e7b10809c53162a54dcb313f75f11dc 100644 --- a/src/whyml/mlw_expr.ml +++ b/src/whyml/mlw_expr.ml @@ -1141,3 +1141,50 @@ let e_fold fn acc e = match e.e_node with | Eabstr (e,_) -> fn acc e | Elogic _ | Evalue _ | Earrow _ | Eany _ | Eassert _ | Eabsurd -> acc + +let t_void = fs_app (fs_tuple 0) [] ty_unit + +let spec_purify sp = + let vs, f = Mlw_ty.open_post sp.c_post in + match f.t_node with + | Tapp (ps, [{t_node = Tvar us}; t]) + when ls_equal ps ps_equ && vs_equal vs us && not (Mvs.mem vs t.t_vars) -> + t + | Tbinop (Tiff, {t_node = Tapp (ps,[{t_node = Tvar us};{t_node = Ttrue}])},f) + when ls_equal ps ps_equ && vs_equal vs us && not (Mvs.mem vs f.t_vars) -> + t_if f t_bool_true t_bool_false + | _ -> raise Exit + +let rec e_purify e = + let t = match e.e_node with + | Elogic f when f.t_ty = None -> + t_if f t_bool_true t_bool_false + | Elogic t -> t + | Evalue pv -> t_var pv.pv_vs + | Earrow _ | Eassert _ -> t_void + | Eapp (_,_,sp) -> spec_purify sp + | Elet ({ let_sym = LetV pv; let_expr = e1 }, e2) -> + t_let_close_simp pv.pv_vs (e_purify e1) (e_purify e2) + | Elet ({ let_sym = LetA _ }, e1) + | Erec (_,e1) | Eghost e1 -> + e_purify e1 + | Eif (e1,e2,e3) -> + t_if_simp (t_equ_simp (e_purify e1) t_bool_true) + (e_purify e2) (e_purify e3) + | Ecase (e1,bl) -> + let conv (p,e) = t_close_branch p.ppat_pattern (e_purify e) in + t_case (e_purify e1) (List.map conv bl) + | Eany sp | Eabstr (_,sp) -> spec_purify sp + | Eassign _ | Eloop _ | Efor _ + | Eraise _ | Etry _ | Eabsurd -> raise Exit + in + let loc = if t.t_loc = None then e.e_loc else t.t_loc in + t_label ?loc (Slab.union e.e_label t.t_label) t + +let e_purify e = + if Sreg.is_empty e.e_effect.eff_writes && + Sreg.is_empty e.e_effect.eff_ghostw && + Sexn.is_empty e.e_effect.eff_raises && + Sexn.is_empty e.e_effect.eff_ghostx + then try Some (e_purify e) with Exit -> None + else None diff --git a/src/whyml/mlw_expr.mli b/src/whyml/mlw_expr.mli index 14ec8840f4904557500b44bfcb8305ee10aeadee..34d4da9d252306609b560a20f85220d9449cdd6b 100644 --- a/src/whyml/mlw_expr.mli +++ b/src/whyml/mlw_expr.mli @@ -253,3 +253,5 @@ val e_absurd : ity -> expr (** expression traversal *) val e_fold : ('a -> expr -> 'a) -> 'a -> expr -> 'a + +val e_purify : expr -> term option diff --git a/src/whyml/mlw_typing.ml b/src/whyml/mlw_typing.ml index 9f815ad5b2846e3ad898a0caebd3b1700c454750..fdcc52a59b7ac6bf7c0a4ce717f9cb7272c10cf3 100644 --- a/src/whyml/mlw_typing.ml +++ b/src/whyml/mlw_typing.ml @@ -88,6 +88,8 @@ let () = Exn_printer.register (fun fmt e -> match e with | _ -> raise e) (* TODO: let type_only = Debug.test_flag Typing.debug_type_only in *) +let implicit_post = Debug.register_flag "implicit_post" + ~desc:"Generate@ a@ postcondition@ for@ pure@ functions@ without@ one." type denv = { uc : module_uc; @@ -1203,11 +1205,23 @@ and expr_rec lenv dfdl = List.iter2 check_user_effect fdl dfdl; fdl -and expr_fun lenv x gh bl tr = +and expr_fun lenv x gh bl (_, dsp as tr) = let lam = expr_lam lenv gh (binders bl) tr in if lam.l_spec.c_variant <> [] then Loc.errorm "variants are not allowed in a non-recursive definition"; - check_user_effect lenv lam.l_expr (snd tr); + check_user_effect lenv lam.l_expr dsp; + let lam = + if Debug.nottest_flag implicit_post || dsp.ds_post <> [] || + oty_equal lam.l_spec.c_post.t_ty (Some ty_unit) then lam + else match e_purify lam.l_expr with + | None -> lam + | Some t -> + let vs, f = Mlw_ty.open_post lam.l_spec.c_post in + let f = t_and_simp (t_equ_simp (t_var vs) t) f in + let f = t_label_add Split_goal.stop_split f in + let post = Mlw_ty.create_post vs f in + let spec = { lam.l_spec with c_post = post } in + { lam with l_spec = spec } in let pvs = l_pvset Spv.empty lam in let lam = lambda_invariant lenv pvs lam.l_expr.e_effect lam in create_fun_defn (Denv.create_user_id x) lam