better typing of match branches

a branch is typed *after* unifying the type of the pattern with the
type of the matched value

fixes issue #124 (return type and coercions)

done in both logic and programs
parent e0cd1156
......@@ -221,45 +221,10 @@ and dterm_node =
| DTuloc of dterm * Loc.position
| DTlabel of dterm * Slab.t
(** Environment *)
type denv = dterm_node Mstr.t
(** Unification tools *)
exception TermExpected
exception FmlaExpected
exception DuplicateVar of string
exception UnboundVar of string
let denv_get denv n = Mstr.find_exn (UnboundVar n) n denv
let denv_get_opt denv n = Mstr.find_opt n denv
let dty_of_dterm dt = Opt.get_def dty_bool dt.dt_dty
let denv_empty = Mstr.empty
let denv_add_var denv {pre_name = n} dty =
Mstr.add n (DTvar (n, dty)) denv
let denv_add_let denv dt {pre_name = n} =
Mstr.add n (DTvar (n, dty_of_dterm dt)) denv
let denv_add_quant denv vl =
let add acc (id,dty,_) = match id with
| Some ({pre_name = n} as id) ->
let exn = match id.pre_loc with
| Some loc -> Loc.Located (loc, DuplicateVar n)
| None -> DuplicateVar n in
Mstr.add_new exn n (DTvar (n,dty)) acc
| None -> acc in
let s = List.fold_left add Mstr.empty vl in
Mstr.set_union s denv
let denv_add_pat denv dp =
let s = Mstr.mapi (fun n dty -> DTvar (n, dty)) dp.dp_vars in
Mstr.set_union s denv
(** Unification tools *)
let dty_unify_app ls unify (l1: 'a list) (l2: dty list) =
try List.iter2 unify l1 l2 with Invalid_argument _ ->
......@@ -294,6 +259,46 @@ let dexpr_expected_type dt dty = match dty with
| Some dty -> dterm_expected_type dt dty
| None -> dfmla_expected_type dt
(** Environment *)
type denv = dterm_node Mstr.t
exception DuplicateVar of string
exception UnboundVar of string
let denv_get denv n = Mstr.find_exn (UnboundVar n) n denv
let denv_get_opt denv n = Mstr.find_opt n denv
let dty_of_dterm dt = Opt.get_def dty_bool dt.dt_dty
let denv_empty = Mstr.empty
let denv_add_var denv {pre_name = n} dty =
Mstr.add n (DTvar (n, dty)) denv
let denv_add_let denv dt {pre_name = n} =
Mstr.add n (DTvar (n, dty_of_dterm dt)) denv
let denv_add_quant denv vl =
let add acc (id,dty,_) = match id with
| Some ({pre_name = n} as id) ->
let exn = match id.pre_loc with
| Some loc -> Loc.Located (loc, DuplicateVar n)
| None -> DuplicateVar n in
Mstr.add_new exn n (DTvar (n,dty)) acc
| None -> acc in
let s = List.fold_left add Mstr.empty vl in
Mstr.set_union s denv
let denv_add_pat denv dp dty =
dpat_expected_type dp dty;
let s = Mstr.mapi (fun n dty -> DTvar (n, dty)) dp.dp_vars in
Mstr.set_union s denv
let denv_add_term_pat denv dp dt =
denv_add_pat denv dp (Opt.get_def dty_bool dt.dt_dty)
(** Constructors *)
let dpattern ?loc node =
......@@ -455,11 +460,8 @@ let dterm crcmap ?loc node =
mk_dty df.dt_dty
| DTcase (_,[]) ->
raise EmptyCase
| DTcase (dt,(dp1,df1)::bl) ->
dterm_expected_type dt dp1.dp_dty;
let check (dp,df) =
dpat_expected_type dp dp1.dp_dty;
dexpr_expected_type df df1.dt_dty in
| DTcase (_,(_,df1)::bl) ->
let check (_,df) = dexpr_expected_type df df1.dt_dty in
List.iter check bl;
let is_fmla (_,df) = df.dt_dty = None in
if List.exists is_fmla bl then mk_dty None else mk_dty df1.dt_dty
......
......@@ -102,7 +102,8 @@ val denv_add_let : denv -> dterm -> preid -> denv
val denv_add_quant : denv -> dbinder list -> denv
val denv_add_pat : denv -> dpattern -> denv
val denv_add_pat : denv -> dpattern -> dty -> denv
val denv_add_term_pat : denv -> dpattern -> dterm -> denv
val denv_get : denv -> string -> dterm_node (** raises UnboundVar *)
......
......@@ -339,6 +339,14 @@ let specialize_ls {ls_args = args; ls_value = res} =
let spec_arg ty = dity_sim (spec_val () ty) in
List.map spec_arg args, Opt.fold spec_val dity_bool res
type dxsymbol =
| DElexn of string * dity
| DEgexn of xsymbol
let specialize_dxs = function
| DEgexn xs -> specialize_single xs.xs_ity
| DElexn (_,dity) -> dity
(** Patterns *)
type dpattern = {
......@@ -391,10 +399,6 @@ let old_mark = "'Old"
type dinvariant = term list
type dxsymbol =
| DElexn of string * dity
| DEgexn of xsymbol
type dexpr = {
de_node : dexpr_node;
de_dvty : dvty;
......@@ -446,6 +450,29 @@ and drec_defn = { fds : dfun_defn list }
and dfun_defn = preid * ghost * rs_kind * dbinder list *
dity * mask * dspec later * variant list later * dexpr
(** Unification tools *)
let dity_unify_app ls fn (l1: 'a list) (l2: dity list) =
try List.iter2 fn l1 l2 with Invalid_argument _ ->
raise (BadArity (ls, List.length l1))
let dpat_expected_type {dp_dity = dp_dity; dp_loc = loc} dity =
try dity_unify dp_dity dity with Exit -> Loc.errorm ?loc
"This pattern has type %a,@ but is expected to have type %a"
print_dity dp_dity print_dity dity
let dexpr_expected_type {de_dvty = dvty; de_loc = loc} dity =
let res = dity_of_dvty dvty in
try dity_unify res dity with Exit -> Loc.errorm ?loc
"This expression has type %a,@ but is expected to have type %a"
print_dity res print_dity dity
let dexpr_expected_type_weak {de_dvty = dvty; de_loc = loc} dity =
let res = dity_of_dvty dvty in
try dity_unify_weak res dity with Exit -> Loc.errorm ?loc
"This expression has type %a,@ but is expected to have type %a"
print_dity res print_dity dity
(** Environment *)
type denv = {
......@@ -537,11 +564,18 @@ let denv_add_args { frozen = fz; locals = ls; excpts = xs } bl =
let s = List.fold_left add Mstr.empty bl in
{ frozen = l; locals = Mstr.set_union s ls; excpts = xs }
let denv_add_pat { frozen = fz; locals = ls; excpts = xs } dp =
let denv_add_pat { frozen = fz; locals = ls; excpts = xs } dp dity =
dpat_expected_type dp dity;
let l = Mstr.fold (fun _ t l -> t::l) dp.dp_vars fz in
let s = Mstr.map (fun t -> false, None, ([], t)) dp.dp_vars in
{ frozen = l; locals = Mstr.set_union s ls; excpts = xs }
let denv_add_expr_pat denv dp de =
denv_add_pat denv dp (dity_of_dvty de.de_dvty)
let denv_add_exn_pat denv dp dxs =
denv_add_pat denv dp (specialize_dxs dxs)
let mk_node n = function
| _, Some tvs, dvty -> DEvar (n, specialize_scheme tvs dvty)
| _, None, dvty -> DEvar (n, dvty)
......@@ -591,29 +625,6 @@ let denv_pure denv get_dty =
let d = dity_fresh () in Hint.add hi i d; d in
dity_pur (Dterm.dty_fold fnS fnV fnI dty)
(** Unification tools *)
let dity_unify_app ls fn (l1: 'a list) (l2: dity list) =
try List.iter2 fn l1 l2 with Invalid_argument _ ->
raise (BadArity (ls, List.length l1))
let dpat_expected_type {dp_dity = dp_dity; dp_loc = loc} dity =
try dity_unify dp_dity dity with Exit -> Loc.errorm ?loc
"This pattern has type %a,@ but is expected to have type %a"
print_dity dp_dity print_dity dity
let dexpr_expected_type {de_dvty = dvty; de_loc = loc} dity =
let res = dity_of_dvty dvty in
try dity_unify res dity with Exit -> Loc.errorm ?loc
"This expression has type %a,@ but is expected to have type %a"
print_dity res print_dity dity
let dexpr_expected_type_weak {de_dvty = dvty; de_loc = loc} dity =
let res = dity_of_dvty dvty in
try dity_unify_weak res dity with Exit -> Loc.errorm ?loc
"This expression has type %a,@ but is expected to have type %a"
print_dity res print_dity dity
(** Generation of letrec blocks *)
type pre_fun_defn = preid * ghost * rs_kind * dbinder list *
......@@ -688,10 +699,6 @@ let dpattern ?loc node =
in
Loc.try1 ?loc dpat node
let specialize_dxs = function
| DEgexn xs -> specialize_single xs.xs_ity
| DElexn (_,dity) -> dity
let dexpr ?loc node =
let get_dvty = function
| DEvar (_,dvty) ->
......@@ -763,15 +770,9 @@ let dexpr ?loc node =
invalid_arg "Dexpr.dexpr: empty branch list in DEmatch"
| DEmatch (de,bl,xl) ->
let res = dity_fresh () in
let ety = if bl = [] then
res else dity_fresh () in
dexpr_expected_type de ety;
List.iter (fun (dp,de) ->
dpat_expected_type dp ety;
dexpr_expected_type de res) bl;
List.iter (fun (xs,dp,de) ->
dpat_expected_type dp (specialize_dxs xs);
dexpr_expected_type de res) xl;
if bl = [] then dexpr_expected_type de res;
List.iter (fun (_,de) -> dexpr_expected_type de res) bl;
List.iter (fun (_,_,de) -> dexpr_expected_type de res) xl;
[], res
| DEassign al ->
List.iter (fun (de1,rs,de2) ->
......
......@@ -156,7 +156,9 @@ val denv_add_let : denv -> dlet_defn -> denv
val denv_add_args : denv -> dbinder list -> denv
val denv_add_pat : denv -> dpattern -> denv
val denv_add_pat : denv -> dpattern -> dity -> denv
val denv_add_expr_pat : denv -> dpattern -> dexpr -> denv
val denv_add_exn_pat : denv -> dpattern -> dxsymbol -> denv
val denv_add_for_index : denv -> preid -> dvty -> denv
......
......@@ -294,7 +294,7 @@ let rec dterm ns km crcmap gvars at denv {term_desc = desc; term_loc = loc} =
let e1 = dterm ns km crcmap gvars at denv e1 in
let branch (p, e) =
let p = dpattern ns km p in
let denv = denv_add_pat denv p in
let denv = denv_add_term_pat denv p e1 in
p, dterm ns km crcmap gvars at denv e in
DTcase (e1, List.map branch bl)
| Ptree.Tif (e1, e2, e3) ->
......@@ -834,7 +834,7 @@ let rec dexpr muc denv {expr_desc = desc; expr_loc = loc} =
let e1 = dexpr muc denv e1 in
let rbranch (pp, e) =
let pp = dpattern muc pp in
let denv = denv_add_pat denv pp in
let denv = denv_add_expr_pat denv pp e1 in
pp, dexpr muc denv e in
let xbranch (q, pp, e) =
let xs = find_dxsymbol q in
......@@ -845,7 +845,7 @@ let rec dexpr muc denv {expr_desc = desc; expr_loc = loc} =
| Some pp -> dpattern muc pp
| None when mb_unit -> Dexpr.dpattern ~loc (DPapp (rs_void, []))
| _ -> Loc.errorm ~loc "exception argument expected" in
let denv = denv_add_pat denv pp in
let denv = denv_add_exn_pat denv pp xs in
let e = dexpr muc denv e in
xs, pp, e in
DEmatch (e1, List.map rbranch bl, List.map xbranch xl)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment