(** {1 N-queens problem} Verification of the following 2-lines C program solving the N-queens problem: {h
```   t(a,b,c){int d=0,e=a&~b&~c,f=1;if(a)for(f=0;d=(e-=d)&-e;f+=t(a-d,(b+d)*2,(
c+d)/2));return f;}main(q){scanf("%d",&q);printf("%d\n",t(~(~0<}

*)

(**

{2 finite sets of integers, with succ and pred operations }

*)

theory S

use export set.Fsetint

function succ (set int) : set int

axiom succ_def:
forall s: set int, i: int. mem i (succ s) <-> i >= 1 /\ mem (i-1) s

function pred (set int) : set int

axiom pred_def:
forall s: set int, i: int. mem i (pred s) <-> i >= 0 /\ mem (i+1) s

end

(** {2 Formalization of the set of solutions of the N-queens problem} *)

theory Solution

use import int.Int
use export map.Map

(** the number of queens *)
constant n : int

type solution = map int int

(** solutions [t] and [u] have the same prefix [[0..i[] *)
predicate eq_prefix (t u: map int 'a) (i: int) =
forall k: int. 0 <= k < i -> t[k] = u[k]

predicate eq_sol (t u: solution) = eq_prefix t u n

(** [s] stores a partial solution, for the rows [0..k-1] *)
predicate partial_solution (k: int) (s: solution) =
forall i: int. 0 <= i < k ->
0 <= s[i] < n /\
(forall j: int. 0 <= j < i -> s[i]      <> s[j] /\
s[i]-s[j] <> i-j  /\
s[i]-s[j] <> j-i)

predicate solution (s: solution) = partial_solution n s

lemma partial_solution_eq_prefix:
forall u t: solution, k: int.
partial_solution k t -> eq_prefix t u k -> partial_solution k u

predicate lt_sol (s1 s2: solution) =
exists i: int. 0 <= i < n /\ eq_prefix s1 s2 i /\ s1[i] < s2[i]

type solutions = map int solution

(** [s[a..b[] is sorted for [lt_sol] *)
predicate sorted (s: solutions) (a b: int) =
forall i j: int. a <= i < j < b -> lt_sol s[i] s[j]

(** a sorted array of solutions contains no duplicate *)
lemma no_duplicate:
forall s: solutions, a b: int. sorted s a b ->
forall i j: int. a <= i < j < b -> not (eq_sol s[i] s[j])

end

(** {2 More realistic code with bitwise operations} *)

module BitsSpec

use import S

constant size : int = 32

type t = {
ghost mdl: set int;
}
invariant { forall i. mem i self.mdl -> 0 <= i < size }

val empty () : t
ensures { is_empty result.mdl }

val is_empty (x:t) : bool
ensures { result  <-> is_empty x.mdl }

val remove_singleton (a b: t) : t
requires { b.mdl = singleton (min_elt b.mdl) }
requires { mem (min_elt b.mdl) a.mdl }
ensures  { result.mdl = remove (min_elt b.mdl) a.mdl }

val add_singleton (a b: t) : t
requires { b.mdl = singleton (min_elt b.mdl) }
(* requires { not (mem (min_elt b.mdl) a.mdl) } *)
(* this is not required if the implementation uses or instead of add *)
ensures  { result.mdl = S.add (min_elt b.mdl) a.mdl }

val mul2 (a: t) : t
ensures { result.mdl = remove size (succ a.mdl) }

val div2 (a: t) : t
ensures { result.mdl = pred a.mdl }

val diff (a b: t) : t
ensures { result.mdl = diff a.mdl b.mdl }

use import ref.Ref
val rightmost_bit_trick (a: t) (ghost min : ref int) : t
requires { not (is_empty a.mdl) }
writes   { min }
ensures  { !min = min_elt a.mdl }
ensures  { result.mdl = singleton !min }

use bv.BV32
val below (n: BV32.t) : t
requires { BV32.ule n (BV32.of_int 32) }
ensures  { result.mdl = interval 0 (BV32.to_uint n) }
end

module Bits "the 1-bits of an integer, as a set of integers"

use import S
use import bv.BV32

constant size : int = 32

type t = {
bv : BV32.t;
ghost mdl: set int;
}
invariant {
forall i: int. (0 <= i < size /\ nth self.bv i) <-> mem i self.mdl
}

let empty () : t
ensures { is_empty result.mdl }
=
{ bv = zeros; mdl = empty }

let is_empty (x:t) : bool
ensures { result  <-> is_empty x.mdl }
=
assert {is_empty x.mdl -> BV32.eq x.bv zeros};
x.bv = zeros

let remove_singleton (a b: t) : t
requires { b.mdl = singleton (min_elt b.mdl) }
requires { mem (min_elt b.mdl) a.mdl }
ensures  { result.mdl = remove (min_elt b.mdl) a.mdl }
=
{ bv = bw_and a.bv (bw_not b.bv);
mdl = remove (min_elt b.mdl) a.mdl }
(* { bv = sub a.bv b.bv; mdl = remove (min_elt b.mdl) a.mdl } *)

let add_singleton (a b: t) : t
requires { b.mdl = singleton (min_elt b.mdl) }
(* requires { not (mem (min_elt b.mdl) a.mdl) } *)
(* this is not required if the implementation uses or instead of add *)
ensures  { result.mdl = S.add (min_elt b.mdl) a.mdl }
=
{ bv = bw_or a.bv b.bv;
mdl = S.add (min_elt b.mdl) a.mdl }
(* { bv = add a.bv b.bv; mdl = add (min_elt b.mdl) a.mdl } *)

let mul2 (a: t) : t
ensures { result.mdl = remove size (succ a.mdl) }
=
{ bv = lsl_bv a.bv (of_int 1); mdl = remove size (succ a.mdl) }

let div2 (a: t) : t
ensures { result.mdl = pred a.mdl }
=
{ bv = lsr_bv a.bv (of_int 1); mdl = pred a.mdl }

let diff (a b: t) : t
ensures { result.mdl = diff a.mdl b.mdl }
=
{ bv = bw_and a.bv (bw_not b.bv);
mdl = diff a.mdl b.mdl }

predicate bits_interval_is_zeros_bv (a:BV32.t) (i:BV32.t) (n:BV32.t) =
eq_sub_bv a zeros i n

predicate bits_interval_is_one_bv (a:BV32.t) (i:BV32.t) (n:BV32.t) =
eq_sub_bv a ones i n

predicate bits_interval_is_zeros (a:BV32.t) (i : int) (n : int) =
eq_sub a zeros i n

predicate bits_interval_is_one (a:BV32.t) (i : int) (n : int) =
eq_sub a ones i n

let rightmost_bit_trick (a: t) : t
requires { not (is_empty a.mdl) }
ensures  { result.mdl = singleton (min_elt a.mdl) }
=
let ghost n = min_elt a.mdl in
let ghost n_bv = of_int n in
assert {bits_interval_is_zeros_bv a.bv zeros n_bv};
assert {nth_bv a.bv n_bv};
assert {nth_bv (neg a.bv) n_bv};
let res = bw_and a.bv (neg a.bv) in
assert {forall i. 0 <= i < n -> not (nth res i)};
assert {bits_interval_is_zeros_bv res (add n_bv (of_int 1)) (sub (of_int 31) n_bv )};
assert {bits_interval_is_zeros res (n + 1) (31 - n)};
{ bv = res;
mdl = singleton n }

let below (n: BV32.t) : t
requires { BV32.ule n (BV32.of_int 32) }
ensures  { result.mdl = interval 0 (to_uint n) }
=
{ bv = bw_not (lsl_bv ones n);
mdl = interval 0 (to_uint n) }

end

module NQueensBits

use import BitsSpec
use import ref.Ref
use import Solution
use import int.Int

val ghost col: ref solution  (* solution under construction *)
(* val ghost k  : ref int       (\* current row in the current solution *\) *)

val ghost sol: ref solutions (* all solutions *)
val ghost s  : ref int       (* next slot for a solution =
number of solutions *)

let rec t (a b c: BitsSpec.t) (ghost k : int)
requires { n <= size }
requires { 0 <= k }
requires { k + S.cardinal a.mdl = n }
requires { !s >= 0 }
requires { forall i: int. S.mem i a.mdl <->
( 0 <= i < n /\ forall j: int. 0 <= j < k ->  !col[j] <> i) }
requires { forall i: int.  0 <= i < size ->
( not (S.mem i b.mdl) <->
forall j: int. 0 <= j < k -> i - !col[j] <> k - j) }
requires { forall i: int. 0 <= i < size ->
( not (S.mem i c.mdl) <->
forall j: int. 0 <= j < k -> i - !col[j] <> j - k ) }
requires { partial_solution k !col }
variant { S.cardinal a.mdl }
ensures { result = !s - old !s >= 0 }
ensures { sorted !sol (old !s) !s }
ensures { forall i:int. old !s <= i < !s -> solution !sol[i] /\ eq_prefix !col !sol[i] k }
ensures { forall u: solution.
solution u /\ eq_prefix !col u k ->
exists i: int. old !s <= i < !s /\ eq_sol u !sol[i] }
(* assigns *)
ensures { eq_prefix (old !col) !col k }
ensures { eq_prefix (old !sol) !sol (old !s) }
= if not (is_empty a) then begin
let e = ref (diff (diff a b) c) in
(* first, you show that if u is a solution with the same k-prefix as col, then
u[k] (the position of the queen on the row k) must belong to e *)
assert { forall u:solution. solution u /\ eq_prefix !col u k -> S.mem u[k] !e.mdl };
let f = ref 0 in
let ghost min = ref (-1) in
'L:
while not (is_empty !e) do
variant { S.cardinal !e.mdl }
invariant { not (S.is_empty !e.mdl) -> !min < S.min_elt !e.mdl }
invariant { !f = !s - at !s 'L >= 0 }
invariant { S.subset !e.mdl (S.diff (S.diff a.mdl b.mdl) c.mdl) }
invariant { partial_solution k !col }
invariant { sorted !sol (at !s 'L) !s }
invariant { forall i: int. S.mem i (at !e.mdl 'L) /\ not (S.mem i !e.mdl) -> i <= !min }
invariant { forall i:int. at !s 'L <= i < !s -> solution !sol[i] /\ eq_prefix !col !sol[i] k /\ 0 <= !sol[i][k] <= !min }
invariant { forall u: solution.
(solution u /\ eq_prefix !col u k /\ 0 <= u[k] <= !min)
->
S.mem u[k] (at !e.mdl 'L) && not (S.mem u[k] !e.mdl) &&
exists i: int. (at !s 'L) <= i < !s /\ eq_sol u !sol[i] }
(* assigns *)
invariant { eq_prefix (at !col 'L) !col k }
invariant { eq_prefix (at !sol 'L) !sol (at !s 'L) }
'N:
let d = rightmost_bit_trick !e min in
ghost col := !col[k <- !min];
assert { 0 <= !col[k] < size };
assert { not (S.mem !col[k] b.mdl) };
assert { not (S.mem !col[k] c.mdl) };

assert { eq_prefix (at !col 'L) !col k };
assert { forall i: int. S.mem i a.mdl -> (forall j: int. 0 <= j < k ->  !col[j] <> i) };
assert { forall i: int. S.mem i (S.remove (S.min_elt d.mdl) a.mdl) <-> (0 <= i < n /\ (forall j: int. 0 <= j < k + 1 ->  !col[j] <> i)) };

let ghost b' = mul2 (add_singleton b d) in
assert { forall i: int.  0 <= i < size ->
S.mem i b'.mdl -> (i = !min + 1 \/ S.mem (i - 1) b.mdl) && not (forall j:int. 0 <= j < k + 1 -> i - !col[j] <> k + 1 - j) };

let ghost c' = div2 (add_singleton c d) in
assert { forall i: int.  0 <= i < size ->
S.mem i c'.mdl -> (i = !min - 1 \/ (i + 1 < size /\ S.mem (i + 1) c.mdl)) && not (forall j:int. 0 <= j < k + 1 -> i - !col[j] <> j - k - 1) };

'M:
f := !f + t (remove_singleton a d)

assert { forall i j. (at !s 'L) <= i < at !s 'M <= j < !s -> (eq_prefix !sol[i] !sol[j] k /\ !sol[i][k] <= at !min 'N < !min = !sol[j][k]) && lt_sol !sol[i] !sol[j]};

e := remove_singleton !e d
done;
assert { forall u:solution. solution u /\ eq_prefix !col u k -> S.mem u[k] (at !e.mdl 'L) && 0 <= u[k] <= !min  };
!f
end else begin
ghost sol := !sol[!s <- !col];
ghost s := !s + 1;
1
end

let queens (q: BV32.t)
requires { BV32.to_uint q = n }
requires { BV32.ule q BV32.size_bv }
requires { !s = 0 }
ensures  { result = !s }
ensures  { sorted !sol 0 !s }
ensures  { forall u: solution.
solution u <-> exists i: int. 0 <= i < result /\ eq_sol u !sol[i] }
=
t (below q) (empty()) (empty()) 0

let test8 ()
requires { size = 32 }
requires { n = 8 }
=
s := 0;
queens (BV32.of_int 8)

end
```