Commit d24283f3 authored by Andrei Paskevich's avatar Andrei Paskevich
Browse files

Dexpr: specialization and environment extenstion

parent f8957d2e
......@@ -37,49 +37,45 @@ and rvar =
| Rtvs of tvsymbol * dity
| Rval of dreg
type dvty = dity list * dity (* A -> B -> C == ([A;B],C) *)
let create_dreg dity =
Rvar (ref (Rtvs (create_tvsymbol (id_fresh "rho"), dity)))
let dity_of_ity ity =
let hreg = Hreg.create 5 in
let rec dity_of_ity ity = match ity.ity_node with
let hreg = Hreg.create 3 in
let rec dity ity = match ity.ity_node with
| Ityvar tv -> Dutv tv
| Itypur (ts,tl) -> Dpur (ts, List.map dity_of_ity tl)
| Ityapp (its,tl,rl) ->
Dapp (its, List.map dity_of_ity tl, List.map dreg_of_reg rl)
and dreg_of_reg r =
try Hreg.find hreg r with Not_found ->
let dreg = create_dreg (dity_of_ity r.reg_ity) in
Hreg.add hreg r dreg;
dreg
| Ityapp (s,tl,rl) -> Dapp (s, List.map dity tl, List.map dreg rl)
| Itypur (s,tl) -> Dpur (s, List.map dity tl)
and dreg reg =
try Hreg.find hreg reg with Not_found ->
let r = create_dreg (dity reg.reg_ity) in
Hreg.add hreg reg r;
r
in
dity_of_ity ity
dity ity
let ity_of_dity ~strict dity =
let rec ity_of_dity = function
| Dvar { contents = Dval dty } -> ity_of_dity dty
let rec ity = function
| Dvar { contents = Dval t } -> ity t
| Dvar _ when strict -> Loc.errorm "undefined type variable"
| Dvar r ->
| Dvar ref ->
let tv = create_tvsymbol (id_fresh "xi") in
r := Dval (Dutv tv);
ref := Dval (Dutv tv);
ity_var tv
| Dapp (ts,dl,rl) ->
ity_app ts (List.map ity_of_dity dl) (List.map reg_of_dreg rl)
| Dpur (ts,dl) ->
ity_pur ts (List.map ity_of_dity dl)
| Dutv tv ->
ity_var tv
and reg_of_dreg = function
| Rvar { contents = Rval dreg } -> reg_of_dreg dreg
| Rvar ({ contents = Rtvs (tv,dty) } as r) ->
let reg = create_region (id_clone tv.tv_name) (ity_of_dity dty) in
r := Rval (Rreg (reg,dty));
reg
| Rreg (reg,_) -> reg
| Dutv tv -> ity_var tv
| Dapp (s,tl,rl) -> ity_app s (List.map ity tl) (List.map reg rl)
| Dpur (s,tl) -> ity_pur s (List.map ity tl)
and reg = function
| Rreg (r,_) -> r
| Rvar { contents = Rval r } -> reg r
| Rvar ({ contents = Rtvs (tv,t) } as ref) ->
let r = create_region (id_clone tv.tv_name) (ity t) in
ref := Rval (Rreg (r,t));
r
in
ity_of_dity dity
type dvty = dity list * dity (* A -> B -> C == ([A;B],C) *)
ity dity
(** Destructive type unification *)
......@@ -102,9 +98,9 @@ let rec unify d1 d2 = match d1,d2 with
r := Dval d
| Dutv tv1, Dutv tv2 when tv_equal tv1 tv2 ->
()
| Dapp (its1,dl1,_), Dapp (its2,dl2,_) when its_equal its1 its2 ->
| Dapp (s1,dl1,_), Dapp (s2,dl2,_) when its_equal s1 s2 ->
List.iter2 unify dl1 dl2
| Dpur (ts1,dl1), Dpur (ts2,dl2) when ts_equal ts1 ts2 ->
| Dpur (s1,dl1), Dpur (s2,dl2) when ts_equal s1 s2 ->
List.iter2 unify dl1 dl2
| _ -> raise Exit
......@@ -120,12 +116,12 @@ let dity_fresh () =
Dvar r
let its_app_fresh s dl =
let htv = Htv.create 5 in
let hreg = Hreg.create 5 in
let htv = Htv.create 3 in
let hreg = Hreg.create 3 in
let rec inst ity = match ity.ity_node with
| Ityvar v -> Htv.find htv v
| Itypur (s,tl) -> Dpur (s, List.map inst tl)
| Ityapp (s,tl,rl) -> Dapp (s, List.map inst tl, List.map fresh rl)
| Itypur (s,tl) -> Dpur (s, List.map inst tl)
and fresh r =
try Hreg.find hreg r with Not_found ->
let reg = create_dreg (inst r.reg_ity) in
......@@ -139,9 +135,9 @@ let its_app_fresh s dl =
let rec dity_refresh = function
| Dvar { contents = Dval dty } -> dity_refresh dty
| Dvar { contents = Dtvs _ } as dity -> dity
| Dapp (s,dl,_) -> its_app_fresh s (List.map dity_refresh dl)
| Dpur (s,dl) -> Dpur (s, List.map dity_refresh dl)
| Dutv _ as dity -> dity
| Dpur (ts,dl) -> Dpur (ts, List.map dity_refresh dl)
| Dapp (its,dl,_) -> its_app_fresh its (List.map dity_refresh dl)
let unify ?(weak=false) d1 d2 =
unify d1 d2;
......@@ -184,7 +180,19 @@ let reunify_regions () =
Queue.iter (fun (d1,d2) -> reunify d1 d2) unify_queue;
Queue.clear unify_queue
(* Pretty-printing *)
(** Chainable relations *)
let rec dity_is_bool = function
| Dvar { contents = Dval dty } -> dity_is_bool dty
| Dpur (ts,_) -> ts_equal ts ts_bool
| _ -> false
let dvty_is_chainable = function
| [t1;t2],t ->
dity_is_bool t && not (dity_is_bool t1) && not (dity_is_bool t2)
| _ -> false
(** Pretty-printing *)
let debug_print_reg_types = Debug.register_info_flag "print_reg_types"
~desc:"Print@ types@ of@ regions@ (mutable@ fields)."
......@@ -197,15 +205,15 @@ let print_dity fmt dity =
| Dvar { contents = Dtvs tv }
| Dutv tv -> Pretty.print_tv fmt tv
| Dvar { contents = Dval dty } -> print_dity inn fmt dty
| Dpur (ts,tl) when is_ts_tuple ts -> Format.fprintf fmt "(%a)"
| Dpur (s,tl) when is_ts_tuple s -> Format.fprintf fmt "(%a)"
(Pp.print_list Pp.comma (print_dity false)) tl
| Dpur (ts,[]) -> Pretty.print_ts fmt ts
| Dpur (ts,tl) -> Format.fprintf fmt (protect_on inn "%a@ %a")
Pretty.print_ts ts (Pp.print_list Pp.space (print_dity true)) tl
| Dapp (its,[],rl) -> Format.fprintf fmt (protect_on inn "%a@ <%a>")
Mlw_pretty.print_its its (Pp.print_list Pp.comma print_dreg) rl
| Dapp (its,tl,rl) -> Format.fprintf fmt (protect_on inn "%a@ <%a>@ %a")
Mlw_pretty.print_its its (Pp.print_list Pp.comma print_dreg) rl
| Dpur (s,[]) -> Pretty.print_ts fmt s
| Dpur (s,tl) -> Format.fprintf fmt (protect_on inn "%a@ %a")
Pretty.print_ts s (Pp.print_list Pp.space (print_dity true)) tl
| Dapp (s,[],rl) -> Format.fprintf fmt (protect_on inn "%a@ <%a>")
Mlw_pretty.print_its s (Pp.print_list Pp.comma print_dreg) rl
| Dapp (s,tl,rl) -> Format.fprintf fmt (protect_on inn "%a@ <%a>@ %a")
Mlw_pretty.print_its s (Pp.print_list Pp.comma print_dreg) rl
(Pp.print_list Pp.space (print_dity true)) tl
and print_dreg fmt = function
| Rreg (r,_) when Debug.test_flag debug_print_reg_types ->
......@@ -220,11 +228,105 @@ let print_dity fmt dity =
in
print_dity false fmt dity
(* Specialization of symbols *)
let specialize_scheme tvs (argl,res) =
let htv = Htv.create 3 and hreg = Htv.create 3 in
let rec spec_dity = function
| Dvar { contents = Dval dity } -> spec_dity dity
| Dvar { contents = Dtvs tv } | Dutv tv as dity -> get_tv tv dity
| Dapp (s,dl,rl) -> Dapp (s, List.map spec_dity dl, List.map spec_reg rl)
| Dpur (s,dl) -> Dpur (s, List.map spec_dity dl)
and spec_reg = function
| Rvar { contents = Rval r } -> spec_reg r
| Rvar { contents = Rtvs (tv,dity) } -> get_reg tv dity
| Rreg _ as r -> r
and get_tv tv dity = try Htv.find htv tv with Not_found ->
let v = dity_fresh () in
(* can't return dity, might differ in regions *)
if Stv.mem tv tvs then unify ~weak:true v dity;
Htv.add htv tv v;
v
and get_reg tv dity = try Htv.find hreg tv with Not_found ->
let r = create_dreg (spec_dity dity) in
Htv.add hreg tv r;
r in
List.map spec_dity argl, spec_dity res
let spec_ity htv hreg vars ity =
let get_tv tv =
assert (not (Stv.mem tv vars.vars_tv));
try Htv.find htv tv with Not_found ->
let v = dity_fresh () in
Htv.add htv tv v;
v in
let rec dity ity = match ity.ity_node with
| Ityvar tv -> get_tv tv
| Ityapp (s,tl,rl) -> Dapp (s, List.map dity tl, List.map dreg rl)
| Itypur (s,tl) -> Dpur (s, List.map dity tl)
and dreg reg = try Hreg.find hreg reg with Not_found ->
let t = dity reg.reg_ity in
let r = if reg_occurs reg vars then Rreg (reg,t) else create_dreg t in
Hreg.add hreg reg r;
r
in
dity ity
let specialize_ity ity =
let htv = Htv.create 3 and hreg = Hreg.create 3 in
spec_ity htv hreg ity.ity_vars ity
let specialize_pvsymbol pv = specialize_ity pv.pv_ity
let specialize_xsymbol xs = specialize_ity xs.xs_ity
let specialize_arrow vars aty =
let htv = Htv.create 3 and hreg = Hreg.create 3 in
let conv pv = spec_ity htv hreg vars pv.pv_ity in
let rec specialize a =
let argl = List.map conv a.aty_args in
let narg,res = match a.aty_result with
| VTvalue v -> [], spec_ity htv hreg vars v
| VTarrow a -> specialize a
in
argl @ narg, res
in
specialize aty
let specialize_psymbol ps =
specialize_arrow ps.ps_vars ps.ps_aty
let specialize_plsymbol pls =
let htv = Htv.create 3 and hreg = Hreg.create 3 in
let conv fd = spec_ity htv hreg vars_empty fd.fd_ity in
List.map conv pls.pl_args, conv pls.pl_value
let dity_of_ty htv hreg vars ty =
let rec pure ty = match ty.ty_node with
| Tyapp (ts,tl) ->
begin try ignore (restore_its ts); false
with Not_found -> List.for_all pure tl end
| Tyvar _ -> true in
if not (pure ty) then raise Exit;
spec_ity htv hreg vars (ity_of_ty ty)
let specialize_lsymbol ls =
let htv = Htv.create 3 and hreg = Hreg.create 3 in
let conv ty = dity_of_ty htv hreg vars_empty ty in
let ty = Opt.get_def ty_bool ls.ls_value in
List.map conv ls.ls_args, conv ty
let specialize_lsymbol ls =
try specialize_lsymbol ls with Exit ->
Loc.errorm "Function symbol `%a' can only be used in specification"
Pretty.print_ls ls
(** Patterns *)
type dpattern = {
dp_pat : pre_ppattern;
dp_dity : dity;
dp_vars : dity Mstr.t;
dp_loc : Loc.position option;
}
......@@ -242,7 +344,7 @@ type ghost = bool
type opaque = Stv.t
type dbinder = preid * ghost * opaque * dity
type dbinder = preid option * ghost * opaque * dity
type 'a later = vsymbol Mstr.t -> 'a
(* specification terms are parsed and typechecked after the program
......@@ -311,7 +413,7 @@ and dval_decl = preid * ghost * dtype_v
and dlet_defn = preid * ghost * dexpr
and dfun_defn = preid * ghost * dbinder list * dity * dexpr * dspec later
and dfun_defn = preid * ghost * dbinder list * dexpr * dspec later
(** Environment *)
......@@ -322,23 +424,112 @@ type denv = {
let denv_empty = { frozen = []; locals = Mstr.empty }
let denv_add_val _ = assert false (* denv -> dval_decl -> denv *)
let denv_add_let _ = assert false (* denv -> dlet_defn -> denv *)
let denv_add_fun _ = assert false (* denv -> dfun_defn -> denv *)
let denv_prepare_rec _ = assert false (* denv -> preid -> dbinder list -> dity -> denv *)
let denv_verify_rec _ = assert false (* denv -> preid -> unit *)
let denv_add_args _ = assert false (* denv -> dbinder list -> denv *)
let denv_add_pat _ = assert false (* denv -> dpattern -> denv *)
let denv_get _ = assert false (* denv -> string -> dexpr_node (** raises UnboundVar *) *)
let denv_get_opt _ = assert false (* denv -> string -> dexpr_node option *)
let is_frozen frozen tv =
try List.iter (occur_check tv) frozen; false with Exit -> true
let freeze_dvty frozen (argl,res) =
let rec add l = function
| Dvar { contents = Dval d } -> add l d
| Dvar { contents = Dtvs _ } as d -> d :: l
| Dutv _ as d -> d :: l
| Dapp (_,tl,_) | Dpur (_,tl) -> List.fold_left add l tl in
List.fold_left add (add frozen res) argl
let freeze_dtvs frozen (argl,res) =
let rec add l = function
| Dvar { contents = Dval d } -> add l d
| Dvar { contents = Dtvs _ } as d -> d :: l
| Dutv _ -> l
| Dapp (_,tl,_) | Dpur (_,tl) -> List.fold_left add l tl in
List.fold_left add (add frozen res) argl
let free_vars frozen (argl,res) =
let rec add s = function
| Dvar { contents = Dval d } -> add s d
| Dvar { contents = Dtvs tv }
| Dutv tv -> if is_frozen frozen tv then s else Stv.add tv s
| Dapp (_,tl,_) | Dpur (_,tl) -> List.fold_left add s tl in
List.fold_left add (add Stv.empty res) argl
let free_user_vars frozen (argl,res) =
let rec add s = function
| Dvar { contents = Dval d } -> add s d
| Dvar { contents = Dtvs _ } -> s
| Dutv tv -> if is_frozen frozen tv then s else Stv.add tv s
| Dapp (_,tl,_) | Dpur (_,tl) -> List.fold_left add s tl in
List.fold_left add (add Stv.empty res) argl
let denv_add_mono { frozen = frozen; locals = locals } id dvty =
let locals = Mstr.add (preid_name id) (None, dvty) locals in
{ frozen = freeze_dvty frozen dvty; locals = locals }
let denv_add_poly { frozen = frozen; locals = locals } id dvty =
let ftvs = free_vars frozen dvty in
let locals = Mstr.add (preid_name id) (Some ftvs, dvty) locals in
{ frozen = frozen; locals = locals }
let denv_add_rec { frozen = frozen; locals = locals } id dvty =
let ftvs = free_user_vars frozen dvty in
let locals = Mstr.add (preid_name id) (Some ftvs, dvty) locals in
{ frozen = freeze_dtvs frozen dvty; locals = locals }
let denv_add_val denv (id,_,dtv) =
let rec dvty argl = function
| DSpecA (bl,(dtv,_)) ->
dvty (List.fold_left (fun l (_,_,_,t) -> t::l) argl bl) dtv
| DSpecV res -> (List.rev argl, res) in
denv_add_poly denv id (dvty [] dtv)
let denv_add_let denv (id,_,{de_dvty = dvty}) =
denv_add_mono denv id dvty
let denv_add_fun denv (id,_,bl,{de_dvty = (argl,res)},_) =
let argl = List.fold_right (fun (_,_,_,t) l -> t::l) bl argl in
denv_add_poly denv id (argl, res)
let denv_prepare_rec denv l =
let add s (id,_,_) = let n = preid_name id in
Sstr.add_new (Dterm.DuplicateVar n) n s in
let _ = try List.fold_left add Sstr.empty l with
| Dterm.DuplicateVar n -> (* TODO: loc *)
Loc.errorm "duplicate function name %s" n in
let add denv (id,bl,res) =
let argl = List.map (fun (_,_,_,t) -> t) bl in
denv_add_rec denv id (argl, res) in
List.fold_left add denv l
let denv_verify_rec { frozen = frozen; locals = locals } id =
let check tv = if is_frozen frozen tv then Loc.errorm (* TODO: loc *)
"This function is expected to be polymorphic in type variable %a"
Pretty.print_tv tv in
match Mstr.find_opt (preid_name id) locals with
| Some (Some tvs, _) -> Stv.iter check tvs
| Some (None, _) -> assert false
| None -> assert false
let denv_add_args { frozen = frozen; locals = locals } bl =
let l = List.fold_left (fun l (_,_,_,t) -> t::l) frozen bl in
let add s (id,_,_,t) = match id with
| Some id -> let n = preid_name id in
Mstr.add_new (Dterm.DuplicateVar n) n (None, ([],t)) s
| None -> s in
let s = List.fold_left add Mstr.empty bl in
{ frozen = l; locals = Mstr.set_union s locals }
let denv_add_pat { frozen = frozen; locals = locals } dp =
let l = Mstr.fold (fun _ t l -> t::l) dp.dp_vars frozen in
let s = Mstr.map (fun t -> None, ([], t)) dp.dp_vars in
{ frozen = l; locals = Mstr.set_union s locals }
let mk_node n = function
| Some tvs, dvty -> DEvar (n, specialize_scheme tvs dvty)
| None, dvty -> DEvar (n, dvty)
let denv_get denv n =
mk_node n (Mstr.find_exn (Dterm.UnboundVar n) n denv.locals)
let denv_get_opt denv n =
Opt.map (mk_node n) (Mstr.find_opt n denv.locals)
(** Constructors *)
......
......@@ -27,11 +27,16 @@ val dity_of_ity : ity -> dity
type dvty = dity list * dity (* A -> B -> C == ([A;B],C) *)
val dity_is_bool : dity -> bool
val dvty_is_chainable : dvty -> bool
(** Patterns *)
type dpattern = private {
dp_pat : pre_ppattern;
dp_dity : dity;
dp_vars : dity Mstr.t;
dp_loc : Loc.position option;
}
......@@ -49,7 +54,7 @@ type ghost = bool
type opaque = Stv.t
type dbinder = preid * ghost * opaque * dity
type dbinder = preid option * ghost * opaque * dity
type 'a later = vsymbol Mstr.t -> 'a
(* specification terms are parsed and typechecked after the program
......@@ -118,7 +123,7 @@ and dval_decl = preid * ghost * dtype_v
and dlet_defn = preid * ghost * dexpr
and dfun_defn = preid * ghost * dbinder list * dity * dexpr * dspec later
and dfun_defn = preid * ghost * dbinder list * dexpr * dspec later
(** Environment *)
......@@ -132,9 +137,18 @@ val denv_add_let : denv -> dlet_defn -> denv
val denv_add_fun : denv -> dfun_defn -> denv
val denv_prepare_rec : denv -> preid -> dbinder list -> dity -> denv
val denv_verify_rec : denv -> preid -> unit
val denv_prepare_rec : denv -> (preid * dbinder list * dity) list -> denv
(* [denv_prepare_rec] adds to the environment the user-supplied
types of every function in a (mutually) recursive definition.
Every user type variable not frozen in [denv] is generalized,
and must not be unified with any outer fresh type variable. *)
val denv_verify_rec : denv -> preid -> unit
(* after a (mutually) recursive definition has been typechecked,
[denv_verify_rec] should be called for every function on the
[denv] before [denv_prepare_rec]. This function verifies that
the resulting functions are not less polymorphic than expected
according the user-supplied type annotations. *)
val denv_add_args : denv -> dbinder list -> denv
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment