Commit d2ae8618 authored by BERNIER Fabien's avatar BERNIER Fabien
Browse files

[+] lime_text_global, an adaptation of lime_global

parent c61de2ac
"""
Implements LIME_Global. Verifies if sensitives features have high contributions.
"""
from collections import Counter
import pandas as pd
import numpy as np
import sys
from lime import lime_tabular, submodular_pick
from lime.lime_text import LimeTextExplainer
from scipy import stats
clusters=50
def features_contributions(predict_fn, train, class_names, sample_size, kernel_width=3):
# sys.stdout = None
explainer = LimeTextExplainer(class_names=class_names, kernel_width=kernel_width)
sp_obj = submodular_pick.SubmodularPick(explainer, train, predict_fn, sample_size=sample_size, num_features=1000, clusters=clusters)
return explainer, sp_obj
def fairness_eval(model, train, max_features, sensitive_features, feature_names, class_names, categorical_features, categorical_names, sample_size, threshold=None):
explainer, sp_obj = features_contributions(model.prob, train, class_names, sample_size)
contributions = Counter()
for i in sp_obj.V:
exp = sp_obj.explanations[i]
a1 = Counter(dict(exp.local_exp[1]))
contributions.update(a1)
if threshold != None and threshold > 0:
actual_sensitive, is_fair, df = fairness_valid_threshold(contributions, feature_names, sensitive_features, threshold)
else:
actual_sensitive, is_fair, df = fairness_valid_top(contributions, feature_names, sensitive_features, max_features)
return actual_sensitive, is_fair, df, explainer
def fairness_valid_top(contributions, feature_names, sensitive_features, max_features):
actual_sensitive = []
counter_top = 0
ans_data = []
sorted_dict = sorted(contributions.items(), key=lambda x: abs(x[1]), reverse=True)
for key,value in sorted_dict:
ans_data1 = [key,feature_names[key],value]
ans_data.append(ans_data1)
if key in sensitive_features:
actual_sensitive.append(key)
counter_top += 1
if counter_top >= max_features:
break
df = pd.DataFrame(ans_data, columns = ["Index", "Feature", "Contribution"])
return actual_sensitive, len(actual_sensitive) < 2, df
def fairness_valid_threshold(contributions, feature_names, sensitive_features, threshold):
actual_sensitive = []
ans_data = []
n_contributions = normalize(contributions)
sorted_dict = sorted(n_contributions.items(), key=lambda x: abs(x[1]), reverse=True)
for key,value in sorted_dict:
if abs(value) < threshold:
break
ans_data.append([key,feature_names[key],value])
for pair in ans_data:
key = pair[0]
if key in sensitive_features:
actual_sensitive.append(key)
df = pd.DataFrame(ans_data, columns = ["Index", "Feature", "Contribution"])
return actual_sensitive, len(actual_sensitive) < 2, df
def normalize(b):
a = b.copy()
values = [abs(x[1]) for x in a.items()]
# values = list(map(abs, a.items()))
minv = np.min(values)
maxv = np.max(values)
for key in a.keys():
v = a[key]
normalized = (abs(v) - minv) / (maxv - minv)
a[key] = normalized if v >= 0 else -normalized
return a
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment