Mentions légales du service

Skip to content
Snippets Groups Projects

Various improvements to REST service

Merged Ian Roberts requested to merge elg-in-process into elg
1 unresolved thread
Files
4
+ 77
44
@@ -12,12 +12,12 @@ from flair.data import Sentence
from flair.models import SequenceTagger
import logging
import numpy.random as npr
import os
logger = logging.getLogger(__name__)
# model_path = 'model/best-model.pt'
transf_dir = 'data/transformation/'
transf_dir = os.path.join(os.path.dirname(__file__), 'data', 'transformation')
def read_ne_dict(data_dir_file):
with open(data_dir_file) as f:
@@ -31,6 +31,17 @@ def read_ne_dict(data_dir_file):
return token_cnt_probs
def get_named_entities(data_dir = transf_dir):
with open(os.path.join(data_dir, 'named_entity_to_idx.json')) as f:
nes_to_idxs = json.load(f)
ne_table_list = ['' for _ in range(len(nes_to_idxs))]
for ne, idx in nes_to_idxs.items():
ne_table_list[idx] = read_ne_dict(os.path.join(data_dir, ne + '.tsv'))
return nes_to_idxs, ne_table_list
def get_tagged_sentence(sent):
tagged_sentence = ''
@@ -356,20 +367,72 @@ def transform_private_tokens(sentence, ne_table_list, nes_to_idxs, i=0, p=1.0):
return tagged_string
def get_named_entities():
class TextTransformer(object):
def __init__(self, models, data_dir = transf_dir):
"""Create a TextTransformer object wrapping one or more models. The
models parameter is a list of strings, where each entry is either a
plain model name "model" or a pair "key=model" - "model" is the flair
model name, "key" is the parameter key we use to refer to it, if no
separate key is specified then the model name itself is used as the
key. For example ["ner", "fr=fr-ner"] would load the "ner" model under
the label "ner" and the "fr-ner" model under the label "fr". The first
model in the list is considered the default model, used when no model
key is provided to the transform method."""
self._models = { m[0]: SequenceTagger.load(m[-1]) for m in (n.split('=',1) for n in models) }
# First listed model is the default
self._models['_default'] = self._models[models[0].split('=',1)[0]]
self._nes_to_idxs, self._ne_table_list = get_named_entities(data_dir)
def identify_private_tokens(self, sents, model = '_default'):
"""Takes an iterable of sentences and returns a generator yielding the
tagged sentence and coNLL format tags for each sentence in turn."""
sentences = [Sentence(sent.strip()) for sent in sents]
self._models[model].predict(sentences)
for sent in sentences:
tagged_sent = sent.to_tagged_string()
tagged_words = tagged_sent.split()
coNLL_format_tags = []
for k, tagged_word in enumerate(tagged_words):
token = tagged_word
if token in ["<unk>", "[noise1]", "[noise2]", "umm", "ahh"]:
tag = 'O'
coNLL_format_tags.append([token, tag])
continue
if token.startswith('<') and token.endswith('>'):
continue
if k < len(tagged_words) - 1:
next_token = tagged_words[k + 1]
else:
next_token = ''
with open(transf_dir + 'named_entity_to_idx.json') as f:
nes_to_idxs = json.load(f)
if next_token.startswith('<') and next_token.endswith('>'):
tag = next_token[1:-1]
else:
tag = 'O'
if tag != 'O':
coNLL_format_tags.append([token, tag[2:]])
else:
coNLL_format_tags.append([token, tag])
ne_table_list = ['' for _ in range(len(nes_to_idxs))]
for ne, idx in nes_to_idxs.items():
ne_table_list[idx] = read_ne_dict(transf_dir + ne + '.tsv')
unzip = list(zip(*coNLL_format_tags))
tags = take_to_bio_format(list(unzip[1]))
coNLL_format_tags = list(zip(list(unzip[0]), tags))
return nes_to_idxs, ne_table_list
yield (tagged_sent, coNLL_format_tags)
def transform(self, sents, replace_type = "FULL", replace_prob = 1.0, model = '_default'):
"""Takes an iterable of sentences and returns a generator yielding the
transformed text for each sentence"""
for tagged_sent, tokens_tags in self.identify_private_tokens(sents, model):
yield transform_private_tokens(tokens_tags,
self._ne_table_list,
self._nes_to_idxs,
i=replace_type,
p=replace_prob)
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser(description='Transform texts removing sensitive words, named entities')
parser.add_argument("-l", "--log",
@@ -418,44 +481,14 @@ if __name__ == "__main__":
level=getattr(logging,
args.logLevel))
# Get named entities
nes_to_idxs, ne_table_list = get_named_entities()
text_transformer = TextTransformer(models=[model])
logger.info("File : {}\n".format(args.input))
model = SequenceTagger.load(model)
'''
for sent in args.input:
logger.info("Sentence: {}\n".format(sent))
tagged_sent, tokens_tags = identify_private_tokens(sent, model)
logger.info("Tagged sentence: {}\n".format(tagged_sent))
tagged_string = transform_private_tokens(tokens_tags,
ne_table_list,
nes_to_idxs,
i=args.replace_type)
for tagged_string in text_transformer.transform(args.input,
replace_type=args.replace_type,
replace_prob = args.replace_prob):
args.output.write(f"{tagged_string}\n")
'''
all_sent = []
for sent in args.input:
all_sent.append(sent)
all_tagged_sents, all_sents_tokens_tags = identify_private_tokens_in_sentences(all_sent, model)
for i in range(len(all_tagged_sents)):
tagged_sent = all_tagged_sents[i]
tokens_tags = all_sents_tokens_tags[i]
logger.info("Tagged sentence: {}\n".format(tagged_sent))
tagged_string = transform_private_tokens(tokens_tags,
ne_table_list,
nes_to_idxs,
i=args.replace_type, p=args.replace_prob)
args.output.write(f"{tagged_string}\n")
args.output.close()
Loading