Attention une mise à jour du service Gitlab va être effectuée le mardi 18 janvier (et non lundi 17 comme annoncé précédemment) entre 18h00 et 18h30. Cette mise à jour va générer une interruption du service dont nous ne maîtrisons pas complètement la durée mais qui ne devrait pas excéder quelques minutes.

Commit 202cd51d authored by BERNIER Fabien's avatar BERNIER Fabien
Browse files

[+] parity check for hate speech example

parent 116f4fcb
This diff is collapsed.
This diff is collapsed.
......@@ -6,6 +6,7 @@ from nltk import pos_tag
from sklearn.metrics import accuracy_score
from sklearn.model_selection import train_test_split
from sklearn.ensemble import RandomForestClassifier
from sklearn.linear_model import LogisticRegression
from sklearn.pipeline import make_pipeline
from sklearn.feature_extraction.text import TfidfVectorizer
from time import time
......@@ -79,21 +80,49 @@ vectorizer = TfidfVectorizer(
rf = RandomForestClassifier(n_estimators=500)
lr = LogisticRegression(class_weight='balanced')
# training the model
model = make_pipeline(vectorizer, rf)
model = make_pipeline(vectorizer, lr), y_train)
# evaluating the model
pred = model.predict(X_test)
print("Accuracy:", accuracy_score(y_test, pred))
print(class_names[model.predict(["piece of cake", "piece of shit"])])
# ~~~~~~~~~~~~~~~~~~~~~~~~~~ checking parity ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
def get_dataset_predictions(dataset_filename) :
variant = pd.read_csv(dataset_filename)
tweets = variant["tweet"].to_numpy()
for i in range(tweets.shape[0]): tweets[i] = tweets[i][1:-1]
groups = variant["group"]
return model.predict(tweets), tweets, groups
def check_parity(preds, tweets, groups, word=None):
if word is not None :
I = list(map(lambda t: word in t, tweets))
groups = groups[I]
preds = preds[I]
preds_aa = preds[groups == "AA"]
preds_sa = preds[groups == "SA"]
hate_rate_aa = len(preds_aa[preds_aa == 1]) / len(preds_aa)
hate_rate_sa = len(preds_sa[preds_sa == 1]) / len(preds_sa)
print(f"[{word}] P(hate_speech | AA) = {hate_rate_aa}")
print(f"[{word}] P(hate_speech | SA) = {hate_rate_sa}")
preds, tweets, groups = get_dataset_predictions("datasets/english_variant.csv")
check_parity(preds, tweets, groups)
btch_preds, btch_tweets, btch_groups = get_dataset_predictions("datasets/english_variant_btch.csv")
check_parity(btch_preds, btch_tweets, btch_groups)
ngga_preds, ngga_tweets, ngga_groups = get_dataset_predictions("datasets/english_variant_ngga.csv")
check_parity(ngga_preds, ngga_tweets, ngga_groups)
# explaining the model
vocab = list(model[0].vocabulary_.keys())
fixout = FixOutText(X, y, vocab, to_drop=["black", "white", "bitch"], algo=model)
fixout = FixOutText(X, y, vocab, to_drop=["black", "white", "bitch"], algo=model, max_features=-1)
t0 = time()
actual_sensitive, is_fair_flag, ans_data, accuracy, threshold = fixout.is_fair()
print("took", time()-t0, "seconds")
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment