Mentions légales du service

Skip to content
Snippets Groups Projects
Commit f32a430b authored by Benjamin Kiessling's avatar Benjamin Kiessling
Browse files

fixes for training

parent 5a89d1e5
Branches
Tags
2 merge requests!35Feature/load hyper,!28Fix/kraken 3
import os
import json
import logging
import numpy as np
......@@ -25,7 +26,6 @@ from torch.utils.data import DataLoader
from users.consumers import send_event
logger = logging.getLogger(__name__)
User = get_user_model()
redis_ = redis.Redis(host=settings.REDIS_HOST,
......@@ -158,26 +158,32 @@ def segtrain(task, model_pk, document_pk, part_pks, user_pk=None):
ground_truth = list(qs.prefetch_related('lines'))
np.random.shuffle(ground_truth)
partition = max(1, int(len(ground_truth) / 10))
training_data = [{'image': part.image.path, 'baselines': {'default': part.lines.values_list('baseline', flat=True)}} for part in qs[partition:]]
validation_data = [{'image': part.image.path, 'baselines': {'default': part.lines.values_list('baseline', flat=True)}} for part in qs[:partition]]
training_data = []
evaluation_data = []
for part in qs[partition:]:
training_data.append({'image': part.image.path, 'baselines': [{'script': 'default', 'baseline': bl} for bl in part.lines.values_list('baseline', flat=True)]})
for part in qs[:partition]:
evaluation_data.append({'image': part.image.path, 'baselines': [{'script': 'default', 'baseline': bl} for bl in part.lines.values_list('baseline', flat=True)]})
DEVICE = getattr(settings, 'KRAKEN_TRAINING_DEVICE', 'cpu')
trainer = KrakenTrainer.segmentation_train_gen(output=os.path.join(os.path.split(modelpath)[0], 'version'),
format_type=None,
device=DEVICE,
load=load,
training_data=training_data,
validation_data=validation_data,
augment=True)
trainer = kraken_train.KrakenTrainer.segmentation_train_gen(output=os.path.join(os.path.split(modelpath)[0], 'version'),
format_type=None,
device=DEVICE,
load=load,
training_data=training_data,
evaluation_data=evaluation_data,
augment=True,
threads=0)
if not os.path.exists(os.path.split(modelpath)[0]):
os.makedirs(os.path.split(modelpath)[0])
def _print_eval(epoch=0, precision=0, recall=0, f1=0,
mcc=0, val_metric=0):
def _print_eval(epoch=0, accuracy=0, mean_acc=0, mean_iu=0, freq_iu=0,
val_metric=0):
model.refresh_from_db()
model.training_epoch = epoch
model.training_accuracy = float(precision)
model.training_accuracy = float(val_metric)
# model.training_total = chars
# model.training_errors = error
new_version_filename = '%s/version_%d.mlmodel' % (os.path.split(upload_to)[0], epoch)
......@@ -188,13 +194,14 @@ def segtrain(task, model_pk, document_pk, part_pks, user_pk=None):
"id": model.pk,
'versions': model.versions,
'epoch': epoch,
'accuracy': float(precision)
'accuracy': float(val_metric)
# 'chars': chars,
# 'error': error
})
trainer.run(even_callback=_draw_progressbar)
nn.save_model(path=modelpath)
trainer.run(_print_eval)
best_version = os.path.join(os.path.dirname(modelpath), f'version_{trainer.stopper.best_epoch}.mlmodel')
shutil.copy(best_version, modelpath)
except Exception as e:
send_event('document', document.pk, "training:error", {
......@@ -310,13 +317,15 @@ def train_(qs, document, transcription, model=None, user=None):
temp_file_prefix = os.path.join(fulldir, 'version')
trainer = KrakenTrainer.recognition_train_gen(device=DEVICE,
load=load,
filename_prefix=temp_file_prefix,
training_data=training_data,
evaluation_data=evaluation_data,
resize='both',
augment=True)
trainer = kraken_train.KrakenTrainer.recognition_train_gen(device=DEVICE,
load=load,
output=temp_file_prefix,
format_type=None,
training_data=training_data,
evaluation_data=evaluation_data,
resize='both',
augment=True,
threads=0)
def _print_eval(epoch=0, accuracy=0, chars=0, error=0, val_metric=0):
model.refresh_from_db()
......@@ -336,8 +345,8 @@ def train_(qs, document, transcription, model=None, user=None):
'chars': chars,
'error': error})
trainer.run(_print_eval, _progress)
best_version = os.path.join(fulldir, 'version_{}.mlmodel'.format(trainer.stopper.best_epoch))
trainer.run(_print_eval)
best_version = os.path.join(fulldir, f'version_{trainer.stopper.best_epoch}.mlmodel')
shutil.copy(best_version, modelpath)
......@@ -403,7 +412,7 @@ def transcribe(instance_pk, model_pk=None, user_pk=None, text_direction=None, **
except DocumentPart.DoesNotExist:
logger.error('Trying to transcribe innexistant DocumentPart : %d', instance_pk)
return
if user_pk:
try:
user = User.objects.get(pk=user_pk)
......@@ -411,7 +420,7 @@ def transcribe(instance_pk, model_pk=None, user_pk=None, text_direction=None, **
user = None
else:
user = None
if model_pk:
try:
OcrModel = apps.get_model('core', 'OcrModel')
......@@ -421,7 +430,7 @@ def transcribe(instance_pk, model_pk=None, user_pk=None, text_direction=None, **
model = None
else:
model = None
try:
part.transcribe(model=model)
except Exception as e:
......@@ -450,7 +459,7 @@ def before_publish_state(sender=None, body=None, **kwargs):
return
instance_id = body[0][0]
data = json.loads(redis_.get('process-%d' % instance_id) or '{}')
try:
# protects against signal race condition
if (data[sender]['task_id'] == sender.request.id and
......@@ -458,7 +467,7 @@ def before_publish_state(sender=None, body=None, **kwargs):
return
except KeyError:
pass
data[sender] = {
"task_id": kwargs['headers']['id'],
"status": 'before_task_publish'
......@@ -476,7 +485,7 @@ def done_state(sender=None, body=None, **kwargs):
if not sender.name.startswith('core.tasks') or sender.name.endswith('train'):
return
instance_id = sender.request.args[0]
try:
data = json.loads(redis_.get('process-%d' % instance_id) or '{}')
except TypeError as e:
......@@ -484,7 +493,7 @@ def done_state(sender=None, body=None, **kwargs):
return
signal_name = kwargs['signal'].name
try:
# protects against signal race condition
if (data[sender.name]['task_id'] == sender.request.id and
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment