Commit 48c29c43 authored by Andrei Paskevich's avatar Andrei Paskevich

keep user-supplied projections in algebraic types

parent 8ff52ea6
......@@ -13,6 +13,8 @@ transformation "eliminate_non_struct_recursion"
transformation "eliminate_if"
transformation "eliminate_projections"
transformation "simplify_formula"
(*transformation "simplify_trivial_quantification_in_goal"*)
......@@ -96,7 +98,7 @@ end
(* removed: Coq Zdiv is NOT true Euclidean division:
Zmod can be negative, in fact (Zmod x y) has the same sign as y,
which is not the usual convention of programming language either.
which is not the usual convention of programming language either.
theory int.EuclideanDivision
......
......@@ -15,6 +15,7 @@ transformation "eliminate_non_struct_recursion"
(* PVS only has simple patterns *)
transformation "compile_match"
transformation "eliminate_projections"
transformation "simplify_formula"
......
......@@ -25,9 +25,12 @@ open Term
(** Type declaration *)
type constructor = lsymbol * lsymbol option list
(** constructor symbol with the list of projections *)
type ty_defn =
| Tabstract
| Talgebraic of lsymbol list
| Talgebraic of constructor list
type ty_decl = tysymbol * ty_defn
......@@ -319,9 +322,12 @@ module Hsdecl = Hashcons.Make (struct
type t = decl
let cs_equal (cs1,pl1) (cs2,pl2) =
ls_equal cs1 cs2 && list_all2 (option_eq ls_equal) pl1 pl2
let eq_td (ts1,td1) (ts2,td2) = ts_equal ts1 ts2 && match td1,td2 with
| Tabstract, Tabstract -> true
| Talgebraic l1, Talgebraic l2 -> list_all2 ls_equal l1 l2
| Talgebraic l1, Talgebraic l2 -> list_all2 cs_equal l1 l2
| _ -> false
let eq_ld (ls1,ld1) (ls2,ld2) = ls_equal ls1 ls2 && match ld1,ld2 with
......@@ -343,9 +349,12 @@ module Hsdecl = Hashcons.Make (struct
k1 = k2 && pr_equal pr1 pr2 && t_equal f1 f2
| _,_ -> false
let cs_hash (cs,pl) =
Hashcons.combine_list (Hashcons.combine_option ls_hash) (ls_hash cs) pl
let hs_td (ts,td) = match td with
| Tabstract -> ts_hash ts
| Talgebraic l -> 1 + Hashcons.combine_list ls_hash (ts_hash ts) l
| Talgebraic l -> 1 + Hashcons.combine_list cs_hash (ts_hash ts) l
let hs_ld (ls,ld) = Hashcons.combine (ls_hash ls)
(Hashcons.combine_option (fun (_,f) -> t_hash f) ld)
......@@ -392,6 +401,11 @@ let mk_decl node syms news = Hsdecl.hashcons {
exception IllegalTypeAlias of tysymbol
exception ClashIdent of ident
exception BadLogicDecl of lsymbol * lsymbol
exception BadConstructor of lsymbol
exception BadRecordField of lsymbol
exception RecordFieldMissing of lsymbol
exception DuplicateRecordField of lsymbol
exception EmptyDecl
exception EmptyAlgDecl of tysymbol
......@@ -411,8 +425,21 @@ let create_ty_decl tdl =
if tdl = [] then raise EmptyDecl;
let add s (ts,_) = Sts.add ts s in
let tss = List.fold_left add Sts.empty tdl in
let check_constr tys ty (syms,news) fs =
ty_equal_check ty (of_option fs.ls_value);
let check_proj tyv s tya ls = match ls with
| None -> s
| Some ({ ls_args = [ptyv]; ls_value = Some ptya } as ls) ->
ty_equal_check tyv ptyv;
ty_equal_check tya ptya;
Sls.add_new (DuplicateRecordField ls) ls s
| Some ls -> raise (BadRecordField ls)
in
let check_constr tys ty pjs (syms,news) (fs,pl) =
ty_equal_check ty (exn_option (BadConstructor fs) fs.ls_value);
let fs_pjs =
try List.fold_left2 (check_proj ty) Sls.empty fs.ls_args pl
with Invalid_argument _ -> raise (BadConstructor fs) in
if not (Sls.equal pjs fs_pjs) then
raise (RecordFieldMissing (Sls.choose (Sls.diff pjs fs_pjs)));
let vs = ty_freevars Stv.empty ty in
let rec check seen ty = match ty.ty_node with
| Tyvar v when Stv.mem v vs -> ()
......@@ -435,8 +462,11 @@ let create_ty_decl tdl =
if cl = [] then raise (EmptyAlgDecl ts);
if ts.ts_def <> None then raise (IllegalTypeAlias ts);
let news = news_id news ts.ts_name in
let pjs = List.fold_left (fun s (_,pl) -> List.fold_left
(option_fold (fun s ls -> Sls.add ls s)) s pl) Sls.empty cl in
let news = Sls.fold (fun pj s -> news_id s pj.ls_name) pjs news in
let ty = ty_app ts (List.map ty_var ts.ts_args) in
List.fold_left (check_constr ts ty) (syms,news) cl
List.fold_left (check_constr ts ty pjs) (syms,news) cl
in
let (syms,news) = List.fold_left check_decl (Sid.empty,Sid.empty) tdl in
mk_decl (Dtype tdl) syms news
......@@ -658,6 +688,7 @@ exception NonExhaustiveCase of pattern list * term
let rec check_matchT kn () t = match t.t_node with
| Tcase (t1,bl) ->
let bl = List.map (fun b -> let p,t = t_open_branch b in [p],t) bl in
let find_constructors kn ts = List.map fst (find_constructors kn ts) in
ignore (try Pattern.CompileTerm.compile (find_constructors kn) [t1] bl
with Pattern.NonExhaustive p -> raise (NonExhaustiveCase (p,t)));
t_fold (check_matchT kn) () t
......@@ -679,7 +710,7 @@ let rec check_foundness kn d =
we can build a value of this type *)
let tss = Sts.add ts tss in
List.exists (check_constr tss tvs) cl
and check_constr tss tvs ls =
and check_constr tss tvs (ls,_) =
(* we can construct a value iff every
argument is of an inhabited type *)
List.for_all (check_type tss tvs) ls.ls_args
......@@ -718,7 +749,7 @@ let rec ts_extract_pos kn sts ts =
if pos then get_ty acc else ty_freevars acc in
List.fold_left2 get stv (ts_extract_pos kn sts ts) tl
in
let get_cs acc ls = List.fold_left get_ty acc ls.ls_args in
let get_cs acc (ls,_) = List.fold_left get_ty acc ls.ls_args in
let negs = List.fold_left get_cs Stv.empty csl in
List.map (fun v -> not (Stv.mem v negs)) ts.ts_args
......@@ -726,7 +757,7 @@ let check_positivity kn d = match d.d_node with
| Dtype tdl ->
let add s (ts,_) = Sts.add ts s in
let tss = List.fold_left add Sts.empty tdl in
let check_constr tys cs =
let check_constr tys (cs,_) =
let rec check_ty ty = match ty.ty_node with
| Tyvar _ -> ()
| Tyapp (ts,tl) ->
......@@ -753,3 +784,44 @@ let known_add_decl kn d =
check_match kn d;
kn
(** Records *)
exception EmptyRecord
let parse_record kn fll =
let fs = match fll with
| [] -> raise EmptyRecord
| (fs,_)::_ -> fs in
let ts = match fs.ls_args with
| [{ ty_node = Tyapp (ts,_) }] -> ts
| _ -> raise (BadRecordField fs) in
let cs, pjl = match find_constructors kn ts with
| [cs,pjl] -> cs, List.map (exn_option (BadRecordField fs)) pjl
| _ -> raise (BadRecordField fs) in
let pjs = List.fold_left (fun s pj -> Sls.add pj s) Sls.empty pjl in
let flm = List.fold_left (fun m (pj,v) ->
if not (Sls.mem pj pjs) then raise (BadRecordField pj) else
Mls.add_new (DuplicateRecordField pj) pj v m) Mls.empty fll in
cs,pjl,flm
let make_record kn fll ty =
let cs,pjl,flm = parse_record kn fll in
let get_arg pj = Mls.find_exn (RecordFieldMissing pj) pj flm in
fs_app cs (List.map get_arg pjl) ty
let make_record_update kn t fll ty =
let cs,pjl,flm = parse_record kn fll in
let get_arg pj = match Mls.find_opt pj flm with
| Some v -> v
| None -> t_app_infer pj [t] in
fs_app cs (List.map get_arg pjl) ty
let make_record_pattern kn fll ty =
let cs,pjl,flm = parse_record kn fll in
let s = ty_match Mtv.empty (of_option cs.ls_value) ty in
let get_arg pj = match Mls.find_opt pj flm with
| Some v -> v
| None -> pat_wild (ty_inst s (of_option pj.ls_value))
in
pat_app cs (List.map get_arg pjl) ty
......@@ -27,9 +27,12 @@ open Term
(** {2 Type declaration} *)
type constructor = lsymbol * lsymbol option list
(** constructor symbol with the list of projections *)
type ty_defn =
| Tabstract
| Talgebraic of lsymbol list
| Talgebraic of constructor list
type ty_decl = tysymbol * ty_defn
......@@ -135,6 +138,11 @@ exception EmptyDecl
exception EmptyAlgDecl of tysymbol
exception EmptyIndDecl of lsymbol
exception BadConstructor of lsymbol
exception BadRecordField of lsymbol
exception RecordFieldMissing of lsymbol
exception DuplicateRecordField of lsymbol
(** {2 Utilities} *)
val decl_map : (term -> term) -> decl -> decl
......@@ -168,9 +176,29 @@ exception NonExhaustiveCase of pattern list * term
exception NonFoundedTypeDecl of tysymbol
val find_type_definition : known_map -> tysymbol -> ty_defn
val find_constructors : known_map -> tysymbol -> lsymbol list
val find_constructors : known_map -> tysymbol -> constructor list
val find_inductive_cases : known_map -> lsymbol -> (prsymbol * term) list
val find_logic_definition : known_map -> lsymbol -> ls_defn option
val find_prop : known_map -> prsymbol -> term
val find_prop_decl : known_map -> prsymbol -> prop_kind * term
(** Records *)
exception EmptyRecord
val parse_record :
known_map -> (lsymbol * 'a) list -> lsymbol * lsymbol list * 'a Mls.t
(** [parse_record kn field_list] takes a list of record field assignments,
checks it for well-formedness and returns the corresponding constructor,
the full list of projection symbols, and the map from projection symbols
to assigned values. *)
val make_record :
known_map -> (lsymbol * term) list -> ty -> term
val make_record_update :
known_map -> term -> (lsymbol * term) list -> ty -> term
val make_record_pattern :
known_map -> (lsymbol * pattern) list -> ty -> pattern
......@@ -296,13 +296,16 @@ let print_tv_arg fmt tv = fprintf fmt "@ %a" print_tv tv
let print_ty_arg fmt ty = fprintf fmt "@ %a" (print_ty_node true) ty
let print_vs_arg fmt vs = fprintf fmt "@ (%a)" print_vsty vs
let print_constr ty fmt cs =
let ty_val = of_option cs.ls_value in
let m = ty_match Mtv.empty ty_val ty in
let tl = List.map (ty_inst m) cs.ls_args in
let print_constr fmt (cs,pjl) =
let add_pj pj ty pjl = (pj,ty)::pjl in
let print_pj fmt (pj,ty) = match pj with
| Some ls -> fprintf fmt "@ (%a:@,%a)" print_ls ls print_ty ty
| None -> print_ty_arg fmt ty
in
fprintf fmt "@[<hov 4>| %a%a%a@]" print_cs cs
print_ident_labels cs.ls_name
(print_list nothing print_ty_arg) tl
(print_list nothing print_pj)
(List.fold_right2 add_pj pjl cs.ls_args [])
let print_type_decl fst fmt (ts,def) = match def with
| Tabstract -> begin match ts.ts_def with
......@@ -318,12 +321,11 @@ let print_type_decl fst fmt (ts,def) = match def with
(print_list nothing print_tv_arg) ts.ts_args print_ty ty
end
| Talgebraic csl ->
let ty = ty_app ts (List.map ty_var ts.ts_args) in
fprintf fmt "@[<hov 2>%s %a%a%a =@\n@[<hov>%a@]@]"
(if fst then "type" else "with") print_ts ts
print_ident_labels ts.ts_name
(print_list nothing print_tv_arg) ts.ts_args
(print_list newline (print_constr ty)) csl
(print_list newline print_constr) csl
let print_type_decl first fmt d =
print_type_decl first fmt d; forget_tvs ()
......@@ -534,6 +536,14 @@ let () = Exn_printer.register
| Pattern.NonExhaustive pl ->
fprintf fmt "Non-exhaustive pattern list:@\n@[<hov 2>%a@]"
(print_list newline print_pat) pl
| Decl.BadConstructor ls ->
fprintf fmt "Bad constructor symbol: %a" print_ls ls
| Decl.BadRecordField ls ->
fprintf fmt "Not a record field: %a" print_ls ls
| Decl.RecordFieldMissing ls ->
fprintf fmt "Record field missing: %a" print_ls ls
| Decl.DuplicateRecordField ls ->
fprintf fmt "Duplicate record field: %a" print_ls ls
| Decl.IllegalTypeAlias ts ->
fprintf fmt
"Type symbol %a is a type alias and cannot be declared as algebraic"
......
......@@ -352,7 +352,12 @@ let add_symbol add id v uc =
| _ -> assert false
let add_type uc (ts,def) =
let add_constr uc fs = add_symbol add_ls fs.ls_name fs uc in
let add_proj uc = function
| Some pj -> add_symbol add_ls pj.ls_name pj uc
| None -> uc in
let add_constr uc (fs,pl) =
let uc = add_symbol add_ls fs.ls_name fs uc in
List.fold_left add_proj uc pl in
let uc = add_symbol add_ts ts.ts_name ts uc in
match def with
| Tabstract -> uc
......@@ -513,11 +518,14 @@ let cl_init th inst =
(* clone declarations *)
let cl_type cl inst tdl =
let add_constr ls =
let add_ls ls =
if Mls.mem ls inst.inst_ls
then raise (CannotInstantiate ls.ls_name)
else cl_find_ls cl ls
in
let add_constr (ls,pl) =
add_ls ls, List.map (option_map add_ls) pl
in
let add_type (ts,td) acc =
if Mts.mem ts inst.inst_ts then
if ts.ts_def = None && td = Tabstract then acc
......@@ -749,7 +757,7 @@ let create_theory ?(path=[]) n =
let bool_theory =
let uc = empty_theory (id_fresh "Bool") [] in
let uc = add_ty_decl uc [ts_bool, Talgebraic [fs_true; fs_false]] in
let uc = add_ty_decl uc [ts_bool, Talgebraic [fs_true,[]; fs_false,[]]] in
close_theory uc
let highord_theory =
......@@ -761,8 +769,10 @@ let highord_theory =
close_theory uc
let tuple_theory = Util.memo_int 17 (fun n ->
let ts = ts_tuple n and fs = fs_tuple n in
let pl = List.map (fun _ -> None) ts.ts_args in
let uc = empty_theory (id_fresh ("Tuple" ^ string_of_int n)) [] in
let uc = add_ty_decl uc [ts_tuple n, Talgebraic [fs_tuple n]] in
let uc = add_ty_decl uc [ts, Talgebraic [fs,pl]] in
close_theory uc)
let tuple_theory_name s =
......
This diff is collapsed.
......@@ -53,7 +53,7 @@ val specialize_psymbol :
Ptree.qualid -> theory_uc -> lsymbol * Denv.dty list
val specialize_tysymbol :
Loc.position -> Ptree.qualid -> theory_uc -> Ty.tysymbol * int
Loc.position -> Ptree.qualid -> theory_uc -> Ty.tysymbol
type denv
......@@ -84,6 +84,7 @@ val split_qualid : Ptree.qualid -> string list * string
val string_list_of_qualid : string list -> Ptree.qualid -> string list
val qloc : Ptree.qualid -> Loc.position
(*
val is_projection : theory_uc -> lsymbol -> (tysymbol * lsymbol * int) option
(** [is_projection uc ls] returns
- [Some (ts, lsc, i)] if [ls] is the i-th projection of an
......@@ -94,5 +95,6 @@ val list_fields: theory_uc ->
(Ptree.qualid * 'a) list -> tysymbol * lsymbol * (Ptree.loc * 'a) option list
(** check that the given fields all belong to the same record type
and do not appear several times *)
*)
val type_inst: theory_uc -> theory -> Ptree.clone_subst list -> th_inst
......@@ -345,7 +345,7 @@ let print_expr info fmt =
(** Declarations *)
let print_constr info ts fmt cs =
let print_constr info ts fmt (cs,_) =
match cs.ls_args with
| [] ->
fprintf fmt "@[<hov 4>| %a : %a %a@]" print_ls cs
......@@ -559,7 +559,7 @@ let print_type_decl ~old info fmt (ts,def) =
name (print_list space print_tv_binder) ts.ts_args
(print_list newline (print_constr info ts)) csl;
List.iter
(fun cs ->
(fun (cs,_) ->
let ty_vars_args, ty_vars_value, all_ty_params = ls_ty_vars cs in
print_implicits fmt cs ty_vars_args ty_vars_value all_ty_params)
csl;
......
......@@ -443,7 +443,7 @@ let print_expr info fmt =
(** Declarations *)
let print_constr info _ts fmt cs =
let print_constr info _ts fmt (cs,_) =
match cs.ls_args with
| [] ->
fprintf fmt "@[<hov 4>%a: %a?@]" print_ls cs print_ls cs
......
......@@ -237,13 +237,16 @@ let print_tv_arg fmt tv = fprintf fmt "@ %a" print_tv tv
let print_ty_arg fmt ty = fprintf fmt "@ %a" (print_ty_node true) ty
let print_vs_arg fmt vs = fprintf fmt "@ (%a)" print_vsty vs
let print_constr ty fmt cs =
let ty_val = of_option cs.ls_value in
let m = ty_match Mtv.empty ty_val ty in
let tl = List.map (ty_inst m) cs.ls_args in
let print_constr fmt (cs,pjl) =
let add_pj pj ty pjl = (pj,ty)::pjl in
let print_pj fmt (pj,ty) = match pj with
| Some ls -> fprintf fmt "@ (%a:@,%a)" print_ls ls print_ty ty
| None -> print_ty_arg fmt ty
in
fprintf fmt "@[<hov 4>| %a%a%a@]" print_cs cs
print_ident_labels cs.ls_name
(print_list nothing print_ty_arg) tl
(print_list nothing print_pj)
(List.fold_right2 add_pj pjl cs.ls_args [])
let print_type_decl fst fmt (ts,def) = match def with
| Tabstract -> begin match ts.ts_def with
......@@ -259,12 +262,11 @@ let print_type_decl fst fmt (ts,def) = match def with
(print_list nothing print_tv_arg) ts.ts_args print_ty ty
end
| Talgebraic csl ->
let ty = ty_app ts (List.map ty_var ts.ts_args) in
fprintf fmt "@[<hov 2>%s %a%a%a =@\n@[<hov>%a@]@]@\n@\n"
(if fst then "type" else "with") print_ts ts
print_ident_labels ts.ts_name
(print_list nothing print_tv_arg) ts.ts_args
(print_list newline (print_constr ty)) csl
(print_list newline print_constr) csl
let print_type_decl first fmt d =
if not (query_remove (fst d).ts_name) then
......
......@@ -143,7 +143,8 @@ let print_tv_args fmt = function
let print_ty_arg fmt ty = fprintf fmt "%a" (print_ty_node true) ty
let print_vs_arg fmt vs = fprintf fmt "(%a)" print_vsty vs
let print_constr ty fmt cs =
(* FIXME: print projections! *)
let print_constr ty fmt (cs,_) =
let ty_val = of_option cs.ls_value in
let m = ty_match Mtv.empty ty_val ty in
let tl = List.map (ty_inst m) cs.ls_args in
......
......@@ -36,26 +36,8 @@ open Pgm_module
let debug = Debug.register_flag "program_typing"
let is_debug () = Debug.test_flag debug
exception Message of string
let error ?loc e = match loc with
| None -> raise e
| Some loc -> raise (Loc.Located (loc, e))
let errorm ?loc f =
let buf = Buffer.create 512 in
let fmt = Format.formatter_of_buffer buf in
Format.kfprintf
(fun _ ->
Format.pp_print_flush fmt ();
let s = Buffer.contents buf in
Buffer.clear buf;
error ?loc (Message s))
fmt f
let () = Exn_printer.register (fun fmt e -> match e with
| Message s -> fprintf fmt "%s" s
| _ -> raise e)
let error = Loc.error
let errorm = Loc.errorm
let id_result = "result"
......@@ -292,7 +274,8 @@ let rec dtype ~user env = function
tyvar (Typing.find_user_type_var x env.denv)
| PPTtyapp (p, x) ->
let loc = Typing.qloc x in
let ts, a = Typing.specialize_tysymbol loc x (impure_uc env.uc) in
let ts = Typing.specialize_tysymbol loc x (impure_uc env.uc) in
let a = List.length ts.ts_args in
let mt = get_mtsymbol ts in
let np = List.length p in
if np <> a - mt.mt_regions then
......@@ -368,6 +351,35 @@ let rec extract_labels labs loc e = match e.Ptree.expr_desc with
labs, loc, Ptree.Ecast ({ e with Ptree.expr_desc = d }, ty)
| e -> List.rev labs, loc, e
(* compatibility functions from Typing *)
let find_qualid_ls uc p =
let loc = Typing.qloc p in
let sl = Typing.string_list_of_qualid [] p in
try ns_find_ls (get_namespace uc) sl with Not_found ->
errorm ~loc "unbound symbol %a" print_qualid p
let is_projection uc ls =
try
let ts = match ls.ls_args with
| [{ty_node = Ty.Tyapp (ts,_)}] -> ts
| _ -> raise Exit in
match Decl.find_constructors (get_known uc) ts with
| [cs,pjl] ->
let find (i,r) = function
| Some pj when ls_equal ls pj -> (succ i, i)
| _ -> (succ i, r) in
let (_,r) = List.fold_left find (0,-1) pjl in
if r < 0 then None else Some (ts,cs,r)
| _ -> None
with Exit -> None
let list_fields uc fl =
let field (q,e) = find_qualid_ls uc q, (Typing.qloc q, e) in
let cs,pjl,flm = Decl.parse_record (get_known uc) (List.map field fl) in
cs, List.map (fun pj -> Mls.find_opt pj flm) pjl
(* [dexpr] translates ptree into dexpr *)
let rec dexpr ~ghost ~userloc env e =
......@@ -494,7 +506,7 @@ and dexpr_desc ~ghost ~userloc env loc = function
let e = List.fold_left2 apply e el tyl in
e.dexpr_desc, ty
| Ptree.Erecord fl ->
let _, cs, fl = Typing.list_fields (impure_uc env.uc) fl in
let cs, fl = list_fields (impure_uc env.uc) fl in
new_regions_vars ();
let tyl, ty = specialize_lsymbol ~loc (Htv.create 17) cs in
let ty = of_option ty in
......@@ -519,7 +531,7 @@ and dexpr_desc ~ghost ~userloc env loc = function
d.dexpr_desc, ty
| Ptree.Eupdate (e1, fl) ->
let e1 = dexpr ~ghost ~userloc env e1 in
let _, cs, fl = Typing.list_fields (impure_uc env.uc) fl in
let cs, fl = list_fields (impure_uc env.uc) fl in
let tyl, ty = Denv.specialize_lsymbol ~loc cs in
let ty = of_option ty in
expected_type e1 ty;
......@@ -589,7 +601,7 @@ and dexpr_desc ~ghost ~userloc env loc = function
| _ ->
assert false
end;
begin match Typing.is_projection (impure_uc env.uc) ls with
begin match is_projection (impure_uc env.uc) ls with
| Some (ts, _, i) ->
let mt = get_mtsymbol ts in
let j =
......@@ -844,7 +856,7 @@ let iuregion env ({ pp_loc = loc; pp_desc = d } as t) = match d with
| PPapp (f, [t]) ->
let th = effect_uc env.i_uc in
let ls, _, _ = Typing.specialize_lsymbol f th in
begin match Typing.is_projection th ls with
begin match is_projection th ls with
| Some (ts, _, i) ->
let j =
try
......@@ -2189,7 +2201,7 @@ let add_types uc dl =
begin match Decl.find_constructors km ts with
| [] -> (* abstract *)
()
| [ls] -> (* record *)
| [ls,_] -> (* record *)
add_logic_ps ~nofail:true uc ls.ls_name.id_string;
let field i ty =
if Hashtbl.mem mutable_field (x, i) then
......@@ -2198,7 +2210,7 @@ let add_types uc dl =
in
list_iteri field ls.ls_args
| cl -> (* algebraic *)
let constructor ls =
let constructor (ls,_) =
add_logic_ps ~nofail:true uc ls.ls_name.id_string;
List.iter visit_type ls.ls_args
in
......
......@@ -164,7 +164,7 @@ let rec update env mreg x ty =
if cl = [] then failwith "WP: cannot update a value of this type";
(* TODO: print the type *)
let s = get_ty_subst ty in
let branch cs =
let branch (cs,_) =
let cs_pure = (get_psymbol cs).ps_pure in
let mk_var ty =
let ty = ty_inst s ty in
......
This diff is collapsed.
......@@ -34,11 +34,38 @@ let unfold def tl ty =
t_ty_subst mt mv e
let is_constructor kn ls = match Mid.find ls.ls_name kn with
| { d_node = Dtype _ } -> true
| { d_node = Dtype dl } ->
let constr = function
| _, Talgebraic csl -> List.exists (fun (cs,_) -> ls_equal cs ls) csl
| _, Tabstract -> false in
List.exists constr dl
| _ -> false
let is_projection kn ls = match Mid.find ls.ls_name kn with
| { d_node = Dtype dl } ->
let constr = function
| _, Talgebraic csl -> List.exists (fun (cs,_) -> ls_equal cs ls) csl
| _, Tabstract -> false in
not (List.exists constr dl)
| _ -> false
let apply_projection kn ls t = t_label_copy t (match t.t_node with
| Tapp (cs,tl) ->
let ts = match cs.ls_value with
| Some { ty_node = Tyapp (ts,_) } -> ts
| _ -> assert false in
let pjl =
try List.assq cs (find_constructors kn ts)
with Not_found -> assert false in
let find acc v = function
| Some pj when ls_equal pj ls -> v
| _ -> acc in
List.fold_left2 find t_true tl pjl
| _ -> assert false)
let make_flat_case kn t bl =
let mk_b b = let p,t = t_open_branch b in [p],t in
let find_constructors kn ts = List.map fst (find_constructors kn ts) in
Pattern.CompileTerm.compile (find_constructors kn) [t] (List.map mk_b bl)
let rec add_quant kn (vl,tl,f) v =
......@@ -48,7 +75,7 @@ let rec add_quant kn (vl,tl,f) v =
| _ -> []
in
match cl with
| [ls] ->
| [ls,_] ->
let s = ty_match Mtv.empty (Util.of_option ls.ls_value) ty in
let mk_v ty = create_vsymbol (id_clone v.vs_name) (ty_inst s ty) in
let nvl = List.map mk_v ls.ls_args in
......@@ -82,6 +109,11 @@ let dive_to_constructor kn fn env t =
let eval_match ~inline kn t =
let rec eval env t = t_label_copy t (match t.t_node with
| Tapp (ls, [t1]) when is_projection kn ls ->
let t1 = eval env t1 in
let fn _env t = apply_projection kn ls t in
begin try dive_to_constructor kn fn env t1
with Exit -> t_app ls [t1] t.t_ty end
| Tapp (ls, tl) when inline kn ls (List.map t_type tl) t.t_ty ->
begin match find_logic_definition kn ls with
| None -> t_map (eval env) t
......
......@@ -115,7 +115,7 @@ let elt d =
end
| Talgebraic l ->
List.fold_left
(fun acc {ls_args = tyl; ls_value = ty} ->
(fun acc ({ls_args = tyl; ls_value = ty},_) ->
let ty = of_option ty in
List.fold_left
(fun acc ty -> ty_s_fold tyoccurences acc ty)
......
......@@ -227,7 +227,8 @@ let add_pdecl uc d =
match d.pd_node with
| PDtype dl ->
let uc = List.fold_left add_type uc dl in
let constructor (ps, _) = ps.ls in
let projection = option_map (fun ps -> ps.ls) in
let constructor (ps,pjl) = ps.ls, List.map projection pjl in
let defn = function
| ITabstract -> Decl.Tabstract
| ITalgebraic cl -> Decl.Talgebraic (List.map constructor cl)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment