coqBackend.ml 20.9 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13
(******************************************************************************)
(*                                                                            *)
(*                                   Menhir                                   *)
(*                                                                            *)
(*                       François Pottier, Inria Paris                        *)
(*              Yann Régis-Gianas, PPS, Université Paris Diderot              *)
(*                                                                            *)
(*  Copyright Inria. All rights reserved. This file is distributed under the  *)
(*  terms of the GNU General Public License version 2, as described in the    *)
(*  file LICENSE.                                                             *)
(*                                                                            *)
(******************************************************************************)

14 15 16 17 18 19 20
open Printf
open Grammar

module Run (T: sig end) = struct

  let print_term t =
    assert (not (Terminal.pseudo t));
21
    sprintf "%s't" (Terminal.print t)
22 23

  let print_nterm nt =
24
    sprintf "%s'nt" (Nonterminal.print true nt)
25 26 27 28 29 30 31 32 33 34

  let print_symbol = function
    | Symbol.N nt -> sprintf "NT %s" (print_nterm nt)
    | Symbol.T t -> sprintf "T %s" (print_term t)

  let print_type ty =
    if Settings.coq_no_actions then
      "unit"
    else
      match ty with
35
        | None -> raise Not_found (* fpottier: argh! *)
36
        | Some t -> match t with
37 38
            | Stretch.Declared s -> s.Stretch.stretch_content
            | Stretch.Inferred _ -> assert false (* We cannot infer coq types *)
39 40

  let is_final_state node =
41
    match Default.has_default_reduction node with
42 43 44 45 46 47 48 49 50 51 52 53 54
      | Some (prod, _) -> Production.is_start prod
      | None -> false

  let lr1_iter_nonfinal f =
    Lr1.iter (fun node -> if not (is_final_state node) then f node)

  let lr1_iterx_nonfinal f =
    Lr1.iterx (fun node -> if not (is_final_state node) then f node)

  let lr1_foldx_nonfinal f =
    Lr1.foldx (fun accu node -> if not (is_final_state node) then f accu node else accu)

  let print_nis nis =
55
    sprintf "Nis'%d" (Lr1.number nis)
56 57

  let print_init init =
58
    sprintf "Init'%d" (Lr1.number init)
59 60 61 62 63 64 65 66 67 68 69 70 71 72

  let print_st st =
    match Lr1.incoming_symbol st with
      | Some _ -> sprintf "Ninit %s" (print_nis st)
      | None -> sprintf "Init %s" (print_init st)

  let (prod_ids, _) =
    Production.foldx (fun p (prod_ids, counters) ->
      let lhs = Production.nt p in
      let id = try SymbolMap.find (Symbol.N lhs) counters with Not_found -> 0 in
      (ProductionMap.add p id prod_ids, SymbolMap.add (Symbol.N lhs) (id+1) counters))
      (ProductionMap.empty, SymbolMap.empty)

  let print_prod p =
73
    sprintf "Prod'%s'%d" (Nonterminal.print true (Production.nt p)) (ProductionMap.find p prod_ids)
74 75 76 77 78 79

  let () =
    if not Settings.coq_no_actions then
      begin
        Nonterminal.iterx (fun nonterminal ->
          match Nonterminal.ocamltype nonterminal with
80 81
            | None -> Error.error [] "I don't know the type of the nonterminal symbol %s."
                                     (Nonterminal.print false nonterminal)
82 83
            | Some _ -> ());
        Production.iterx (fun prod ->
84 85
          if not (Keyword.KeywordSet.is_empty (Action.keywords (Production.action prod))) then
            Error.error [] "The Coq back-end supports none of the $ keywords."
86
        )
87 88 89 90 91 92 93
      end;

    Production.iterx (fun prod ->
      Array.iter (fun symb ->
        match symb with
          | Symbol.T t ->
              if t = Terminal.error then
94
                Error.error [] "the Coq back-end does not support the error token."
95 96 97 98
          | _ -> ())
        (Production.rhs prod));

    if Front.grammar.UnparameterizedSyntax.parameters <> [] then
99
      Error.error [] "the Coq back-end does not support %%parameter."
100

POTTIER Francois's avatar
POTTIER Francois committed
101
  (* Optimized because if we extract some constants to the right caml term,
102 103 104 105 106 107
     the ocaml inlining+constant unfolding replaces that by the actual constant *)
  let rec write_optimized_int31 f n =
    match n with
      | 0 -> fprintf f "Int31.On"
      | 1 -> fprintf f "Int31.In"
      | k when k land 1 = 0 ->
108 109 110
        fprintf f "(twice ";
        write_optimized_int31 f (n lsr 1);
        fprintf f ")"
111
      | _ ->
112 113 114
        fprintf f "(twice_plus_one ";
        write_optimized_int31 f (n lsr 1);
        fprintf f ")"
115 116 117

  let write_inductive_alphabet f name constrs =
    fprintf f "Inductive %s' : Set :=" name;
118
    List.iter (fprintf f "\n| %s") constrs;
119 120 121 122 123 124 125 126
    fprintf f ".\n";
    fprintf f "Definition %s := %s'.\n\n" name name;
    if List.length constrs > 0 then
      begin
        let iteri f = ignore (List.fold_left (fun k x -> f k x; succ k) 0 constrs) in
        fprintf f "Program Instance %sNum : Numbered %s :=\n" name name;
        fprintf f "  { inj := fun x => match x return _ with ";
        iteri (fun k constr ->
127 128 129 130
          fprintf f "| %s => " constr;
          write_optimized_int31 f k;
          fprintf f " ";
        );
131 132 133 134
        fprintf f "end;\n";
        fprintf f "    surj := (fun n => match n return _ with ";
        iteri (fprintf f "| %d => %s ");
        fprintf f "| _ => %s end)%%int31;\n" (List.hd constrs);
135
        fprintf f "    inj_bound := %d%%int31 }.\n" (List.length constrs);
136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160
      end
    else
      begin
        fprintf f "Program Instance %sAlph : Alphabet %s :=\n" name name;
        fprintf f "  { AlphabetComparable := {| compare := fun x y =>\n";
        fprintf f "      match x, y return comparison with end |};\n";
        fprintf f "    AlphabetEnumerable := {| all_list := [] |} }.";
      end

  let write_terminals f =
    write_inductive_alphabet f "terminal" (
      Terminal.fold (fun t l -> if Terminal.pseudo t then l else print_term t::l)
        []);
    fprintf f "Instance TerminalAlph : Alphabet terminal := _.\n\n"

  let write_nonterminals f =
    write_inductive_alphabet f "nonterminal" (
      Nonterminal.foldx (fun nt l -> (print_nterm nt)::l) []);
    fprintf f "Instance NonTerminalAlph : Alphabet nonterminal := _.\n\n"

  let write_symbol_semantic_type f =
    fprintf f "Definition terminal_semantic_type (t:terminal) : Type:=\n";
    fprintf f "  match t with\n";
    Terminal.iter (fun terminal ->
      if not (Terminal.pseudo terminal) then
161
        fprintf f "  | %s => %s%%type\n"
162 163 164 165 166 167 168 169
          (print_term terminal)
          (try print_type (Terminal.ocamltype terminal) with Not_found -> "unit")
    );
    fprintf f "  end.\n\n";

    fprintf f "Definition nonterminal_semantic_type (nt:nonterminal) : Type:=\n";
    fprintf f "  match nt with\n";
    Nonterminal.iterx (fun nonterminal ->
170
                         fprintf f "  | %s => %s%%type\n"
171 172
                           (print_nterm nonterminal)
                           (print_type (Nonterminal.ocamltype nonterminal)));
173 174 175 176
    fprintf f "  end.\n\n";

    fprintf f "Definition symbol_semantic_type (s:symbol) : Type:=\n";
    fprintf f "  match s with\n";
177 178
    fprintf f "  | T t => terminal_semantic_type t\n";
    fprintf f "  | NT nt => nonterminal_semantic_type nt\n";
179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197
    fprintf f "  end.\n\n"

  let write_productions f =
    write_inductive_alphabet f "production" (
      Production.foldx (fun prod l -> (print_prod prod)::l) []);
    fprintf f "Instance ProductionAlph : Alphabet production := _.\n\n"

  let write_productions_contents f =
    fprintf f "Definition prod_contents (p:production) :\n";
    fprintf f "  { p:nonterminal * list symbol &\n";
    fprintf f "    arrows_left (map symbol_semantic_type (rev (snd p)))\n";
    fprintf f "                (symbol_semantic_type (NT (fst p))) }\n";
    fprintf f " :=\n";
    fprintf f "  let box := existT (fun p =>\n";
    fprintf f "    arrows_left (map symbol_semantic_type (rev (snd p)))\n";
    fprintf f "                (symbol_semantic_type (NT (fst p))))\n";
    fprintf f "  in\n";
    fprintf f "  match p with\n";
    Production.iterx (fun prod ->
198 199
      fprintf f "  | %s => box\n" (print_prod prod);
      fprintf f "    (%s, [%s])\n"
200 201 202 203
        (print_nterm (Production.nt prod))
        (String.concat "; "
           (List.map print_symbol (List.rev (Array.to_list (Production.rhs prod)))));
      if Production.length prod = 0 then
204
        fprintf f "    (\n"
205
      else
206
        fprintf f "    (fun %s =>\n"
207 208 209 210
          (String.concat " " (List.rev (Array.to_list (Production.identifiers prod))));
      if Settings.coq_no_actions then
        fprintf f "()"
      else
211
        Printer.print_expr f (Action.to_il_expr (Production.action prod));
212 213 214 215 216 217 218 219 220 221 222 223 224 225
      fprintf f "\n)\n");
    fprintf f "  end.\n\n";

    fprintf f "Definition prod_lhs (p:production) :=\n";
    fprintf f "  fst (projT1 (prod_contents p)).\n";
    fprintf f "Definition prod_rhs_rev (p:production) :=\n";
    fprintf f "  snd (projT1 (prod_contents p)).\n";
    fprintf f "Definition prod_action (p:production) :=\n";
    fprintf f "  projT2 (prod_contents p).\n\n"

  let write_nullable_first f =
    fprintf f "Definition nullable_nterm (nt:nonterminal) : bool :=\n";
    fprintf f "  match nt with\n";
    Nonterminal.iterx (fun nt ->
226
      fprintf f "  | %s => %b\n"
227
        (print_nterm nt)
228
        (Analysis.nullable nt));
229 230 231 232 233
    fprintf f "  end.\n\n";

    fprintf f "Definition first_nterm (nt:nonterminal) : list terminal :=\n";
    fprintf f "  match nt with\n";
    Nonterminal.iterx (fun nt ->
234
      let firstSet = Analysis.first nt in
235
      fprintf f "  | %s => [" (print_nterm nt);
236 237 238 239 240 241 242 243 244 245
      let first = ref true in
      TerminalSet.iter (fun t ->
        if !first then first := false else fprintf f "; ";
        fprintf f "%s" (print_term t)
        ) firstSet;
      fprintf f "]\n");
    fprintf f "  end.\n\n"

  let write_grammar f =
    fprintf f "Module Import Gram <: Grammar.T.\n\n";
246
    fprintf f "Local Obligation Tactic := let x := fresh in intro x; case x; reflexivity.\n\n";
247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262
    write_terminals f;
    write_nonterminals f;
    fprintf f "Include Grammar.Symbol.\n\n";
    write_symbol_semantic_type f;
    write_productions f;
    write_productions_contents f;
    fprintf f "Include Grammar.Defs.\n\n";
    fprintf f "End Gram.\n\n"

  let write_nis f =
    write_inductive_alphabet f "noninitstate" (
      lr1_foldx_nonfinal (fun l node -> (print_nis node)::l) []);
    fprintf f "Instance NonInitStateAlph : Alphabet noninitstate := _.\n\n"

  let write_init f =
    write_inductive_alphabet f "initstate" (
263
      ProductionMap.fold (fun _prod node l ->
264
        (print_init node)::l) Lr1.entry []);
265 266 267 268 269
    fprintf f "Instance InitStateAlph : Alphabet initstate := _.\n\n"

  let write_start_nt f =
    fprintf f "Definition start_nt (init:initstate) : nonterminal :=\n";
    fprintf f "  match init with\n";
270
    Lr1.fold_entry (fun _prod node startnt _t () ->
271
      fprintf f "  | %s => %s\n" (print_init node) (print_nterm startnt)
272
    ) ();
273 274 275 276 277 278
    fprintf f "  end.\n\n"

  let write_actions f =
    fprintf f "Definition action_table (state:state) : action :=\n";
    fprintf f "  match state with\n";
    lr1_iter_nonfinal (fun node ->
279
      fprintf f "  | %s => " (print_st node);
280
      match Default.has_default_reduction node with
281
        | Some (prod, _) ->
282
          fprintf f "Default_reduce_act %s\n" (print_prod prod)
283 284
        | None ->
          fprintf f "Lookahead_act (fun terminal:terminal =>\n";
285
          fprintf f "    match terminal return lookahead_action terminal with\n";
286 287 288 289 290 291
          let has_fail = ref false in
          Terminal.iter (fun t ->
            if not (Terminal.pseudo t) then
              begin
                try
                  let target = SymbolMap.find (Symbol.T t) (Lr1.transitions node) in
292
                  fprintf f "    | %s => Shift_act %s (eq_refl _)\n" (print_term t) (print_nis target)
293 294 295 296 297
                with Not_found ->
                  try
                    let prod =
                      Misc.single (TerminalMap.find t (Lr1.reductions node))
                    in
298
                    fprintf f "    | %s => Reduce_act %s\n" (print_term t) (print_prod prod)
299 300 301
                  with Not_found -> has_fail := true
              end);
          if !has_fail then
302 303
            fprintf f "    | _ => Fail_act\n";
          fprintf f "    end)\n"
304 305 306 307 308 309 310 311 312 313 314
    );
    fprintf f "  end.\n\n"

  let write_gotos f =
    fprintf f "Definition goto_table (state:state) (nt:nonterminal) :=\n";
    fprintf f "  match state, nt return option { s:noninitstate | NT nt = last_symb_of_non_init_state s } with\n";
    let has_none = ref false in
    lr1_iter_nonfinal (fun node ->
      Nonterminal.iterx (fun nt ->
        try
          let target = SymbolMap.find (Symbol.N nt) (Lr1.transitions node) in
315
          fprintf f "  | %s, %s => " (print_st node) (print_nterm nt);
316 317
          if is_final_state target then fprintf f "None"
          else fprintf f "Some (exist _ %s (eq_refl _))\n" (print_nis target)
318
        with Not_found -> has_none := true));
319
    if !has_none then fprintf f "  | _, _ => None\n";
320 321 322 323 324 325 326
    fprintf f "  end.\n\n"

  let write_last_symb f =
    fprintf f "Definition last_symb_of_non_init_state (noninitstate:noninitstate) : symbol :=\n";
    fprintf f "  match noninitstate with\n";
    lr1_iterx_nonfinal (fun node ->
      match Lr1.incoming_symbol node with
327
        | Some s -> fprintf f "  | %s => %s\n" (print_nis node) (print_symbol s)
328 329 330 331 332 333 334 335 336 337 338 339
        | None -> assert false);
    fprintf f "  end.\n\n"

  let write_past_symb f =
    fprintf f "Definition past_symb_of_non_init_state (noninitstate:noninitstate) : list symbol :=\n";
    fprintf f "  match noninitstate with\n";
    lr1_iterx_nonfinal (fun node ->
      let s =
        String.concat "; " (List.tl
          (Invariant.fold (fun l _ symb _ -> print_symbol symb::l)
             [] (Invariant.stack node)))
      in
340
      fprintf f "  | %s => [%s]\n" (print_nis node) s);
341 342 343
    fprintf f "  end.\n";
    fprintf f "Extract Constant past_symb_of_non_init_state => \"fun _ -> assert false\".\n\n"

344
  module NodeSetMap = Map.Make(Lr1.NodeSet)
345
  let write_past_states f =
346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368
    let get_stateset_id =
      let memo = ref NodeSetMap.empty in
      let next_id = ref 1 in
      fun stateset ->
        try NodeSetMap.find stateset !memo
        with
        | Not_found ->
           let id = sprintf "state_set_%d" !next_id in
           memo := NodeSetMap.add stateset id !memo;
           incr next_id;
           fprintf f "Definition %s (s:state) : bool :=\n" id;
           fprintf f "  match s with\n";
           fprintf f "  ";
           Lr1.NodeSet.iter (fun st -> fprintf f "| %s " (print_st st)) stateset;
           fprintf f "=> true\n";
           fprintf f "  | _ => false\n";
           fprintf f "  end.\n";
           fprintf f "Extract Inlined Constant %s => \"assert false\".\n\n" id;
           id
    in
    let b = Buffer.create 256 in
    bprintf b "Definition past_state_of_non_init_state (s:noninitstate) : list (state -> bool) :=\n";
    bprintf b "  match s with\n";
369 370
    lr1_iterx_nonfinal (fun node ->
      let s =
371 372 373
        String.concat "; "
          (Invariant.fold (fun accu _ _ states -> get_stateset_id states::accu)
            [] (Invariant.stack node))
374
      in
375 376 377
      bprintf b "  | %s => [ %s ]\n" (print_nis node) s);
    bprintf b "  end.\n";
    Buffer.output_buffer f b;
378 379
    fprintf f "Extract Constant past_state_of_non_init_state => \"fun _ -> assert false\".\n\n"

380
  module TerminalSetMap = Map.Make(TerminalSet)
381 382 383
  let write_items f =
    if not Settings.coq_no_complete then
      begin
384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407
        let get_lookaheadset_id =
          let memo = ref TerminalSetMap.empty in
          let next_id = ref 1 in
          fun lookaheadset ->
            let lookaheadset =
              if TerminalSet.mem Terminal.sharp lookaheadset then TerminalSet.universe
              else lookaheadset
            in
            try TerminalSetMap.find lookaheadset !memo
            with Not_found ->
              let id = sprintf "lookahead_set_%d" !next_id in
              memo := TerminalSetMap.add lookaheadset id !memo;
              incr next_id;
              fprintf f "Definition %s : list terminal :=\n  [" id;
              let first = ref true in
              TerminalSet.iter (fun lookahead ->
                if !first then first := false
                else fprintf f "; ";
                fprintf f "%s" (print_term lookahead)
              ) lookaheadset;
              fprintf f "].\nExtract Inlined Constant %s => \"assert false\".\n\n" id;
              id
        in
        let b = Buffer.create 256 in
408
        lr1_iter_nonfinal (fun node ->
409 410
          bprintf b "Definition items_of_state_%d : list item :=\n" (Lr1.number node);
          bprintf b "  [ ";
411 412 413
          let first = ref true in
          Item.Map.iter (fun item lookaheads ->
            let prod, pos = Item.export item in
414 415
            if not (Production.is_start prod) then begin
                if !first then first := false
416 417 418
                else bprintf b ";\n    ";
                bprintf b "{| prod_item := %s; dot_pos_item := %d; lookaheads_item := %s |}"
                        (print_prod prod) pos (get_lookaheadset_id lookaheads);
419
            end
420
          )  (Lr0.closure (Lr0.export (Lr1.state node)));
421 422
          bprintf b " ].\n";
          bprintf b "Extract Inlined Constant items_of_state_%d => \"assert false\".\n\n" (Lr1.number node)
423
        );
424
        Buffer.output_buffer f b;
425

426 427 428
        fprintf f "Definition items_of_state (s:state) : list item :=\n";
        fprintf f "  match s with\n";
        lr1_iter_nonfinal (fun node ->
429
          fprintf f "  | %s => items_of_state_%d\n" (print_st node) (Lr1.number node));
430
        fprintf f "  end.\n";
431 432 433 434 435 436 437
      end
    else
      fprintf f "Definition items_of_state (s:state): list item := [].\n";
    fprintf f "Extract Constant items_of_state => \"fun _ -> assert false\".\n\n"

  let write_automaton f =
    fprintf f "Module Aut <: Automaton.T.\n\n";
438
    fprintf f "Local Obligation Tactic := let x := fresh in intro x; case x; reflexivity.\n\n";
439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469
    fprintf f "Module Gram := Gram.\n";
    fprintf f "Module GramDefs := Gram.\n\n";
    write_nullable_first f;
    write_nis f;
    write_last_symb f;
    write_init f;
    fprintf f "Include Automaton.Types.\n\n";
    write_start_nt f;
    write_actions f;
    write_gotos f;
    write_past_symb f;
    write_past_states f;
    write_items f;
    fprintf f "End Aut.\n\n"

  let write_theorems f =
    fprintf f "Require Import Main.\n\n";

    fprintf f "Module Parser := Main.Make Aut.\n";

    fprintf f "Theorem safe:\n";
    fprintf f "  Parser.safe_validator () = true.\n";
    fprintf f "Proof eq_refl true<:Parser.safe_validator () = true.\n\n";

    if not Settings.coq_no_complete then
      begin
        fprintf f "Theorem complete:\n";
        fprintf f "  Parser.complete_validator () = true.\n";
        fprintf f "Proof eq_refl true<:Parser.complete_validator () = true.\n\n";
      end;

470
    Lr1.fold_entry (fun _prod node startnt _t () ->
471
          let funName = Nonterminal.print true startnt in
POTTIER Francois's avatar
POTTIER Francois committed
472
          fprintf f "Definition %s := Parser.parse safe Aut.%s.\n\n"
473 474 475 476
            funName (print_init node);

          fprintf f "Theorem %s_correct iterator buffer:\n" funName;
          fprintf f "  match %s iterator buffer with\n" funName;
477 478 479 480 481
          fprintf f "  | Parser.Inter.Parsed_pr sem buffer_new =>\n";
          fprintf f "    exists word,\n";
          fprintf f "      buffer = Parser.Inter.app_str word buffer_new /\\\n";
          fprintf f "      inhabited (Gram.parse_tree (%s) word sem)\n" (print_symbol (Symbol.N startnt));
          fprintf f "  | _ => True\n";
482 483 484 485 486
          fprintf f "  end.\n";
          fprintf f "Proof. apply Parser.parse_correct. Qed.\n\n";

          if not Settings.coq_no_complete then
            begin
487
              fprintf f "Theorem %s_complete (iterator:nat) word buffer_end (output:%s):\n"
488
                funName (print_type (Nonterminal.ocamltype startnt));
489 490
              fprintf f "  forall tree:Gram.parse_tree (%s) word output,\n" (print_symbol (Symbol.N startnt));
              fprintf f "  match %s iterator (Parser.Inter.app_str word buffer_end) with\n" funName;
491 492 493 494 495
              fprintf f "  | Parser.Inter.Fail_pr => False\n";
              fprintf f "  | Parser.Inter.Parsed_pr output_res buffer_end_res =>\n";
              fprintf f "    output_res = output /\\ buffer_end_res = buffer_end  /\\\n";
              fprintf f "    le (Gram.pt_size tree) iterator\n";
              fprintf f "  | Parser.Inter.Timeout_pr => lt iterator (Gram.pt_size tree)\n";
496 497
              fprintf f "  end.\n";
              fprintf f "Proof. apply Parser.parse_complete with (init:=Aut.%s); exact complete. Qed.\n\n" (print_init node);
498
            end
499
    ) ()
500 501 502 503 504 505 506 507 508 509 510 511 512

  let write_all f =
    if not Settings.coq_no_actions then
      List.iter (fun s -> fprintf f "%s\n\n" s.Stretch.stretch_content)
        Front.grammar.UnparameterizedSyntax.preludes;

    fprintf f "Require Import List.\n";
    fprintf f "Require Import Int31.\n";
    fprintf f "Require Import Syntax.\n";
    fprintf f "Require Import Tuples.\n";
    fprintf f "Require Import Alphabet.\n";
    fprintf f "Require Grammar.\n";
    fprintf f "Require Automaton.\n\n";
513
    fprintf f "Unset Elimination Schemes.\n\n";
514 515 516 517 518
    write_grammar f;
    write_automaton f;
    write_theorems f;

    if not Settings.coq_no_actions then
519
      List.iter (fun stretch -> fprintf f "\n\n%s" stretch.Stretch.stretch_raw_content)
520 521
        Front.grammar.UnparameterizedSyntax.postludes
end