max_matrix.mlw 5.44 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16

(* Given a nxn matrix m of nonnegative integers, we want to pick up one element
   in each row and each column, so that their sum is maximal.

   We generalize the problem as follows: f(i,c) is the maximum for rows >= i
   and columns in set c. Thus the solution is f(0,{0,1,...,n-1}).

   f is easily defined recursively, as we have

      f(i,c) = max{j in c} m[i][j] + f(i+1, C\{j})

   As such, it would still be a brute force approach (of complexity n!)
   but we can memoize f and then the search space decreases to 2^n-1.

   The following code implements such a solution. Sets of integers are
   provided in theory Bitset. Hash tables for memoization are provided
17
   in module HashTable (see file hash_tables.mlw for an implementation).
18 19 20 21 22 23 24 25
   Code for f is in module MaxMatrixMemo (mutually recursive functions
   maximum and memo).
*)

theory Bitset "sets of small integers"

  use import int.Int

26
  constant size : int (* elements belong to 0..size-1 *)
27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70

  type set

  (* membership
     [mem i s] can be implemented as [s land (1 lsl i) <> 0] *)
  predicate mem int set

  (* removal
     [remove i s] can be implemented as [s - (1 lsl i)] *)
  function remove int set : set

  axiom remove_def1:
    forall x y: int, s: set.
    mem x (remove y s) <-> x <> y /\ mem x s

  (* the set {0,1,...,n-1}
     [below n] can be implemented as [1 lsl n - 1] *)
  function below int : set

  axiom below_def:
    forall x n: int. 0 <= n <= size ->
    mem x (below n) <-> 0 <= x < n

  function cardinal set : int

  axiom cardinal_empty:
    forall s: set. cardinal s = 0 <-> (forall x: int. not (mem x s))

  axiom cardinal_remove:
    forall x: int. forall s: set.
    mem x s -> cardinal s = 1 + cardinal (remove x s)

  axiom cardinal_below:
    forall n: int.  0 <= n <= size ->
    cardinal (below n) = if n >= 0 then n else 0

end

module HashTable

  use import option.Option
  use import int.Int
  use import map.Map

71
  type t 'a 'b model { mutable contents: map 'a (option 'b) }
72 73 74

  function ([]) (h: t 'a 'b) (k: 'a) : option 'b = Map.get h.contents k

75 76
  val create (n:int) : t 'a 'b
    requires { 0 < n } ensures { forall k: 'a. result[k] = None }
77

78 79
  val clear (h: t 'a 'b) : unit writes {h}
    ensures { forall k: 'a. h[k] = None }
80

81 82
  val add (h: t 'a 'b) (k: 'a) (v: 'b) : unit writes {h}
    ensures { h[k] = Some v /\ forall k': 'a. k' <> k -> h[k'] = (old h)[k'] }
83 84 85

  exception Not_found

Andrei Paskevich's avatar
Andrei Paskevich committed
86
  val find (h: t 'a 'b) (k: 'a) : 'b
87
    ensures { h[k] = Some result } raises { Not_found -> h[k] = None }
88 89 90 91 92 93 94 95

end

module MaxMatrixMemo

  use import int.Int
  use import int.MinMax
  use import map.Map
96
  use map.Const
97
  use import ref.Ref
98

99
  constant n : int
100 101 102 103 104
  axiom n_nonneg: 0 <= n

  use import Bitset
  axiom integer_size: n <= size

105
  constant m : map int (map int int)
106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122
  axiom m_pos: forall i j: int. 0 <= i < n -> 0 <= j < n -> 0 <= m[i][j]

  predicate solution (s: map int int) (i: int) =
    (forall k: int. i <= k < n -> 0 <= s[k] < n) /\
    (forall k1 k2: int. i <= k1 < k2 < n -> s[k1] <> s[k2])

  predicate permutation (s: map int int) = solution s 0

  type mapii = map int int
  function f (s: map int int) (i: int) : int = m[i][s[i]]
  clone import sum.Sum with type container = mapii, function f = f

  lemma sum_ind:
    forall i: int. i < n -> forall j: int.
    forall s: map int int. sum s[i <- j] i n = m[i][j] + sum s (i+1) n

  use import option.Option
123
  use HashTable as H
124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148

  type key = (int, set)
  type value = (int, mapii)

  predicate pre (k: key) =
    let (i, c) = k in
    0 <= i <= n /\ cardinal c = n-i /\ (forall k: int. mem k c -> 0 <= k < n)

  predicate post (k: key) (v: value) =
    let (i, c) = k in
    let (r, sol) = v in
    0 <= r /\ solution sol i /\
    (forall k: int. i <= k < n -> mem sol[k] c) /\
    r = sum sol i n /\
    (forall s: map int int.
       solution s i -> (forall k: int. i <= k < n -> mem s[k] c) ->
       r >= sum s i n)

  type table = H.t key value

  val table: table

  predicate inv (t: table) =
    forall k: key, v: value. H.([]) t k = Some v -> post k v

149 150 151 152
  let rec maximum (i:int) (c: set) : (int, map int int) variant {2*n-2*i}
    requires { pre (i, c) /\ inv table }
    ensures { post (i,c) result /\ inv table }
  = if i = n then
153
      (0, Const.const 0)
154 155
    else begin
      let r = ref (-1) in
156
      let sol = ref (Const.const 0) in
157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177
      for j = 0 to n-1 do
        invariant {
          inv table /\
          (  (!r = -1 /\ forall k: int. 0 <= k < j -> not (mem k c))
          \/
            (0 <= !r /\ solution !sol i /\
              (forall k: int. i <= k < n -> mem !sol[k] c) /\
              !r = sum !sol i n /\
              (forall s: map int int.
                 solution s i -> (forall k: int. i <= k < n -> mem s[k] c) ->
                 mem s[i] c -> s[i] < j -> !r >= sum s i n)))
        }
        if mem j c then
          let (r', sol') = memo (i+1) (remove j c) in
          let x = m[i][j] + r' in
          if x > !r then begin r := x; sol := sol'[i <- j] end
      done;
      assert { 0 <= !r };
      (!r, !sol)
    end

178 179 180 181
  with memo (i:int) (c: set) : (int, map int int) variant {2*n-2*i+1}
    requires { pre (i,c) /\ inv table }
    ensures { post (i,c) result /\ inv table }
  = try  H.find table (i,c)
182 183
    with H.Not_found -> let r = maximum i c in H.add table (i,c) r; r end

184 185 186 187
  let maxmat ()
    ensures { exists s: map int int. permutation s /\ result =  sum s 0 n }
    ensures { forall s: map int int. permutation s -> result >= sum s 0 n }
  = H.clear table;
188 189 190 191
    assert { inv table };
    let (r, _) = maximum 0 (below n) in r

end