diff --git a/app/apps/core/tasks.py b/app/apps/core/tasks.py index 65d564ed5cdf7dff9e3414350725bda24781e185..e0b37852df101755f1c90a5450114252ade8c9dc 100644 --- a/app/apps/core/tasks.py +++ b/app/apps/core/tasks.py @@ -17,7 +17,6 @@ from celery import shared_task from celery.signals import before_task_publish, task_prerun, task_success, task_failure from django_redis import get_redis_connection from easy_thumbnails.files import get_thumbnailer -from kraken.lib import default_specs from kraken.lib import train as kraken_train @@ -196,6 +195,7 @@ def segtrain(task, model_pk, part_pks, user_pk=None): threads=LOAD_THREADS, augment=True, resize='both', + hyper_params={'epochs': 20}, load_hyper_parameters=True) if not os.path.exists(os.path.split(modelpath)[0]): @@ -222,9 +222,15 @@ def segtrain(task, model_pk, part_pks, user_pk=None): }) trainer.run(_print_eval) + best_version = os.path.join(os.path.dirname(modelpath), f'version_{trainer.stopper.best_epoch}.mlmodel') - shutil.copy(best_version, modelpath) + + try: + shutil.copy(best_version, modelpath) + except FileNotFoundError: + user.notify(_("Training didn't get better results than base model!"), + id="seg-no-gain-error", level='danger') except Exception as e: send_event('document', document.pk, "training:error", { @@ -240,7 +246,6 @@ def segtrain(task, model_pk, part_pks, user_pk=None): user.notify(_("Training finished!"), id="training-success", level='success') - user.notify(report) finally: model.training = False model.save() @@ -361,8 +366,6 @@ def train_(qs, document, transcription, model=None, user=None): DEVICE = getattr(settings, 'KRAKEN_TRAINING_DEVICE', 'cpu') LOAD_THREADS = getattr(settings, 'KRAKEN_TRAINING_LOAD_THREADS', 0) - hyper_params = default_specs.RECOGNITION_HYPER_PARAMS.copy() - hyper_params['batch_size'] = 1 trainer = (kraken_train.KrakenTrainer .recognition_train_gen(device=DEVICE, load=load, @@ -373,7 +376,7 @@ def train_(qs, document, transcription, model=None, user=None): resize='add', threads=LOAD_THREADS, augment=True, - hyper_params=hyper_params, + hyper_params={'batch_size': 1}, load_hyper_parameters=True)) def _print_eval(epoch=0, accuracy=0, chars=0, error=0, val_metric=0):