python: functions and predicates

parent 26e65150
......@@ -57,10 +57,14 @@ and stmt_desc =
| Sbreak
| Slabel of ident
and block = stmt list
and block = decl list
and def = ident * ident list * Ptree.spec * block
and decl =
| Dimport of ident * ident list
| Ddef of ident * ident list * Ptree.spec * block
| Dstmt of stmt
| Dlogic of bool (*is_func*) * ident * ident list
and file = def list * block
type file = block
......@@ -40,7 +40,7 @@
["invariant", INVARIANT; "variant", VARIANT;
"assert", ASSERT; "assume", ASSUME; "check", CHECK;
"requires", REQUIRES; "ensures", ENSURES;
"label", LABEL;
"label", LABEL; "function", FUNCTION; "predicate", PREDICATE;
];
fun s -> try Hashtbl.find h s with Not_found ->
raise (Lexing_error ("no such annotation '" ^ s ^ "'"))
......@@ -80,6 +80,8 @@ rule next_tokens = parse
| '\n' { newline lexbuf; update_stack (indentation lexbuf) }
| (space | comment)+
{ next_tokens lexbuf }
| "\\" space* '\n' space* "#@"?
{ next_tokens lexbuf }
| "#@" space* (ident as id)
{ [annotation id] }
| "#@" { raise (Lexing_error "expecting an annotation") }
......
......@@ -105,21 +105,36 @@ let loop_annotation env a =
let add_loop_invariant i a =
{ a with loop_invariant = i :: a.loop_invariant }
let rec has_break s = match s.stmt_desc with
| Sbreak -> true
| Sreturn _ | Sassign _ | Slabel _
| Seval _ | Sset _ | Sassert _ | Swhile _ -> false
| Sif (_, bl1, bl2) -> has_breakl bl1 || has_breakl bl2
| Sfor (_, _, _, bl) -> has_breakl bl
and has_breakl bl = List.exists has_break bl
let rec has_return s = match s.stmt_desc with
| Sreturn _ -> true
| Sbreak | Sassign _ | Slabel _
| Seval _ | Sset _ | Sassert _ -> false
| Sif (_, bl1, bl2) -> has_returnl bl1 || has_returnl bl2
| Swhile (_, _, bl) | Sfor (_, _, _, bl) -> has_returnl bl
and has_returnl bl = List.exists has_return bl
let stmt_of_decl = function Dstmt s -> s | _ -> invalid_arg "not a statement"
let rec has_stmt p = function
| Dstmt s -> p s || begin match s.stmt_desc with
| Sbreak | Sreturn _ | Sassign _ | Slabel _
| Seval _ | Sset _ | Sassert _ | Swhile _ -> false
| Sif (_, bl1, bl2) -> has_stmtl p bl1 || has_stmtl p bl2
| Sfor (_, _, _, bl) -> has_stmtl p bl end
| _ -> false
and has_stmtl p bl = List.exists (has_stmt p) bl
let has_breakl = has_stmtl (fun s -> s.stmt_desc = Sbreak)
let has_returnl = has_stmtl (function { stmt_desc = Sreturn _ } -> true | _ -> false)
let rec expr_has_call id e = match e.Py_ast.expr_desc with
| Enone | Ebool _ | Eint _ | Estring _ | Py_ast.Eident _ -> false
| Emake (e1, e2) | Eget (e1, e2) | Ebinop (_, e1, e2) ->
expr_has_call id e1 || expr_has_call id e2
| Eunop (_, e1) -> expr_has_call id e1
| Ecall (f, el) -> id.id_str = f.id_str || List.exists (expr_has_call id) el
| Elist el -> List.exists (expr_has_call id) el
let rec stmt_has_call id s = match s.stmt_desc with
| Sbreak | Slabel _ | Sassert _ -> false
| Sreturn e | Sassign (_, e) | Seval e -> expr_has_call id e
| Sset (e1, e2, e3) ->
expr_has_call id e1 || expr_has_call id e2 || expr_has_call id e3
| Sif (e, s1, s2) -> expr_has_call id e || block_has_call id s1 || block_has_call id s2
| Sfor (_, e, _, s) | Swhile (e, _, s) -> expr_has_call id e || block_has_call id s
and block_has_call id = has_stmtl (stmt_has_call id)
let rec expr env {Py_ast.expr_loc = loc; Py_ast.expr_desc = d } = match d with
| Py_ast.Enone ->
......@@ -180,6 +195,18 @@ let rec expr env {Py_ast.expr_loc = loc; Py_ast.expr_desc = d } = match d with
| Py_ast.Eget (e1, e2) ->
mk_expr ~loc (Eidapp (mixfix ~loc "[]", [expr env e1; expr env e2]))
let post env (loc, l) =
loc, List.map (fun (pat, t) -> pat, deref env t) l
let spec env sp =
assert (sp.sp_xpost = [] && sp.sp_reads = [] && sp.sp_writes = []
&& sp.sp_variant = []);
{ sp with
sp_pre = List.map (deref env) sp.sp_pre;
sp_post = List.map (post env) sp.sp_post }
let no_params ~loc = [loc, None, false, Some (PTtuple [])]
let rec stmt env ({Py_ast.stmt_loc = loc; Py_ast.stmt_desc = d } as s) =
match d with
| Py_ast.Seval e ->
......@@ -194,7 +221,7 @@ let rec stmt env ({Py_ast.stmt_loc = loc; Py_ast.stmt_desc = d } as s) =
let x = let loc = id.id_loc in mk_expr ~loc (Eident (Qident id)) in
mk_expr ~loc (Einfix (x, mk_id ~loc "infix :=", e))
else
block env ~loc [s]
block env ~loc [Dstmt s]
| Py_ast.Sset (e1, e2, e3) ->
array_set ~loc (expr env e1) (expr env e2) (expr env e3)
| Py_ast.Sassert (k, t) ->
......@@ -255,35 +282,30 @@ let rec stmt env ({Py_ast.stmt_loc = loc; Py_ast.stmt_desc = d } as s) =
and block env ~loc = function
| [] ->
mk_unit ~loc
| { stmt_loc = loc; stmt_desc = Slabel id } :: sl ->
| Dstmt { stmt_loc = loc; stmt_desc = Slabel id } :: sl ->
mk_expr ~loc (Emark (id, block env ~loc sl))
| { Py_ast.stmt_loc = loc; stmt_desc = Py_ast.Sassign (id, e) } :: sl
| Dstmt { Py_ast.stmt_loc = loc; stmt_desc = Py_ast.Sassign (id, e) } :: sl
when not (Mstr.mem id.id_str env.vars) ->
let e = expr env e in (* check e *before* adding id to environment *)
let env = add_var env id in
mk_expr ~loc (Elet (id, Gnone, mk_ref ~loc e, block env ~loc sl))
| { Py_ast.stmt_loc = loc } as s :: sl ->
| Dstmt ({ Py_ast.stmt_loc = loc } as s) :: sl ->
let s = stmt env s in
if sl = [] then s else mk_expr ~loc (Esequence (s, block env ~loc sl))
let post env (loc, l) =
loc, List.map (fun (pat, t) -> pat, deref env t) l
let spec env sp =
assert (sp.sp_xpost = [] && sp.sp_reads = [] && sp.sp_writes = []
&& sp.sp_variant = []);
{ sp with
sp_pre = List.map (deref env) sp.sp_pre;
sp_post = List.map (post env) sp.sp_post }
let no_params ~loc = [loc, None, false, Some (PTtuple [])]
| Ddef (id, idl, sp, bl) :: sl ->
let lam = def (id, idl, sp, bl) in
let s = block env ~loc sl in
if block_has_call id bl then mk_expr ~loc (Erec ([id, Gnone, lam], s))
else mk_expr ~loc (Efun (id, Gnone, lam, s))
| (Dimport _ | Py_ast.Dlogic _) :: sl ->
block env ~loc sl
(* f(x1,...,xn): body
let f x1 ... xn =
let x1 = ref x1 in ... let xn = ref xn in
try body with Return x -> x *)
let def inc (id, idl, sp, bl) =
and def (id, idl, sp, bl) =
let loc = id.id_loc in
let env = empty_env in
let env = List.fold_left add_var env idl in
......@@ -296,13 +318,29 @@ let def inc (id, idl, sp, bl) =
let body = List.fold_left local body idl in
let param id = id.id_loc, Some id, false, None in
let params = if idl = [] then no_params ~loc else List.map param idl in
let fd = (params, None, body, sp) in
let d = Dfun (id, Gnone, fd) in
inc.new_pdecl id.id_loc d
params, None, body, sp
let fresh_type_var =
let r = ref 0 in
fun loc -> incr r;
PTtyvar ({ id_str = "a" ^ string_of_int !r; id_loc = loc; id_lab = [] }, false)
let logic_param id =
id.id_loc, Some id, false, fresh_type_var id.id_loc
let logic inc = function
| Py_ast.Dlogic (func, id, idl) ->
let d = { ld_loc = id.id_loc;
ld_ident = id;
ld_params = List.map logic_param idl;
ld_type = if func then Some (fresh_type_var id.id_loc) else None;
ld_def = None } in
inc.new_decl id.id_loc (Dlogic [d])
| _ -> ()
let translate ~loc inc (dl, s) =
List.iter (def inc) dl;
let fd = (no_params ~loc, None, block empty_env ~loc s, empty_spec) in
let translate ~loc inc dl =
List.iter (logic inc) dl;
let fd = (no_params ~loc, None, block empty_env ~loc dl, empty_spec) in
let main = Dfun (mk_id ~loc "main", Gnone, fd) in
inc.new_pdecl loc main
......
......@@ -23,7 +23,7 @@
let mk_pat d s e = { pat_desc = d; pat_loc = floc s e }
let mk_term d s e = { term_desc = d; term_loc = floc s e }
let mk_expr loc d = { expr_desc = d; expr_loc = loc }
let mk_stmt loc d = { stmt_desc = d; stmt_loc = loc }
let mk_stmt loc d = Dstmt { stmt_desc = d; stmt_loc = loc }
let variant_union v1 v2 = match v1, v2 with
| _, [] -> v1
......@@ -71,6 +71,7 @@
%token PLUS MINUS TIMES DIV MOD
(* annotations *)
%token INVARIANT VARIANT ASSUME ASSERT CHECK REQUIRES ENSURES LABEL
%token FUNCTION PREDICATE
%token ARROW LARROW LRARROW FORALL EXISTS DOT THEN LET
(* precedences *)
......@@ -90,22 +91,37 @@
%start file
%type <Py_ast.file> file
%type <Py_ast.decl> stmt
%%
file:
| NEWLINE? import* dl=list(def) b=list(stmt) EOF
{ dl, b }
| NEWLINE* EOF
{ [] }
| NEWLINE? dl=nonempty_list(decl) NEWLINE? EOF
{ dl }
;
decl:
| import { $1 }
| def { $1 }
| stmt { $1 }
| func { $1 }
import:
| FROM _m=ident IMPORT _f=ident NEWLINE
{ () (* FIXME: check legal imports *) }
| FROM m=ident IMPORT l=separated_list(COMMA, ident) NEWLINE
{ Dimport (m, l) }
func:
| FUNCTION id=ident LEFTPAR l=separated_list(COMMA, ident) RIGHTPAR NEWLINE
{ Dlogic (true, id, l) }
| PREDICATE id=ident LEFTPAR l=separated_list(COMMA, ident) RIGHTPAR NEWLINE
{ Dlogic (false, id, l) }
def:
| DEF f = ident LEFTPAR x = separated_list(COMMA, ident) RIGHTPAR
COLON NEWLINE BEGIN s=spec l=nonempty_list(stmt) END
{ f, x, s, l }
{ Ddef (f, x, s, l) }
;
spec:
......@@ -117,6 +133,8 @@ single_spec:
{ { empty_spec with sp_pre = [t] } }
| ENSURES e=ensures NEWLINE
{ { empty_spec with sp_post = [floc $startpos(e) $endpos(e), e] } }
| variant
{ { empty_spec with sp_variant = $1 } }
ensures:
| term
......@@ -182,11 +200,11 @@ suite:
{ l }
;
stmt: located(stmt_desc) { $1 };
stmt:
| located(stmt_desc) { $1 }
| s = simple_stmt NEWLINE { s }
stmt_desc:
| s = simple_stmt NEWLINE
{ s.stmt_desc }
| IF c = expr COLON s1 = suite s2=else_branch
{ Sif (c, s1, s2) }
| WHILE e = expr COLON b=loop_body
......
def fact(n):
#@ ensures result >= 1
#@ variant n
if n <= 1:
return 1
return n * fact(n - 1)
......@@ -10,6 +10,7 @@ while s <= n:
#@ invariant 0 <= r
#@ invariant r * r <= n
#@ invariant s == (r+1) * (r+1)
#@ variant n - s
r = r + 1
s = s + 2 * r + 1
......
......@@ -10,7 +10,7 @@ def isqrt(n):
#@ invariant 0 <= r
#@ invariant r * r <= n
#@ invariant s == (r+1) * (r+1)
#@ variant n - s
#@ variant n - s
r = r + 1
s = s + 2 * r + 1
return r
#@ predicate win(n)
#@ predicate lose(n)
#@ assume lose(1)
#@ assume forall n. n >= 1 and lose(n) -> win(n+1) \
#@ and win(n+2) and win(n+3)
#@ assume forall n. n >= 1 and win(n) and win(n+1) \
#@ and win(n+2) -> lose(n+3)
#@ assume forall n. n >= 1 -> lose(n) -> win(n) -> False
#@ assume forall n. n >= 1 -> (lose(n) <-> n % 4 == 1)
start = int(input("start = "))
#@ assume start >= 1
n = start
while n > 0:
#@ invariant lose(start) -> lose(n)
print(n, " matches")
k = int(input("your turn: "))
#@ assume k == 1 or k == 2 or k == 3
#@ assume k <= n
n = n - k
if n == 0:
print("you lose")
break
if n == 1:
#@ assert win(start)
print("you win")
break
m = n % 4
if m == 1:
k = 1
elif m == 0:
k = 3
else:
k = m - 1
#@ assert win(n) -> lose(n - k)
n = n - k
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment