From d13f6a69da7153d7d4b433bcbd517c41147cdfda Mon Sep 17 00:00:00 2001 From: Eva Bardou <ebardou@teklia.com> Date: Wed, 19 May 2021 14:59:31 +0200 Subject: [PATCH] Fix all broken behaviours introduced by the previous commit --- app/apps/api/serializers.py | 31 ++++--- app/apps/api/tests.py | 14 ++-- app/apps/api/views.py | 2 +- app/apps/core/forms.py | 142 ++++++++++++++++++++++++++++----- app/apps/core/models.py | 3 +- app/apps/core/tasks.py | 11 ++- app/apps/core/tests/factory.py | 10 ++- app/apps/core/tests/tasks.py | 2 +- app/apps/core/views.py | 2 +- 9 files changed, 164 insertions(+), 53 deletions(-) diff --git a/app/apps/api/serializers.py b/app/apps/api/serializers.py index 220b4642..fc2d0296 100644 --- a/app/apps/api/serializers.py +++ b/app/apps/api/serializers.py @@ -20,7 +20,8 @@ from core.models import (Document, BlockType, LineType, Script, - OcrModel) + OcrModel, + OcrModelDocument) from core.tasks import (segtrain, train, segment, transcribe) logger = logging.getLogger(__name__) @@ -340,8 +341,7 @@ class SegmentSerializer(ProcessSerializerMixin, serializers.Serializer): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - self.fields['model'].queryset = OcrModel.objects.filter(job=OcrModel.MODEL_JOB_SEGMENT, - document=self.document) + self.fields['model'].queryset = self.document.ocr_models.filter(job=OcrModel.MODEL_JOB_SEGMENT) self.fields['parts'].queryset = DocumentPart.objects.filter(document=self.document) def process(self): @@ -367,8 +367,7 @@ class SegTrainSerializer(ProcessSerializerMixin, serializers.Serializer): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - self.fields['model'].queryset = OcrModel.objects.filter(job=OcrModel.MODEL_JOB_SEGMENT, - document=self.document) + self.fields['model'].queryset = self.document.ocr_models.filter(job=OcrModel.MODEL_JOB_SEGMENT) self.fields['parts'].queryset = DocumentPart.objects.filter(document=self.document) def validate_parts(self, data): @@ -388,14 +387,19 @@ class SegTrainSerializer(ProcessSerializerMixin, serializers.Serializer): if self.validated_data.get('model_name'): file_ = model and model.file or None model = OcrModel.objects.create( - document=self.document, owner=self.user, name=self.validated_data['model_name'], job=OcrModel.MODEL_JOB_RECOGNIZE, file=file_ ) + OcrModelDocument.objects.create( + document=self.document, + ocr_model=model, + trained_on=False, + executed_on=True, + ) - segtrain.delay(model.pk if model else None, + segtrain.delay(model.pk if model else None, self.document.pk, [part.pk for part in self.validated_data.get('parts')], user_pk=self.user.pk) @@ -411,8 +415,7 @@ class TrainSerializer(ProcessSerializerMixin, serializers.Serializer): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.fields['transcription'].queryset = Transcription.objects.filter(document=self.document) - self.fields['model'].queryset = OcrModel.objects.filter(job=OcrModel.MODEL_JOB_RECOGNIZE, - document=self.document) + self.fields['model'].queryset = self.document.ocr_models.filter(job=OcrModel.MODEL_JOB_RECOGNIZE) self.fields['parts'].queryset = DocumentPart.objects.filter(document=self.document) def validate(self, data): @@ -427,11 +430,16 @@ class TrainSerializer(ProcessSerializerMixin, serializers.Serializer): if self.validated_data.get('model_name'): file_ = model and model.file or None model = OcrModel.objects.create( - document=self.document, owner=self.user, name=self.validated_data['model_name'], job=OcrModel.MODEL_JOB_RECOGNIZE, file=file_) + OcrModelDocument.objects.create( + document=self.document, + ocr_model=model, + trained_on=False, + executed_on=True, + ) train.delay([part.pk for part in self.validated_data.get('parts')], self.validated_data['transcription'].pk, @@ -448,8 +456,7 @@ class TranscribeSerializer(ProcessSerializerMixin, serializers.Serializer): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) # self.fields['transcription'].queryset = Transcription.objects.filter(document=self.document) - self.fields['model'].queryset = OcrModel.objects.filter(job=OcrModel.MODEL_JOB_RECOGNIZE, - document=self.document) + self.fields['model'].queryset = self.document.ocr_models.filter(job=OcrModel.MODEL_JOB_RECOGNIZE) self.fields['parts'].queryset = DocumentPart.objects.filter(document=self.document) def process(self): diff --git a/app/apps/api/tests.py b/app/apps/api/tests.py index fb1ee141..cf1ccf27 100644 --- a/app/apps/api/tests.py +++ b/app/apps/api/tests.py @@ -35,7 +35,7 @@ class OcrModelViewSetTestCase(CoreFactoryTestCase): super().setUp() self.part = self.factory.make_part() self.user = self.part.document.owner - self.model = self.factory.make_model(document=self.part.document) + self.model = self.factory.make_model(self.part.document) def test_list(self): self.client.force_login(self.user) @@ -123,7 +123,7 @@ class DocumentViewSetTestCase(CoreFactoryTestCase): def test_segtrain_less_two_parts(self): self.client.force_login(self.doc.owner) - model = self.factory.make_model(job=OcrModel.MODEL_JOB_SEGMENT, document=self.doc) + model = self.factory.make_model(self.doc, job=OcrModel.MODEL_JOB_SEGMENT) uri = reverse('api:document-segtrain', kwargs={'pk': self.doc.pk}) resp = self.client.post(uri, data={ 'parts': [self.part.pk], @@ -147,7 +147,7 @@ class DocumentViewSetTestCase(CoreFactoryTestCase): def test_segtrain_existing_model_rename(self): self.client.force_login(self.doc.owner) - model = self.factory.make_model(job=OcrModel.MODEL_JOB_SEGMENT, document=self.doc) + model = self.factory.make_model(self.doc, job=OcrModel.MODEL_JOB_SEGMENT) uri = reverse('api:document-segtrain', kwargs={'pk': self.doc.pk}) resp = self.client.post(uri, data={ 'parts': [self.part.pk, self.part2.pk], @@ -160,7 +160,7 @@ class DocumentViewSetTestCase(CoreFactoryTestCase): def test_segment(self): uri = reverse('api:document-segment', kwargs={'pk': self.doc.pk}) self.client.force_login(self.doc.owner) - model = self.factory.make_model(job=OcrModel.MODEL_JOB_SEGMENT, document=self.doc) + model = self.factory.make_model(self.doc, job=OcrModel.MODEL_JOB_SEGMENT) resp = self.client.post(uri, data={ 'parts': [self.part.pk, self.part2.pk], 'seg_steps': 'both', @@ -177,15 +177,13 @@ class DocumentViewSetTestCase(CoreFactoryTestCase): 'transcription': self.transcription.pk }) self.assertEqual(resp.status_code, 200) - self.assertEqual(OcrModel.objects.filter( - document=self.doc, - job=OcrModel.MODEL_JOB_RECOGNIZE).count(), 1) + self.assertEqual(self.doc.ocr_models.filter(job=OcrModel.MODEL_JOB_RECOGNIZE).count(), 1) def test_transcribe(self): trans = Transcription.objects.create(document=self.part.document) self.client.force_login(self.doc.owner) - model = self.factory.make_model(job=OcrModel.MODEL_JOB_RECOGNIZE, document=self.doc) + model = self.factory.make_model(self.doc, job=OcrModel.MODEL_JOB_RECOGNIZE) uri = reverse('api:document-transcribe', kwargs={'pk': self.doc.pk}) resp = self.client.post(uri, data={ 'parts': [self.part.pk, self.part2.pk], diff --git a/app/apps/api/views.py b/app/apps/api/views.py index f33ba76d..66cc6afc 100644 --- a/app/apps/api/views.py +++ b/app/apps/api/views.py @@ -469,7 +469,7 @@ class OcrModelViewSet(DocumentPermissionMixin, ModelViewSet): def get_queryset(self): return (super().get_queryset() - .filter(document=self.kwargs['document_pk'])) + .filter(documents=self.kwargs['document_pk'])) @action(detail=True, methods=['post']) def cancel_training(self, request, pk=None): diff --git a/app/apps/core/forms.py b/app/apps/core/forms.py index c591e7fd..d9811a51 100644 --- a/app/apps/core/forms.py +++ b/app/apps/core/forms.py @@ -12,7 +12,7 @@ from django.utils.translation import gettext_lazy as _ from bootstrap.forms import BootstrapFormMixin from core.models import (Document, Metadata, DocumentMetadata, - DocumentPart, OcrModel, Transcription, + DocumentPart, OcrModel, OcrModelDocument, Transcription, BlockType, LineType, AlreadyProcessingException) from users.models import User @@ -129,12 +129,12 @@ class DocumentProcessForm1(BootstrapFormMixin, forms.Form): if self.document.read_direction == self.document.READ_DIRECTION_RTL: self.initial['text_direction'] = 'horizontal-rl' self.fields['binarizer'].widget.attrs['disabled'] = True - self.fields['train_model'].queryset &= OcrModel.objects.filter(document=self.document) - self.fields['segtrain_model'].queryset &= OcrModel.objects.filter(document=self.document) - self.fields['seg_model'].queryset &= OcrModel.objects.filter(document=self.document) + self.fields['train_model'].queryset &= self.document.ocr_models.all() + self.fields['segtrain_model'].queryset &= self.document.ocr_models.all() + self.fields['seg_model'].queryset &= self.document.ocr_models.all() self.fields['ocr_model'].queryset &= OcrModel.objects.filter( - Q(document=None, script=document.main_script) - | Q(document=self.document)) + Q(documents=None, script=self.document.main_script) + | Q(documents=self.document)) self.fields['transcription'].queryset = Transcription.objects.filter(document=self.document) def process(self): @@ -172,16 +172,29 @@ class DocumentSegmentForm(DocumentProcessForm1): if data.get('upload_model'): model = OcrModel.objects.create( - document=self.parts[0].document, owner=self.user, name=data['upload_model'].name.rsplit('.', 1)[0], job=model_job) + OcrModelDocument.objects.create( + document=self.parts[0].document, + ocr_model=model, + trained_on=False, + executed_on=True, + ) # Note: needs to save the file in a second step because the path needs the db PK model.file = data['upload_model'] model.save() elif data.get('seg_model'): model = data.get('seg_model') + ocr_model_document, created = OcrModelDocument.objects.get_or_create( + ocr_model=model, + document=self.parts[0].document, + defaults={'trained_on': False, 'executed_on': True} + ) + if not created: + ocr_model_document.executed_on = True + ocr_model_document.save() else: model = None @@ -224,13 +237,26 @@ class DocumentTrainForm(DocumentProcessForm1): if data.get('train_model'): model = data.get('train_model') + ocr_model_document, created = OcrModelDocument.objects.get_or_create( + ocr_model=model, + document=self.parts[0].document, + defaults={'trained_on': True, 'executed_on': False} + ) + if not created: + ocr_model_document.trained_on = True + ocr_model_document.save() elif data.get('upload_model'): model = OcrModel.objects.create( - document=self.parts[0].document, owner=self.user, name=data['upload_model'].name.rsplit('.', 1)[0], job=model_job) + OcrModelDocument.objects.create( + document=self.parts[0].document, + ocr_model=model, + trained_on=True, + executed_on=False, + ) # Note: needs to save the file in a second step because the path needs the db PK model.file = data['upload_model'] model.save() @@ -238,10 +264,15 @@ class DocumentTrainForm(DocumentProcessForm1): elif data.get('new_model'): # file will be created by the training process model = OcrModel.objects.create( - document=self.parts[0].document, owner=self.user, name=data['new_model'], job=model_job) + OcrModelDocument.objects.create( + document=self.parts[0].document, + ocr_model=model, + trained_on=True, + executed_on=False, + ) else: raise forms.ValidationError( @@ -279,12 +310,25 @@ class DocumentSegtrainForm(DocumentProcessForm1): if data.get('segtrain_model'): model = data.get('segtrain_model') + ocr_model_document, created = OcrModelDocument.objects.get_or_create( + ocr_model=model, + document=self.parts[0].document, + defaults={'trained_on': True, 'executed_on': False} + ) + if not created: + ocr_model_document.trained_on = True + ocr_model_document.save() elif data.get('upload_model'): model = OcrModel.objects.create( - document=self.parts[0].document, owner=self.user, name=data['upload_model'].name.rsplit('.', 1)[0], job=model_job) + OcrModelDocument.objects.create( + document=self.parts[0].document, + ocr_model=model, + trained_on=True, + executed_on=False, + ) # Note: needs to save the file in a second step because the path needs the db PK model.file = data['upload_model'] model.save() @@ -292,10 +336,15 @@ class DocumentSegtrainForm(DocumentProcessForm1): elif data.get('new_model'): # file will be created by the training process model = OcrModel.objects.create( - document=self.parts[0].document, owner=self.user, name=data['new_model'], job=model_job) + OcrModelDocument.objects.create( + document=self.parts[0].document, + ocr_model=model, + trained_on=True, + executed_on=False, + ) else: @@ -328,16 +377,29 @@ class DocumentTranscribeForm(DocumentProcessForm1): if data.get('upload_model'): model = OcrModel.objects.create( - document=self.parts[0].document, owner=self.user, name=data['upload_model'].name.rsplit('.', 1)[0], job=model_job) + OcrModelDocument.objects.create( + document=self.parts[0].document, + ocr_model=model, + trained_on=False, + executed_on=True, + ) # Note: needs to save the file in a second step because the path needs the db PK model.file = data['upload_model'] model.save() elif data.get('ocr_model'): model = data.get('ocr_model') + ocr_model_document, created = OcrModelDocument.objects.get_or_create( + ocr_model=model, + document=self.parts[0].document, + defaults={'trained_on': False, 'executed_on': True} + ) + if not created: + ocr_model_document.executed_on = True + ocr_model_document.save() else: raise forms.ValidationError( _("Either select a name for your new model or an existing one.")) @@ -436,12 +498,12 @@ class DocumentProcessForm(BootstrapFormMixin, forms.Form): if self.document.read_direction == self.document.READ_DIRECTION_RTL: self.initial['text_direction'] = 'horizontal-rl' self.fields['binarizer'].widget.attrs['disabled'] = True - self.fields['train_model'].queryset &= OcrModel.objects.filter(document=self.document) - self.fields['segtrain_model'].queryset &= OcrModel.objects.filter(document=self.document) - self.fields['seg_model'].queryset &= OcrModel.objects.filter(document=self.document) + self.fields['train_model'].queryset &= self.document.ocr_models.all() + self.fields['segtrain_model'].queryset &= self.document.ocr_models.all() + self.fields['seg_model'].queryset &= self.document.ocr_models.all() self.fields['ocr_model'].queryset &= OcrModel.objects.filter( - Q(document=None, script=document.main_script) - | Q(document=self.document)) + Q(documents=None, script=self.document.main_script) + | Q(documents=self.document)) self.fields['transcription'].queryset = Transcription.objects.filter(document=self.document) @cached_property @@ -487,14 +549,35 @@ class DocumentProcessForm(BootstrapFormMixin, forms.Form): if task == self.TASK_TRAIN and data.get('train_model'): model = data.get('train_model') + ocr_model_document, created = OcrModelDocument.objects.get_or_create( + ocr_model=model, + document=self.parts[0].document, + defaults={'trained_on': True, 'executed_on': False} + ) + if not created: + ocr_model_document.trained_on = True + ocr_model_document.save() elif task == self.TASK_SEGTRAIN and data.get('segtrain_model'): model = data.get('segtrain_model') + ocr_model_document, created = OcrModelDocument.objects.get_or_create( + ocr_model=model, + document=self.parts[0].document, + defaults={'trained_on': True, 'executed_on': False} + ) + if not created: + ocr_model_document.trained_on = True + ocr_model_document.save() elif data.get('upload_model'): model = OcrModel.objects.create( - document=self.parts[0].document, owner=self.user, name=data['upload_model'].name.rsplit('.', 1)[0], job=model_job) + OcrModelDocument.objects.create( + document=self.parts[0].document, + ocr_model=model, + trained_on=False, + executed_on=True, + ) # Note: needs to save the file in a second step because the path needs the db PK model.file = data['upload_model'] model.save() @@ -502,14 +585,35 @@ class DocumentProcessForm(BootstrapFormMixin, forms.Form): elif data.get('new_model'): # file will be created by the training process model = OcrModel.objects.create( - document=self.parts[0].document, owner=self.user, name=data['new_model'], job=model_job) + OcrModelDocument.objects.create( + document=self.parts[0].document, + ocr_model=model, + trained_on=True, + executed_on=False, + ) elif data.get('ocr_model'): model = data.get('ocr_model') + ocr_model_document, created = OcrModelDocument.objects.get_or_create( + ocr_model=model, + document=self.parts[0].document, + defaults={'trained_on': False, 'executed_on': True} + ) + if not created: + ocr_model_document.executed_on = True + ocr_model_document.save() elif data.get('seg_model'): model = data.get('seg_model') + ocr_model_document, created = OcrModelDocument.objects.get_or_create( + ocr_model=model, + document=self.parts[0].document, + defaults={'trained_on': False, 'executed_on': True} + ) + if not created: + ocr_model_document.executed_on = True + ocr_model_document.save() else: if task in (self.TASK_TRAIN, self.TASK_SEGTRAIN): raise forms.ValidationError( diff --git a/app/apps/core/models.py b/app/apps/core/models.py index c0dd0046..20b352fd 100644 --- a/app/apps/core/models.py +++ b/app/apps/core/models.py @@ -1083,7 +1083,7 @@ class LineTranscription(Versioned, models.Model): def models_path(instance, filename): fn, ext = os.path.splitext(filename) - return 'models/%d/%s%s' % (instance.document.pk, slugify(fn), ext) + return 'models/%s%s' % (slugify(fn), ext) class OcrModel(Versioned, models.Model): @@ -1126,6 +1126,7 @@ class OcrModel(Versioned, models.Model): def segtrain(self, document, parts_qs, user=None): segtrain.delay(self.pk, + document.pk, list(parts_qs.values_list('pk', flat=True)), user_pk=user and user.pk or None) diff --git a/app/apps/core/tasks.py b/app/apps/core/tasks.py index 7d5732be..b10c637c 100644 --- a/app/apps/core/tasks.py +++ b/app/apps/core/tasks.py @@ -130,7 +130,7 @@ def make_segmentation_training_data(part): @shared_task(bind=True, autoretry_for=(MemoryError,), default_retry_delay=60 * 60) -def segtrain(task, model_pk, part_pks, user_pk=None): +def segtrain(task, model_pk, document_pk, part_pks, user_pk=None): # # Note hack to circumvent AssertionError: daemonic processes are not allowed to have children from multiprocessing import current_process current_process().daemon = False @@ -166,8 +166,7 @@ def segtrain(task, model_pk, part_pks, user_pk=None): try: model.training = True model.save() - document = model.document - send_event('document', document.pk, "training:start", { + send_event('document', document_pk, "training:start", { "id": model.pk, }) qs = DocumentPart.objects.filter(pk__in=part_pks) @@ -213,7 +212,7 @@ def segtrain(task, model_pk, part_pks, user_pk=None): model.new_version(file=new_version_filename) model.save() - send_event('document', document.pk, "training:eval", { + send_event('document', document_pk, "training:eval", { "id": model.pk, 'versions': model.versions, 'epoch': epoch, @@ -234,7 +233,7 @@ def segtrain(task, model_pk, part_pks, user_pk=None): id="seg-no-gain-error", level='danger') except Exception as e: - send_event('document', document.pk, "training:error", { + send_event('document', document_pk, "training:error", { "id": model.pk, }) if user: @@ -251,7 +250,7 @@ def segtrain(task, model_pk, part_pks, user_pk=None): model.training = False model.save() - send_event('document', document.pk, "training:done", { + send_event('document', document_pk, "training:done", { "id": model.pk, }) diff --git a/app/apps/core/tests/factory.py b/app/apps/core/tests/factory.py index 98ffc303..9c5655cc 100644 --- a/app/apps/core/tests/factory.py +++ b/app/apps/core/tests/factory.py @@ -86,13 +86,15 @@ class CoreFactory(): fp = os.path.join(os.path.dirname(__file__), 'assets', asset_name) return open(fp, 'rb') - def make_model(self, job=OcrModel.MODEL_JOB_RECOGNIZE, document=None): + def make_model(self, document, job=OcrModel.MODEL_JOB_RECOGNIZE): spec = '[1,48,0,1 Lbx100 Do O1c10]' nn = vgsl.TorchVGSLModel(spec) model_name = 'test-model.mlmodel' - model = OcrModel.objects.create(name=model_name, - document=document, - job=job) + model = document.ocr_models.add( + name=model_name, + job=job, + through_defaults={'trained_on': False, 'executed_on': True} + ) modeldir = os.path.join(settings.MEDIA_ROOT, os.path.split( model.file.field.upload_to(model, 'test-model.mlmodel'))[0]) if not os.path.exists(modeldir): diff --git a/app/apps/core/tests/tasks.py b/app/apps/core/tests/tasks.py index 00087bb6..10ebf07c 100644 --- a/app/apps/core/tests/tasks.py +++ b/app/apps/core/tests/tasks.py @@ -68,7 +68,7 @@ class TasksTestCase(CoreFactoryTestCase): def test_train_existing_transcription_model(self): self.makeTranscriptionContent() - model = self.factory.make_model(document=self.part.document) + model = self.factory.make_model(self.part.document) self.client.force_login(self.part.document.owner) uri = reverse('document-parts-process', kwargs={'pk': self.part.document.pk}) with self.assertNumQueries(17): diff --git a/app/apps/core/views.py b/app/apps/core/views.py index 1304370d..a140d63d 100644 --- a/app/apps/core/views.py +++ b/app/apps/core/views.py @@ -274,7 +274,7 @@ class ModelsList(LoginRequiredMixin, ListView): self.document = Document.objects.for_user(self.request.user).get(pk=self.kwargs.get('document_pk')) except Document.DoesNotExist: raise PermissionDenied - return OcrModel.objects.filter(document=self.document) + return self.document.ocr_models.all() else: self.document = None return OcrModel.objects.filter(owner=self.request.user) -- GitLab