diff --git a/evaluate_flan.py b/evaluate_flan.py
new file mode 100644
index 0000000000000000000000000000000000000000..c4a22f478f93b4b3144d72e61453eaefa4403377
--- /dev/null
+++ b/evaluate_flan.py
@@ -0,0 +1,171 @@
+import argparse
+import os
+import torch
+import numpy as np
+import pandas as pd
+from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
+import time
+
+choices = ["A", "B", "C", "D"]
+
+
+def format_subject(subject):
+    l = subject.split("_")
+    s = ""
+    for entry in l:
+        s += " " + entry
+    return s
+
+
+def format_example(df, idx, include_answer=True):
+    prompt = df.iloc[idx, 0]
+    k = df.shape[1] - 2
+    for j in range(k):
+        prompt += "\n{}. {}".format(choices[j], df.iloc[idx, j + 1])
+    prompt += "\nAnswer:"
+    if include_answer:
+        prompt += " {}\n\n".format(df.iloc[idx, k + 1])
+    return prompt
+
+
+def gen_prompt(train_df, subject, k=-1):
+    prompt = "The following are multiple choice questions (with answers) about {}.\n\n".format(
+        format_subject(subject)
+    )
+    if k == -1:
+        k = train_df.shape[0]
+    for i in range(k):
+        prompt += format_example(train_df, i)
+    return prompt
+
+
+def eval(args, subject, model, tokenizer, dev_df, test_df):
+    cors = []
+    all_probs = []
+    answers = choices[: test_df.shape[1] - 2]
+
+    for i in range(test_df.shape[0]):
+        # get prompt and make sure it fits
+        k = args.ntrain
+        prompt_end = format_example(test_df, i, include_answer=False)
+        train_prompt = gen_prompt(dev_df, subject, k)
+        prompt = train_prompt + prompt_end
+
+        input_ids = tokenizer(prompt, return_tensors="pt").input_ids.cuda()
+
+        while input_ids.shape[-1] > 2048:
+            k -= 1
+            train_prompt = gen_prompt(dev_df, subject, k)
+            prompt = train_prompt + prompt_end
+            input_ids = tokenizer(prompt, return_tensors="pt").input_ids.cuda()
+            print(input_ids.shape[-1])
+
+        label = test_df.iloc[i, test_df.shape[1] - 1]
+
+        decoder_input_ids = tokenizer("", return_tensors="pt").input_ids.cuda()
+        decoder_input_ids = model._shift_right(decoder_input_ids)
+        logits = model(
+            input_ids=input_ids, decoder_input_ids=decoder_input_ids
+        ).logits.flatten()
+
+        probs = (
+            torch.nn.functional.softmax(
+                torch.tensor(
+                    [
+                        logits[tokenizer("A").input_ids[0]],
+                        logits[tokenizer("B").input_ids[0]],
+                        logits[tokenizer("C").input_ids[0]],
+                        logits[tokenizer("D").input_ids[0]],
+                    ]
+                ),
+                dim=0,
+            )
+            .detach()
+            .cpu()
+            .numpy()
+        )
+        pred = {0: "A", 1: "B", 2: "C", 3: "D"}[np.argmax(probs)]
+
+        cor = pred == label
+        cors.append(cor)
+        all_probs.append(probs)
+
+    acc = np.mean(cors)
+    cors = np.array(cors)
+
+    all_probs = np.array(all_probs)
+    print("Average accuracy {:.3f} - {}".format(acc, subject))
+
+    return cors, acc, all_probs
+
+
+def main(args):
+
+    model = AutoModelForSeq2SeqLM.from_pretrained(args.model)
+    tokenizer = AutoTokenizer.from_pretrained(args.model)
+    heads_per_gpu = len(model.encoder.block) // args.ngpu
+    device_map = {
+        gpu: list(
+            range(
+                0 + (gpu * heads_per_gpu),
+                (0 + (gpu * heads_per_gpu)) + heads_per_gpu,
+            )
+        )
+        for gpu in range(args.ngpu)
+    }
+    model.parallelize(device_map)
+    subjects = sorted(
+        [
+            f.split("_test.csv")[0]
+            for f in os.listdir(os.path.join(args.data_dir, "test"))
+            if "_test.csv" in f
+        ]
+    )
+
+    if not os.path.exists(args.save_dir):
+        os.makedirs(args.save_dir)
+    if not os.path.exists(os.path.join(args.save_dir, "results_{}".format(args.model))):
+        os.makedirs(os.path.join(args.save_dir, "results_{}".format(args.model)))
+
+    all_cors = []
+
+    for subject in subjects:
+        dev_df = pd.read_csv(
+            os.path.join(args.data_dir, "dev", subject + "_dev.csv"), header=None
+        )[: args.ntrain]
+        test_df = pd.read_csv(
+            os.path.join(args.data_dir, "test", subject + "_test.csv"), header=None
+        )
+
+        cors, acc, probs = eval(args, subject, model, tokenizer, dev_df, test_df)
+        all_cors.append(cors)
+
+        test_df["{}_correct".format(args.model)] = cors
+        for j in range(probs.shape[1]):
+            choice = choices[j]
+            test_df["{}_choice{}_probs".format(args.model, choice)] = probs[:, j]
+        test_df.to_csv(
+            os.path.join(
+                args.save_dir, "results_{}".format(args.model), "{}.csv".format(subject)
+            ),
+            index=None,
+        )
+
+    weighted_acc = np.mean(np.concatenate(all_cors))
+    print("Average accuracy: {:.3f}".format(weighted_acc))
+
+
+if __name__ == "__main__":
+    parser = argparse.ArgumentParser()
+    parser.add_argument("--ntrain", "-k", type=int, default=5)
+    parser.add_argument("--ngpu", "-g", type=int, default=2)
+    parser.add_argument("--data_dir", "-d", type=str, default="data")
+    parser.add_argument("--save_dir", "-s", type=str, default="results")
+    parser.add_argument(
+        "--model",
+        "-m",
+        type=str,
+        default="google/flan-t5-small",
+    )
+    args = parser.parse_args()
+    main(args)