From 3fdb9a141e9c5b55a40a6d3eadb7f44dad7832e3 Mon Sep 17 00:00:00 2001
From: Gabriel Scherer <gabriel.scherer@gmail.com>
Date: Sat, 11 Feb 2023 17:22:37 +0100
Subject: [PATCH] support 'let rec .. in ..' in the ML parser and AST

---
 client/FPrinter.ml                           |  4 +-
 client/Infer.ml                              | 16 ++++++--
 client/Infer.mli                             |  1 +
 client/ML.ml                                 |  8 ++--
 client/MLLexer.mll                           |  1 +
 client/MLParser.mly                          | 10 ++++-
 client/MLPrinter.ml                          |  6 +--
 client/P.ml                                  |  4 +-
 client/Printer.ml                            |  9 ++--
 client/test/CheckML.ml                       |  2 +-
 client/test/RandomML.ml                      | 15 ++++++-
 client/test/TestML.ml                        | 24 +++++++----
 client/test/TestMLRandom.ml                  |  8 ++--
 client/test/suite.t/letrec-list-length.midml | 12 ++++++
 client/test/suite.t/letrec-loop.midml        |  4 ++
 client/test/suite.t/nat-annot.midml          |  2 +-
 client/test/suite.t/nat.midml                |  1 +
 client/test/suite.t/run.t                    | 43 +++++++++++++++++++-
 18 files changed, 136 insertions(+), 34 deletions(-)
 create mode 100644 client/test/suite.t/letrec-list-length.midml
 create mode 100644 client/test/suite.t/letrec-loop.midml
 create mode 100644 client/test/suite.t/nat.midml

diff --git a/client/FPrinter.ml b/client/FPrinter.ml
index c48ede5..630a786 100644
--- a/client/FPrinter.ml
+++ b/client/FPrinter.ml
@@ -61,7 +61,7 @@ let rec translate_term env (t : F.nominal_term) : P.term =
   | F.App (_, t1, t2) ->
       P.App (self env t1, self env t2)
   | F.Let (_, x, t, u) ->
-      P.Let (P.PVar x, self env t, self env u)
+      P.Let (P.Non_recursive, P.PVar x, self env t, self env u)
   | F.TyAbs (_, x, t) ->
       let alpha = new_tyvar () in
       P.TyAbs (alpha, self ((x, alpha) :: env) t)
@@ -73,7 +73,7 @@ let rec translate_term env (t : F.nominal_term) : P.term =
       P.Proj (i, self env t)
   | F.LetProd (_, xs, t, u) ->
       let pat = P.PTuple (List.map (fun x -> P.PVar x) xs) in
