Mentions légales du service

Skip to content
Snippets Groups Projects
Commit 6642419a authored by Robin Tissot's avatar Robin Tissot
Browse files

Merge branch 'develop'

parents f22c1ee5 08e4e0e2
No related branches found
No related tags found
No related merge requests found
...@@ -17,7 +17,6 @@ from celery import shared_task ...@@ -17,7 +17,6 @@ from celery import shared_task
from celery.signals import before_task_publish, task_prerun, task_success, task_failure from celery.signals import before_task_publish, task_prerun, task_success, task_failure
from django_redis import get_redis_connection from django_redis import get_redis_connection
from easy_thumbnails.files import get_thumbnailer from easy_thumbnails.files import get_thumbnailer
from kraken.lib import default_specs
from kraken.lib import train as kraken_train from kraken.lib import train as kraken_train
...@@ -196,6 +195,7 @@ def segtrain(task, model_pk, part_pks, user_pk=None): ...@@ -196,6 +195,7 @@ def segtrain(task, model_pk, part_pks, user_pk=None):
threads=LOAD_THREADS, threads=LOAD_THREADS,
augment=True, augment=True,
resize='both', resize='both',
hyper_params={'epochs': 20},
load_hyper_parameters=True) load_hyper_parameters=True)
if not os.path.exists(os.path.split(modelpath)[0]): if not os.path.exists(os.path.split(modelpath)[0]):
...@@ -222,9 +222,15 @@ def segtrain(task, model_pk, part_pks, user_pk=None): ...@@ -222,9 +222,15 @@ def segtrain(task, model_pk, part_pks, user_pk=None):
}) })
trainer.run(_print_eval) trainer.run(_print_eval)
best_version = os.path.join(os.path.dirname(modelpath), best_version = os.path.join(os.path.dirname(modelpath),
f'version_{trainer.stopper.best_epoch}.mlmodel') f'version_{trainer.stopper.best_epoch}.mlmodel')
shutil.copy(best_version, modelpath)
try:
shutil.copy(best_version, modelpath)
except FileNotFoundError:
user.notify(_("Training didn't get better results than base model!"),
id="seg-no-gain-error", level='danger')
except Exception as e: except Exception as e:
send_event('document', document.pk, "training:error", { send_event('document', document.pk, "training:error", {
...@@ -240,7 +246,6 @@ def segtrain(task, model_pk, part_pks, user_pk=None): ...@@ -240,7 +246,6 @@ def segtrain(task, model_pk, part_pks, user_pk=None):
user.notify(_("Training finished!"), user.notify(_("Training finished!"),
id="training-success", id="training-success",
level='success') level='success')
user.notify(report)
finally: finally:
model.training = False model.training = False
model.save() model.save()
...@@ -361,8 +366,6 @@ def train_(qs, document, transcription, model=None, user=None): ...@@ -361,8 +366,6 @@ def train_(qs, document, transcription, model=None, user=None):
DEVICE = getattr(settings, 'KRAKEN_TRAINING_DEVICE', 'cpu') DEVICE = getattr(settings, 'KRAKEN_TRAINING_DEVICE', 'cpu')
LOAD_THREADS = getattr(settings, 'KRAKEN_TRAINING_LOAD_THREADS', 0) LOAD_THREADS = getattr(settings, 'KRAKEN_TRAINING_LOAD_THREADS', 0)
hyper_params = default_specs.RECOGNITION_HYPER_PARAMS.copy()
hyper_params['batch_size'] = 1
trainer = (kraken_train.KrakenTrainer trainer = (kraken_train.KrakenTrainer
.recognition_train_gen(device=DEVICE, .recognition_train_gen(device=DEVICE,
load=load, load=load,
...@@ -373,7 +376,7 @@ def train_(qs, document, transcription, model=None, user=None): ...@@ -373,7 +376,7 @@ def train_(qs, document, transcription, model=None, user=None):
resize='add', resize='add',
threads=LOAD_THREADS, threads=LOAD_THREADS,
augment=True, augment=True,
hyper_params=hyper_params, hyper_params={'batch_size': 1},
load_hyper_parameters=True)) load_hyper_parameters=True))
def _print_eval(epoch=0, accuracy=0, chars=0, error=0, val_metric=0): def _print_eval(epoch=0, accuracy=0, chars=0, error=0, val_metric=0):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment