diff --git a/app/apps/api/serializers.py b/app/apps/api/serializers.py index 220b4642cad65332ddbcfe54e75cdc545788d109..f5cc969b79c8387116f03f6516d0e001949dffb8 100644 --- a/app/apps/api/serializers.py +++ b/app/apps/api/serializers.py @@ -4,6 +4,7 @@ import html from django.conf import settings from django.db.utils import IntegrityError +from django.utils import timezone from django.utils.translation import gettext_lazy as _ from rest_framework import serializers @@ -20,7 +21,8 @@ from core.models import (Document, BlockType, LineType, Script, - OcrModel) + OcrModel, + OcrModelDocument) from core.tasks import (segtrain, train, segment, transcribe) logger = logging.getLogger(__name__) @@ -298,9 +300,9 @@ class OcrModelSerializer(serializers.ModelSerializer): def create(self, data): document = Document.objects.get(pk=self.context["view"].kwargs["document_pk"]) - data['document'] = document data['owner'] = self.context["view"].request.user obj = super().create(data) + document.ocr_models.add(obj) return obj @@ -340,8 +342,7 @@ class SegmentSerializer(ProcessSerializerMixin, serializers.Serializer): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - self.fields['model'].queryset = OcrModel.objects.filter(job=OcrModel.MODEL_JOB_SEGMENT, - document=self.document) + self.fields['model'].queryset = self.document.ocr_models.filter(job=OcrModel.MODEL_JOB_SEGMENT) self.fields['parts'].queryset = DocumentPart.objects.filter(document=self.document) def process(self): @@ -367,8 +368,7 @@ class SegTrainSerializer(ProcessSerializerMixin, serializers.Serializer): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - self.fields['model'].queryset = OcrModel.objects.filter(job=OcrModel.MODEL_JOB_SEGMENT, - document=self.document) + self.fields['model'].queryset = self.document.ocr_models.filter(job=OcrModel.MODEL_JOB_SEGMENT) self.fields['parts'].queryset = DocumentPart.objects.filter(document=self.document) def validate_parts(self, data): @@ -388,14 +388,18 @@ class SegTrainSerializer(ProcessSerializerMixin, serializers.Serializer): if self.validated_data.get('model_name'): file_ = model and model.file or None model = OcrModel.objects.create( - document=self.document, owner=self.user, name=self.validated_data['model_name'], job=OcrModel.MODEL_JOB_RECOGNIZE, file=file_ ) + OcrModelDocument.objects.create( + document=self.document, + ocr_model=model, + executed_on=timezone.now(), + ) - segtrain.delay(model.pk if model else None, + segtrain.delay(model.pk if model else None, self.document.pk, [part.pk for part in self.validated_data.get('parts')], user_pk=self.user.pk) @@ -411,8 +415,7 @@ class TrainSerializer(ProcessSerializerMixin, serializers.Serializer): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.fields['transcription'].queryset = Transcription.objects.filter(document=self.document) - self.fields['model'].queryset = OcrModel.objects.filter(job=OcrModel.MODEL_JOB_RECOGNIZE, - document=self.document) + self.fields['model'].queryset = self.document.ocr_models.filter(job=OcrModel.MODEL_JOB_RECOGNIZE) self.fields['parts'].queryset = DocumentPart.objects.filter(document=self.document) def validate(self, data): @@ -427,11 +430,15 @@ class TrainSerializer(ProcessSerializerMixin, serializers.Serializer): if self.validated_data.get('model_name'): file_ = model and model.file or None model = OcrModel.objects.create( - document=self.document, owner=self.user, name=self.validated_data['model_name'], job=OcrModel.MODEL_JOB_RECOGNIZE, file=file_) + OcrModelDocument.objects.create( + document=self.document, + ocr_model=model, + executed_on=timezone.now(), + ) train.delay([part.pk for part in self.validated_data.get('parts')], self.validated_data['transcription'].pk, @@ -448,8 +455,7 @@ class TranscribeSerializer(ProcessSerializerMixin, serializers.Serializer): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) # self.fields['transcription'].queryset = Transcription.objects.filter(document=self.document) - self.fields['model'].queryset = OcrModel.objects.filter(job=OcrModel.MODEL_JOB_RECOGNIZE, - document=self.document) + self.fields['model'].queryset = self.document.ocr_models.filter(job=OcrModel.MODEL_JOB_RECOGNIZE) self.fields['parts'].queryset = DocumentPart.objects.filter(document=self.document) def process(self): diff --git a/app/apps/api/tests.py b/app/apps/api/tests.py index fb1ee141f2dd08f64d9a3fc76ad0354f9974bcf0..6523e1f2070b693120f9d096678f9dc6699dfdaa 100644 --- a/app/apps/api/tests.py +++ b/app/apps/api/tests.py @@ -35,12 +35,12 @@ class OcrModelViewSetTestCase(CoreFactoryTestCase): super().setUp() self.part = self.factory.make_part() self.user = self.part.document.owner - self.model = self.factory.make_model(document=self.part.document) + self.model = self.factory.make_model(self.part.document) def test_list(self): self.client.force_login(self.user) uri = reverse('api:model-list', kwargs={'document_pk': self.part.document.pk}) - with self.assertNumQueries(7): + with self.assertNumQueries(8): resp = self.client.get(uri) self.assertEqual(resp.status_code, 200) @@ -49,14 +49,14 @@ class OcrModelViewSetTestCase(CoreFactoryTestCase): uri = reverse('api:model-detail', kwargs={'document_pk': self.part.document.pk, 'pk': self.model.pk}) - with self.assertNumQueries(6): + with self.assertNumQueries(7): resp = self.client.get(uri) self.assertEqual(resp.status_code, 200) def test_create(self): self.client.force_login(self.user) uri = reverse('api:model-list', kwargs={'document_pk': self.part.document.pk}) - with self.assertNumQueries(4): + with self.assertNumQueries(6): resp = self.client.post(uri, { 'name': 'test.mlmodel', 'file': self.factory.make_asset_file(name='test.mlmodel', @@ -100,7 +100,7 @@ class DocumentViewSetTestCase(CoreFactoryTestCase): def test_list(self): self.client.force_login(self.doc.owner) uri = reverse('api:document-list') - with self.assertNumQueries(8): + with self.assertNumQueries(10): resp = self.client.get(uri) self.assertEqual(resp.status_code, 200) @@ -108,7 +108,7 @@ class DocumentViewSetTestCase(CoreFactoryTestCase): self.client.force_login(self.doc.owner) uri = reverse('api:document-detail', kwargs={'pk': self.doc.pk}) - with self.assertNumQueries(7): + with self.assertNumQueries(8): resp = self.client.get(uri) self.assertEqual(resp.status_code, 200) @@ -123,7 +123,7 @@ class DocumentViewSetTestCase(CoreFactoryTestCase): def test_segtrain_less_two_parts(self): self.client.force_login(self.doc.owner) - model = self.factory.make_model(job=OcrModel.MODEL_JOB_SEGMENT, document=self.doc) + model = self.factory.make_model(self.doc, job=OcrModel.MODEL_JOB_SEGMENT) uri = reverse('api:document-segtrain', kwargs={'pk': self.doc.pk}) resp = self.client.post(uri, data={ 'parts': [self.part.pk], @@ -147,7 +147,7 @@ class DocumentViewSetTestCase(CoreFactoryTestCase): def test_segtrain_existing_model_rename(self): self.client.force_login(self.doc.owner) - model = self.factory.make_model(job=OcrModel.MODEL_JOB_SEGMENT, document=self.doc) + model = self.factory.make_model(self.doc, job=OcrModel.MODEL_JOB_SEGMENT) uri = reverse('api:document-segtrain', kwargs={'pk': self.doc.pk}) resp = self.client.post(uri, data={ 'parts': [self.part.pk, self.part2.pk], @@ -160,7 +160,7 @@ class DocumentViewSetTestCase(CoreFactoryTestCase): def test_segment(self): uri = reverse('api:document-segment', kwargs={'pk': self.doc.pk}) self.client.force_login(self.doc.owner) - model = self.factory.make_model(job=OcrModel.MODEL_JOB_SEGMENT, document=self.doc) + model = self.factory.make_model(self.doc, job=OcrModel.MODEL_JOB_SEGMENT) resp = self.client.post(uri, data={ 'parts': [self.part.pk, self.part2.pk], 'seg_steps': 'both', @@ -177,15 +177,13 @@ class DocumentViewSetTestCase(CoreFactoryTestCase): 'transcription': self.transcription.pk }) self.assertEqual(resp.status_code, 200) - self.assertEqual(OcrModel.objects.filter( - document=self.doc, - job=OcrModel.MODEL_JOB_RECOGNIZE).count(), 1) + self.assertEqual(self.doc.ocr_models.filter(job=OcrModel.MODEL_JOB_RECOGNIZE).count(), 1) def test_transcribe(self): trans = Transcription.objects.create(document=self.part.document) self.client.force_login(self.doc.owner) - model = self.factory.make_model(job=OcrModel.MODEL_JOB_RECOGNIZE, document=self.doc) + model = self.factory.make_model(self.doc, job=OcrModel.MODEL_JOB_RECOGNIZE) uri = reverse('api:document-transcribe', kwargs={'pk': self.doc.pk}) resp = self.client.post(uri, data={ 'parts': [self.part.pk, self.part2.pk], diff --git a/app/apps/api/views.py b/app/apps/api/views.py index f33ba76d4db9f3d9c3e0acd681d53083e812252e..66cc6afc8601985e861519aaa3778d2b8929415d 100644 --- a/app/apps/api/views.py +++ b/app/apps/api/views.py @@ -469,7 +469,7 @@ class OcrModelViewSet(DocumentPermissionMixin, ModelViewSet): def get_queryset(self): return (super().get_queryset() - .filter(document=self.kwargs['document_pk'])) + .filter(documents=self.kwargs['document_pk'])) @action(detail=True, methods=['post']) def cancel_training(self, request, pk=None): diff --git a/app/apps/core/admin.py b/app/apps/core/admin.py index bf33b7d0bb797f8a1d0e72f101cfba14295bcab8..fc4881acc459bfda675ba233c15881cad5a6c3e6 100644 --- a/app/apps/core/admin.py +++ b/app/apps/core/admin.py @@ -6,6 +6,7 @@ from core.models import (Document, DocumentMetadata, LineTranscription, OcrModel, + OcrModelDocument, Script, DocumentType, DocumentPartType, @@ -17,9 +18,13 @@ class MetadataInline(admin.TabularInline): model = DocumentMetadata +class OcrModelDocumentInline(admin.TabularInline): + model = OcrModelDocument + + class DocumentAdmin(admin.ModelAdmin): list_display = ['pk', 'name', 'owner'] - inlines = (MetadataInline,) + inlines = (MetadataInline, OcrModelDocumentInline) class DocumentPartAdmin(admin.ModelAdmin): @@ -42,6 +47,11 @@ class ScriptAdmin(admin.ModelAdmin): class OcrModelAdmin(admin.ModelAdmin): list_display = ['name', 'job', 'owner', 'script', 'training'] + inlines = (OcrModelDocumentInline,) + + +class OcrModelDocumentAdmin(admin.ModelAdmin): + list_display = ['document', 'ocr_model', 'trained_on', 'executed_on', 'created_at'] admin.site.register(Document, DocumentAdmin) @@ -54,3 +64,4 @@ admin.site.register(LineType) admin.site.register(Script, ScriptAdmin) admin.site.register(Metadata) admin.site.register(OcrModel, OcrModelAdmin) +admin.site.register(OcrModelDocument, OcrModelDocumentAdmin) diff --git a/app/apps/core/forms.py b/app/apps/core/forms.py index c591e7fdfd4bd494a5a2b3dcab03bd0825912f47..dbddbe187612c210da646ed1a973631f1bb9aca5 100644 --- a/app/apps/core/forms.py +++ b/app/apps/core/forms.py @@ -7,12 +7,13 @@ from django.conf import settings from django.core.validators import FileExtensionValidator, MinValueValidator, MaxValueValidator from django.db.models import Q from django.forms.models import inlineformset_factory +from django.utils import timezone from django.utils.functional import cached_property from django.utils.translation import gettext_lazy as _ from bootstrap.forms import BootstrapFormMixin from core.models import (Document, Metadata, DocumentMetadata, - DocumentPart, OcrModel, Transcription, + DocumentPart, OcrModel, OcrModelDocument, Transcription, BlockType, LineType, AlreadyProcessingException) from users.models import User @@ -129,12 +130,12 @@ class DocumentProcessForm1(BootstrapFormMixin, forms.Form): if self.document.read_direction == self.document.READ_DIRECTION_RTL: self.initial['text_direction'] = 'horizontal-rl' self.fields['binarizer'].widget.attrs['disabled'] = True - self.fields['train_model'].queryset &= OcrModel.objects.filter(document=self.document) - self.fields['segtrain_model'].queryset &= OcrModel.objects.filter(document=self.document) - self.fields['seg_model'].queryset &= OcrModel.objects.filter(document=self.document) + self.fields['train_model'].queryset &= self.document.ocr_models.all() + self.fields['segtrain_model'].queryset &= self.document.ocr_models.all() + self.fields['seg_model'].queryset &= self.document.ocr_models.all() self.fields['ocr_model'].queryset &= OcrModel.objects.filter( - Q(document=None, script=document.main_script) - | Q(document=self.document)) + Q(documents=None, script=self.document.main_script) + | Q(documents=self.document)) self.fields['transcription'].queryset = Transcription.objects.filter(document=self.document) def process(self): @@ -172,16 +173,28 @@ class DocumentSegmentForm(DocumentProcessForm1): if data.get('upload_model'): model = OcrModel.objects.create( - document=self.parts[0].document, owner=self.user, name=data['upload_model'].name.rsplit('.', 1)[0], job=model_job) + OcrModelDocument.objects.create( + document=self.parts[0].document, + ocr_model=model, + executed_on=timezone.now(), + ) # Note: needs to save the file in a second step because the path needs the db PK model.file = data['upload_model'] model.save() elif data.get('seg_model'): model = data.get('seg_model') + ocr_model_document, created = OcrModelDocument.objects.get_or_create( + ocr_model=model, + document=self.parts[0].document, + defaults={'executed_on': timezone.now()} + ) + if not created: + ocr_model_document.executed_on = timezone.now() + ocr_model_document.save() else: model = None @@ -224,13 +237,25 @@ class DocumentTrainForm(DocumentProcessForm1): if data.get('train_model'): model = data.get('train_model') + ocr_model_document, created = OcrModelDocument.objects.get_or_create( + ocr_model=model, + document=self.parts[0].document, + defaults={'trained_on': timezone.now()} + ) + if not created: + ocr_model_document.trained_on = timezone.now() + ocr_model_document.save() elif data.get('upload_model'): model = OcrModel.objects.create( - document=self.parts[0].document, owner=self.user, name=data['upload_model'].name.rsplit('.', 1)[0], job=model_job) + OcrModelDocument.objects.create( + document=self.parts[0].document, + ocr_model=model, + trained_on=timezone.now(), + ) # Note: needs to save the file in a second step because the path needs the db PK model.file = data['upload_model'] model.save() @@ -238,10 +263,14 @@ class DocumentTrainForm(DocumentProcessForm1): elif data.get('new_model'): # file will be created by the training process model = OcrModel.objects.create( - document=self.parts[0].document, owner=self.user, name=data['new_model'], job=model_job) + OcrModelDocument.objects.create( + document=self.parts[0].document, + ocr_model=model, + trained_on=timezone.now(), + ) else: raise forms.ValidationError( @@ -279,12 +308,24 @@ class DocumentSegtrainForm(DocumentProcessForm1): if data.get('segtrain_model'): model = data.get('segtrain_model') + ocr_model_document, created = OcrModelDocument.objects.get_or_create( + ocr_model=model, + document=self.parts[0].document, + defaults={'trained_on': timezone.now()} + ) + if not created: + ocr_model_document.trained_on = timezone.now() + ocr_model_document.save() elif data.get('upload_model'): model = OcrModel.objects.create( - document=self.parts[0].document, owner=self.user, name=data['upload_model'].name.rsplit('.', 1)[0], job=model_job) + OcrModelDocument.objects.create( + document=self.parts[0].document, + ocr_model=model, + trained_on=timezone.now(), + ) # Note: needs to save the file in a second step because the path needs the db PK model.file = data['upload_model'] model.save() @@ -292,10 +333,14 @@ class DocumentSegtrainForm(DocumentProcessForm1): elif data.get('new_model'): # file will be created by the training process model = OcrModel.objects.create( - document=self.parts[0].document, owner=self.user, name=data['new_model'], job=model_job) + OcrModelDocument.objects.create( + document=self.parts[0].document, + ocr_model=model, + trained_on=timezone.now(), + ) else: @@ -328,16 +373,28 @@ class DocumentTranscribeForm(DocumentProcessForm1): if data.get('upload_model'): model = OcrModel.objects.create( - document=self.parts[0].document, owner=self.user, name=data['upload_model'].name.rsplit('.', 1)[0], job=model_job) + OcrModelDocument.objects.create( + document=self.parts[0].document, + ocr_model=model, + executed_on=timezone.now(), + ) # Note: needs to save the file in a second step because the path needs the db PK model.file = data['upload_model'] model.save() elif data.get('ocr_model'): model = data.get('ocr_model') + ocr_model_document, created = OcrModelDocument.objects.get_or_create( + ocr_model=model, + document=self.parts[0].document, + defaults={'executed_on': timezone.now()} + ) + if not created: + ocr_model_document.executed_on = timezone.now() + ocr_model_document.save() else: raise forms.ValidationError( _("Either select a name for your new model or an existing one.")) @@ -436,12 +493,12 @@ class DocumentProcessForm(BootstrapFormMixin, forms.Form): if self.document.read_direction == self.document.READ_DIRECTION_RTL: self.initial['text_direction'] = 'horizontal-rl' self.fields['binarizer'].widget.attrs['disabled'] = True - self.fields['train_model'].queryset &= OcrModel.objects.filter(document=self.document) - self.fields['segtrain_model'].queryset &= OcrModel.objects.filter(document=self.document) - self.fields['seg_model'].queryset &= OcrModel.objects.filter(document=self.document) + self.fields['train_model'].queryset &= self.document.ocr_models.all() + self.fields['segtrain_model'].queryset &= self.document.ocr_models.all() + self.fields['seg_model'].queryset &= self.document.ocr_models.all() self.fields['ocr_model'].queryset &= OcrModel.objects.filter( - Q(document=None, script=document.main_script) - | Q(document=self.document)) + Q(documents=None, script=self.document.main_script) + | Q(documents=self.document)) self.fields['transcription'].queryset = Transcription.objects.filter(document=self.document) @cached_property @@ -487,14 +544,34 @@ class DocumentProcessForm(BootstrapFormMixin, forms.Form): if task == self.TASK_TRAIN and data.get('train_model'): model = data.get('train_model') + ocr_model_document, created = OcrModelDocument.objects.get_or_create( + ocr_model=model, + document=self.parts[0].document, + defaults={'trained_on': timezone.now()} + ) + if not created: + ocr_model_document.trained_on = timezone.now() + ocr_model_document.save() elif task == self.TASK_SEGTRAIN and data.get('segtrain_model'): model = data.get('segtrain_model') + ocr_model_document, created = OcrModelDocument.objects.get_or_create( + ocr_model=model, + document=self.parts[0].document, + defaults={'trained_on': timezone.now()} + ) + if not created: + ocr_model_document.trained_on = timezone.now() + ocr_model_document.save() elif data.get('upload_model'): model = OcrModel.objects.create( - document=self.parts[0].document, owner=self.user, name=data['upload_model'].name.rsplit('.', 1)[0], job=model_job) + OcrModelDocument.objects.create( + document=self.parts[0].document, + ocr_model=model, + executed_on=timezone.now(), + ) # Note: needs to save the file in a second step because the path needs the db PK model.file = data['upload_model'] model.save() @@ -502,14 +579,34 @@ class DocumentProcessForm(BootstrapFormMixin, forms.Form): elif data.get('new_model'): # file will be created by the training process model = OcrModel.objects.create( - document=self.parts[0].document, owner=self.user, name=data['new_model'], job=model_job) + OcrModelDocument.objects.create( + document=self.parts[0].document, + ocr_model=model, + trained_on=timezone.now(), + ) elif data.get('ocr_model'): model = data.get('ocr_model') + ocr_model_document, created = OcrModelDocument.objects.get_or_create( + ocr_model=model, + document=self.parts[0].document, + defaults={'executed_on': timezone.now()} + ) + if not created: + ocr_model_document.executed_on = timezone.now() + ocr_model_document.save() elif data.get('seg_model'): model = data.get('seg_model') + ocr_model_document, created = OcrModelDocument.objects.get_or_create( + ocr_model=model, + document=self.parts[0].document, + defaults={'executed_on': timezone.now()} + ) + if not created: + ocr_model_document.executed_on = timezone.now() + ocr_model_document.save() else: if task in (self.TASK_TRAIN, self.TASK_SEGTRAIN): raise forms.ValidationError( diff --git a/app/apps/core/migrations/0045_auto_20210521_1034.py b/app/apps/core/migrations/0045_auto_20210521_1034.py new file mode 100644 index 0000000000000000000000000000000000000000..1deb56b1d079224722452b5809392605a2aeefc8 --- /dev/null +++ b/app/apps/core/migrations/0045_auto_20210521_1034.py @@ -0,0 +1,55 @@ +# Generated by Django 2.2.19 on 2021-05-21 10:34 + +from django.db import migrations, models +import django.db.models.deletion +from django.utils import timezone + + +def populate_m2m(apps, schema_editor): + OcrModel = apps.get_model('core', 'OcrModel') + OcrModelDocument = apps.get_model('core', 'OcrModelDocument') + + OcrModelDocument.objects.bulk_create([ + OcrModelDocument( + document_id=model.document_id, + ocr_model_id=model.id, + executed_on=timezone.now(), + ) for model in OcrModel.objects.exclude(document__isnull=True) + ]) + + +class Migration(migrations.Migration): + + dependencies = [ + ('core', '0044_auto_20210520_1332'), + ] + + operations = [ + migrations.CreateModel( + name='OcrModelDocument', + fields=[ + ('id', models.AutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')), + ('created_at', models.DateTimeField(auto_now_add=True)), + ('trained_on', models.DateTimeField(null=True)), + ('executed_on', models.DateTimeField(null=True)), + ('document', models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, related_name='ocr_model_documents', to='core.Document')), + ('ocr_model', models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, related_name='ocr_model_documents', to='core.OcrModel')), + ], + options={ + 'unique_together': {('document', 'ocr_model')}, + }, + ), + migrations.AddField( + model_name='ocrmodel', + name='documents', + field=models.ManyToManyField(related_name='ocr_models', through='core.OcrModelDocument', to='core.Document'), + ), + migrations.RunPython( + populate_m2m, + reverse_code=migrations.RunPython.noop, + ), + migrations.RemoveField( + model_name='ocrmodel', + name='document', + ), + ] diff --git a/app/apps/core/models.py b/app/apps/core/models.py index 9299c1bcd024634896fdecc6d9dfc2630aeef831..f199741fb8fa07c77c474e7da5adf6b19a520af3 100644 --- a/app/apps/core/models.py +++ b/app/apps/core/models.py @@ -1083,7 +1083,7 @@ class LineTranscription(Versioned, models.Model): def models_path(instance, filename): fn, ext = os.path.splitext(filename) - return 'models/%d/%s%s' % (instance.document.pk, slugify(fn), ext) + return 'models/%d/%s%s' % (instance.owner.pk, slugify(fn), ext) class OcrModel(Versioned, models.Model): @@ -1105,12 +1105,12 @@ class OcrModel(Versioned, models.Model): training_accuracy = models.FloatField(default=0.0) training_total = models.IntegerField(default=0) training_errors = models.IntegerField(default=0) - document = models.ForeignKey(Document, - related_name='ocr_models', - default=None, on_delete=models.CASCADE) + documents = models.ManyToManyField(Document, + through='core.OcrModelDocument', + related_name='ocr_models') script = models.ForeignKey(Script, blank=True, null=True, on_delete=models.SET_NULL) - version_ignore_fields = ('name', 'owner', 'document', 'script', 'training') + version_ignore_fields = ('name', 'owner', 'documents', 'script', 'training') version_history_max_length = None # keep em all class Meta: @@ -1126,6 +1126,7 @@ class OcrModel(Versioned, models.Model): def segtrain(self, document, parts_qs, user=None): segtrain.delay(self.pk, + document.pk, list(parts_qs.values_list('pk', flat=True)), user_pk=user and user.pk or None) @@ -1178,6 +1179,17 @@ class OcrModel(Versioned, models.Model): super().delete_revision(revision) +class OcrModelDocument(models.Model): + document = models.ForeignKey(Document, on_delete=models.CASCADE, related_name='ocr_model_documents') + ocr_model = models.ForeignKey(OcrModel, on_delete=models.CASCADE, related_name='ocr_model_documents') + created_at = models.DateTimeField(auto_now_add=True) + trained_on = models.DateTimeField(null=True) + executed_on = models.DateTimeField(null=True) + + class Meta: + unique_together = (('document', 'ocr_model'),) + + @receiver(pre_delete, sender=DocumentPart, dispatch_uid='thumbnails_delete_signal') def delete_thumbnails(sender, instance, using, **kwargs): thumbnailer = get_thumbnailer(instance.image) diff --git a/app/apps/core/tasks.py b/app/apps/core/tasks.py index 7d5732bedc51145a64030969e9652a74b9a73c62..b10c637c4bdbd48ae84f5dc4214fafc1eee90a93 100644 --- a/app/apps/core/tasks.py +++ b/app/apps/core/tasks.py @@ -130,7 +130,7 @@ def make_segmentation_training_data(part): @shared_task(bind=True, autoretry_for=(MemoryError,), default_retry_delay=60 * 60) -def segtrain(task, model_pk, part_pks, user_pk=None): +def segtrain(task, model_pk, document_pk, part_pks, user_pk=None): # # Note hack to circumvent AssertionError: daemonic processes are not allowed to have children from multiprocessing import current_process current_process().daemon = False @@ -166,8 +166,7 @@ def segtrain(task, model_pk, part_pks, user_pk=None): try: model.training = True model.save() - document = model.document - send_event('document', document.pk, "training:start", { + send_event('document', document_pk, "training:start", { "id": model.pk, }) qs = DocumentPart.objects.filter(pk__in=part_pks) @@ -213,7 +212,7 @@ def segtrain(task, model_pk, part_pks, user_pk=None): model.new_version(file=new_version_filename) model.save() - send_event('document', document.pk, "training:eval", { + send_event('document', document_pk, "training:eval", { "id": model.pk, 'versions': model.versions, 'epoch': epoch, @@ -234,7 +233,7 @@ def segtrain(task, model_pk, part_pks, user_pk=None): id="seg-no-gain-error", level='danger') except Exception as e: - send_event('document', document.pk, "training:error", { + send_event('document', document_pk, "training:error", { "id": model.pk, }) if user: @@ -251,7 +250,7 @@ def segtrain(task, model_pk, part_pks, user_pk=None): model.training = False model.save() - send_event('document', document.pk, "training:done", { + send_event('document', document_pk, "training:done", { "id": model.pk, }) diff --git a/app/apps/core/tests/factory.py b/app/apps/core/tests/factory.py index 98ffc303894f59f7dbb7ab3f4de9dd54740a00d0..f80e24ee9a92a8865e8555e6802f0be02610c3c9 100644 --- a/app/apps/core/tests/factory.py +++ b/app/apps/core/tests/factory.py @@ -6,6 +6,7 @@ import os.path from django.conf import settings from django.core.files.uploadedfile import SimpleUploadedFile from django.test import TestCase +from django.utils import timezone from django_redis import get_redis_connection from kraken.lib import vgsl @@ -86,13 +87,15 @@ class CoreFactory(): fp = os.path.join(os.path.dirname(__file__), 'assets', asset_name) return open(fp, 'rb') - def make_model(self, job=OcrModel.MODEL_JOB_RECOGNIZE, document=None): + def make_model(self, document, job=OcrModel.MODEL_JOB_RECOGNIZE): spec = '[1,48,0,1 Lbx100 Do O1c10]' nn = vgsl.TorchVGSLModel(spec) model_name = 'test-model.mlmodel' model = OcrModel.objects.create(name=model_name, - document=document, + owner=document.owner, job=job) + + document.ocr_models.add(model) modeldir = os.path.join(settings.MEDIA_ROOT, os.path.split( model.file.field.upload_to(model, 'test-model.mlmodel'))[0]) if not os.path.exists(modeldir): diff --git a/app/apps/core/tests/tasks.py b/app/apps/core/tests/tasks.py index 00087bb60f68fc7fbc037e04cb2f2075f2851040..10ebf07c65ec5f4644b93c3c4333b058b89f49a2 100644 --- a/app/apps/core/tests/tasks.py +++ b/app/apps/core/tests/tasks.py @@ -68,7 +68,7 @@ class TasksTestCase(CoreFactoryTestCase): def test_train_existing_transcription_model(self): self.makeTranscriptionContent() - model = self.factory.make_model(document=self.part.document) + model = self.factory.make_model(self.part.document) self.client.force_login(self.part.document.owner) uri = reverse('document-parts-process', kwargs={'pk': self.part.document.pk}) with self.assertNumQueries(17): diff --git a/app/apps/core/views.py b/app/apps/core/views.py index 1304370de8fdd3aaf3ab4f5d816bf53623224d3f..a140d63d3ecd99ef868ee91d9ede80a119aa5424 100644 --- a/app/apps/core/views.py +++ b/app/apps/core/views.py @@ -274,7 +274,7 @@ class ModelsList(LoginRequiredMixin, ListView): self.document = Document.objects.for_user(self.request.user).get(pk=self.kwargs.get('document_pk')) except Document.DoesNotExist: raise PermissionDenied - return OcrModel.objects.filter(document=self.document) + return self.document.ocr_models.all() else: self.document = None return OcrModel.objects.filter(owner=self.request.user)