diff --git a/app/apps/api/serializers.py b/app/apps/api/serializers.py index f57fc5d0e243ef4d341a533b78924eff83884287..f5cc969b79c8387116f03f6516d0e001949dffb8 100644 --- a/app/apps/api/serializers.py +++ b/app/apps/api/serializers.py @@ -300,9 +300,9 @@ class OcrModelSerializer(serializers.ModelSerializer): def create(self, data): document = Document.objects.get(pk=self.context["view"].kwargs["document_pk"]) - data['document'] = document data['owner'] = self.context["view"].request.user obj = super().create(data) + document.ocr_models.add(obj) return obj diff --git a/app/apps/api/tests.py b/app/apps/api/tests.py index cf1ccf27e00d47b951067e8d870c86b75a1d6a61..6523e1f2070b693120f9d096678f9dc6699dfdaa 100644 --- a/app/apps/api/tests.py +++ b/app/apps/api/tests.py @@ -40,7 +40,7 @@ class OcrModelViewSetTestCase(CoreFactoryTestCase): def test_list(self): self.client.force_login(self.user) uri = reverse('api:model-list', kwargs={'document_pk': self.part.document.pk}) - with self.assertNumQueries(7): + with self.assertNumQueries(8): resp = self.client.get(uri) self.assertEqual(resp.status_code, 200) @@ -49,14 +49,14 @@ class OcrModelViewSetTestCase(CoreFactoryTestCase): uri = reverse('api:model-detail', kwargs={'document_pk': self.part.document.pk, 'pk': self.model.pk}) - with self.assertNumQueries(6): + with self.assertNumQueries(7): resp = self.client.get(uri) self.assertEqual(resp.status_code, 200) def test_create(self): self.client.force_login(self.user) uri = reverse('api:model-list', kwargs={'document_pk': self.part.document.pk}) - with self.assertNumQueries(4): + with self.assertNumQueries(6): resp = self.client.post(uri, { 'name': 'test.mlmodel', 'file': self.factory.make_asset_file(name='test.mlmodel', @@ -100,7 +100,7 @@ class DocumentViewSetTestCase(CoreFactoryTestCase): def test_list(self): self.client.force_login(self.doc.owner) uri = reverse('api:document-list') - with self.assertNumQueries(8): + with self.assertNumQueries(10): resp = self.client.get(uri) self.assertEqual(resp.status_code, 200) @@ -108,7 +108,7 @@ class DocumentViewSetTestCase(CoreFactoryTestCase): self.client.force_login(self.doc.owner) uri = reverse('api:document-detail', kwargs={'pk': self.doc.pk}) - with self.assertNumQueries(7): + with self.assertNumQueries(8): resp = self.client.get(uri) self.assertEqual(resp.status_code, 200) diff --git a/app/apps/core/tests/factory.py b/app/apps/core/tests/factory.py index 2c4028a24bf89e54afe0ae63d9999a08ee4ccde5..f80e24ee9a92a8865e8555e6802f0be02610c3c9 100644 --- a/app/apps/core/tests/factory.py +++ b/app/apps/core/tests/factory.py @@ -91,11 +91,11 @@ class CoreFactory(): spec = '[1,48,0,1 Lbx100 Do O1c10]' nn = vgsl.TorchVGSLModel(spec) model_name = 'test-model.mlmodel' - model = document.ocr_models.add( - name=model_name, - job=job, - through_defaults={'executed_on': timezone.now()} - ) + model = OcrModel.objects.create(name=model_name, + owner=document.owner, + job=job) + + document.ocr_models.add(model) modeldir = os.path.join(settings.MEDIA_ROOT, os.path.split( model.file.field.upload_to(model, 'test-model.mlmodel'))[0]) if not os.path.exists(modeldir):