Commit 3da7f128 authored by Robin Tissot's avatar Robin Tissot
Browse files

Merge branch 'develop'

parents 29c4fc85 355dd20e
......@@ -54,6 +54,13 @@ class DocumentViewSet(ModelViewSet):
else:
return Response({'status': 'already canceled'}, status=400)
@action(detail=True, methods=['post'])
def cancel_training(self, request, pk=None):
document = self.get_object()
model = document.ocr_models.filter(training=True).last()
model.cancel_training()
return Response({'status': 'canceled'})
@action(detail=True, methods=['get'])
def export(self, request, pk=None):
format_ = request.GET.get('as', 'text')
......@@ -114,7 +121,6 @@ class DocumentViewSet(ModelViewSet):
class PartViewSet(ModelViewSet):
queryset = DocumentPart.objects.all().select_related('document')
paginate_by = 50
def get_queryset(self):
qs = self.queryset.filter(document=self.kwargs.get('document_pk'))
......
......@@ -182,9 +182,35 @@ class DocumentProcessForm(BootstrapFormMixin, forms.Form):
if model and model.training:
raise AlreadyProcessingException
return model
def clean(self):
cleaned_data = super().clean()
if cleaned_data.get('train_model'):
model = cleaned_data.get('train_model')
elif cleaned_data.get('upload_model'):
model = OcrModel.objects.create(
document=self.parts[0].document,
owner=self.user,
name=self.cleaned_data['upload_model'].name,
file=self.cleaned_data['upload_model'])
elif cleaned_data.get('new_model'):
# file will be created by the training process
model = OcrModel.objects.create(
document=self.parts[0].document,
owner=self.user,
name=self.cleaned_data['new_model'])
elif cleaned_data.get('ocr_model'):
model = cleaned_data.get('ocr_model')
else:
model = None
cleaned_data['model'] = model
return cleaned_data
def process(self):
task = self.cleaned_data.get('task')
model = self.cleaned_data.get('model')
if task == self.TASK_BINARIZE:
if len(self.parts) == 1 and self.cleaned_data.get('bw_image'):
self.parts[0].bw_image = self.cleaned_data['bw_image']
......@@ -203,25 +229,13 @@ class DocumentProcessForm(BootstrapFormMixin, forms.Form):
text_direction=self.cleaned_data['text_direction'])
elif task == self.TASK_TRANSCRIBE:
if self.cleaned_data.get('upload_model'):
model = OcrModel.objects.create(
document=self.parts[0].document,
owner=self.user,
name=self.cleaned_data['upload_model'].name,
file=self.cleaned_data['upload_model'])
elif self.cleaned_data['ocr_model']:
model = self.cleaned_data['ocr_model']
else:
model = None
for part in self.parts:
part.task('transcribe', user_pk=self.user.pk, model_pk=model and model.pk or None)
elif task == self.TASK_TRAIN:
model = self.cleaned_data.get('upload_model') or self.cleaned_data.get('train_model')
OcrModel.train(self.parts,
self.cleaned_data['transcription'],
model=model,
model_name=self.cleaned_data['new_model'],
model,
user=self.user)
......
......@@ -292,12 +292,18 @@ class DocumentPart(OrderedModel):
@property
def binarized(self):
return self.bw_image is not None
try:
self.bw_image.file
except ValueError:
# catches ValueError: The 'bw_image' attribute has no file associated with it.
return False
else:
return True
@property
def segmented(self):
return self.lines.count() > 0
def make_external_id(self):
return 'eSc_page_%d' % self.pk
......@@ -311,7 +317,7 @@ class DocumentPart(OrderedModel):
return 0
transcribed = LineTranscription.objects.filter(line__document_part=self).count()
self.transcription_progress = min(int(transcribed / total * 100), 100)
def recalculate_ordering(self, text_direction=None, line_level_treshold=1/100):
"""
Re-order the lines of the DocumentPart depending or text direction.
......@@ -773,7 +779,7 @@ def models_path(instance, filename):
class OcrModel(Versioned, models.Model):
name = models.CharField(max_length=256)
file = models.FileField(upload_to=models_path,
file = models.FileField(upload_to=models_path, null=True,
validators=[FileExtensionValidator(
allowed_extensions=['mlmodel'])])
owner = models.ForeignKey(User, null=True, on_delete=models.SET_NULL)
......@@ -801,20 +807,22 @@ class OcrModel(Versioned, models.Model):
return self.training_accuracy * 100
@classmethod
def train(cls, parts_qs, transcription, model=None, model_name=None, user=None):
def train(cls, parts_qs, transcription, model, user=None):
btasks = []
for part in parts_qs:
if not part.binarized:
for task in part.task('binarize', commit=False):
btasks.append(task)
if not (model or model_name):
raise ValueError("OcrModel.train() requires either a `model` or `model_name`.")
ttask = train.si(list(parts_qs.values_list('pk', flat=True)),
transcription.pk,
model_pk=model and model.pk or None,
model_name=model_name,
model_pk=model.pk,
user_pk=user and user.pk or None)
chord(btasks, ttask).delay()
def cancel_training(self):
task_id = json.loads(redis_.get('training-%d' % self.pk))['task_id']
if task_id:
revoke(task_id, terminate=True)
# versioning
def pack(self, **kwargs):
......
from rest_framework.pagination import PageNumberPagination
class CustomPagination(PageNumberPagination):
page_size_query_param = 'paginate_by'
max_page_size = 50
......@@ -233,14 +233,14 @@ class partCard {
this.$element.remove();
}
select() {
select(scroll=true) {
if (this.locked) return;
lastSelected = this;
this.$element.addClass('bg-dark');
this.$element.css({'color': 'white'});
$('i', this.selectButton).removeClass('fa-square');
$('i', this.selectButton).addClass('fa-check-square');
this.$element.get(0).scrollIntoView();
if (scroll) this.$element.get(0).scrollIntoView();
this.selected = true;
}
unselect() {
......@@ -449,26 +449,35 @@ $(document).ready(function() {
// training
var max_accuracy = 0;
$alertsContainer.on('training:start', function(ev, data) {
$('#train-counter').addClass('ongoing');
$('#train-selected').addClass('blink');
$('#train-counter').text('Gathering data.');
$('#cancel-training').show();
});
$alertsContainer.on('training:gathering', function(ev, data) {
$('#train-selected').addClass('blink');
$('#cancel-training').show();
});
$alertsContainer.on('training:eval', function(ev, data) {
$('#train-counter').addClass('ongoing');
$('#train-selected').addClass('blink');
let accuracy = Math.round(data.data.accuracy*100,1);
if (max_accuracy < accuracy) {
$('#train-counter').text('Reached '+accuracy+'% at epoch #'+data.data.epoch);
}
$('#cancel-training').show();
});
$alertsContainer.on('training:done', function(ev, data) {
// $('#train-counter').removeClass('ongoing');
$('#train-selected').removeClass('blink');
$('#cancel-training').hide();
});
$alertsContainer.on('training:error', function(ev, data) {
// $('#train-counter').removeClass('ongoing');
$('#train-selected').removeClass('blink');
$('#train-counter').text('Error.');
$('#train-selected').removeClass('blink').addClass('btn-danger');
$('#cancel-training').hide();
});
$('#cancel-training').click(function(ev, data) {
let url = API.document + '/cancel_training/';
$.post(url, {})
.done(function(data) {
$('#train-selected').removeClass('blink');
$('#cancel-training').hide();
})
.fail(function(data) {
console.log("Couldn't cancel training");
});
});
// create & configure dropzone
......@@ -498,7 +507,7 @@ $(document).ready(function() {
$('#select-all').click(function(ev) {
var cards = partCard.getRange(0, $('#cards-container .card').length);
cards.each(function(i, el) {
$(el).data('partCard').select();
$(el).data('partCard').select(false);
});
partCard.refreshSelectedCount();
});
......@@ -535,7 +544,7 @@ $(document).ready(function() {
if (proc == 'import-xml' || proc == 'import-iiif') {
$('#import-counter').text('Queued.').show().parent().addClass('ongoing');;
} else if (proc == 'train') {
$('#train-counter').text('Queued.').show();
$('#train-selected').addClass('blink');
}
}).fail(function(xhr) {
var data = xhr.responseJSON;
......@@ -551,7 +560,7 @@ $(document).ready(function() {
/* fetch the images and create the cards */
var counter=0;
var getNextParts = function(page) {
var uri = API.parts + '?page=' + page;
var uri = API.parts + '?paginate_by=50&page=' + page;
$.get(uri, function(data) {
counter += data.results.length;
$('#loading-counter').html(counter+'/'+data.count);
......
......@@ -6,13 +6,49 @@ $(document).ready(function() {
$('#models-table tr.model-head').each(function(i, e) {
max_accuracy[$(e).data('id')] = $('td#accuracy-'+$(e).data('id'), e).data('value');
});
$alertsContainer.on('training:start', function(ev, data) {
let $row = $('tr#tr-'+data.id);
$('.training-ongoing', $row).show();
$('.training-done', $row).hide();
$('.training-error', $row).hide();
$('.cancel-training', $row).show();
});
$alertsContainer.on('training:gathering', function(ev, data) {
let $row = $('tr#tr-'+data.id);
$('.training-ongoing', $row).show();
$('.training-done', $row).hide();
$('.training-error', $row).hide();
$('.training-gathering', $row).css('display', 'flex');
$('.training-gathering .progress-bar', $row).css('width', Math.round(data.index/data.total*100)+'%');
$('.cancel-training', $row).show();
});
$alertsContainer.on('training:eval', function(ev, data) {
let $row = $('tr#tr-'+data.id);
$('.training-ongoing', $row).show();
$('.training-done', $row).hide();
$('.training-error', $row).hide();
$('.training-gathering', $row).hide();
if (max_accuracy[data.id] < data.accuracy) {
$row.data('value', data.accuracy);
$('td#accuracy-'+e.data('id'), $row).text(Math.round(data.accuracy*100, 1));
$('td#accuracy-'+data.id, $row).text(Math.round(data.accuracy*100*100)/100 + '%');
max_accuracy[data.id] = data.accuracy;
}
$('.cancel-training', $row).show();
});
$alertsContainer.on('training:done', function(ev, data) {
let $row = $('tr#tr-'+data.id);
$('.training-ongoing', $row).hide();
$('.training-done', $row).show();
// $('.training-error', $row).hide();
$('.training-gathering', $row).hide();
$('.cancel-training', $row).hide();
});
$alertsContainer.on('training:error', function(ev, data) {
let $row = $('tr#tr-'+data.id);
$('.training-ongoing', $row).hide();
$('.training-done', $row).hide();
$('.training-error', $row).show();
$('.cancel-training', $row).hide();
});
});
......@@ -5,6 +5,7 @@ import os.path
import redis
import subprocess
import torch
import shutil
from PIL import Image
from django.apps import apps
......@@ -158,7 +159,7 @@ def segment(instance_pk, user_pk=None, steps=None, text_direction=None, **kwargs
def add_data_to_training_set(data, target_set):
# reorder the lines inside the set to make sure we only open the image once
data.sort(key=lambda e: e['image'])
im = None;
im = None
for i, lt in enumerate(data):
if lt['image'] != im:
if im:
......@@ -166,18 +167,17 @@ def add_data_to_training_set(data, target_set):
im.close() # close previous image
im = Image.open(os.path.join(settings.MEDIA_ROOT, lt['image']))
logger.debug('Opened', im)
if lt['content']:
logger.debug('Loading {} {} {}'.format(i, lt['box'], lt['content']))
target_set.add_loaded(im.crop(lt['box']), lt['content'])
logger.debug('Loading {} {} {}'.format(i, lt['box'], lt['content']))
target_set.add_loaded(im.crop(lt['box']), lt['content'])
yield i, lt
im.close()
def train_(qs, document, transcription, model_pk=None, model_name=None, user=None):
def train_(qs, document, transcription, model=None, user=None):
DEVICE = getattr(settings, 'KRAKEN_TRAINING_DEVICE', 'cpu')
LAG = 5
send_event('document', document.pk, "training:start", {})
# [1,48,0,1 Cr3,3,32 Do0.1,2 Mp2,2 Cr3,3,64 Do0.1,2 Mp2,2 S1(1x12)1,3 Lbx100 Do]
# m = re.match(r'(\d+),(\d+),(\d+),(\d+)', blocks[0])
# if not m:
......@@ -204,8 +204,11 @@ def train_(qs, document, transcription, model_pk=None, model_name=None, user=Non
preload=True)
partition = int(len(ground_truth) / 10)
add_data_to_training_set(ground_truth[partition:], gt_set)
for i, data in add_data_to_training_set(ground_truth[partition:], gt_set):
if i%10 == 0:
logger.debug('Gathering #{} {}/{}'.format(1, i, partition*10))
send_event('document', document.pk, "training:gathering",
{'id': model.pk, 'index': i, 'total': partition*10})
try:
gt_set.encode(None) # codec
except KrakenEncodeException:
......@@ -213,17 +216,16 @@ def train_(qs, document, transcription, model_pk=None, model_name=None, user=Non
train_loader = DataLoader(gt_set, batch_size=1, shuffle=True,
num_workers=0, pin_memory=True)
add_data_to_training_set(ground_truth[:partition], val_set)
for i, data in add_data_to_training_set(ground_truth[:partition], val_set):
if i%10 == 0:
logger.debug('Gathering #{} {}/{}'.format(2, i, partition))
send_event('document', document.pk, "training:gathering",
{'id': model.pk, 'index': partition*9+i, 'total': partition*10})
logger.debug('Done loading training data')
OcrModel = apps.get_model('core', 'OcrModel')
if model_pk:
model = OcrModel.objects.get(pk=model_pk)
nn = vgsl.TorchVGSLModel.load_model(model.file.path)
upload_to = model.file.name
fulldir = os.path.join(settings.MEDIA_ROOT, os.path.split(upload_to)[0], '')
else:
try:
model.file.path
except ValueError:
spec = '[1,48,0,1 Cr3,3,32 Do0.1,2 Mp2,2 Cr3,3,64 Do0.1,2 Mp2,2 S1(1x12)1,3 Lbx100 Do]'
spec = '[{} O1c{}]'.format(spec[1:-1], gt_set.codec.max_label()+1)
nn = vgsl.TorchVGSLModel(spec)
......@@ -231,12 +233,7 @@ def train_(qs, document, transcription, model_pk=None, model_name=None, user=Non
nn.user_metadata['accuracy'] = []
nn.init_weights()
nn.add_codec(gt_set.codec)
filename = slugify(model_name) + '.mlmodel'
model = OcrModel.objects.create(name=filename,
owner=user,
document=document,
script=document.main_script,
version_author=user and user.username or 'unknown')
filename = slugify(model.name) + '.mlmodel'
upload_to = model.file.field.upload_to(model, filename)
fulldir = os.path.join(settings.MEDIA_ROOT, os.path.split(upload_to)[0], '')
if not os.path.exists(fulldir):
......@@ -244,10 +241,13 @@ def train_(qs, document, transcription, model_pk=None, model_name=None, user=Non
modelpath = os.path.join(fulldir, filename)
nn.save_model(path=modelpath)
model.file = upload_to
model.training = True
model.save()
model.save()
else:
nn = vgsl.TorchVGSLModel.load_model(model.file.path)
upload_to = model.file.name
fulldir = os.path.join(settings.MEDIA_ROOT, os.path.split(upload_to)[0], '')
modelpath = os.path.join(settings.MEDIA_ROOT, model.file.name)
val_set.training_set = list(zip(val_set._images, val_set._gt))
# #nn.train()
nn.set_num_threads(1)
......@@ -279,28 +279,23 @@ def train_(qs, document, transcription, model_pk=None, model_name=None, user=Non
new_version_filename = '%s/version_%d.mlmodel' % (os.path.split(upload_to)[0], epoch)
model.new_version(file=new_version_filename)
model.save()
send_event('document', document.pk, "training:eval", {
"id": model.pk,
"data": {
'epoch': epoch,
'accuracy': accuracy,
'chars': chars,
'error': error
}})
'versions': model.versions,
'epoch': epoch,
'accuracy': accuracy,
'chars': chars,
'error': error})
trainer.run(_print_eval, _progress)
send_event('document', document.pk, "training:done", {
"id": model.pk,
})
return model
best_version = os.path.join(fulldir, 'version_{}.mlmodel'.format(trainer.stopper.best_epoch))
shutil.copy(best_version, modelpath)
@shared_task
def train(part_pks, transcription_pk, model_pk=None, model_name=None, user_pk=None):
if not (model_pk or model_name):
raise ValueError("tasks.train() was called without either a model_pk or a model_name.")
@shared_task(bind=True)
def train(task, part_pks, transcription_pk, model_pk, user_pk=None):
if user_pk:
try:
user = User.objects.get(pk=user_pk)
......@@ -308,23 +303,33 @@ def train(part_pks, transcription_pk, model_pk=None, model_name=None, user_pk=No
user = None
else:
user = None
redis_.set('training-%d' % model_pk, json.dumps({'task_id': task.request.id}))
Line = apps.get_model('core', 'Line')
DocumentPart = apps.get_model('core', 'DocumentPart')
Transcription = apps.get_model('core', 'Transcription')
LineTranscription = apps.get_model('core', 'LineTranscription')
model = None
OcrModel = apps.get_model('core', 'OcrModel')
try:
model = OcrModel.objects.get(pk=model_pk)
model.training = True
model.save()
transcription = Transcription.objects.get(pk=transcription_pk)
document = transcription.document
send_event('document', document.pk, "training:start", {
"id": model.pk,
})
qs = (LineTranscription.objects
.filter(transcription=transcription,
line__document_part__pk__in=part_pks)
.exclude(content__isnull=True))
model = train_(qs, document, transcription, model_pk=model_pk, model_name=model_name, user=user)
train_(qs, document, transcription, model=model, user=user)
except Exception as e:
send_event('document', document.pk, "training:error", {})
send_event('document', document.pk, "training:error", {
"id": model.pk,
})
if user:
user.notify(_("Something went wrong during the training process!"),
id="training-error", level='danger')
......@@ -335,9 +340,12 @@ def train(part_pks, transcription_pk, model_pk=None, model_name=None, user_pk=No
id="training-success",
level='success')
finally:
if model:
model.training = False
model.save()
model.training = False
model.save()
send_event('document', document.pk, "training:done", {
"id": model.pk,
})
@shared_task
......
......@@ -12,6 +12,8 @@ urlpatterns = [
path('document/<int:pk>/part/<int:part_pk>/edit/', EditPart.as_view(), name='document-part-edit'),
path('document/<int:pk>/images/', DocumentImages.as_view(), name='document-images'),
path('models/', ModelsList.as_view(), name='user-models'),
path('model/<int:pk>/delete/', ModelDelete.as_view(), name='model-delete'),
path('model/<int:pk>/cancel_training/', ModelCancelTraining.as_view(), name='model-cancel-training'),
path('document/<int:document_pk>/models/', ModelsList.as_view(), name='document-models'),
path('document/<int:pk>/publish/', PublishDocument.as_view(), name='document-publish'),
path('document/<int:pk>/share/', ShareDocument.as_view(), name='document-share'),
......
......@@ -257,3 +257,39 @@ class ModelsList(LoginRequiredMixin, ListView):
context['document'] = self.document
context['object'] = self.document # legacy
return context
class ModelDelete(LoginRequiredMixin, SuccessMessageMixin, DeleteView):
model = OcrModel
success_message = _("Model deleted successfully!")
def get_queryset(self):
return OcrModel.objects.filter(owner=self.request.user)
def get_success_url(self):
if 'next' in self.request.GET:
return self.request.GET.get('next')
else:
return reverse('user-models')
class ModelCancelTraining(LoginRequiredMixin, SuccessMessageMixin, DetailView):
model = OcrModel
http_method_names = ('post',)
def get_success_url(self):
if 'next' in self.request.GET:
return self.request.GET.get('next')
else:
return reverse('user-models')
def post(self, request, *args, **kwargs):
model = self.get_object()
try:
model.cancel_training()
except Exception as e:
logger.exception(e)
return HttpResponse({'status': 'failed'}, status=400,
content_type="application/json")
else:
return HttpResponseRedirect(self.get_success_url())
......@@ -264,7 +264,7 @@ REST_FRAMEWORK = {
#'rest_framework.permissions.DjangoModelPermissionsOrAnonReadOnly'
'rest_framework.permissions.IsAuthenticated'
],
'DEFAULT_PAGINATION_CLASS': 'rest_framework.pagination.PageNumberPagination',
'DEFAULT_PAGINATION_CLASS': 'core.pagination.CustomPagination',
'PAGE_SIZE': 10,
}
......
......@@ -5,6 +5,10 @@ body {
max-width: 1800px;
}
.hide {
display: none;
}
.container {
margin-right: auto;
margin-left: auto;
......
......@@ -42,10 +42,8 @@
{% endwith %}
{% if user.is_staff %}
{% with training_model=document.training_model %}
<button id="train-selected" class="btn btn-sm btn-primary js-proc-selected ml-auto {% if training_model %}ongoing blink{% endif %}" data-proc="train" title="{% trans 'Train a model from selected images.' %}"><i class="fas fa-subway mr-1"></i>{% trans "Train" %}</button>
<span id="train-counter" class="ml-1 {% if training_model %}ongoing{% endif %}">
{% if training_model.training_accuracy %}{% blocktrans with accuracy=training_model.training_accuracy epoch=training_model.training_epoch %}Reached {{accuracy}}% at epoch #{{epoch}}{% endblocktrans %}
{% else %}{% trans "Gathering data."%}{% endif %}</span>
<button id="train-selected" class="btn btn-sm btn-primary js-proc-selected ml-auto {% if training_model %}blink{% endif %}" data-proc="train" title="{% trans 'Train a model from selected images.' %}"><i class="fas fa-subway mr-1"></i>{% trans "Train" %}</button>