-      P.Let (pat , self env t, self env u)
+      P.Let (P.Non_recursive, pat, self env t, self env u)
   | F.Variant (_, lbl, (tid, tys) , t) ->
       P.Variant (
           lbl,
diff --git a/client/Infer.ml b/client/Infer.ml
index 6851e1d..e632ca7 100644
--- a/client/Infer.ml
+++ b/client/Infer.ml
@@ -104,6 +104,7 @@ type error =
   | Cycle of F.nominal_type
   | VariableConflict of ML.tevar
   | VariableScopeEscape
+  | Unsupported of string
 
 exception Error of Utils.loc * error
 
@@ -278,7 +279,7 @@ let convert env params ty =
 let get_loc t =
   match t with
   | ML.Var (pos, _) | ML.Hole (pos, _) | ML.Abs (pos, _, _)
-  | ML.App (pos, _, _) | ML.Let (pos, _, _, _) | ML.Annot (pos, _, _)
+  | ML.App (pos, _, _) | ML.Let (pos, _, _, _, _) | ML.Annot (pos, _, _)
   | ML.Tuple (pos, _) | ML.LetProd (pos, _, _, _)
   | ML.Variant (pos, _, _) | ML.Match (pos, _, _)
     -> pos
@@ -352,8 +353,13 @@ let hastype (typedecl_env : ML.datatype_env) (t : ML.term) (w : variable) : F.no
       and+ t2' = hastype t2 v
       in F.App (loc, t1', t2')
 
-    (* Generalization. *)
-    | ML.Let (loc, x, t, u) ->
+      (* Generalization. *)
+    | ML.Let (loc, rec_, x, t, u) ->
+
+        begin match rec_ with
+        | ML.Non_recursive -> ()
+        | ML.Recursive -> raise (Error (loc, Unsupported "let rec"))
+        end;
 
       (* Construct a ``let'' constraint. *)
       let+ (a, (b, _), t', u') = let1 (X.Var x) (hastype t) (hastype u w) in
@@ -700,5 +706,9 @@ let emit_error loc (error : error) =
       Printf.printf
         "Scope error: variable %s is already bound in this pattern."
         x
+  | Unsupported (feature) ->
+      Printf.printf
+        "Type inference does not yet support %S."
+        feature
   end;
   flush stdout
diff --git a/client/Infer.mli b/client/Infer.mli
index 7cb977e..064ace3 100644
--- a/client/Infer.mli
+++ b/client/Infer.mli
@@ -4,6 +4,7 @@ type error =
   | Cycle of F.nominal_type
   | VariableConflict of ML.tevar
   | VariableScopeEscape
+  | Unsupported of string
 
 exception Error of Utils.loc * error
 
diff --git a/client/ML.ml b/client/ML.ml
index f91cf3c..5451119 100644
--- a/client/ML.ml
+++ b/client/ML.ml
@@ -24,18 +24,20 @@ type typ =
  [@@deriving compare]
 
 type rigidity = Flexible | Rigid
-
-let compare_rigidity r1 r2 = compare r1 r2
+  [@@deriving compare]
 
 type type_annotation = rigidity * tyvar list * typ [@@deriving compare]
   (* some <flexible vars> . <ty> *)
 
+type recursive_status = Recursive | Non_recursive
+  [@@deriving compare]
+
 type term =
   | Var of loc * tevar
   | Hole of loc * term list
   | Abs of loc * tevar * term
   | App of loc * term * term
-  | Let of loc * tevar * term * term
+  | Let of loc * recursive_status * tevar * term * term
   | Annot of loc * term * type_annotation
   | Tuple of loc * term list
   | LetProd of loc * tevar list * term * term
diff --git a/client/MLLexer.mll b/client/MLLexer.mll
index 1dd9f6b..03f711d 100644
--- a/client/MLLexer.mll
+++ b/client/MLLexer.mll
@@ -6,6 +6,7 @@
 
   let keywords = [
       "let", LET;
+      "rec", REC;
       "in", IN;
       "fun", FUN;
       "let", LET;
diff --git a/client/MLParser.mly b/client/MLParser.mly
index 13cda8c..b87017d 100644
--- a/client/MLParser.mly
+++ b/client/MLParser.mly
@@ -9,6 +9,7 @@
 %token FUN
 %token ARROW "->"
 %token LET IN
+%token REC "rec"
 %token EQ "="
 
 %token LPAR "("
@@ -72,7 +73,9 @@ let term_abs :=
       List.fold_right (fun x t -> Abs (Some $loc, x, t)) xs t
     }
 
-  | (x, t1, t2) = letin(tevar) ;          { Let (Some $loc, x, t1, t2) }
+  | (x, t1, t2) = letin(tevar) ;          { Let (Some $loc, Non_recursive, x, t1, t2) }
+
+  | (x, t1, t2) = letrecin(tevar) ;       { Let (Some $loc, Recursive, x, t1, t2) }
 
   | (xs, t1, t2) = letin(tuple(tevar)) ;  { LetProd (Some $loc, xs, t1, t2) }
 
@@ -228,5 +231,10 @@ let letin (X) :=
       t1 = term ; IN ;
       t2 = term_abs ;                     { (x, t1, t2) }
 
+let letrecin (X) :=
+  | LET ; REC; x = X ; EQ ;
+      t1 = term ; IN ;
+      t2 = term_abs ;                     { (x, t1, t2) }
+
 let pipe_separated_list (X) :=
   | option (PIPE) ; ~ = separated_list (PIPE, X) ; <>
diff --git a/client/MLPrinter.ml b/client/MLPrinter.ml
index 0cefd3c..cf3e310 100644
--- a/client/MLPrinter.ml
+++ b/client/MLPrinter.ml
@@ -37,15 +37,15 @@ let rec translate_term (t : ML.term) : P.term =
       P.Abs (P.PVar x, self t)
   | ML.App (_, t1, t2) ->
       P.App (self t1, self t2)
-  | ML.Let (_, x, t, u) ->
-      P.Let (P.PVar x, self t, self u)
+  | ML.Let (_, r, x, t, u) ->
+      P.Let (r, P.PVar x, self t, self u)
   | ML.Annot (_, t, tyannot) ->
       P.Annot (self t, translate_annot tyannot)
   | ML.Tuple (_, ts) ->
       P.Tuple (List.map translate_term ts)
   | ML.LetProd (_, xs, t, u) ->
       let pat = P.PTuple (List.map (fun x -> P.PVar x) xs) in
-      P.Let (pat, self t, self u)
+      P.Let (P.Non_recursive, pat, self t, self u)
   | ML.Variant (_, lbl, t) ->
       P.Variant (lbl, None, Option.map self t)
   | ML.Match (_, t, brs) ->
diff --git a/client/P.ml b/client/P.ml
index 495c6fe..fe1d506 100644
--- a/client/P.ml
+++ b/client/P.ml
@@ -15,6 +15,8 @@ and datatype = Datatype.tyconstr_id * typ list
 
 type rigidity = Flexible | Rigid
 
+type recursive_status = ML.recursive_status = Recursive | Non_recursive
+
 type type_annotation = rigidity * tyvar list * typ
   (* some <flexible vars> . <ty> *)
 
@@ -25,7 +27,7 @@ type term =
   | Hole of typ option * term list
   | Abs of pattern * term
   | App of term * term
-  | Let of pattern * term * term
+  | Let of recursive_status * pattern * term * term
   | Annot of term * type_annotation
   | TyAbs of tyvar * term
   | TyApp of term * typ
diff --git a/client/Printer.ml b/client/Printer.ml
index c05446e..b74de8d 100644
--- a/client/Printer.ml
+++ b/client/Printer.ml
@@ -106,8 +106,11 @@ let print_tuple print_elem elems =
           separate_map (comma ^^ break 1) print_elem elems in
     surround 2 0 lparen contents rparen
 
-let print_let_in lhs rhs body =
+let print_let_in rec_ lhs rhs body =
   string "let"
+  ^^ (match rec_ with
+      | Recursive -> space ^^ string "rec"
+      | Non_recursive -> empty)
   ^^ surround 2 1 empty lhs empty
   ^^ string "="
   ^^ surround 2 1 empty rhs empty
@@ -163,8 +166,8 @@ and print_term_abs t =
   | TyAbs _ ->
       let (xs, t) = flatten_tyabs t in
       print_nary_abstraction "FUN" print_tyvar xs (print_term_abs t)
-  | Let (p, t1, t2) ->
-      print_let_in
+  | Let (rec_, p, t1, t2) ->
+      print_let_in rec_
         (print_pattern p)
         (print_term t1)
         (print_term_abs t2)
diff --git a/client/test/CheckML.ml b/client/test/CheckML.ml
index 33e9e98..039f76b 100644
--- a/client/test/CheckML.ml
+++ b/client/test/CheckML.ml
@@ -48,7 +48,7 @@ let letify xts =
     | [(_loc, _last_var, t)] ->
         t
     | (loc, x, t) :: xts ->
-        ML.Let (loc, x, t, aux xts)
+        ML.Let (loc, ML.Non_recursive, x, t, aux xts)
   in
   aux xts
 
diff --git a/client/test/RandomML.ml b/client/test/RandomML.ml
index e320cf8..c1f4f00 100644
--- a/client/test/RandomML.ml
+++ b/client/test/RandomML.ml
@@ -34,8 +34,19 @@ let app self k n =
 
 let let_ self k n =
   let* n1, n2 = split (n - 1) in
-  let+ x, t1, t2 = triple (bind k) (self (k, n1)) (self (k + 1, n2))
-  in ML.Let (None, x, t1, t2)
+  let* rec_ =
+    frequency [
+      3, pure ML.Non_recursive;
+      (* we disable 'let rec' generation for tnow,
+         as the type-checker does not suport it. *)
+      (* 1, pure ML.Recursive; *)
+    ] in
+  let+ x, t1, t2 =
+    let inner_k = match rec_ with
+      | ML.Non_recursive -> k
+      | ML.Recursive -> k + 1 in
+    triple (bind k) (self (inner_k, n1)) (self (k + 1, n2))
+  in ML.Let (None, rec_, x, t1, t2)
 
 let pair self k n =
   let* n1, n2 = split (n - 1) in
diff --git a/client/test/TestML.ml b/client/test/TestML.ml
index c55da53..dca063b 100644
--- a/client/test/TestML.ml
+++ b/client/test/TestML.ml
@@ -28,16 +28,20 @@ let k =
   ML.Abs (None, "x", ML.Abs (None, "y", x))
 
 let genid =
-  ML.Let (None, "x", id, x)
+  ML.Let (None, ML.Non_recursive,
+          "x", id, x)
 
 let genidid =
-  ML.Let (None, "x", id, ML.App (None, x, x))
+  ML.Let (None, ML.Non_recursive,
+          "x", id, ML.App (None, x, x))
 
 let genkidid =
-  ML.Let (None, "x", ML.App (None, k, id), ML.App (None, x, id))
+  ML.Let (None, ML.Non_recursive,
+          "x", ML.App (None, k, id), ML.App (None, x, id))
 
 let genkidid2 =
-  ML.Let (None, "x", ML.App (None, ML.App (None, k, id), id), x)
+  ML.Let (None, ML.Non_recursive,
+          "x", ML.App (None, ML.App (None, k, id), id), x)
 
 (* unused *)
 let _app_pair = (* ill-typed *)
@@ -48,8 +52,9 @@ let unit =
 
 (* "let x1 = (...[], ...[]) in ...[] x1" *)
 let regression1 =
-  ML.Let (None, "x1", ML.Tuple (None, [ ML.Hole (None, []) ;
-                                                  ML.Hole (None, []) ]),
+  ML.Let (None, ML.Non_recursive,
+          "x1", ML.Tuple (None, [ ML.Hole (None, []) ;
+                                       ML.Hole (None, []) ]),
           ML.App (None, ML.Hole (None, []), ML.Var (None, "x1")))
 
 (* "let f = fun x ->
@@ -59,11 +64,11 @@ let regression1 =
             fun x -> fun y -> f" *)
 let regression2 =
   ML.(
-    Let (None,
+    Let (None, Non_recursive,
       "f",
       Abs (None,
         "x",
-        Let (None,
+        Let (None, Non_recursive,
           "g",
           Abs (None,
             "y",
@@ -351,7 +356,8 @@ let test_abs_match_with () =
   test_ok "fun x -> match () with () -> () end" abs_match_with
 
 let test_let () =
-  test_ok "let y = fun x -> x in ()" (ML.Let(None, "y", id, unit))
+  test_ok "let y = fun x -> x in ()"
+    (ML.Let(None, ML.Non_recursive, "y", id, unit))
 
 let test_let_prod_singleton () =
   test_ok "let (y,) = (fun x -> x,) in ()"
diff --git a/client/test/TestMLRandom.ml b/client/test/TestMLRandom.ml
index 33081fb..95e193a 100644
--- a/client/test/TestMLRandom.ml
+++ b/client/test/TestMLRandom.ml
@@ -27,8 +27,8 @@ module Shrinker = struct
       ML.Abs (pos, y, subst_under [y] t x u)
     | ML.App (pos, t1, t2) ->
       ML.App (pos, subst t1 x u, subst t2 x u)
-    | ML.Let (pos, y, t1, t2) ->
-      ML.Let (pos, y, subst t1 x u, subst_under [y] t2 x u)
+    | ML.Let (pos, r, y, t1, t2) ->
+      ML.Let (pos, r, y, subst t1 x u, subst_under [y] t2 x u)
     | ML.Tuple (pos, ts) ->
       ML.Tuple (pos, List.map (fun t -> subst t x u) ts)
     | ML.Hole (pos, ts) ->
@@ -119,12 +119,12 @@ module Shrinker = struct
          in ML.App (pos, t1',t2')
        )
 
-      | ML.Let (pos, x, t, u) ->
+      | ML.Let (pos, r, x, t, u) ->
         subterms [t; remove_variable u x]
         <+> (
          let++ t' = t, shrink_term t
          and++ u' = u, shrink_term u
-         in ML.Let (pos, x, t', u')
+         in ML.Let (pos, r, x, t', u')
        )
 
       | ML.LetProd (pos, xs, t, u) ->
diff --git a/client/test/suite.t/letrec-list-length.midml b/client/test/suite.t/letrec-list-length.midml
new file mode 100644
index 0000000..0ec6498
--- /dev/null
+++ b/client/test/suite.t/letrec-list-length.midml
@@ -0,0 +1,12 @@
+#use nat.midml
+#use list.midml
+
+(* toplevel recursion is currently not supported,
+   so we wrap a "let rec .. in .." *)
+let length =
+  let rec length = fun xs ->
+    match xs with
+    | Nil -> Zero
+    | Cons (_, rest) -> Succ (length rest)
+    end
+  in length
diff --git a/client/test/suite.t/letrec-loop.midml b/client/test/suite.t/letrec-loop.midml
new file mode 100644
index 0000000..d9d62a8
--- /dev/null
+++ b/client/test/suite.t/letrec-loop.midml
@@ -0,0 +1,4 @@
+(* expected type: 'a 'b. 'a -> 'b *)
+let loop =
+  let rec loop = fun x -> loop x in
+  loop
diff --git a/client/test/suite.t/nat-annot.midml b/client/test/suite.t/nat-annot.midml
index be47836..a041cd6 100644
--- a/client/test/suite.t/nat-annot.midml
+++ b/client/test/suite.t/nat-annot.midml
@@ -1,2 +1,2 @@
-type nat = | Zero | Succ of nat
+#use nat.midml
 let f = (fun x -> x : nat -> nat)
\ No newline at end of file
diff --git a/client/test/suite.t/nat.midml b/client/test/suite.t/nat.midml
new file mode 100644
index 0000000..7351495
--- /dev/null
+++ b/client/test/suite.t/nat.midml
@@ -0,0 +1 @@
+type nat = | Zero | Succ of nat
diff --git a/client/test/suite.t/run.t b/client/test/suite.t/run.t
index 8b8ef78..653f658 100644
--- a/client/test/suite.t/run.t
+++ b/client/test/suite.t/run.t
@@ -253,7 +253,7 @@ Variable scope escape
 # Annotations
 
   $ cat nat-annot.midml
-  type nat = | Zero | Succ of nat
+  #use nat.midml
   let f = (fun x -> x : nat -> nat)
 
   $ midml nat-annot.midml
@@ -302,7 +302,48 @@ Variable scope escape
   Converting the System F term to de Bruijn style...
   Type-checking the System F term...
 
+# Recursion
 
+For now there is only support in the parser, so the examples below
+parse correctly but fail to type-check.
+
+Mutual recursion is not yet supported.
+
+Polymorphic recursion is not yet supported.
+
+A simple let-rec function: List.length
+
+  $ cat letrec-list-length.midml
+  #use nat.midml
+  #use list.midml
+  
+  (* toplevel recursion is currently not supported,
+     so we wrap a "let rec .. in .." *)
+  let length =
+    let rec length = fun xs ->
+      match xs with
+      | Nil -> Zero
+      | Cons (_, rest) -> Succ (length rest)
+      end
+    in length
+
+  $ midml letrec-list-length.midml
+  Type inference and translation to System F...
+  File "test", line 7, characters 2-127:
+  Type inference does not yet support "let rec".
+
+Using recursion to define (loop : 'a -> 'b)
+
+  $ cat letrec-loop.midml
+  (* expected type: 'a 'b. 'a -> 'b *)
+  let loop =
+    let rec loop = fun x -> loop x in
+    loop
+
+  $ midml letrec-loop.midml
+  Type inference and translation to System F...
+  File "test", line 3, characters 2-42:
+  Type inference does not yet support "let rec".
 
 # Rigid, flexible variables.
 
-- 
GitLab