Commit 116f4fcb authored by BERNIER Fabien's avatar BERNIER Fabien
Browse files

[+] Exception handling in Lime Text

parent 12851916
...@@ -412,9 +412,19 @@ class LimeTextExplainer(object): ...@@ -412,9 +412,19 @@ class LimeTextExplainer(object):
split_expression=self.split_expression, split_expression=self.split_expression,
mask_string=self.mask_string)) mask_string=self.mask_string))
domain_mapper = TextDomainMapper(indexed_string) domain_mapper = TextDomainMapper(indexed_string)
data, yss, distances = self.__data_labels_distances( try:
indexed_string, classifier_fn, num_samples, data, yss, distances = self.__data_labels_distances(
distance_metric=distance_metric) indexed_string, classifier_fn, num_samples,
distance_metric=distance_metric)
except:
ret_exp = explanation.Explanation(domain_mapper=domain_mapper,
class_names=self.class_names,
random_state=self.random_state)
ret_exp.predict_proba = classifier_fn([text_instance])[0]
ret_exp.local_pred = [ret_exp.predict_proba[1]]
ret_exp.intercept[1] = ret_exp.predict_proba[1]
ret_exp.local_exp[1] = [(0,0)]
return ret_exp
if self.class_names is None: if self.class_names is None:
self.class_names = [str(x) for x in range(yss[0].shape[0])] self.class_names = [str(x) for x in range(yss[0].shape[0])]
ret_exp = explanation.Explanation(domain_mapper=domain_mapper, ret_exp = explanation.Explanation(domain_mapper=domain_mapper,
......
...@@ -5,9 +5,11 @@ from collections import Counter ...@@ -5,9 +5,11 @@ from collections import Counter
import pandas as pd import pandas as pd
import numpy as np import numpy as np
from lime import lime_tabular, submodular_pick from lime import submodular_pick
from lime.lime_text import LimeTextExplainer from lime.lime_text import LimeTextExplainer
clusters = 50
def features_contributions(predict_fn, train, class_names, sample_size, kernel_width=3): def features_contributions(predict_fn, train, class_names, sample_size, kernel_width=3):
...@@ -16,6 +18,9 @@ def features_contributions(predict_fn, train, class_names, sample_size, kernel_w ...@@ -16,6 +18,9 @@ def features_contributions(predict_fn, train, class_names, sample_size, kernel_w
sample_size = len(train) sample_size = len(train)
indexes = np.random.choice(range(sample_size), sample_size) indexes = np.random.choice(range(sample_size), sample_size)
explanations = [explainer.explain_instance(train[i], predict_fn, num_features=1000) for i in indexes] explanations = [explainer.explain_instance(train[i], predict_fn, num_features=1000) for i in indexes]
# sp_obj = submodular_pick.SubmodularPick(explainer, train, predict_fn, sample_size=sample_size,
# num_features=1000, clusters=clusters)
# explanations = sp_obj.sp_explanations
return explainer, explanations return explainer, explanations
...@@ -42,6 +47,9 @@ def fairness_valid_top(contributions, feature_names, sensitive_features, max_fea ...@@ -42,6 +47,9 @@ def fairness_valid_top(contributions, feature_names, sensitive_features, max_fea
actual_sensitive = [] actual_sensitive = []
ans_data = [] ans_data = []
sorted_dict = sorted(contributions.items(), key=lambda x: abs(x[1]), reverse=True) sorted_dict = sorted(contributions.items(), key=lambda x: abs(x[1]), reverse=True)
if max_features is None or max_features < 0 :
max_features = len(sorted_dict)
for i in range(max_features): for i in range(max_features):
key, value = sorted_dict[i] key, value = sorted_dict[i]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment