Commit 202cd51d authored by BERNIER Fabien's avatar BERNIER Fabien
Browse files

[+] parity check for hate speech example

parent 116f4fcb
This source diff could not be displayed because it is too large. You can view the blob instead.
This source diff could not be displayed because it is too large. You can view the blob instead.
......@@ -6,6 +6,7 @@ from nltk import pos_tag
from sklearn.metrics import accuracy_score
from sklearn.model_selection import train_test_split
from sklearn.ensemble import RandomForestClassifier
from sklearn.linear_model import LogisticRegression
from sklearn.pipeline import make_pipeline
from sklearn.feature_extraction.text import TfidfVectorizer
from time import time
......@@ -79,21 +80,49 @@ vectorizer = TfidfVectorizer(
max_df=0.75
)
rf = RandomForestClassifier(n_estimators=500)
lr = LogisticRegression(class_weight='balanced')
# training the model
model = make_pipeline(vectorizer, rf)
model = make_pipeline(vectorizer, lr)
model.fit(X_train, y_train)
# evaluating the model
pred = model.predict(X_test)
print("Accuracy:", accuracy_score(y_test, pred))
print(class_names[model.predict(["piece of cake", "piece of shit"])])
# ~~~~~~~~~~~~~~~~~~~~~~~~~~ checking parity ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
def get_dataset_predictions(dataset_filename) :
variant = pd.read_csv(dataset_filename)
tweets = variant["tweet"].to_numpy()
for i in range(tweets.shape[0]): tweets[i] = tweets[i][1:-1]
groups = variant["group"]
return model.predict(tweets), tweets, groups
def check_parity(preds, tweets, groups, word=None):
if word is not None :
I = list(map(lambda t: word in t, tweets))
groups = groups[I]
preds = preds[I]
preds_aa = preds[groups == "AA"]
preds_sa = preds[groups == "SA"]
hate_rate_aa = len(preds_aa[preds_aa == 1]) / len(preds_aa)
hate_rate_sa = len(preds_sa[preds_sa == 1]) / len(preds_sa)
print(f"[{word}] P(hate_speech | AA) = {hate_rate_aa}")
print(f"[{word}] P(hate_speech | SA) = {hate_rate_sa}")
preds, tweets, groups = get_dataset_predictions("datasets/english_variant.csv")
check_parity(preds, tweets, groups)
btch_preds, btch_tweets, btch_groups = get_dataset_predictions("datasets/english_variant_btch.csv")
check_parity(btch_preds, btch_tweets, btch_groups)
ngga_preds, ngga_tweets, ngga_groups = get_dataset_predictions("datasets/english_variant_ngga.csv")
check_parity(ngga_preds, ngga_tweets, ngga_groups)
# explaining the model
vocab = list(model[0].vocabulary_.keys())
fixout = FixOutText(X, y, vocab, to_drop=["black", "white", "bitch"], algo=model)
fixout = FixOutText(X, y, vocab, to_drop=["black", "white", "bitch"], algo=model, max_features=-1)
t0 = time()
actual_sensitive, is_fair_flag, ans_data, accuracy, threshold = fixout.is_fair()
print("took", time()-t0, "seconds")
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment