Commit d21588a5 authored by Clovis Eberhart's avatar Clovis Eberhart

Implemented a basic type inference algorithm.

parent db5cf682
......@@ -38,7 +38,7 @@ PREVIOUS_DIRS = ../utils
# Source files in the right order of dependance
#ML = error.ml abstract_syntax.ml interface.ml environment.ml entry.ml parser.ml lexer.ml data_parsing.ml
ML = abstract_syntax.ml lambda.ml
ML = abstract_syntax.ml lambda.ml type_inference.ml
EXE_SOURCES =
......
open Lambda
open Abstract_syntax
module type Signature_sig =
sig
type t
(** [unfold_term_definition id s] returns the term corresponding to the defined constant with id [id] if it exists *)
val unfold_term_definition : int -> t -> Lambda.term
end
module TypeInference =
struct
open Lambda
exception NotImplemented
exception NotUnifiable
exception NonLinear
exception NotWellTyped
type subst = S of (int * stype) list
let rec substitute x = function
| S [] -> Atom x
| S ((v,t)::q) -> if v = x then t else substitute x (S q)
let rec lift_subst sigma = function
| Atom v -> substitute v sigma
| LFun(t1,t2) -> LFun(lift_subst sigma t1,lift_subst sigma t2)
| _ -> raise NotImplemented
let rec occurs x = function
| Atom v -> v = x
| LFun(t1,t2) -> (occurs x t1) || (occurs x t2)
| _ -> raise NotImplemented
let unify c =
let rec aux sigma = function
| (s,t)::q when s=t -> aux sigma q
| (Atom s,t)::q when not (occurs s t) -> let (q1,q2) = List.split q and replace_s_by_t = List.map (lift_subst (S [(s,t)])) in
aux ((s,t)::sigma) (List.combine (replace_s_by_t q1) (replace_s_by_t q2))
| (s,Atom t)::q -> aux sigma ((Atom t,s)::q)
| (LFun(t1,t2),LFun(s1,s2))::q -> aux sigma ((t1,s1)::(t2,s2)::q)
| [] -> sigma
| ((DAtom _ | Fun(_,_) | Dprod(_,_,_) | Record(_,_) | Variant(_,_) | TAbs(_,_) | TApp(_,_)),_)::q -> raise NotImplemented
| _ -> raise NotUnifiable in
S(aux [] c)
let rename_vars m =
let rec aux i j c l = function
| LVar x -> let k = List.nth l x in
LVar (j+1),i,j+1,c,[(k,j+1)]
| App(n,p) -> let (n',i',j',c',r) = aux i j c l n in
let (p',i'',j'',c'',r') = aux i' j' c' l p in
App(n',p'),i'',j'',c'',r@r'
| LAbs(x,n) -> let (n',i',j',c',r) = aux (i+1) j c (i::l) n in
LAbs(x,n'),i',j',c',r
| Const x -> Const (c+1),i,j,c+1,[]
| (Abs(_,_) | Var _) -> raise NonLinear
| _ -> raise NotImplemented in
let (n,_,_,c,r) = aux 0 0 0 [] m in
(n,r,c)
(* let m0 = App(LAbs("x",LVar 0),LAbs("y",LVar 0))
let m1 = LAbs("x",LAbs("y",LAbs("z",App(LVar 2,App(LVar 1,LVar 0)))))
let m2 = App(LAbs("x",LVar 0),Const 2)
let m3 = LAbs("x",LAbs("y",App(LVar 0,LVar 1)))
let m4 = App(LAbs("x",App(LVar 0,Const 3)),Const 4) *)
let rec nb_vars = function
| LVar _ -> 1
| App(t1,t2) -> (nb_vars t1) + (nb_vars t2)
| LAbs(_,t) -> nb_vars t
| Const _ -> 0
| DConst _ -> failwith "Type not unfolded"
| (Var _ | Abs(_,_)) -> raise NonLinear
| _ -> raise NotImplemented
let fixed_point f d =
let rec aux a b =
if a=b then a else aux b (f b) in
aux d (f d)
let type_inference_aux m =
let (n,r,c) = rename_vars m in
let v = (nb_vars n) + 1 in
let fresh_type = (+) v in
let rec aux e i j = function
| [] -> e
| (LVar x, tau)::q -> aux ((Atom x,tau)::e) i j q
| (App(m,p), tau)::q -> let t = fresh_type i in
aux e (i+1) j ((m,LFun(Atom t,tau))::(p,Atom t)::q)
| (LAbs(x,m), tau)::q -> let t = fresh_type i in
aux ((tau,LFun(Atom (List.assoc j r),Atom t))::e) (i+1) (j+1) ((m,Atom t)::q)
| (Const x, tau)::q -> aux ((Atom (-x),tau)::e) i j q
| (DConst _,_)::_ -> failwith "Type not unfolded"
| (Var _,_)::_ -> raise NonLinear
| (Abs(_,_),_)::_ -> raise NonLinear
| _ -> raise NotImplemented in
try
let f = fixed_point (lift_subst (unify (aux [] 1 0 [(n, Atom 0)]))) in
let rec const = function
| 0 -> []
| i -> (Atom (-i))::(const (i-1)) in
f (Atom 0),List.rev_map f (const c)
with
NotUnifiable ->
raise NotWellTyped
let rec print_subst = function
| S [] -> ()
| S((i,ty)::q) -> (Printf.printf "%d->%s\n" i (type_to_string ty (function x -> Abstract_syntax.Default,Printf.sprintf "%d" x));
print_subst (S q))
end
(** This module is provides a type inference algorithm for linear lambda terms *)
open Lambda
(** The module that provides the type inference algorithm *)
module TypeInference :
sig
(** the list of exceptions used in the module *)
exception NotImplemented
exception NotUnifiable
exception NonLinear
(** the type of substitutions on term types used in the unify algorithm *)
type subst
(** [substitute i s] returns the type associated to atom [i] in substitution [s] *)
val substitute : int -> subst -> Lambda.stype
(** [lift_subst s ty] returns the type [ty] in which substitution [s] has been applied*)
val lift_subst : subst -> Lambda.stype -> Lambda.stype
(** [occurs i ty] returns true if [i] is an atomic type that appears in [ty] *)
val occurs : int -> Lambda.stype -> bool
(** [unify (l0,r0)...(ln,rn)] returns the most general unifier for the unification problem:
l0=r0,...,ln=rn if it exists,raises NotUnifiable otherwise *)
val unify : (Lambda.stype * Lambda.stype) list -> subst
(** [rename_vars t] returns a lambda term in which de Bruijn's indices have been replaced
by static indices and the constants made distinct,the mapping of each lambda to its new
variable and the number of constants in [t] *)
val rename_vars : Lambda.term -> Lambda.term * (int * int) list * int
(** [type_inference_aux m] returns the most general type of m in which constants have been
replaced by variables and the type of these variables*)
val type_inference_aux : Lambda.term -> Lambda.stype * Lambda.stype list
end
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment