Mentions légales du service

Skip to content
Snippets Groups Projects
Commit 9a2de4fa authored by Helw150's avatar Helw150
Browse files

Evaluate Flan

parent 17180990
No related branches found
No related tags found
No related merge requests found
import argparse
import os
import torch
import numpy as np
import pandas as pd
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
import time
choices = ["A", "B", "C", "D"]
def format_subject(subject):
l = subject.split("_")
s = ""
for entry in l:
s += " " + entry
return s
def format_example(df, idx, include_answer=True):
prompt = df.iloc[idx, 0]
k = df.shape[1] - 2
for j in range(k):
prompt += "\n{}. {}".format(choices[j], df.iloc[idx, j + 1])
prompt += "\nAnswer:"
if include_answer:
prompt += " {}\n\n".format(df.iloc[idx, k + 1])
return prompt
def gen_prompt(train_df, subject, k=-1):
prompt = "The following are multiple choice questions (with answers) about {}.\n\n".format(
format_subject(subject)
)
if k == -1:
k = train_df.shape[0]
for i in range(k):
prompt += format_example(train_df, i)
return prompt
def eval(args, subject, model, tokenizer, dev_df, test_df):
cors = []
all_probs = []
answers = choices[: test_df.shape[1] - 2]
for i in range(test_df.shape[0]):
# get prompt and make sure it fits
k = args.ntrain
prompt_end = format_example(test_df, i, include_answer=False)
train_prompt = gen_prompt(dev_df, subject, k)
prompt = train_prompt + prompt_end
input_ids = tokenizer(prompt, return_tensors="pt").input_ids.cuda()
while input_ids.shape[-1] > 2048:
k -= 1
train_prompt = gen_prompt(dev_df, subject, k)
prompt = train_prompt + prompt_end
input_ids = tokenizer(prompt, return_tensors="pt").input_ids.cuda()
print(input_ids.shape[-1])
label = test_df.iloc[i, test_df.shape[1] - 1]
decoder_input_ids = tokenizer("", return_tensors="pt").input_ids.cuda()
decoder_input_ids = model._shift_right(decoder_input_ids)
logits = model(
input_ids=input_ids, decoder_input_ids=decoder_input_ids
).logits.flatten()
probs = (
torch.nn.functional.softmax(
torch.tensor(
[
logits[tokenizer("A").input_ids[0]],
logits[tokenizer("B").input_ids[0]],
logits[tokenizer("C").input_ids[0]],
logits[tokenizer("D").input_ids[0]],
]
),
dim=0,
)
.detach()
.cpu()
.numpy()
)
pred = {0: "A", 1: "B", 2: "C", 3: "D"}[np.argmax(probs)]
cor = pred == label
cors.append(cor)
all_probs.append(probs)
acc = np.mean(cors)
cors = np.array(cors)
all_probs = np.array(all_probs)
print("Average accuracy {:.3f} - {}".format(acc, subject))
return cors, acc, all_probs
def main(args):
model = AutoModelForSeq2SeqLM.from_pretrained(args.model)
tokenizer = AutoTokenizer.from_pretrained(args.model)
heads_per_gpu = len(model.encoder.block) // args.ngpu
device_map = {
gpu: list(
range(
0 + (gpu * heads_per_gpu),
(0 + (gpu * heads_per_gpu)) + heads_per_gpu,
)
)
for gpu in range(args.ngpu)
}
model.parallelize(device_map)
subjects = sorted(
[
f.split("_test.csv")[0]
for f in os.listdir(os.path.join(args.data_dir, "test"))
if "_test.csv" in f
]
)
if not os.path.exists(args.save_dir):
os.makedirs(args.save_dir)
if not os.path.exists(os.path.join(args.save_dir, "results_{}".format(args.model))):
os.makedirs(os.path.join(args.save_dir, "results_{}".format(args.model)))
all_cors = []
for subject in subjects:
dev_df = pd.read_csv(
os.path.join(args.data_dir, "dev", subject + "_dev.csv"), header=None
)[: args.ntrain]
test_df = pd.read_csv(
os.path.join(args.data_dir, "test", subject + "_test.csv"), header=None
)
cors, acc, probs = eval(args, subject, model, tokenizer, dev_df, test_df)
all_cors.append(cors)
test_df["{}_correct".format(args.model)] = cors
for j in range(probs.shape[1]):
choice = choices[j]
test_df["{}_choice{}_probs".format(args.model, choice)] = probs[:, j]
test_df.to_csv(
os.path.join(
args.save_dir, "results_{}".format(args.model), "{}.csv".format(subject)
),
index=None,
)
weighted_acc = np.mean(np.concatenate(all_cors))
print("Average accuracy: {:.3f}".format(weighted_acc))
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--ntrain", "-k", type=int, default=5)
parser.add_argument("--ngpu", "-g", type=int, default=2)
parser.add_argument("--data_dir", "-d", type=str, default="data")
parser.add_argument("--save_dir", "-s", type=str, default="results")
parser.add_argument(
"--model",
"-m",
type=str,
default="google/flan-t5-small",
)
args = parser.parse_args()
main(args)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment