From d13f6a69da7153d7d4b433bcbd517c41147cdfda Mon Sep 17 00:00:00 2001
From: Eva Bardou <ebardou@teklia.com>
Date: Wed, 19 May 2021 14:59:31 +0200
Subject: [PATCH] Fix all broken behaviours introduced by the previous commit

---
 app/apps/api/serializers.py    |  31 ++++---
 app/apps/api/tests.py          |  14 ++--
 app/apps/api/views.py          |   2 +-
 app/apps/core/forms.py         | 142 ++++++++++++++++++++++++++++-----
 app/apps/core/models.py        |   3 +-
 app/apps/core/tasks.py         |  11 ++-
 app/apps/core/tests/factory.py |  10 ++-
 app/apps/core/tests/tasks.py   |   2 +-
 app/apps/core/views.py         |   2 +-
 9 files changed, 164 insertions(+), 53 deletions(-)

diff --git a/app/apps/api/serializers.py b/app/apps/api/serializers.py
index 220b4642..fc2d0296 100644
--- a/app/apps/api/serializers.py
+++ b/app/apps/api/serializers.py
@@ -20,7 +20,8 @@ from core.models import (Document,
                          BlockType,
                          LineType,
                          Script,
-                         OcrModel)
+                         OcrModel,
+                         OcrModelDocument)
 from core.tasks import (segtrain, train, segment, transcribe)
 
 logger = logging.getLogger(__name__)
@@ -340,8 +341,7 @@ class SegmentSerializer(ProcessSerializerMixin, serializers.Serializer):
 
     def __init__(self, *args, **kwargs):
         super().__init__(*args, **kwargs)
-        self.fields['model'].queryset = OcrModel.objects.filter(job=OcrModel.MODEL_JOB_SEGMENT,
-                                                                document=self.document)
+        self.fields['model'].queryset = self.document.ocr_models.filter(job=OcrModel.MODEL_JOB_SEGMENT)
         self.fields['parts'].queryset = DocumentPart.objects.filter(document=self.document)
 
     def process(self):
@@ -367,8 +367,7 @@ class SegTrainSerializer(ProcessSerializerMixin, serializers.Serializer):
 
     def __init__(self, *args, **kwargs):
         super().__init__(*args, **kwargs)
-        self.fields['model'].queryset = OcrModel.objects.filter(job=OcrModel.MODEL_JOB_SEGMENT,
-                                                                document=self.document)
+        self.fields['model'].queryset = self.document.ocr_models.filter(job=OcrModel.MODEL_JOB_SEGMENT)
         self.fields['parts'].queryset = DocumentPart.objects.filter(document=self.document)
 
     def validate_parts(self, data):
@@ -388,14 +387,19 @@ class SegTrainSerializer(ProcessSerializerMixin, serializers.Serializer):
         if self.validated_data.get('model_name'):
             file_ = model and model.file or None
             model = OcrModel.objects.create(
-                document=self.document,
                 owner=self.user,
                 name=self.validated_data['model_name'],
                 job=OcrModel.MODEL_JOB_RECOGNIZE,
                 file=file_
             )
+            OcrModelDocument.objects.create(
+                document=self.document,
+                ocr_model=model,
+                trained_on=False,
+                executed_on=True,
+            )
 
-        segtrain.delay(model.pk if model else None,
+        segtrain.delay(model.pk if model else None, self.document.pk,
                        [part.pk for part in self.validated_data.get('parts')],
                        user_pk=self.user.pk)
 
@@ -411,8 +415,7 @@ class TrainSerializer(ProcessSerializerMixin, serializers.Serializer):
     def __init__(self, *args, **kwargs):
         super().__init__(*args, **kwargs)
         self.fields['transcription'].queryset = Transcription.objects.filter(document=self.document)
-        self.fields['model'].queryset = OcrModel.objects.filter(job=OcrModel.MODEL_JOB_RECOGNIZE,
-                                                                document=self.document)
+        self.fields['model'].queryset = self.document.ocr_models.filter(job=OcrModel.MODEL_JOB_RECOGNIZE)
         self.fields['parts'].queryset = DocumentPart.objects.filter(document=self.document)
 
     def validate(self, data):
@@ -427,11 +430,16 @@ class TrainSerializer(ProcessSerializerMixin, serializers.Serializer):
         if self.validated_data.get('model_name'):
             file_ = model and model.file or None
             model = OcrModel.objects.create(
-                document=self.document,
                 owner=self.user,
                 name=self.validated_data['model_name'],
                 job=OcrModel.MODEL_JOB_RECOGNIZE,
                 file=file_)
+            OcrModelDocument.objects.create(
+                document=self.document,
+                ocr_model=model,
+                trained_on=False,
+                executed_on=True,
+            )
 
         train.delay([part.pk for part in self.validated_data.get('parts')],
                     self.validated_data['transcription'].pk,
@@ -448,8 +456,7 @@ class TranscribeSerializer(ProcessSerializerMixin, serializers.Serializer):
     def __init__(self, *args, **kwargs):
         super().__init__(*args, **kwargs)
         # self.fields['transcription'].queryset = Transcription.objects.filter(document=self.document)
-        self.fields['model'].queryset = OcrModel.objects.filter(job=OcrModel.MODEL_JOB_RECOGNIZE,
-                                                                document=self.document)
+        self.fields['model'].queryset = self.document.ocr_models.filter(job=OcrModel.MODEL_JOB_RECOGNIZE)
         self.fields['parts'].queryset = DocumentPart.objects.filter(document=self.document)
 
     def process(self):
diff --git a/app/apps/api/tests.py b/app/apps/api/tests.py
index fb1ee141..cf1ccf27 100644
--- a/app/apps/api/tests.py
+++ b/app/apps/api/tests.py
@@ -35,7 +35,7 @@ class OcrModelViewSetTestCase(CoreFactoryTestCase):
         super().setUp()
         self.part = self.factory.make_part()
         self.user = self.part.document.owner
-        self.model = self.factory.make_model(document=self.part.document)
+        self.model = self.factory.make_model(self.part.document)
 
     def test_list(self):
         self.client.force_login(self.user)
@@ -123,7 +123,7 @@ class DocumentViewSetTestCase(CoreFactoryTestCase):
 
     def test_segtrain_less_two_parts(self):
         self.client.force_login(self.doc.owner)
-        model = self.factory.make_model(job=OcrModel.MODEL_JOB_SEGMENT, document=self.doc)
+        model = self.factory.make_model(self.doc, job=OcrModel.MODEL_JOB_SEGMENT)
         uri = reverse('api:document-segtrain', kwargs={'pk': self.doc.pk})
         resp = self.client.post(uri, data={
                 'parts': [self.part.pk],
@@ -147,7 +147,7 @@ class DocumentViewSetTestCase(CoreFactoryTestCase):
 
     def test_segtrain_existing_model_rename(self):
         self.client.force_login(self.doc.owner)
-        model = self.factory.make_model(job=OcrModel.MODEL_JOB_SEGMENT, document=self.doc)
+        model = self.factory.make_model(self.doc, job=OcrModel.MODEL_JOB_SEGMENT)
         uri = reverse('api:document-segtrain', kwargs={'pk': self.doc.pk})
         resp = self.client.post(uri, data={
             'parts': [self.part.pk, self.part2.pk],
@@ -160,7 +160,7 @@ class DocumentViewSetTestCase(CoreFactoryTestCase):
     def test_segment(self):
         uri = reverse('api:document-segment', kwargs={'pk': self.doc.pk})
         self.client.force_login(self.doc.owner)
-        model = self.factory.make_model(job=OcrModel.MODEL_JOB_SEGMENT, document=self.doc)
+        model = self.factory.make_model(self.doc, job=OcrModel.MODEL_JOB_SEGMENT)
         resp = self.client.post(uri, data={
             'parts': [self.part.pk, self.part2.pk],
             'seg_steps': 'both',
@@ -177,15 +177,13 @@ class DocumentViewSetTestCase(CoreFactoryTestCase):
             'transcription': self.transcription.pk
         })
         self.assertEqual(resp.status_code, 200)
-        self.assertEqual(OcrModel.objects.filter(
-            document=self.doc,
-            job=OcrModel.MODEL_JOB_RECOGNIZE).count(), 1)
+        self.assertEqual(self.doc.ocr_models.filter(job=OcrModel.MODEL_JOB_RECOGNIZE).count(), 1)
 
     def test_transcribe(self):
         trans = Transcription.objects.create(document=self.part.document)
 
         self.client.force_login(self.doc.owner)
-        model = self.factory.make_model(job=OcrModel.MODEL_JOB_RECOGNIZE, document=self.doc)
+        model = self.factory.make_model(self.doc, job=OcrModel.MODEL_JOB_RECOGNIZE)
         uri = reverse('api:document-transcribe', kwargs={'pk': self.doc.pk})
         resp = self.client.post(uri, data={
             'parts': [self.part.pk, self.part2.pk],
diff --git a/app/apps/api/views.py b/app/apps/api/views.py
index f33ba76d..66cc6afc 100644
--- a/app/apps/api/views.py
+++ b/app/apps/api/views.py
@@ -469,7 +469,7 @@ class OcrModelViewSet(DocumentPermissionMixin, ModelViewSet):
 
     def get_queryset(self):
         return (super().get_queryset()
-                .filter(document=self.kwargs['document_pk']))
+                .filter(documents=self.kwargs['document_pk']))
 
     @action(detail=True, methods=['post'])
     def cancel_training(self, request, pk=None):
diff --git a/app/apps/core/forms.py b/app/apps/core/forms.py
index c591e7fd..d9811a51 100644
--- a/app/apps/core/forms.py
+++ b/app/apps/core/forms.py
@@ -12,7 +12,7 @@ from django.utils.translation import gettext_lazy as _
 
 from bootstrap.forms import BootstrapFormMixin
 from core.models import (Document, Metadata, DocumentMetadata,
-                         DocumentPart, OcrModel, Transcription,
+                         DocumentPart, OcrModel, OcrModelDocument, Transcription,
                          BlockType, LineType, AlreadyProcessingException)
 from users.models import User
 
@@ -129,12 +129,12 @@ class DocumentProcessForm1(BootstrapFormMixin, forms.Form):
         if self.document.read_direction == self.document.READ_DIRECTION_RTL:
             self.initial['text_direction'] = 'horizontal-rl'
         self.fields['binarizer'].widget.attrs['disabled'] = True
-        self.fields['train_model'].queryset &= OcrModel.objects.filter(document=self.document)
-        self.fields['segtrain_model'].queryset &= OcrModel.objects.filter(document=self.document)
-        self.fields['seg_model'].queryset &= OcrModel.objects.filter(document=self.document)
+        self.fields['train_model'].queryset &= self.document.ocr_models.all()
+        self.fields['segtrain_model'].queryset &= self.document.ocr_models.all()
+        self.fields['seg_model'].queryset &= self.document.ocr_models.all()
         self.fields['ocr_model'].queryset &= OcrModel.objects.filter(
-            Q(document=None, script=document.main_script)
-            | Q(document=self.document))
+            Q(documents=None, script=self.document.main_script)
+            | Q(documents=self.document))
         self.fields['transcription'].queryset = Transcription.objects.filter(document=self.document)
 
     def process(self):
@@ -172,16 +172,29 @@ class DocumentSegmentForm(DocumentProcessForm1):
 
         if data.get('upload_model'):
             model = OcrModel.objects.create(
-                document=self.parts[0].document,
                 owner=self.user,
                 name=data['upload_model'].name.rsplit('.', 1)[0],
                 job=model_job)
+            OcrModelDocument.objects.create(
+                document=self.parts[0].document,
+                ocr_model=model,
+                trained_on=False,
+                executed_on=True,
+            )
             # Note: needs to save the file in a second step because the path needs the db PK
             model.file = data['upload_model']
             model.save()
 
         elif data.get('seg_model'):
             model = data.get('seg_model')
+            ocr_model_document, created = OcrModelDocument.objects.get_or_create(
+                ocr_model=model,
+                document=self.parts[0].document,
+                defaults={'trained_on': False, 'executed_on': True}
+            )
+            if not created:
+                ocr_model_document.executed_on = True
+                ocr_model_document.save()
         else:
             model = None
 
@@ -224,13 +237,26 @@ class DocumentTrainForm(DocumentProcessForm1):
 
         if data.get('train_model'):
             model = data.get('train_model')
+            ocr_model_document, created = OcrModelDocument.objects.get_or_create(
+                ocr_model=model,
+                document=self.parts[0].document,
+                defaults={'trained_on': True, 'executed_on': False}
+            )
+            if not created:
+                ocr_model_document.trained_on = True
+                ocr_model_document.save()
 
         elif data.get('upload_model'):
             model = OcrModel.objects.create(
-                document=self.parts[0].document,
                 owner=self.user,
                 name=data['upload_model'].name.rsplit('.', 1)[0],
                 job=model_job)
+            OcrModelDocument.objects.create(
+                document=self.parts[0].document,
+                ocr_model=model,
+                trained_on=True,
+                executed_on=False,
+            )
             # Note: needs to save the file in a second step because the path needs the db PK
             model.file = data['upload_model']
             model.save()
@@ -238,10 +264,15 @@ class DocumentTrainForm(DocumentProcessForm1):
         elif data.get('new_model'):
             # file will be created by the training process
             model = OcrModel.objects.create(
-                document=self.parts[0].document,
                 owner=self.user,
                 name=data['new_model'],
                 job=model_job)
+            OcrModelDocument.objects.create(
+                document=self.parts[0].document,
+                ocr_model=model,
+                trained_on=True,
+                executed_on=False,
+            )
 
         else:
             raise forms.ValidationError(
@@ -279,12 +310,25 @@ class DocumentSegtrainForm(DocumentProcessForm1):
 
         if data.get('segtrain_model'):
             model = data.get('segtrain_model')
+            ocr_model_document, created = OcrModelDocument.objects.get_or_create(
+                ocr_model=model,
+                document=self.parts[0].document,
+                defaults={'trained_on': True, 'executed_on': False}
+            )
+            if not created:
+                ocr_model_document.trained_on = True
+                ocr_model_document.save()
         elif data.get('upload_model'):
             model = OcrModel.objects.create(
-                document=self.parts[0].document,
                 owner=self.user,
                 name=data['upload_model'].name.rsplit('.', 1)[0],
                 job=model_job)
+            OcrModelDocument.objects.create(
+                document=self.parts[0].document,
+                ocr_model=model,
+                trained_on=True,
+                executed_on=False,
+            )
             # Note: needs to save the file in a second step because the path needs the db PK
             model.file = data['upload_model']
             model.save()
@@ -292,10 +336,15 @@ class DocumentSegtrainForm(DocumentProcessForm1):
         elif data.get('new_model'):
             # file will be created by the training process
             model = OcrModel.objects.create(
-                document=self.parts[0].document,
                 owner=self.user,
                 name=data['new_model'],
                 job=model_job)
+            OcrModelDocument.objects.create(
+                document=self.parts[0].document,
+                ocr_model=model,
+                trained_on=True,
+                executed_on=False,
+            )
 
         else:
 
@@ -328,16 +377,29 @@ class DocumentTranscribeForm(DocumentProcessForm1):
 
         if data.get('upload_model'):
             model = OcrModel.objects.create(
-                document=self.parts[0].document,
                 owner=self.user,
                 name=data['upload_model'].name.rsplit('.', 1)[0],
                 job=model_job)
+            OcrModelDocument.objects.create(
+                document=self.parts[0].document,
+                ocr_model=model,
+                trained_on=False,
+                executed_on=True,
+            )
             # Note: needs to save the file in a second step because the path needs the db PK
             model.file = data['upload_model']
             model.save()
 
         elif data.get('ocr_model'):
             model = data.get('ocr_model')
+            ocr_model_document, created = OcrModelDocument.objects.get_or_create(
+                ocr_model=model,
+                document=self.parts[0].document,
+                defaults={'trained_on': False, 'executed_on': True}
+            )
+            if not created:
+                ocr_model_document.executed_on = True
+                ocr_model_document.save()
         else:
             raise forms.ValidationError(
                     _("Either select a name for your new model or an existing one."))
@@ -436,12 +498,12 @@ class DocumentProcessForm(BootstrapFormMixin, forms.Form):
         if self.document.read_direction == self.document.READ_DIRECTION_RTL:
             self.initial['text_direction'] = 'horizontal-rl'
         self.fields['binarizer'].widget.attrs['disabled'] = True
-        self.fields['train_model'].queryset &= OcrModel.objects.filter(document=self.document)
-        self.fields['segtrain_model'].queryset &= OcrModel.objects.filter(document=self.document)
-        self.fields['seg_model'].queryset &= OcrModel.objects.filter(document=self.document)
+        self.fields['train_model'].queryset &= self.document.ocr_models.all()
+        self.fields['segtrain_model'].queryset &= self.document.ocr_models.all()
+        self.fields['seg_model'].queryset &= self.document.ocr_models.all()
         self.fields['ocr_model'].queryset &= OcrModel.objects.filter(
-            Q(document=None, script=document.main_script)
-            | Q(document=self.document))
+            Q(documents=None, script=self.document.main_script)
+            | Q(documents=self.document))
         self.fields['transcription'].queryset = Transcription.objects.filter(document=self.document)
 
     @cached_property
@@ -487,14 +549,35 @@ class DocumentProcessForm(BootstrapFormMixin, forms.Form):
 
         if task == self.TASK_TRAIN and data.get('train_model'):
             model = data.get('train_model')
+            ocr_model_document, created = OcrModelDocument.objects.get_or_create(
+                ocr_model=model,
+                document=self.parts[0].document,
+                defaults={'trained_on': True, 'executed_on': False}
+            )
+            if not created:
+                ocr_model_document.trained_on = True
+                ocr_model_document.save()
         elif task == self.TASK_SEGTRAIN and data.get('segtrain_model'):
             model = data.get('segtrain_model')
+            ocr_model_document, created = OcrModelDocument.objects.get_or_create(
+                ocr_model=model,
+                document=self.parts[0].document,
+                defaults={'trained_on': True, 'executed_on': False}
+            )
+            if not created:
+                ocr_model_document.trained_on = True
+                ocr_model_document.save()
         elif data.get('upload_model'):
             model = OcrModel.objects.create(
-                document=self.parts[0].document,
                 owner=self.user,
                 name=data['upload_model'].name.rsplit('.', 1)[0],
                 job=model_job)
+            OcrModelDocument.objects.create(
+                document=self.parts[0].document,
+                ocr_model=model,
+                trained_on=False,
+                executed_on=True,
+            )
             # Note: needs to save the file in a second step because the path needs the db PK
             model.file = data['upload_model']
             model.save()
@@ -502,14 +585,35 @@ class DocumentProcessForm(BootstrapFormMixin, forms.Form):
         elif data.get('new_model'):
             # file will be created by the training process
             model = OcrModel.objects.create(
-                document=self.parts[0].document,
                 owner=self.user,
                 name=data['new_model'],
                 job=model_job)
+            OcrModelDocument.objects.create(
+                document=self.parts[0].document,
+                ocr_model=model,
+                trained_on=True,
+                executed_on=False,
+            )
         elif data.get('ocr_model'):
             model = data.get('ocr_model')
+            ocr_model_document, created = OcrModelDocument.objects.get_or_create(
+                ocr_model=model,
+                document=self.parts[0].document,
+                defaults={'trained_on': False, 'executed_on': True}
+            )
+            if not created:
+                ocr_model_document.executed_on = True
+                ocr_model_document.save()
         elif data.get('seg_model'):
             model = data.get('seg_model')
+            ocr_model_document, created = OcrModelDocument.objects.get_or_create(
+                ocr_model=model,
+                document=self.parts[0].document,
+                defaults={'trained_on': False, 'executed_on': True}
+            )
+            if not created:
+                ocr_model_document.executed_on = True
+                ocr_model_document.save()
         else:
             if task in (self.TASK_TRAIN, self.TASK_SEGTRAIN):
                 raise forms.ValidationError(
diff --git a/app/apps/core/models.py b/app/apps/core/models.py
index c0dd0046..20b352fd 100644
--- a/app/apps/core/models.py
+++ b/app/apps/core/models.py
@@ -1083,7 +1083,7 @@ class LineTranscription(Versioned, models.Model):
 
 def models_path(instance, filename):
     fn, ext = os.path.splitext(filename)
-    return 'models/%d/%s%s' % (instance.document.pk, slugify(fn), ext)
+    return 'models/%s%s' % (slugify(fn), ext)
 
 
 class OcrModel(Versioned, models.Model):
@@ -1126,6 +1126,7 @@ class OcrModel(Versioned, models.Model):
 
     def segtrain(self, document, parts_qs, user=None):
         segtrain.delay(self.pk,
+                       document.pk,
                        list(parts_qs.values_list('pk', flat=True)),
                        user_pk=user and user.pk or None)
 
diff --git a/app/apps/core/tasks.py b/app/apps/core/tasks.py
index 7d5732be..b10c637c 100644
--- a/app/apps/core/tasks.py
+++ b/app/apps/core/tasks.py
@@ -130,7 +130,7 @@ def make_segmentation_training_data(part):
 
 
 @shared_task(bind=True, autoretry_for=(MemoryError,), default_retry_delay=60 * 60)
-def segtrain(task, model_pk, part_pks, user_pk=None):
+def segtrain(task, model_pk, document_pk, part_pks, user_pk=None):
     # # Note hack to circumvent AssertionError: daemonic processes are not allowed to have children
     from multiprocessing import current_process
     current_process().daemon = False
@@ -166,8 +166,7 @@ def segtrain(task, model_pk, part_pks, user_pk=None):
     try:
         model.training = True
         model.save()
-        document = model.document
-        send_event('document', document.pk, "training:start", {
+        send_event('document', document_pk, "training:start", {
             "id": model.pk,
         })
         qs = DocumentPart.objects.filter(pk__in=part_pks)
@@ -213,7 +212,7 @@ def segtrain(task, model_pk, part_pks, user_pk=None):
             model.new_version(file=new_version_filename)
             model.save()
 
-            send_event('document', document.pk, "training:eval", {
+            send_event('document', document_pk, "training:eval", {
                 "id": model.pk,
                 'versions': model.versions,
                 'epoch': epoch,
@@ -234,7 +233,7 @@ def segtrain(task, model_pk, part_pks, user_pk=None):
                         id="seg-no-gain-error", level='danger')
 
     except Exception as e:
-        send_event('document', document.pk, "training:error", {
+        send_event('document', document_pk, "training:error", {
             "id": model.pk,
         })
         if user:
@@ -251,7 +250,7 @@ def segtrain(task, model_pk, part_pks, user_pk=None):
         model.training = False
         model.save()
 
-        send_event('document', document.pk, "training:done", {
+        send_event('document', document_pk, "training:done", {
             "id": model.pk,
         })
 
diff --git a/app/apps/core/tests/factory.py b/app/apps/core/tests/factory.py
index 98ffc303..9c5655cc 100644
--- a/app/apps/core/tests/factory.py
+++ b/app/apps/core/tests/factory.py
@@ -86,13 +86,15 @@ class CoreFactory():
         fp = os.path.join(os.path.dirname(__file__), 'assets', asset_name)
         return open(fp, 'rb')
 
-    def make_model(self, job=OcrModel.MODEL_JOB_RECOGNIZE, document=None):
+    def make_model(self, document, job=OcrModel.MODEL_JOB_RECOGNIZE):
         spec = '[1,48,0,1 Lbx100 Do O1c10]'
         nn = vgsl.TorchVGSLModel(spec)
         model_name = 'test-model.mlmodel'
-        model = OcrModel.objects.create(name=model_name,
-                                        document=document,
-                                        job=job)
+        model = document.ocr_models.add(
+            name=model_name,
+            job=job,
+            through_defaults={'trained_on': False, 'executed_on': True}
+        )
         modeldir = os.path.join(settings.MEDIA_ROOT, os.path.split(
             model.file.field.upload_to(model, 'test-model.mlmodel'))[0])
         if not os.path.exists(modeldir):
diff --git a/app/apps/core/tests/tasks.py b/app/apps/core/tests/tasks.py
index 00087bb6..10ebf07c 100644
--- a/app/apps/core/tests/tasks.py
+++ b/app/apps/core/tests/tasks.py
@@ -68,7 +68,7 @@ class TasksTestCase(CoreFactoryTestCase):
     
     def test_train_existing_transcription_model(self):
         self.makeTranscriptionContent()
-        model = self.factory.make_model(document=self.part.document)
+        model = self.factory.make_model(self.part.document)
         self.client.force_login(self.part.document.owner)
         uri = reverse('document-parts-process', kwargs={'pk': self.part.document.pk})
         with self.assertNumQueries(17):
diff --git a/app/apps/core/views.py b/app/apps/core/views.py
index 1304370d..a140d63d 100644
--- a/app/apps/core/views.py
+++ b/app/apps/core/views.py
@@ -274,7 +274,7 @@ class ModelsList(LoginRequiredMixin, ListView):
                 self.document = Document.objects.for_user(self.request.user).get(pk=self.kwargs.get('document_pk'))
             except Document.DoesNotExist:
                 raise PermissionDenied
-            return OcrModel.objects.filter(document=self.document)
+            return self.document.ocr_models.all()
         else:
             self.document = None
             return OcrModel.objects.filter(owner=self.request.user)
-- 
GitLab