Attention une mise à jour du service Gitlab va être effectuée le mardi 30 novembre entre 17h30 et 18h00. Cette mise à jour va générer une interruption du service dont nous ne maîtrisons pas complètement la durée mais qui ne devrait pas excéder quelques minutes. Cette mise à jour intermédiaire en version 14.0.12 nous permettra de rapidement pouvoir mettre à votre disposition une version plus récente.

Commit 202cd51d authored by BERNIER Fabien's avatar BERNIER Fabien
Browse files

[+] parity check for hate speech example

parent 116f4fcb
This diff is collapsed.
This diff is collapsed.
......@@ -6,6 +6,7 @@ from nltk import pos_tag
from sklearn.metrics import accuracy_score
from sklearn.model_selection import train_test_split
from sklearn.ensemble import RandomForestClassifier
from sklearn.linear_model import LogisticRegression
from sklearn.pipeline import make_pipeline
from sklearn.feature_extraction.text import TfidfVectorizer
from time import time
......@@ -79,21 +80,49 @@ vectorizer = TfidfVectorizer(
max_df=0.75
)
rf = RandomForestClassifier(n_estimators=500)
lr = LogisticRegression(class_weight='balanced')
# training the model
model = make_pipeline(vectorizer, rf)
model = make_pipeline(vectorizer, lr)
model.fit(X_train, y_train)
# evaluating the model
pred = model.predict(X_test)
print("Accuracy:", accuracy_score(y_test, pred))
print(class_names[model.predict(["piece of cake", "piece of shit"])])
# ~~~~~~~~~~~~~~~~~~~~~~~~~~ checking parity ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
def get_dataset_predictions(dataset_filename) :
variant = pd.read_csv(dataset_filename)
tweets = variant["tweet"].to_numpy()
for i in range(tweets.shape[0]): tweets[i] = tweets[i][1:-1]
groups = variant["group"]
return model.predict(tweets), tweets, groups
def check_parity(preds, tweets, groups, word=None):
if word is not None :
I = list(map(lambda t: word in t, tweets))
groups = groups[I]
preds = preds[I]
preds_aa = preds[groups == "AA"]
preds_sa = preds[groups == "SA"]
hate_rate_aa = len(preds_aa[preds_aa == 1]) / len(preds_aa)
hate_rate_sa = len(preds_sa[preds_sa == 1]) / len(preds_sa)
print(f"[{word}] P(hate_speech | AA) = {hate_rate_aa}")
print(f"[{word}] P(hate_speech | SA) = {hate_rate_sa}")
preds, tweets, groups = get_dataset_predictions("datasets/english_variant.csv")
check_parity(preds, tweets, groups)
btch_preds, btch_tweets, btch_groups = get_dataset_predictions("datasets/english_variant_btch.csv")
check_parity(btch_preds, btch_tweets, btch_groups)
ngga_preds, ngga_tweets, ngga_groups = get_dataset_predictions("datasets/english_variant_ngga.csv")
check_parity(ngga_preds, ngga_tweets, ngga_groups)
# explaining the model
vocab = list(model[0].vocabulary_.keys())
fixout = FixOutText(X, y, vocab, to_drop=["black", "white", "bitch"], algo=model)
fixout = FixOutText(X, y, vocab, to_drop=["black", "white", "bitch"], algo=model, max_features=-1)
t0 = time()
actual_sensitive, is_fair_flag, ans_data, accuracy, threshold = fixout.is_fair()
print("took", time()-t0, "seconds")
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment