lime_text_global.py 3.58 KB
Newer Older
1
2
3
4
5
"""
Implements LIME_Global. Verifies if sensitives features have high contributions.
"""
from collections import Counter
import pandas as pd
6
import numpy as np
7

8
from lime import submodular_pick
9
10
from lime.lime_text import LimeTextExplainer

11
12
clusters = 50

13

14
def features_contributions(predict_fn, train, class_names, sample_size, kernel_width=5):
15

16
    explainer = LimeTextExplainer(class_names=class_names, kernel_width=kernel_width)
17
18
19
    if sample_size > len(train):
        sample_size = len(train)
    indexes = np.random.choice(range(sample_size), sample_size)
20
    explanations = [explainer.explain_instance(train[i].lower(), predict_fn, num_features=1000) for i in indexes]
21
22
23
    # sp_obj = submodular_pick.SubmodularPick(explainer, train, predict_fn, sample_size=sample_size,
    #                                         num_features=1000, clusters=clusters)
    # explanations = sp_obj.sp_explanations
24

25
    return explainer, explanations
26
27


BERNIER Fabien's avatar
BERNIER Fabien committed
28
def fairness_eval(model, train, max_features, sensitive_features, feature_names, class_names, sample_size, threshold=None):
29
    
30
    explainer, explanations = features_contributions(model.prob, train, class_names, sample_size)
31
32
    
    contributions = Counter()
33
    for exp in explanations:
34
35
36
        vocab = exp.domain_mapper.indexed_string.inverse_vocab
        words_weights = {vocab[i]: weight for i, weight in exp.local_exp[1]}
        a1 = Counter(words_weights)
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
        contributions.update(a1)
    
    if threshold != None and threshold > 0:
        actual_sensitive, is_fair, df = fairness_valid_threshold(contributions, feature_names, sensitive_features, threshold)
    else:
        actual_sensitive, is_fair, df = fairness_valid_top(contributions, feature_names, sensitive_features, max_features)
    
    
    return actual_sensitive, is_fair, df, explainer

def fairness_valid_top(contributions, feature_names, sensitive_features, max_features):
    
    actual_sensitive = []
    ans_data = []
    sorted_dict = sorted(contributions.items(), key=lambda x: abs(x[1]), reverse=True)
52
53
54

    if max_features is None or max_features < 0 :
        max_features = len(sorted_dict)
55
    
BERNIER Fabien's avatar
BERNIER Fabien committed
56
    for i in range(max_features):
57
58
        feature, value = sorted_dict[i]
        ans_data.append([i, feature, value])
59
        
BERNIER Fabien's avatar
BERNIER Fabien committed
60
61
        if feature in sensitive_features:
            actual_sensitive.append(feature)
62
    
63
    df = pd.DataFrame(ans_data, columns=["Index", "Word", "Contribution"])
64
65
    return actual_sensitive, len(actual_sensitive) < 2, df

BERNIER Fabien's avatar
BERNIER Fabien committed
66

67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
def fairness_valid_threshold(contributions, feature_names, sensitive_features, threshold):
    
    actual_sensitive = []
    ans_data = []
    
    n_contributions = normalize(contributions)
    
    sorted_dict = sorted(n_contributions.items(), key=lambda x: abs(x[1]), reverse=True)
    
    for key,value in sorted_dict:
        if abs(value) < threshold:
            break
        ans_data.append([key,feature_names[key],value])
    
    for pair in ans_data:
        key = pair[0]
        if key in sensitive_features: 
            actual_sensitive.append(key)
    
    df = pd.DataFrame(ans_data, columns = ["Index", "Feature", "Contribution"])
    return actual_sensitive, len(actual_sensitive) < 2, df

def normalize(b):
    
    a = b.copy()
    
    values = [abs(x[1]) for x in a.items()]
    
#     values = list(map(abs, a.items()))
    minv = np.min(values)
    maxv = np.max(values)
    
    for key in a.keys():
        v = a[key]
        normalized = (abs(v) - minv) / (maxv - minv)   
        a[key] = normalized if v >= 0 else -normalized
    
    return a