mlw_ocaml.ml 34.4 KB
Newer Older
Andrei Paskevich's avatar
Andrei Paskevich committed
1
2
3
(********************************************************************)
(*                                                                  *)
(*  The Why3 Verification Platform   /   The Why3 Development Team  *)
4
(*  Copyright 2010-2013   --   INRIA - CNRS - Paris-Sud University  *)
Andrei Paskevich's avatar
Andrei Paskevich committed
5
6
7
8
9
10
(*                                                                  *)
(*  This software is distributed under the terms of the GNU Lesser  *)
(*  General Public License version 2.1, with the special exception  *)
(*  on linking described in file LICENSE.                           *)
(*                                                                  *)
(********************************************************************)
11
12
13

open Format
open Pp
14

15
open Stdlib
16
open Number
17
18
19
20
21
open Ident
open Ty
open Term
open Decl
open Theory
22
open Printer
23

24
let debug =
Andrei Paskevich's avatar
Andrei Paskevich committed
25
26
  Debug.register_info_flag "extraction"
    ~desc:"Print@ details@ of@ program@ extraction."
27

28
29
30
31
let clean_fname fname =
  let fname = Filename.basename fname in
  (try Filename.chop_extension fname with _ -> fname)

32
33
let modulename ?fname path t =
  let fname = match fname, path with
34
    | Some fname, _ -> clean_fname fname
35
36
    | None, [] -> "why3"
    | None, _ -> String.concat "__" path
37
  in
38
  fname ^ "__" ^ t
39
40

let extract_filename ?fname th =
41
  (modulename ?fname th.th_path th.th_name.Ident.id_string) ^ ".ml"
42

43
44
45
(* let modulename path t = *)
(*   String.capitalize *)
(*     (if path = [] then "why3__" ^ t else String.concat "__" path ^ "__" ^ t) *)
46

47
48
(** Printers *)

49
50
51
52
53
54
55
56
57
58
59
60
let ocaml_keywords =
  ["and"; "as"; "assert"; "asr"; "begin";
   "class"; "constraint"; "do"; "done"; "downto"; "else"; "end";
   "exception"; "external"; "false"; "for"; "fun"; "function";
   "functor"; "if"; "in"; "include"; "inherit"; "initializer";
   "land"; "lazy"; "let"; "lor"; "lsl"; "lsr"; "lxor"; "match";
   "method"; "mod"; "module"; "mutable"; "new"; "object"; "of";
   "open"; "or"; "private"; "rec"; "sig"; "struct"; "then"; "to";
   "true"; "try"; "type"; "val"; "virtual"; "when"; "while"; "with";
   "raise";]

let is_ocaml_keyword =
61
62
63
  let h = Hstr.create 17 in
  List.iter (fun s -> Hstr.add h s ()) ocaml_keywords;
  Hstr.mem h
64

65
let iprinter,aprinter,_tprinter,_pprinter =
66
67
  let isanitize = sanitizer char_to_alpha char_to_alnumus in
  let lsanitize = sanitizer char_to_lalpha char_to_alnumus in
68
69
70
71
  create_ident_printer ocaml_keywords ~sanitizer:isanitize,
  create_ident_printer ocaml_keywords ~sanitizer:lsanitize,
  create_ident_printer ocaml_keywords ~sanitizer:lsanitize,
  create_ident_printer ocaml_keywords ~sanitizer:isanitize
72
73
74
75

let forget_tvs () =
  forget_all aprinter

76
(* dead code
77
78
79
80
81
let forget_all () =
  forget_all iprinter;
  forget_all aprinter;
  forget_all tprinter;
  forget_all pprinter
82
*)
83

84
85
86
87
(* info *)

type info = {
  info_syn: syntax_map;
88
  current_theory: Theory.theory;
89
  current_module: Mlw_module.modul option;
90
91
  th_known_map: Decl.known_map;
  mo_known_map: Mlw_decl.known_map;
92
  fname: string option;
93
94
95
  (* symbol_printers : (string * ident_printer) Mid.t; *)
}

96
97
let is_constructor info ls =
  (* eprintf "is_constructor: ls=%s@." ls.ls_name.id_string; *)
98
99
100
101
102
103
104
105
106
107
108
109
  match Mid.find_opt ls.ls_name info.th_known_map with
    | Some { d_node = Ddata dl } ->
        let constr (_,csl) = List.exists (fun (cs,_) -> ls_equal cs ls) csl in
        List.exists constr dl
    | _ -> false

let get_record info ls =
  match Mid.find_opt ls.ls_name info.th_known_map with
    | Some { d_node = Ddata dl } ->
        let rec lookup = function
        | [] -> []
        | (_, [cs, pjl]) :: _ when ls_equal cs ls ->
110
          (try List.map Opt.get pjl with _ -> [])
111
112
113
114
115
        | _ :: dl -> lookup dl
        in
        lookup dl
    | Some _ | None ->
        []
116

117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
(* type variables always start with a quote *)
let print_tv fmt tv =
  fprintf fmt "'%s" (id_unique aprinter tv.tv_name)

(* logic variables always start with a lower case letter *)
let print_vs fmt vs =
  let sanitizer = String.uncapitalize in
  fprintf fmt "%s" (id_unique iprinter ~sanitizer vs.vs_name)

let forget_var vs = forget_id iprinter vs.vs_name
let forget_vars = List.iter forget_var

let print_ident fmt id =
  let s = id_unique iprinter id in
  fprintf fmt "%s" s

133
134
135
136
let print_path = print_list dot pp_print_string

let print_qident ~sanitizer info fmt id =
  try
137
138
139
    let lp, t, p =
      try Mlw_module.restore_path id
      with Not_found -> Theory.restore_path id in
140
141
142
    let s = String.concat "__" p in
    let s = Ident.sanitizer char_to_alpha char_to_alnumus s in
    let s = sanitizer s in
143
    let s = if is_ocaml_keyword s then s ^ "_renamed" else s in
144
    if Sid.mem id info.current_theory.th_local ||
145
146
       Opt.fold (fun _ m -> Sid.mem id m.Mlw_module.mod_local)
        false info.current_module
147
    then
148
149
      fprintf fmt "%s" s
    else
150
151
152
      let fname = if lp = [] then info.fname else None in
      let m = String.capitalize (modulename ?fname lp t) in
      fprintf fmt "%s.%s" m s
153
  with Not_found ->
154
    let s = id_unique ~sanitizer iprinter id in
155
156
157
158
159
160
161
162
163
164
165
166
167
168
    fprintf fmt "%s" s

let print_lident = print_qident ~sanitizer:String.uncapitalize
let print_uident = print_qident ~sanitizer:String.capitalize

let print_ls info fmt ls = print_lident info fmt ls.ls_name
let print_cs info fmt ls = print_uident info fmt ls.ls_name
let print_ts info fmt ts = print_lident info fmt ts.ts_name

let print_path_id fmt = function
  | [], id -> print_ident fmt id
  | p , id -> fprintf fmt "%a.%a" print_path p print_ident id

let print_theory_name fmt th = print_path_id fmt (th.th_path, th.th_name)
169
let print_module_name fmt m  = print_theory_name fmt m.Mlw_module.mod_theory
170
171
172
173

let to_be_implemented fmt s =
  fprintf fmt "failwith \"to be implemented\" (* %s *)" s

174
175
let tbi s = "failwith \"to be implemented\" (* " ^^ s ^^ " *)"

176
177
178
179
180
181
(** Types *)

let protect_on x s = if x then "(" ^^ s ^^ ")" else s

let star fmt () = fprintf fmt " *@ "

182
183
let has_syntax info id = Mid.mem id info.info_syn

184
let rec print_ty_node inn info fmt ty = match ty.ty_node with
185
186
  | Tyvar v ->
      print_tv fmt v
187
188
  | Tyapp (ts, []) when is_ts_tuple ts ->
      fprintf fmt "unit"
189
  | Tyapp (ts, tl) when is_ts_tuple ts ->
190
      fprintf fmt "(%a)" (print_list star (print_ty_node false info)) tl
191
  | Tyapp (ts, tl) ->
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
      begin match query_syntax info.info_syn ts.ts_name with
        | Some s -> syntax_arguments s (print_ty_node true info) fmt tl
        | None ->
          begin match tl with
            | [] ->
                print_ts info fmt ts
            | [ty] ->
                fprintf fmt (protect_on inn "%a@ %a")
                  (print_ty_node true info) ty (print_ts info) ts
            | _ ->
                fprintf fmt (protect_on inn "(%a)@ %a")
                  (print_list comma (print_ty_node false info)) tl
                  (print_ts info) ts
        end
      end
207
208
209

let print_ty = print_ty_node false

210
211
let print_vsty info fmt v =
  fprintf fmt "%a:@ %a" print_vs v (print_ty info) v.vs_ty
212
213
214
215
216
217
218

let print_tv_arg = print_tv
let print_tv_args fmt = function
  | [] -> ()
  | [tv] -> fprintf fmt "%a@ " print_tv_arg tv
  | tvl -> fprintf fmt "(%a)@ " (print_list comma print_tv_arg) tvl

219
220
let print_ty_arg info fmt ty = fprintf fmt "%a" (print_ty_node true info) ty
let print_vs_arg info fmt vs = fprintf fmt "(%a)" (print_vsty info) vs
221

222
223
224
225
226
227
let print_constr info fmt (cs,_) = match cs.ls_args with
  | [] ->
      fprintf fmt "@[<hov 4>| %a@]" (print_cs info) cs
  | tl ->
      fprintf fmt "@[<hov 4>| %a of %a@]" (print_cs info) cs
        (print_list star (print_ty_arg info)) tl
228

229
let print_type_decl info fmt ts = match ts.ts_def with
230
  | None ->
231
232
      fprintf fmt
        "@[<hov 2>type %a%a (* to be defined (uninterpreted type) *)@]"
233
        print_tv_args ts.ts_args (print_ts info) ts
234
235
  | Some ty ->
      fprintf fmt "@[<hov 2>type %a%a =@ %a@]"
236
        print_tv_args ts.ts_args (print_ts info) ts (print_ty info) ty
237

238
let print_type_decl info fmt ts =
239
240
241
242
  if has_syntax info ts.ts_name then
    fprintf fmt "(* type %a is overridden by driver *)"
      (print_lident info) ts.ts_name
  else begin print_type_decl info fmt ts; forget_tvs () end
243

244
let print_data_decl info fst fmt (ts,csl) =
245
  let print_default () = print_list newline (print_constr info) fmt csl in
246
247
  let print_field fmt ls =
    fprintf fmt "%a: %a"
248
      (print_ls info) ls (print_ty info) (Opt.get ls.ls_value) in
249
250
251
252
253
254
255
256
  let print_defn fmt = function
    | [cs, _] ->
        let pjl = get_record info cs in
        if pjl = [] then print_default ()
        else fprintf fmt "{ %a }" (print_list semi print_field) pjl
    | _ ->
        print_default ()
  in
257
258
  fprintf fmt "@[<hov 2>%s %a%a =@\n@[<hov>%a@]@]"
    (if fst then "type" else "and")
259
    print_tv_args ts.ts_args (print_ts info) ts print_defn csl
260

261
262
263
264
265
let print_data_decl info first fmt (ts, _ as d) =
  if has_syntax info ts.ts_name then
    fprintf fmt "(* type %a is overridden by driver *)"
      (print_lident info) ts.ts_name
  else begin print_data_decl info first fmt d; forget_tvs () end
266

267
268
269
270
271
272
let is_record = function
  | _, [_, pjl] -> List.for_all ((<>) None) pjl
  | _ -> false

let print_projections info fmt (_, csl) =
  let pjl = List.filter ((<>) None) (snd (List.hd csl)) in
273
  let pjl = List.map Opt.get pjl in
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
  let print ls =
    let print_branch fmt (cs, pjl) =
      let print_arg fmt = function
        | Some ls' when ls_equal ls' ls -> fprintf fmt "x"
        | _ -> fprintf fmt "_" in
      fprintf fmt "| %a (%a) -> x"
        (print_cs info) cs (print_list comma print_arg) pjl
    in
    fprintf fmt "@[<hov 2>let %a = function@\n" (print_ls info) ls;
    print_list newline print_branch fmt csl;
    fprintf fmt "@]@\n@\n"
  in
  List.iter print pjl

let print_projections info fmt (ts, _ as d) =
  if not (has_syntax info ts.ts_name) && not (is_record d) then begin
    print_projections info fmt d; forget_tvs ()
  end

293
294
295
296
297
298
299
(** Inductive *)

let name_args l =
  let r = ref 0 in
  let mk ty = incr r; create_vsymbol (id_fresh "x") ty in
  List.map mk l

300
301
302
303
let print_ind_decl info sign fst fmt (ps,_ as d) =
  let print_ind fmt d =
    if fst then Pretty.print_ind_decl fmt sign d
    else Pretty.print_next_ind_decl fmt d in
304
  let vars = name_args ps.ls_args in
305
  fprintf fmt "@[<hov 2>%s %a %a : bool =@ @[<hov>%a@\n(* @[%a@] *)@]@]"
306
307
    (if fst then "let rec" else "and") (print_ls info) ps
    (print_list space (print_vs_arg info)) vars
308
309
    to_be_implemented "inductive"
    print_ind d;
310
311
  forget_vars vars

312
313
314
315
316
let print_ind_decl info sign first fmt (ls, _ as d) =
  if has_syntax info ls.ls_name then
    fprintf fmt "(* inductive %a is overridden by driver *)"
      (print_lident info) ls.ls_name
  else begin print_ind_decl info sign first fmt d; forget_tvs () end
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343

(** Functions/Predicates *)

let rec is_exec_term t = match t.t_node with
  | Tvar _
  | Tconst _
  | Ttrue
  | Tfalse ->
      true
  | Tapp (_, tl) ->
      List.for_all is_exec_term tl
  | Tif (t1, t2, t3) ->
      is_exec_term t1 && is_exec_term t2 && is_exec_term t3
  | Tlet (t1, b2) ->
      is_exec_term t1 && let _, t2 = t_open_bound b2 in is_exec_term t2
  | Tcase (t1, bl) ->
      is_exec_term t1 && List.for_all is_exec_branch bl
  | Teps _ | Tquant _ ->
      false (* TODO: improve? *)
  | Tbinop (_, t1, t2) ->
      is_exec_term t1 && is_exec_term t2
  | Tnot t1 ->
      is_exec_term t1

and is_exec_branch b =
  let _, t = t_open_branch b in is_exec_term t

344
let print_const fmt = function
345
  | ConstInt (IConstDec s) ->
346
      fprintf fmt "(Why3__BuiltIn.int_constant \"%s\")" s
347
348
349
350
351
352
353
  | ConstInt (IConstHex s) -> fprintf fmt (tbi "0x%s") s
  | ConstInt (IConstOct s) -> fprintf fmt (tbi "0o%s") s
  | ConstInt (IConstBin s) -> fprintf fmt (tbi "0b%s") s
  | ConstReal (RConstDec (i,f,None)) -> fprintf fmt (tbi "%s.%s") i f
  | ConstReal (RConstDec (i,f,Some e)) -> fprintf fmt (tbi "%s.%se%s") i f e
  | ConstReal (RConstHex (i,f,Some e)) -> fprintf fmt (tbi "0x%s.%sp%s") i f e
  | ConstReal (RConstHex (i,f,None)) -> fprintf fmt (tbi "0x%s.%s") i f
354
355
356
357
358
359
360
361
362
363
364
365

(* can the type of a value be derived from the type of the arguments? *)
let unambig_fs fs =
  let rec lookup v ty = match ty.ty_node with
    | Tyvar u when tv_equal u v -> true
    | _ -> ty_any (lookup v) ty
  in
  let lookup v = List.exists (lookup v) fs.ls_args in
  let rec inspect ty = match ty.ty_node with
    | Tyvar u when not (lookup u) -> false
    | _ -> ty_all inspect ty
  in
366
  Opt.fold (fun _ -> inspect) true fs.ls_value
367
368
369

(** Patterns, terms, and formulas *)

370
371
372
373
374
let filter_ghost ls def al =
  let flt fd arg = if fd.Mlw_expr.fd_ghost then def else arg in
  try List.map2 flt (Mlw_expr.restore_pl ls).Mlw_expr.pl_args al
  with Not_found -> al

375
let rec print_pat_node pri info fmt p = match p.pat_node with
376
377
378
379
380
381
  | Term.Pwild ->
      fprintf fmt "_"
  | Term.Pvar v ->
      print_vs fmt v
  | Term.Pas (p, v) ->
      fprintf fmt (protect_on (pri > 1) "%a as %a")
382
        (print_pat_node 1 info) p print_vs v
383
384
  | Term.Por (p, q) ->
      fprintf fmt (protect_on (pri > 0) "%a | %a")
385
        (print_pat_node 0 info) p (print_pat_node 0 info) q
386
387
  | Term.Papp (cs, pl) when is_fs_tuple cs ->
      fprintf fmt "(%a)"
388
        (print_list comma (print_pat_node 1 info)) pl
389
  | Term.Papp (cs, pl) ->
390
391
392
    begin match query_syntax info.info_syn cs.ls_name with
      | Some s -> syntax_arguments s (print_pat_node 0 info) fmt pl
      | None when pl = [] -> print_cs info fmt cs
393
      | _ ->
394
395
          let pat_void = Term.pat_app Mlw_expr.fs_void [] Mlw_ty.ty_unit in
          let pl = filter_ghost cs pat_void pl in
396
397
398
399
400
401
402
403
404
          let pjl = get_record info cs in
          if pjl = [] then
            fprintf fmt (protect_on (pri > 1) "%a@ (%a)")
              (print_cs info) cs (print_list comma (print_pat_node 2 info)) pl
          else
            let print_field fmt (ls, p) = fprintf fmt "%a = %a"
              (print_ls info) ls (print_pat_node 0 info) p in
            fprintf fmt "{ %a }" (print_list semi print_field)
              (List.combine pjl pl)
405
    end
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420

let print_pat = print_pat_node 0

let print_binop fmt = function
  | Tand -> fprintf fmt "&&"
  | Tor -> fprintf fmt "||"
  | Tiff -> fprintf fmt "="
  | Timplies -> assert false

let prio_binop = function
  | Tand -> 3
  | Tor -> 2
  | Timplies -> 1
  | Tiff -> 1

421
422
423
424
425
426
427
let rec print_term info fmt t =
  print_lterm 0 info fmt t

and print_lterm pri info fmt t =
  print_tnode pri info fmt t

and print_app pri ls info fmt tl =
428
429
430
  let isconstr = is_constructor info ls in
  let is_field (_, csl) = match csl with
    | [_, pjl] ->
431
432
        let is_ls = function None -> false | Some ls' -> ls_equal ls ls' in
        List.for_all ((<>) None) pjl && List.exists is_ls pjl
433
434
435
    | _ -> false in
  let isfield = match Mid.find_opt ls.ls_name info.th_known_map with
    | Some { d_node = Ddata dl } -> not isconstr && List.exists is_field dl
436
437
    | _ -> false
  in
438
  let print = if isconstr then print_cs else print_ls in
439
440
441
  match tl with
  | [] ->
      print info fmt ls
442
  | tl when isconstr ->
443
      let tl = filter_ghost ls Mlw_expr.t_void tl in
444
445
446
447
448
449
450
451
452
      let pjl = get_record info ls in
      if pjl = [] then
        fprintf fmt (protect_on (pri > 5) "@[<hov 1>%a@ (%a)@]")
          (print_cs info) ls (print_list comma (print_lterm 6 info)) tl
      else
        let print_field fmt (ls, t) =
          fprintf fmt "%a = %a" (print_ls info) ls (print_term info) t in
        fprintf fmt "@[<hov 1>{ %a }@]" (print_list semi print_field)
          (List.combine pjl tl)
453
454
455
456
457
  | [t1] when isfield ->
      fprintf fmt "(%a).%a" (print_term info) t1 (print info) ls
  | [t1] ->
      fprintf fmt (protect_on (pri > 4) "%a %a")
        (print info) ls (print_lterm 5 info) t1
458
  | tl ->
459
      fprintf fmt (protect_on (pri > 5) "@[<hov 1>%a@ %a@]")
460
        (print_ls info) ls (print_list space (print_lterm 6 info)) tl
461

462
and print_tnode pri info fmt t = match t.t_node with
463
  | Tvar v ->
464
465
      let gh = try (Mlw_ty.restore_pv v).Mlw_ty.pv_ghost
      with Not_found -> false in
466
      if gh then fprintf fmt "()" else print_vs fmt v
467
468
469
  | Tconst c ->
      print_const fmt c
  | Tapp (fs, tl) when is_fs_tuple fs ->
470
      fprintf fmt "(%a)" (print_list comma (print_term info)) tl
471
  | Tapp (fs, tl) ->
472
473
474
475
476
477
478
      begin match query_syntax info.info_syn fs.ls_name with
      | Some s -> syntax_arguments s (print_term info) fmt tl
      | None when unambig_fs fs -> print_app pri fs info fmt tl
      | None ->
          fprintf fmt (protect_on (pri > 0) "@[<hov 2>(%a:@ %a)@]")
            (print_app 5 fs info) tl (print_ty info) (t_type t)
      end
479
480
  | Tif (f,t1,t2) ->
      fprintf fmt (protect_on (pri > 0) "if @[%a@] then %a@ else %a")
481
        (print_term info) f (print_term info) t1 (print_term info) t2
482
483
484
  | Tlet (t1,tb) ->
      let v,t2 = t_open_bound tb in
      fprintf fmt (protect_on (pri > 0) "let %a = @[%a@] in@ %a")
485
        print_vs v (print_lterm 4 info) t1 (print_term info) t2;
486
487
      forget_var v
  | Tcase (t1,bl) ->
488
489
      fprintf fmt "@[(match @[%a@] with@\n@[<hov>%a@])@]"
        (print_term info) t1 (print_list newline (print_tbranch info)) bl
490
491
492
493
494
495
496
497
  | Teps _
  | Tquant _ ->
      assert false
  | Ttrue ->
      fprintf fmt "true"
  | Tfalse ->
      fprintf fmt "false"
  | Tbinop (Timplies,f1,f2) ->
498
      fprintf fmt "(not (%a) || (%a))" (print_term info) f1 (print_term info) f2
499
500
501
  | Tbinop (b,f1,f2) ->
      let p = prio_binop b in
      fprintf fmt (protect_on (pri > p) "@[<hov 1>%a %a@ %a@]")
502
        (print_lterm (p + 1) info) f1 print_binop b (print_lterm p info) f2
503
  | Tnot f ->
504
      fprintf fmt (protect_on (pri > 4) "not %a") (print_lterm 5 info) f
505

506
and print_tbranch info fmt br =
507
  let p,t = t_open_branch br in
508
  fprintf fmt "@[<hov 4>| %a ->@ %a@]" (print_pat info) p (print_term info) t;
509
510
  Svs.iter forget_var p.pat_vars

511
(* dead code
512
and print_tl info fmt tl =
513
  if tl = [] then () else fprintf fmt "@ [%a]"
514
    (print_list alt (print_list comma (print_term info))) tl
515
*)
516

517
let print_ls_type info fmt = function
518
  | None -> fprintf fmt "bool"
519
  | Some ty -> print_ty info fmt ty
520

521
522
let print_defn info fmt e =
  if is_exec_term e then print_term info fmt e
523
524
  else fprintf fmt "@[<hov>%a@ @[(* %a *)@]@]"
    to_be_implemented "not executable" Pretty.print_term e
525

526
let print_param_decl info fmt ls =
527
528
529
530
531
  if has_syntax info ls.ls_name then
    fprintf fmt "(* parameter %a is overridden by driver *)"
      (print_lident info) ls.ls_name
  else begin
    let vars = name_args ls.ls_args in
532
    fprintf fmt "@[<hov 2>(*let %a %a : %a =@ %a*)@]"
533
534
535
536
537
538
539
      (print_ls info) ls
      (print_list space (print_vs_arg info)) vars
      (print_ls_type info) ls.ls_value
      to_be_implemented "uninterpreted symbol";
    forget_vars vars;
    forget_tvs ()
  end
540

541
let print_logic_decl info isrec fst fmt (ls,ld) =
542
543
544
545
546
547
  if has_syntax info ls.ls_name then
    fprintf fmt "(* symbol %a is overridden by driver *)"
      (print_lident info) ls.ls_name
  else begin
    let vl,e = open_ls_defn ld in
    fprintf fmt "@[<hov 2>%s %a %a : %a =@ %a@]"
548
549
      (if fst then if isrec then "let rec" else "let" else "and")
      (print_ls info) ls
550
551
552
553
554
      (print_list space (print_vs_arg info)) vl
      (print_ls_type info) ls.ls_value (print_defn info) e;
    forget_vars vl;
    forget_tvs ()
  end
555
556
557
558
559
560
561
562
563
564
565
566

(** Logic Declarations *)

let print_list_next sep print fmt = function
  | [] ->
      ()
  | [x] ->
      print true fmt x
  | x :: r ->
      print true fmt x; sep fmt ();
      print_list sep (print false) fmt r

567
let logic_decl info fmt d = match d.d_node with
568
  | Dtype ts ->
569
570
      print_type_decl info fmt ts;
      fprintf fmt "@\n@\n"
571
  | Ddata tl ->
572
      print_list_next newline (print_data_decl info) fmt tl;
573
574
      fprintf fmt "@\n@\n";
      print_list nothing (print_projections info) fmt tl
575
  | Decl.Dparam ls ->
576
577
      print_param_decl info fmt ls;
      fprintf fmt "@\n@\n"
578
579
580
581
  | Dlogic [ls,_ as ld] ->
      let isrec = Sid.mem ls.ls_name d.d_syms in
      print_logic_decl info isrec true fmt ld;
      fprintf fmt "@\n@\n"
582
  | Dlogic ll ->
583
      print_list_next newline (print_logic_decl info true) fmt ll;
584
585
586
587
588
589
590
      fprintf fmt "@\n@\n"
  | Dind (s, il) ->
      print_list_next newline (print_ind_decl info s) fmt il;
      fprintf fmt "@\n@\n"
  | Dprop (_pk, _pr, _) ->
      ()
      (* fprintf fmt "(* %a %a *)" Pretty.print_pkind pk Pretty.print_pr pr *)
591

592
let logic_decl info fmt td = match td.td_node with
593
  | Decl d ->
594
595
596
597
598
      let union = Sid.union d.d_syms d.d_news in
      let inter = Mid.set_inter union info.mo_known_map in
      if Sid.is_empty inter then logic_decl info fmt d
  | Use _ | Clone _ | Meta _ ->
      ()
599
600
601

(** Theories *)

602
let extract_theory drv ?old ?fname fmt th =
603
  ignore (old); ignore (fname);
604
  let sm = drv.Mlw_driver.drv_syntax in
605
606
  let info = {
    info_syn = sm;
607
    current_theory = th;
608
    current_module = None;
609
610
    th_known_map = th.th_known;
    mo_known_map = Mid.empty;
611
    fname = Opt.map clean_fname fname; } in
612
613
614
  fprintf fmt
    "(* This file has been generated from Why3 theory %a *)@\n@\n"
    print_theory_name th;
615
616
  fprintf fmt
    "open Why3extract@\n@\n";
617
  print_list nothing (logic_decl info) fmt th.th_decls;
618
  fprintf fmt "@."
619

620
(** Programs *)
621

622
623
624
625
626
627
open Mlw_ty
open Mlw_ty.T
open Mlw_expr
open Mlw_decl
open Mlw_module

628
let print_its info fmt ts = print_ts info fmt ts.its_ts
629
let print_pv info fmt pv =
630
  if pv.pv_ghost then
631
632
633
634
    fprintf fmt "((* ghost %a *))" (print_lident info) pv.pv_vs.vs_name
  else
    print_lident info fmt pv.pv_vs.vs_name
let print_ps info fmt ps =
635
  if ps.ps_ghost then
636
637
638
    fprintf fmt "((* ghost %a *))" (print_lident info) ps.ps_name
  else
    print_lident info fmt ps.ps_name
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
let print_lv info fmt = function
  | LetV pv -> print_pv info fmt pv
  | LetA ps -> print_ps info fmt ps

let print_xs info fmt xs = print_uident info fmt xs.xs_name

let forget_ps ps = forget_id iprinter ps.ps_name
let forget_pv pv = forget_var pv.pv_vs
let forget_lv = function
  | LetV pv -> forget_pv pv
  | LetA ps -> forget_ps ps

let rec print_ity_node inn info fmt ity = match ity.ity_node with
  | Ityvar v ->
      print_tv fmt v
  | Itypur (ts, []) when is_ts_tuple ts ->
      fprintf fmt "unit"
  | Itypur (ts, tl) when is_ts_tuple ts ->
      fprintf fmt "(%a)" (print_list star (print_ity_node false info)) tl
  | Itypur (ts, tl) ->
      begin match query_syntax info.info_syn ts.ts_name with
        | Some s -> syntax_arguments s (print_ity_node true info) fmt tl
        | None -> begin match tl with
            | [] -> print_ts info fmt ts
            | [ity] -> fprintf fmt (protect_on inn "%a@ %a")
                  (print_ity_node true info) ity (print_ts info) ts
            | _ -> fprintf fmt (protect_on inn "(%a)@ %a")
              (print_list comma (print_ity_node false info)) tl
              (print_ts info) ts
        end
      end
  | Ityapp (ts, tl, _) ->
671
      begin match query_syntax info.info_syn ts.its_ts.ts_name with
672
673
674
675
676
677
678
679
680
681
682
683
        | Some s -> syntax_arguments s (print_ity_node true info) fmt tl
        | None -> begin match tl with
            | [] -> print_its info fmt ts
            | [ity] -> fprintf fmt (protect_on inn "%a@ %a")
                  (print_ity_node true info) ity (print_its info) ts
            | _ -> fprintf fmt (protect_on inn "(%a)@ %a")
              (print_list comma (print_ity_node false info)) tl
              (print_its info) ts
        end
      end

let print_ity info = print_ity_node false info
684

685
let print_pvty info fmt pv =
686
  if pv.pv_ghost then fprintf fmt "((* ghost *))" else
687
  fprintf fmt "@[(%a:@ %a)@]"
688
    (print_lident info) pv.pv_vs.vs_name (print_ity info) pv.pv_ity
689

690
let rec print_aty info fmt aty =
691
  let print_arg fmt pv = print_ity info fmt pv.pv_ity in
692
693
  fprintf fmt "(%a -> %a)" (print_list Pp.arrow print_arg) aty.aty_args
    (print_vty info) aty.aty_result
694

695
and print_vty info fmt = function
696
  | VTvalue ity -> print_ity info fmt ity
697
  | VTarrow aty -> print_aty info fmt aty
698

699
let is_letrec = function
700
  | [fd] -> fd.fun_lambda.l_spec.c_letrec <> 0
701
702
  | _ -> true

703
704
705
706
let ity_mark = ity_pur Mlw_wp.ts_mark []

let rec print_expr info fmt e = print_lexpr 0 info fmt e

707
and print_lexpr pri info fmt e =
708
  if e.e_ghost then
709
710
    fprintf fmt "((* ghost *))"
  else match e.e_node with
711
712
713
714
715
  | Elogic t ->
      fprintf fmt "(%a)" (print_term info) t
  | Evalue v ->
      print_pv info fmt v
  | Earrow a ->
716
717
718
      begin match query_syntax info.info_syn a.ps_name with
        | Some s -> syntax_arguments s (print_expr info) fmt []
        | None   -> print_ps info fmt a end
719
720
  | Eapp (e,v,_) ->
      fprintf fmt "(%a@ %a)" (print_lexpr pri info) e (print_pv info) v
721
  | Elet ({ let_expr = e1 }, e2) when e1.e_ghost ->
722
      print_expr info fmt e2
723
  | Elet ({ let_sym = LetV pv }, e2)
724
    when ity_equal pv.pv_ity ity_mark ->
725
726
727
      print_expr info fmt e2
  | Elet ({ let_sym = LetV pv ; let_expr = e1 }, e2)
    when pv.pv_vs.vs_name.id_string = "_" &&
728
         ity_equal pv.pv_ity ity_unit ->
729
730
731
732
733
734
735
736
737
738
      fprintf fmt (protect_on (pri > 0) "@[begin %a;@ %a end@]")
        (print_expr info) e1 (print_expr info) e2;
  | Elet ({ let_sym = lv ; let_expr = e1 }, e2) ->
      fprintf fmt (protect_on (pri > 0) "@[<hov 2>let %a =@ %a@ in@]@\n%a")
        (print_lv info) lv (print_lexpr 4 info) e1 (print_expr info) e2;
      forget_lv lv
  | Eif (e0,e1,e2) ->
      fprintf fmt (protect_on (pri > 0)
                     "@[<hv>if %a@ @[<hov 2>then %a@]@ @[<hov 2>else %a@]@]")
        (print_expr info) e0 (print_expr info) e1 (print_expr info) e2
739
740
741
  | Eassign (pl,e,_,pv) ->
      fprintf fmt (protect_on (pri > 0) "%a.%a <- %a")
        (print_expr info) e (print_ls info) pl.pl_ls (print_pv info) pv
742
743
744
  | Eloop (_,_,e) ->
      fprintf fmt "@[while true do@ %a@ done@]" (print_expr info) e
  | Efor (pv,(pvfrom,dir,pvto),_,e) ->
745
746
747
748
749
      fprintf fmt
        "@[<hov 2>(Int__Int.for_loop_%s %a %a@ (fun %a -> %a))@]"
        (if dir = To then "to" else "downto")
        (print_pv info) pvfrom (print_pv info) pvto
        (print_pv info) pv (print_expr info) e
750
  | Eraise (xs,e) ->
751
752
753
754
755
756
757
      begin match query_syntax info.info_syn xs.xs_name with
        | Some s -> syntax_arguments s (print_expr info) fmt [e]
        | None when ity_equal xs.xs_ity ity_unit ->
            fprintf fmt "raise %a" (print_xs info) xs
        | None ->
            fprintf fmt "raise (%a %a)" (print_xs info) xs (print_expr info) e
      end
758
  | Etry (e,bl) ->
759
      fprintf fmt "@[(try %a with@\n@[<hov>%a@])@]"
760
        (print_expr info) e (print_list newline (print_xbranch info)) bl
761
  | Eabstr (e,_) ->
762
763
764
765
      print_lexpr pri info fmt e
  | Eabsurd ->
      fprintf fmt "assert false (* absurd *)"
  | Eassert _ ->
766
      fprintf fmt "((* assert *))"
767
  | Eghost _ ->
768
      assert false
769
770
771
  | Eany _ ->
      fprintf fmt "@[(%a :@ %a)@]" to_be_implemented "any"
        (print_vty info) e.e_vty
772
  | Ecase (e1, [_,e2]) when e1.e_ghost ->
773
      print_lexpr pri info fmt e2
774
775
776
  | Ecase (e1, bl) ->
      fprintf fmt "@[(match @[%a@] with@\n@[<hov>%a@])@]"
        (print_expr info) e1 (print_list newline (print_ebranch info)) bl
777
  | Erec (fdl, e) ->
778
779
      (* print non-ghost first *)
      let cmp {fun_ps=ps1} {fun_ps=ps2} =
780
        Pervasives.compare ps1.ps_ghost ps2.ps_ghost in
781
      let fdl = List.sort cmp fdl in
782
      fprintf fmt "@[<v>%a@\nin@\n%a@]"
783
        (print_list_next newline (print_rec_decl (is_letrec fdl) info)) fdl
784
        (print_expr info) e
785
786

and print_rec lr info fst fmt { fun_ps = ps ; fun_lambda = lam } =
787
  if ps.ps_ghost then
788
    fprintf fmt "@[<hov 2>%s %a = ()@]"
789
      (if fst then if lr then "let rec" else "let" else "with")
790
791
792
793
      (print_ps info) ps
  else
    let print_arg fmt pv = fprintf fmt "@[%a@]" (print_pvty info) pv in
    fprintf fmt "@[<hov 2>%s %a %a =@ %a@]"
794
      (if fst then if lr then "let rec" else "let" else "and")
795
796
      (print_ps info) ps (print_list space print_arg) lam.l_args
      (print_expr info) lam.l_expr
797

798
799
800
801
and print_ebranch info fmt ({ppat_pattern=p}, e) =
  fprintf fmt "@[<hov 4>| %a ->@ %a@]" (print_pat info) p (print_expr info) e;
  Svs.iter forget_var p.pat_vars

802
and print_xbranch info fmt (xs, pv, e) =
803
804
805
806
807
808
809
810
811
  begin match query_syntax info.info_syn xs.xs_name with
    | Some s -> syntax_arguments s (print_pv info) fmt [pv]
    | None when ity_equal xs.xs_ity ity_unit ->
        fprintf fmt "@[<hov 4>| %a ->@ %a@]"
          (print_xs info) xs (print_expr info) e
    | None ->
        fprintf fmt "@[<hov 4>| %a %a ->@ %a@]"
          (print_xs info) xs (print_pv info) pv (print_expr info) e
  end;
812
813
  forget_pv pv

814
815
and print_rec_decl lr info fst fmt fd =
  print_rec lr info fst fmt fd;
816
  forget_tvs ()
817

818
819
let print_rec_decl lr info fst fmt fd =
  let id = fd.fun_ps.ps_name in
820
821
822
  if has_syntax info id then
    fprintf fmt "(* symbol %a is overridden by driver *)" (print_lident info) id
  else
823
    print_rec_decl lr info fst fmt fd
824

825
826
827
let print_let_decl info fmt { let_sym = lv ; let_expr = e } =
  fprintf fmt "@[<hov 2>let %a =@ %a@]" (print_lv info) lv (print_expr info) e;
  forget_tvs ()
828

829
830
831
832
833
let lv_name = function
  | LetV pv -> pv.pv_vs.vs_name
  | LetA ps -> ps.ps_name

let is_ghost_lv = function
834
835
  | LetV pv -> pv.pv_ghost
  | LetA ps -> ps.ps_ghost
836
837
838
839
840
841
842
843
844
845
846
847
848
849

let print_let_decl info fmt ld =
  if is_ghost_lv ld.let_sym then
    fprintf fmt "(* let %a *)@\n@\n" (print_lv info) ld.let_sym
  else
    fprintf fmt "%a@\n@\n" (print_let_decl info) ld

let print_let_decl info fmt ld =
  let id = lv_name ld.let_sym in
  if has_syntax info id then
    fprintf fmt "(* symbol %a is overridden by driver *)" (print_lident info) id
  else
    print_let_decl info fmt ld

850
851
let rec extract_aty_args args aty =
  let new_args = List.map (fun pv -> pv.pv_vs) aty.aty_args in
852
  let args = List.rev_append new_args args in
853
  match aty.aty_result with
854
  | VTvalue ity -> List.rev args, ity
855
  | VTarrow aty -> extract_aty_args args aty
856

857
let extract_lv_args = function
858
  | LetV pv -> [], pv.pv_ity
859
  | LetA ps -> extract_aty_args [] ps.ps_aty
860

861
let print_val_decl info fmt lv =
862
  let vars, ity = extract_lv_args lv in
863
864
865
  fprintf fmt "@[<hov 2>let %a %a : %a =@ %a@]"
    (print_lv info) lv
    (print_list space (print_vs_arg info)) vars
866
    (print_ity info) ity
867
868
869
    to_be_implemented "val";
  forget_vars vars;
  forget_tvs ()
870

871
872
873
874
875
876
877
878
879
880
881
882
883
let print_val_decl info fmt lv =
  if is_ghost_lv lv then
    fprintf fmt "(* val %a *)@\n@\n" (print_lv info) lv
  else
    fprintf fmt "%a@\n@\n" (print_val_decl info) lv

let print_val_decl info fmt lv =
  let id = lv_name lv in
  if has_syntax info id then
    fprintf fmt "(* symbol %a is overridden by driver *)" (print_lident info) id
  else
    print_val_decl info fmt lv

884
885
886
887
let print_type_decl info fmt its = match its.its_def with
  | None ->
      fprintf fmt
        "@[<hov 2>type %a%a (* to be defined (uninterpreted type) *)@]@\n@\n"
888
        print_tv_args its.its_ts.ts_args (print_its info) its
889
890
  | Some ty ->
      fprintf fmt "@[<hov 2>type %a%a =@ %a@]@\n@\n"
891
892
        print_tv_args its.its_ts.ts_args
        (print_its info) its (print_ity info) ty
893
894

let print_type_decl info fmt its =
895
  if has_syntax info its.its_ts.ts_name then
896
    fprintf fmt "(* type %a is overridden by driver *)"
897
      (print_lident info) its.its_ts.ts_name
898
899
  else begin print_type_decl info fmt its; forget_tvs () end

900
901
902
903
904
905
906
907
908
909
910
911
912
913
let print_exn_decl info fmt xs =
  if ity_equal xs.xs_ity ity_unit then
    fprintf fmt "exception %a@\n@\n" (print_xs info) xs
  else
    fprintf fmt "exception %a of %a@\n@\n" (print_uident info) xs.xs_name
      (print_ity info) xs.xs_ity

let print_exn_decl info fmt xs =
  if has_syntax info xs.xs_name then
    fprintf fmt "(* symbol %a is overridden by driver *)"
      (print_lident info) xs.xs_name
  else
    print_exn_decl info fmt xs

914
915
916
let print_field info fmt fd =
  if fd.fd_ghost then fprintf fmt "unit" else print_ity info fmt fd.fd_ity

917
918
919
920
921
let print_pconstr info fmt (cs,_) = match cs.pl_args with
  | [] ->
      fprintf fmt "@[<hov 4>| %a@]" (print_cs info) cs.pl_ls
  | tl ->
      fprintf fmt "@[<hov 4>| %a of %a@]" (print_cs info) cs.pl_ls
922
        (print_list star (print_field info)) tl
923
924
925

let print_pdata_decl info fst fmt (its, csl, _) =
  let print_default () = print_list newline (print_pconstr info) fmt csl in
926
927
928
929
930
  let print_field fmt (ls, fd) =
    fprintf fmt "%s%a: %a"
      (if fd.fd_mut <> None then "mutable " else "")
      (print_ls info) ls (print_field info) fd
  in
931
932
933
934
935
936
937
938
939
940
941
  let print_defn fmt = function
    | [cs, _] ->
        let pjl = get_record info cs.pl_ls in
        if pjl = [] then print_default ()
        else fprintf fmt "{ %a }" (print_list semi print_field)
          (List.combine pjl cs.pl_args)
    | _ ->
        print_default ()
  in
  fprintf fmt "@[<hov 2>%s %a%a =@\n@[<hov>%a@]@]"
    (if fst then "type" else "and")
942
    print_tv_args its.its_ts.ts_args (print_its info) its print_defn csl
943
944

let print_pdata_decl info first fmt (its, _, _ as d) =
945
  if has_syntax info its.its_ts.ts_name then
946
    fprintf fmt "(* type %a is overridden by driver *)"
947
      (print_lident info) its.its_ts.ts_name
948
949
950
951
952
953
954
955
  else begin print_pdata_decl info first fmt d; forget_tvs () end

let is_record = function
  | _, [_, pjl], _ -> List.for_all ((<>) None) pjl
  | _ -> false

let print_pprojections info fmt (_, csl, _) =
  let pjl = List.filter ((<>) None) (snd (List.hd csl)) in
956
  let pjl = List.map Opt.get pjl in
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
  let print ls =
    let print_branch fmt (cs, pjl) =
      let print_arg fmt = function
        | Some ls' when pl_equal ls' ls -> fprintf fmt "x"
        | _ -> fprintf fmt "_" in
      fprintf fmt "| %a (%a) -> x"
        (print_cs info) cs.pl_ls (print_list comma print_arg) pjl
    in
    fprintf fmt "@[<hov 2>let %a = function@\n" (print_ls info) ls.pl_ls;
    print_list newline print_branch fmt csl;
    fprintf fmt "@]@\n@\n"
  in
  List.iter print pjl

let print_pprojections info fmt (ts, _, _ as d) =
972
  if not (has_syntax info ts.its_ts.ts_name) && not (is_record d) then begin
973
974
975
    print_pprojections info fmt d; forget_tvs ()
  end

976
let pdecl info fmt pd = match pd.pd_node with
977
  | PDtype ts ->
978
979
      print_type_decl info fmt ts;
      fprintf fmt "@\n@\n"
980
981
982
983
  | PDdata tl ->
      print_list_next newline (print_pdata_decl info) fmt tl;
      fprintf fmt "@\n@\n";
      print_list nothing (print_pprojections info) fmt tl
984
  | PDval lv ->
985
986
      print_val_decl info fmt lv;
      fprintf fmt "@\n@\n"
987
  | PDlet ld ->
988
      print_let_decl info fmt ld
989
  | PDrec fdl ->
990
991
992
      (* print defined, non-ghost first *)
      let cmp {fun_ps=ps1} {fun_ps=ps2} =
        Pervasives.compare
993
994
          (ps1.ps_ghost || has_syntax info ps1.ps_name)
          (ps2.ps_ghost || has_syntax info ps2.ps_name) in
995
996
      let fdl = List.sort cmp fdl in
      print_list_next newline (print_rec_decl (is_letrec fdl) info) fmt fdl;
997
998
      fprintf fmt "@\n@\n"
  | PDexn xs ->
999
      print_exn_decl info fmt xs
1000
1001
1002

(** Modules *)

1003
let extract_module drv ?old ?fname fmt m =
1004
  ignore (old); ignore (fname);
1005
  let sm = drv.Mlw_driver.drv_syntax in
1006
1007
1008
  let th = m.mod_theory in
  let info = {
    info_syn = sm;
1009
    current_theory = th;
1010
    current_module = Some m;
1011
1012
    th_known_map = th.th_known;
    mo_known_map = m.mod_known;
1013
    fname = Opt.map clean_fname fname; } in
1014
1015
1016
  fprintf fmt
    "(* This file has been generated from Why3 module %a *)@\n@\n"
    print_module_name m;
1017
1018
  fprintf fmt
    "open Why3extract@\n@\n";
1019
  print_list nothing (logic_decl info) fmt th.th_decls;
1020
  print_list nothing (pdecl info) fmt m.mod_decls;
1021
1022
1023
1024
1025
  fprintf fmt "@."


(*
Local Variables:
1026
compile-command: "unset LANG; make -C ../.."
1027
1028
End:
*)