SortUnification.ml 3.9 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13
(******************************************************************************)
(*                                                                            *)
(*                                   Menhir                                   *)
(*                                                                            *)
(*                       François Pottier, Inria Paris                        *)
(*              Yann Régis-Gianas, PPS, Université Paris Diderot              *)
(*                                                                            *)
(*  Copyright Inria. All rights reserved. This file is distributed under the  *)
(*  terms of the GNU General Public License version 2, as described in the    *)
(*  file LICENSE.                                                             *)
(*                                                                            *)
(******************************************************************************)

14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43
(* This module implements sort inference. *)

(* -------------------------------------------------------------------------- *)

(* The syntax of sorts is:

     sort ::= (sort, ..., sort) -> *

   where the arity (the number of sorts on the left-hand side of the arrow)
   can be zero. *)

module S = struct

  type 'a structure =
    | Arrow of 'a list

  let map f (Arrow xs) =
    Arrow (List.map f xs)

  let iter f (Arrow xs) =
    List.iter f xs

  exception Iter2

  let iter2 f (Arrow xs1) (Arrow xs2) =
    let n1 = List.length xs1
    and n2 = List.length xs2 in
    if n1 = n2 then
      List.iter2 f xs1 xs2
    else
POTTIER Francois's avatar
POTTIER Francois committed
44
      raise Iter2
45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69

end

include S

(* -------------------------------------------------------------------------- *)

(* Instantiate the unification algorithm with the above signature. *)

include Unifier.Make(S)

type sort = term =
  | TVar of int
  | TNode of sort structure

(* -------------------------------------------------------------------------- *)

(* Sort constructors. *)

let arrow (args : variable list) : variable =
  fresh (Some (Arrow args))

let star : variable =
  arrow []

70 71 72 73 74
let fresh () =
  fresh None

(* Sort accessors. *)

75
let domain (x : variable) : variable list option =
76 77
  match structure x with
  | Some (Arrow xs) ->
78
      Some xs
79
  | None ->
80
      None
81

82 83
(* -------------------------------------------------------------------------- *)

84 85
(* Converting between sorts and ground sorts. *)

86
let rec ground s =
87 88 89
  match s with
  | TVar _ ->
      (* All variables are replaced with [*]. *)
90
      GroundSort.GArrow []
91
  | TNode (Arrow ss) ->
92
      GroundSort.GArrow (List.map ground ss)
93

94
let rec unground (GroundSort.GArrow ss) =
95 96 97 98
  TNode (Arrow (List.map unground ss))

(* -------------------------------------------------------------------------- *)

99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118
(* A name generator for unification variables. *)

let make_gensym () : unit -> string =
  let c = ref 0 in
  let gensym () =
    let n = Misc.postincrement c in
    Printf.sprintf "%c%s"
      (char_of_int (Char.code 'a' + n mod 26))
      (let d = n / 26 in if d = 0 then "" else string_of_int d)
  in
  gensym

(* A memoized name generator. *)

let make_name () : int -> string =
  let gensym = make_gensym() in
  Memoize.Int.memoize (fun _x -> gensym())

(* -------------------------------------------------------------------------- *)

119 120
(* A printer. *)

121
let rec print name (b : Buffer.t) (sort : sort) =
122 123
  match sort with
  | TVar x ->
124
      Printf.bprintf b "%s" (name x)
125 126 127 128 129
  | TNode (S.Arrow []) ->
      Printf.bprintf b "*"
  | TNode (S.Arrow (sort :: sorts)) ->
      (* Always parenthesize the domain, so there is no ambiguity. *)
      Printf.bprintf b "(%a%a) -> *"
130 131
        (print name) sort
        (print_comma_sorts name) sorts
132

133 134
and print_comma_sorts name b sorts =
  List.iter (print_comma_sort name b) sorts
135

136 137
and print_comma_sort name b sort =
  Printf.bprintf b ", %a" (print name) sort
138 139 140

let print sort : string =
  let b = Buffer.create 32 in
141
  print (make_name()) b sort;
142
  Buffer.contents b