Commit 9ec2f5d8 authored by Andrei Paskevich's avatar Andrei Paskevich

Mlw: type inference with overloaded symbols

parent e9093a2d
......@@ -178,7 +178,7 @@ LIB_DRIVER = prove_client call_provers driver_ast driver_parser driver_lexer dri
collect_data_model parse_smtv2_model_lexer parse_smtv2_model \
parse_smtv2_model
LIB_MLW = ity expr dexpr pdecl eval_match typeinv vc pmodule \
LIB_MLW = ity expr pdecl eval_match typeinv vc pmodule dexpr \
pinterp compile pdriver cprinter ocaml_printer
LIB_PARSER = ptree glob typing parser lexer
......
......@@ -15,6 +15,7 @@ open Ty
open Term
open Ity
open Expr
open Pmodule
(** Program types *)
......@@ -404,8 +405,7 @@ type dexpr = {
and dexpr_node =
| DEvar of string * dvty
| DEpv of pvsymbol
| DErs of rsymbol
| DEsym of prog_symbol
| DEls of lsymbol
| DEconst of Number.constant * dity
| DEapp of dexpr * dexpr
......@@ -508,7 +508,7 @@ let denv_add_let denv (id,_,_,({de_dvty = dvty} as de)) =
if fst dvty = [] then denv_add_mono denv id dvty else
let rec is_value de = match de.de_node with
| DEghost de | DEuloc (de,_) | DElabel (de,_) -> is_value de
| DEvar _ | DErs _ | DEls _ | DEfun _ | DEany _ -> true
| DEvar _ | DEsym _ | DEls _ | DEfun _ | DEany _ -> true
| _ -> false in
if is_value de
then denv_add_poly denv id dvty
......@@ -661,10 +661,18 @@ let dexpr ?loc node =
let get_dvty = function
| DEvar (_,dvty) ->
dvty
| DEpv pv ->
| DEsym (PV pv) ->
[], specialize_pv pv
| DErs rs ->
| DEsym (RS rs) ->
specialize_rs rs
| DEsym (OO ss) ->
let dt = dity_fresh () in
let ot = overload_of_rs (Srs.choose ss) in
begin match ot with
| UnOp -> [dt], dt
| BinOp -> [dt;dt], dt
| BinRel -> [dt;dt], dity_bool
| NoOver -> assert false end
| DEls ls ->
specialize_ls ls
| DEconst (_, ity) -> [],ity
......@@ -1109,6 +1117,21 @@ and try_cexp uloc env ({de_dvty = argl,res} as de0) lpl =
let al = List.map (fun v -> v.pv_ghost) s.rs_cty.cty_args in
let gh = env.ghs || env.lgh || rs_ghost s || all_ghost al lpl in
apply c_app gh s al lpl in
let c_oop s lpl =
let al = (Srs.choose s).rs_cty.cty_args in
let al = List.map (fun _ -> false) al in
let gh = env.ghs || env.lgh || all_ghost al lpl in
let loc = Opt.get_def de0.de_loc uloc in
let app s vl al res =
let app s cl = try Expr.c_app s vl al res :: cl with
(* TODO: are there other valid exceptions here? *)
| TypeMismatch _ -> cl in
match Srs.fold app s [] with
| [c] -> c
| [] -> Loc.errorm ?loc "No suitable symbol found"
(* TODO: show types or locations for ambiguity *)
| _cl -> Loc.errorm ?loc "Ambiguous notation" in
apply app gh s al lpl in
let c_pur s lpl =
apply c_pur true s (List.map Util.ttrue s.ls_args) lpl in
let proxy c =
......@@ -1122,7 +1145,8 @@ and try_cexp uloc env ({de_dvty = argl,res} as de0) lpl =
c_app s (LD ld :: lpl) in
match de0.de_node with
| DEvar (n,_) -> c_app (get_rs env n) lpl
| DErs s -> c_app s lpl
| DEsym (RS s) -> c_app s lpl
| DEsym (OO s) -> c_oop s lpl
| DEls s -> c_pur s lpl
| DEapp (de1,de2) ->
let e2 = e_ghostify env.cgh (expr uloc env de2) in
......@@ -1152,7 +1176,7 @@ and try_cexp uloc env ({de_dvty = argl,res} as de0) lpl =
cexp uloc env de (LD ld :: lpl)
| DEmark _ ->
Loc.errorm "Marks are not allowed over higher-order expressions"
| DEpv _ | DEconst _ | DEnot _ | DEand _ | DEor _ | DEif _ | DEcase _
| DEsym _ | DEconst _ | DEnot _ | DEand _ | DEor _ | DEif _ | DEcase _
| DEassign _ | DEwhile _ | DEfor _ | DEtry _ | DEraise _ | DEassert _
| DEpure _ | DEabsurd | DEtrue | DEfalse -> assert false (* expr-only *)
| DEcast _ | DEuloc _ | DElabel _ -> assert false (* already stripped *)
......@@ -1161,7 +1185,7 @@ and try_expr uloc env ({de_dvty = argl,res} as de0) =
match de0.de_node with
| DEvar (n,_) when argl = [] ->
e_var (get_pv env n)
| DEpv v ->
| DEsym (PV v) ->
e_var v
| DEconst(c,dity) ->
e_const c (ity_of_dity dity)
......@@ -1169,7 +1193,7 @@ and try_expr uloc env ({de_dvty = argl,res} as de0) =
let e1 = expr uloc env de1 in
let e2 = expr uloc env de2 in
e_app rs_func_app [e1; e2] [] (ity_of_dity res)
| DEvar _ | DErs _ | DEls _ | DEapp _ | DEfun _ | DEany _ ->
| DEvar _ | DEsym _ | DEls _ | DEapp _ | DEfun _ | DEany _ ->
let cgh,ldl,c = try_cexp uloc env de0 [] in
let e = e_ghostify cgh (e_exec c) in
List.fold_left e_let_check e ldl
......
......@@ -14,6 +14,7 @@ open Ident
open Term
open Ity
open Expr
open Pmodule
(** Program types *)
......@@ -94,8 +95,7 @@ type dexpr = private {
and dexpr_node =
| DEvar of string * dvty
| DEpv of pvsymbol
| DErs of rsymbol
| DEsym of prog_symbol
| DEls of lsymbol
| DEconst of Number.constant * dity
| DEapp of dexpr * dexpr
......
......@@ -1142,10 +1142,12 @@ let print_rs fmt ({rs_name = {id_string = nm}} as s) =
if nm = "mixfix [.._]" then pp_print_string fmt "([.._])" else
if nm = "mixfix [_.._]" then pp_print_string fmt "([_.._])" else
match extract_op s.rs_name, s.rs_logic with
| Some s, _ ->
let s = if Strings.has_prefix "*" s then " " ^ s else s in
let s = if Strings.has_suffix "*" s then s ^ " " else s in
fprintf fmt "(%s)" s
| Some x, _ ->
fprintf fmt "(%s%s%s)"
(if Strings.has_prefix "*" x then " " else "")
x
(if List.length s.rs_cty.cty_args = 1 then "_" else
if Strings.has_suffix "*" x then " " else "")
| _, RLnone | _, RLlemma ->
pp_print_string fmt (id_unique sprinter s.rs_name)
| _, RLpv v -> print_pv fmt v
......
......@@ -164,11 +164,12 @@ let load_driver env file extra_files =
try match ns_find_prog_symbol m.mod_export q with
| PV pv -> pv.Ity.pv_vs.vs_name
| RS rs -> rs.Expr.rs_name
with Not_found -> raise (Loc.Located (loc, UnknownVal (!qualid,q)))
| OO _ -> raise Not_found (* TODO: proper error message *)
with Not_found -> Loc.error ~loc (UnknownVal (!qualid,q))
in
let find_xs m (loc,q) =
try ns_find_xs m.mod_export q
with Not_found -> raise (Loc.Located (loc, UnknownExn (!qualid,q)))
with Not_found -> Loc.error ~loc (UnknownExn (!qualid,q))
in
let add_local_module loc m = function
| MRexception (q,s) ->
......
......@@ -24,6 +24,7 @@ open Pdecl
type prog_symbol =
| PV of pvsymbol
| RS of rsymbol
| OO of Srs.t
type namespace = {
ns_ts : itysymbol Mstr.t; (* type symbols *)
......@@ -47,7 +48,34 @@ let ns_replace eq chk x vo vn =
let merge_ts = ns_replace its_equal
let merge_xs = ns_replace xs_equal
type overload =
| UnOp (* t -> t *)
| BinOp (* t -> t -> t *)
| BinRel (* t -> t -> bool *)
| NoOver (* none of the above *)
let overload_of_rs {rs_cty = cty} =
if cty.cty_effect.eff_ghost then NoOver else
if cty.cty_mask <> MaskVisible then NoOver else
match cty.cty_args with
| [a;b] when ity_equal a.pv_ity b.pv_ity &&
ity_equal cty.cty_result ity_bool &&
not a.pv_ghost && not b.pv_ghost -> BinRel
| [a;b] when ity_equal a.pv_ity b.pv_ity &&
ity_equal cty.cty_result a.pv_ity &&
not a.pv_ghost && not b.pv_ghost -> BinOp
| [a] when ity_equal cty.cty_result a.pv_ity &&
not a.pv_ghost -> UnOp
| _ -> NoOver
exception IncompatibleNotation of string
let merge_ps chk x vo vn = match vo, vn with
| OO s1, OO s2 ->
let o1 = overload_of_rs (Srs.choose s1) in
let o2 = overload_of_rs (Srs.choose s2) in
if o1 <> o2 then raise (IncompatibleNotation x);
OO (Srs.union s1 s2)
| _ when not chk -> vn
| PV v1, PV v2 when pv_equal v1 v2 -> vo
| RS r1, RS r2 when rs_equal r1 r2 -> vo
......@@ -79,10 +107,18 @@ let rec ns_find get_map ns = function
| [a] -> Mstr.find a (get_map ns)
| a::l -> ns_find get_map (Mstr.find a ns.ns_ns) l
let ns_find_prog_symbol = ns_find (fun ns -> ns.ns_ps)
let ns_find_ns = ns_find (fun ns -> ns.ns_ns)
let ns_find_xs = ns_find (fun ns -> ns.ns_xs)
let ns_find_its = ns_find (fun ns -> ns.ns_ts)
let ns_find_its = ns_find (fun ns -> ns.ns_ts)
let ns_find_xs = ns_find (fun ns -> ns.ns_xs)
let ns_find_ns = ns_find (fun ns -> ns.ns_ns)
let ns_find_prog_symbol ns s =
let ps = ns_find (fun ns -> ns.ns_ps) ns s in
match ps with
| RS _ | PV _ -> ps
| OO ss ->
let rs1 = Expr.Srs.min_elt ss in
let rs2 = Expr.Srs.max_elt ss in
if Expr.rs_equal rs1 rs2 then RS rs1 else ps
let ns_find_pv ns s = match ns_find_prog_symbol ns s with
| PV pv -> pv | _ -> raise Not_found
......@@ -1127,7 +1163,20 @@ let print_module fmt m = Format.fprintf fmt
"@[<hov 2>module %s@\n%a@]@\nend" m.mod_theory.th_name.id_string
(Pp.print_list Pp.newline2 print_unit) m.mod_units
let get_rs_name nm =
if nm = "mixfix []" then "([])" else
if nm = "mixfix []<-" then "([]<-)" else
if nm = "mixfix [<-]" then "([<-])" else
if nm = "mixfix [_..]" then "([_..])" else
if nm = "mixfix [.._]" then "([.._])" else
if nm = "mixfix [_.._]" then "([_.._])" else
try "(" ^ Strings.remove_prefix "infix " nm ^ ")" with Not_found ->
try "(" ^ Strings.remove_prefix "prefix " nm ^ "_)" with Not_found ->
nm
let () = Exn_printer.register (fun fmt e -> match e with
| IncompatibleNotation nm -> Format.fprintf fmt
"Incombatible type signatures for notation '%s'" (get_rs_name nm)
| ModuleNotFound (sl,s) -> Format.fprintf fmt
"Module %s not found in library %a" s print_path sl
| _ -> raise e)
......@@ -24,6 +24,7 @@ open Pdecl
type prog_symbol =
| PV of pvsymbol
| RS of rsymbol
| OO of Srs.t
type namespace = {
ns_ts : itysymbol Mstr.t; (* type symbols *)
......@@ -32,16 +33,24 @@ type namespace = {
ns_ns : namespace Mstr.t; (* inner namespaces *)
}
val ns_find_its : namespace -> string list -> itysymbol
val ns_find_prog_symbol : namespace -> string list -> prog_symbol
val ns_find_its : namespace -> string list -> itysymbol
val ns_find_pv : namespace -> string list -> pvsymbol
val ns_find_rs : namespace -> string list -> rsymbol
val ns_find_xs : namespace -> string list -> xsymbol
val ns_find_ns : namespace -> string list -> namespace
type overload =
| UnOp (* t -> t *)
| BinOp (* t -> t -> t *)
| BinRel (* t -> t -> bool *)
| NoOver (* none of the above *)
val overload_of_rs : rsymbol -> overload
exception IncompatibleNotation of string
(** {2 Module} *)
type pmodule = private {
......
......@@ -96,7 +96,10 @@ let find_xsymbol_ns ns q =
let find_prog_symbol_ns ns p =
let get_id_ps = function
| PV pv -> pv.pv_vs.vs_name
| RS rs -> rs.rs_name in
| RS rs -> rs.rs_name
(* FIXME: this is incorrect, but we cannot
know the correct symbol at this stage *)
| OO ss -> (Srs.choose ss).rs_name in
find_qualid get_id_ps ns_find_prog_symbol ns p
let get_namespace muc = List.hd muc.Pmodule.muc_import
......@@ -550,7 +553,7 @@ let dbinder muc (_,id,gh,pty) = dbinder muc id gh pty
(* expressions *)
let is_reusable de = match de.de_node with
| DEvar _ | DEpv _ -> true | _ -> false
| DEvar _ | DEsym _ -> true | _ -> false
let mk_var n de =
Dexpr.dexpr ?loc:de.de_loc (DEvar (n, de.de_dvty))
......@@ -574,8 +577,7 @@ let rec dexpr muc denv {expr_desc = desc; expr_loc = loc} =
DEapp (Dexpr.dexpr ~loc e1, e2)) e el
in
let qualid_app loc q el =
let e = try match find_prog_symbol muc q with
| PV pv -> DEpv pv | RS rs -> DErs rs with
let e = try DEsym (find_prog_symbol muc q) with
| _ -> DEls (find_lsymbol muc.muc_theory q) in
expr_app loc e el
in
......@@ -594,7 +596,7 @@ let rec dexpr muc denv {expr_desc = desc; expr_loc = loc} =
| Ptree.Eapply (e1, e2) ->
DEapp (dexpr muc denv e1, dexpr muc denv e2)
| Ptree.Etuple el ->
let e = DErs (rs_tuple (List.length el)) in
let e = DEsym (RS (rs_tuple (List.length el))) in
expr_app loc e (List.map (dexpr muc denv) el)
| Ptree.Einfix (e1, op1, e23)
| Ptree.Einnfix (e1, op1, e23) ->
......@@ -627,18 +629,18 @@ let rec dexpr muc denv {expr_desc = desc; expr_loc = loc} =
| None -> Loc.error ~loc (Decl.RecordFieldMissing (ls_of_rs pj))
| Some e -> dexpr muc denv e in
let cs,fl = parse_record ~loc muc get_val fl in
expr_app loc (DErs cs) fl
expr_app loc (DEsym (RS cs)) fl
| Ptree.Eupdate (e1, fl) ->
let e1 = dexpr muc denv e1 in
let re = is_reusable e1 in
let v = if re then e1 else mk_var "q " e1 in
let get_val _ pj = function
| None ->
let pj = Dexpr.dexpr ~loc (DErs pj) in
let pj = Dexpr.dexpr ~loc (DEsym (RS pj)) in
Dexpr.dexpr ~loc (DEapp (pj, v))
| Some e -> dexpr muc denv e in
let cs,fl = parse_record ~loc muc get_val fl in
let d = expr_app loc (DErs cs) fl in
let d = expr_app loc (DEsym (RS cs)) fl in
if re then d else mk_let ~loc "q " e1 d
| Ptree.Elet (id, gh, kind, e1, e2) ->
let e1 = update_any kind e1 in
......@@ -720,7 +722,7 @@ let rec dexpr muc denv {expr_desc = desc; expr_loc = loc} =
let e1 = match e1 with
| Some e1 -> dexpr muc denv e1
| None when ity_equal xs.xs_ity ity_unit ->
Dexpr.dexpr ~loc (DErs rs_void)
Dexpr.dexpr ~loc (DEsym (RS rs_void))
| _ -> Loc.errorm ~loc "exception argument expected" in
DEraise (xs, e1)
| Ptree.Etry (e1, cl) ->
......@@ -1085,6 +1087,8 @@ let type_inst ({muc_theory = tuc} as muc) ({mod_theory = t} as m) s =
Loc.errorm ~loc:(qloc q) "program constant expected"
| RS _, PV _ ->
Loc.errorm ~loc:(qloc q) "program function expected"
| OO _, _ | _, OO _ ->
Loc.errorm ~loc:(qloc q) "ambiguous notation"
end
| CSxsym (p,q) ->
let xs1 = find_xsymbol_ns m.mod_export p in
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment