register_allocation.mlw 8.65 KB
Newer Older
1

2 3 4 5 6
(** A tiny register allocator for tree expressions.

    Authors: Martin Clochard (École Normale Supérieure)
             Jean-Christophe Filliâtre (CNRS)
 *)
7 8 9

module Spec

10
  use int.Int
11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39

  type addr

  type expr =
  | Evar addr
  | Eneg expr
  | Eadd expr expr

  type memory = addr -> int

  function eval (m: memory) (e: expr) : int =
    match e with
    | Evar x     -> m x
    | Eneg e     -> - (eval m e)
    | Eadd e1 e2 -> eval m e1 + eval m e2
    end

  type register = int

  type instr =
    | Iload addr register
    | Ineg  register
    | Iadd  register register
    | Ipush register
    | Ipop  register

  type registers = register -> int

  function update (reg: registers) (r: register) (v: int) : registers =
40
    fun r' -> if r' = r then v else reg r'
41

42
  use list.List
43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62

  type stack = list int

  type state = {
    mem: memory;
    reg: registers;
    st : stack;
  }

  function exec (i: instr) (s: state) : state =
    match i with
    | Iload x r   -> { s with reg = update s.reg r (s.mem x) }
    | Ineg  r     -> { s with reg = update s.reg r (- s.reg r) }
    | Iadd  r1 r2 -> { s with reg = update s.reg r2 (s.reg r1 + s.reg r2) }
    | Ipush r     -> { s with st = Cons (s.reg r) s.st }
    | Ipop  r     -> match s.st with
                     | Nil       -> s (* fails *)
                     | Cons v st -> { s with reg = update s.reg r v; st = st }
                     end
    end
63
  meta rewrite_def function exec
64 65 66 67 68 69 70 71 72

  type code = list instr

  function exec_list (c: code) (s: state) : state =
    match c with
    | Nil      -> s
    | Cons i l -> exec_list l (exec i s)
    end

73
  use list.Append
74 75 76 77 78 79

  let rec lemma exec_append (c1 c2: code) (s: state) : unit
    ensures { exec_list (c1 ++ c2) s = exec_list c2 (exec_list c1 s) }
    variant { c1 }
  = match c1 with
    | Nil        -> ()
80
    | Cons i1 l1 -> exec_append l1 c2 (exec i1 s)
81 82
    end

83 84 85 86 87
  (** specification of the forthcoming compilation:
      - value of expression e lies in register r in final state
      - all registers smaller than are preserved
      - memory and stack are preserved *)
  function expr_post (e: expr) (r: register) : state -> state -> bool =
88
    fun s s' -> s'.mem = s.mem /\ s'.reg r = eval s.mem e /\ s'.st = s.st /\
89 90 91 92 93
      forall r'. r' < r -> s'.reg r' = s.reg r'
  meta rewrite_def function expr_post

end

94 95 96 97 98 99 100 101
(** Double WP technique

    If you read French, see https://hal.inria.fr/hal-01094488

    See also this other Why3 proof, from where this technique originates:
    http://toccata.lri.fr/gallery/double_wp.en.html
*)

102 103
module DWP

104 105 106
  use list.List
  use list.Append
  use Spec
107 108 109

  meta compute_max_steps 0x10000

110
  predicate (-->) (x y: 'a) = [@rewrite] x = y
111 112 113 114 115 116 117
  meta rewrite_def predicate (-->)

  type post = state -> state -> bool
  type hcode = {
    hcode : code;
    ghost post : post;
  }
118
  predicate hcode_ok (hc: hcode) = forall s. hc.post s (exec_list hc.hcode s)
119 120 121 122 123 124

  type trans = (state -> bool) -> state -> bool
  type wcode = {
    ghost trans : trans;
    wcode : code;
  }
125
  predicate wcode_ok (wc: wcode) = forall q s.
126 127
    wc.trans q s -> q (exec_list wc.wcode s)

128
  function to_wp (pst: post) : trans = fun q s1 -> forall s2. pst s1 s2 -> q s2
129 130
  meta rewrite_def function to_wp

131
  function rcompose : ('a -> 'b) -> ('b -> 'c) -> 'a -> 'c = fun f g x -> g (f x)
132 133
  meta rewrite_def function rcompose

134 135
  function exec_closure (i: instr) : state -> state = fun s -> exec i s
  function id : 'a -> 'a = fun x -> x
136

137
  let ($_) (hc: hcode) : wcode
138 139 140
    requires { hcode_ok hc }
    ensures { wcode_ok result }
    ensures { result.trans --> to_wp hc.post }
141
  = { wcode = hc.hcode; trans = to_wp hc.post }
142

143
  let wrap (wc: wcode) (ghost pst: post) : hcode
144 145 146 147 148 149
    requires { wcode_ok wc }
    requires { forall x. wc.trans (pst x) x }
    ensures { hcode_ok result }
    ensures { result.post --> pst }
  = { hcode = wc.wcode; post = pst }

150
  let (--) (w1 w2: wcode) : wcode
151 152 153
    requires { wcode_ok w1 /\ wcode_ok w2 }
    ensures { wcode_ok result }
    ensures { result.trans --> rcompose w2.trans w1.trans }
154
  = { wcode = w1.wcode ++ w2.wcode; trans = rcompose w2.trans w1.trans }
155

156
  let cons (i: instr) (w: wcode) : wcode
157 158 159 160
    requires { wcode_ok w }
    ensures { wcode_ok result }
    ensures { result.trans --> rcompose w.trans (rcompose (exec i)) }
  = { wcode = Cons i w.wcode;
161
      trans = rcompose w.trans (rcompose (exec_closure i)) }
162 163 164

  let nil () : wcode
    ensures { wcode_ok result }
165
    ensures { result.trans --> fun q -> q }
166
  = { wcode = Nil; trans = id }
167

168 169 170 171
end

module InfinityOfRegisters

172 173 174 175 176
  use int.Int
  use list.List
  use list.Append
  use Spec
  use DWP
177

178 179
  (** `compile e r` returns a list of instructions that stores the value
      of `e` in register `r`, without modifying any register `r' < r`. *)
180

181
  let rec compile (e: expr) (r: register) : hcode
182
    variant { e }
183 184 185 186 187 188 189 190
    ensures { hcode_ok result }
    ensures { result.post --> expr_post e r }
  = wrap (
      match e with
      | Evar x -> cons (Iload x r) (nil ())
      | Eneg e -> $ compile e r -- cons (Ineg r) (nil ())
      | Eadd e1 e2 ->
          $ compile e1 r -- $ compile e2 (r + 1) -- cons (Iadd (r+1) r) (nil ())
191
      end) (expr_post e r)
192 193

  (* To recover usual specification. *)
194
  let ghost recover (e: expr) (r: register) (h: hcode) : unit
195
    requires { hcode_ok h /\ h.post --> expr_post e r }
196 197 198 199 200
    ensures  { forall s. let s' = exec_list h.hcode s in
               s'.mem = s.mem /\
               s'.reg r = eval s.mem e /\
               s'.st = s.st /\
               forall r'. r' < r -> s'.reg r' = s.reg r' }
201
  = ()
202 203 204 205 206

end

module FiniteNumberOfRegisters

207 208 209 210 211
  use int.Int
  use list.List
  use list.Append
  use Spec
  use DWP
212

213 214 215 216
  (** we have k registers, namely 0,1,...,k-1,
      and there are at least two of them, otherwise we can't add *)
  val constant k: int
    ensures { 2 <= result }
217

218 219
  (** `compile e r` returns a list of instructions that stores the value
      of `e` in register `r`, without modifying any register `r' < r`. *)
220

221
  let rec compile (e: expr) (r: register) : hcode
222 223
    requires { 0 <= r < k }
    variant  { e }
224 225 226 227 228 229 230 231 232 233 234 235 236 237 238
    ensures  { hcode_ok result }
    ensures  { result.post --> expr_post e r }
  = wrap (
      match e with
      | Evar x -> cons (Iload x r) (nil ())
      | Eneg e -> $ compile e r -- cons (Ineg r) (nil ())
      | Eadd e1 e2 ->
          if r < k-1 then
            $ compile e1 r -- $ compile e2 (r + 1) --
            cons (Iadd (r + 1) r) (nil ())
          else
            cons (Ipush (k - 2)) (
            $ compile e1 (k - 2) -- $ compile e2 (k - 1) --
            cons (Iadd (k - 2) (k - 1)) (
            cons (Ipop (k - 2)) (nil ())))
239
      end) (expr_post e r)
240 241

end
242 243 244

module OptimalNumberOfRegisters

245 246 247 248 249 250
  use int.Int
  use int.MinMax
  use list.List
  use list.Append
  use Spec
  use DWP
251

252
  (** we have `k` registers, namely `0,1,...,k-1`,
253 254 255
      and there are at least two of them, otherwise we can't add *)
  val constant k: int
    ensures { 2 <= result }
256 257

  (** the minimal number of registers needed to evaluate e *)
258 259 260
  let rec function n (e: expr) : int
  variant { e }
  = match e with
261 262 263 264 265 266
    | Evar _     -> 1
    | Eneg e     -> n e
    | Eadd e1 e2 -> let n1 = n e1 in let n2 = n e2 in
                    if n1 = n2 then 1 + n1 else max n1 n2
    end

267 268 269 270
  (** Note: This is of course inefficient to recompute function `n` many
      times. A realistic implementation would compute `n e` once for
      each sub-expression `e`, either with a first pass of tree decoration,
      or with function `compile` returning the value of `n e` as well,
271 272 273 274 275 276 277 278 279 280 281 282
      in a bottom-up way *)

  function measure (e: expr) : int =
    match e with
    | Evar _     -> 0
    | Eneg e     -> 1 + measure e
    | Eadd e1 e2 -> 1 + if n e1 >= n e2 then measure e1 + measure e2
                        else 1 + measure e1 + measure e2
    end

  lemma measure_nonneg: forall e. measure e >= 0

283 284
  (** `compile e r` returns a list of instructions that stores the value
      of `e` in register `r`, without modifying any register `r' < r`. *)
285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306

  let rec compile (e: expr) (r: register) : hcode
    requires { 0 <= r < k }
    variant  { measure e }
    ensures  { hcode_ok result }
    ensures  { result.post --> expr_post e r }
  = wrap (
      match e with
      | Evar x -> cons (Iload x r) (nil ())
      | Eneg e -> $ compile e r -- cons (Ineg r) (nil ())
      | Eadd e1 e2 ->
          if n e1 >= n e2 then (* we must compile e1 first *)
            if r < k-1 then
              $ compile e1 r -- $ compile e2 (r + 1) --
              cons (Iadd (r + 1) r) (nil ())
            else
              cons (Ipush (k - 2)) (
              $ compile e1 (k - 2) -- $ compile e2 (k - 1) --
              cons (Iadd (k - 2) (k - 1)) (
              cons (Ipop (k - 2)) (nil ())))
          else
            $ compile (Eadd e2 e1) r (* compile e2 first *)
307
      end) (expr_post e r)
308 309

end