Commit 6642419a authored by Robin Tissot's avatar Robin Tissot
Browse files

Merge branch 'develop'

parents f22c1ee5 08e4e0e2
......@@ -17,7 +17,6 @@ from celery import shared_task
from celery.signals import before_task_publish, task_prerun, task_success, task_failure
from django_redis import get_redis_connection
from easy_thumbnails.files import get_thumbnailer
from kraken.lib import default_specs
from kraken.lib import train as kraken_train
......@@ -196,6 +195,7 @@ def segtrain(task, model_pk, part_pks, user_pk=None):
threads=LOAD_THREADS,
augment=True,
resize='both',
hyper_params={'epochs': 20},
load_hyper_parameters=True)
if not os.path.exists(os.path.split(modelpath)[0]):
......@@ -222,9 +222,15 @@ def segtrain(task, model_pk, part_pks, user_pk=None):
})
trainer.run(_print_eval)
best_version = os.path.join(os.path.dirname(modelpath),
f'version_{trainer.stopper.best_epoch}.mlmodel')
shutil.copy(best_version, modelpath)
try:
shutil.copy(best_version, modelpath)
except FileNotFoundError:
user.notify(_("Training didn't get better results than base model!"),
id="seg-no-gain-error", level='danger')
except Exception as e:
send_event('document', document.pk, "training:error", {
......@@ -240,7 +246,6 @@ def segtrain(task, model_pk, part_pks, user_pk=None):
user.notify(_("Training finished!"),
id="training-success",
level='success')
user.notify(report)
finally:
model.training = False
model.save()
......@@ -361,8 +366,6 @@ def train_(qs, document, transcription, model=None, user=None):
DEVICE = getattr(settings, 'KRAKEN_TRAINING_DEVICE', 'cpu')
LOAD_THREADS = getattr(settings, 'KRAKEN_TRAINING_LOAD_THREADS', 0)
hyper_params = default_specs.RECOGNITION_HYPER_PARAMS.copy()
hyper_params['batch_size'] = 1
trainer = (kraken_train.KrakenTrainer
.recognition_train_gen(device=DEVICE,
load=load,
......@@ -373,7 +376,7 @@ def train_(qs, document, transcription, model=None, user=None):
resize='add',
threads=LOAD_THREADS,
augment=True,
hyper_params=hyper_params,
hyper_params={'batch_size': 1},
load_hyper_parameters=True))
def _print_eval(epoch=0, accuracy=0, chars=0, error=0, val_metric=0):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment