Mentions légales du service

Skip to content
Snippets Groups Projects

Replace the FK towards Document on OcrModel by a M2M relation

Merged Eva Bardou requested to merge db-ocrmodel-refactoring into develop
All threads resolved!
Files
11
+ 19
13
@@ -4,6 +4,7 @@ import html
from django.conf import settings
from django.db.utils import IntegrityError
from django.utils import timezone
from django.utils.translation import gettext_lazy as _
from rest_framework import serializers
@@ -20,7 +21,8 @@ from core.models import (Document,
BlockType,
LineType,
Script,
OcrModel)
OcrModel,
OcrModelDocument)
from core.tasks import (segtrain, train, segment, transcribe)
logger = logging.getLogger(__name__)
@@ -298,9 +300,9 @@ class OcrModelSerializer(serializers.ModelSerializer):
def create(self, data):
document = Document.objects.get(pk=self.context["view"].kwargs["document_pk"])
data['document'] = document
data['owner'] = self.context["view"].request.user
obj = super().create(data)
document.ocr_models.add(obj)
return obj
@@ -340,8 +342,7 @@ class SegmentSerializer(ProcessSerializerMixin, serializers.Serializer):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.fields['model'].queryset = OcrModel.objects.filter(job=OcrModel.MODEL_JOB_SEGMENT,
document=self.document)
self.fields['model'].queryset = self.document.ocr_models.filter(job=OcrModel.MODEL_JOB_SEGMENT)
self.fields['parts'].queryset = DocumentPart.objects.filter(document=self.document)
def process(self):
@@ -367,8 +368,7 @@ class SegTrainSerializer(ProcessSerializerMixin, serializers.Serializer):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.fields['model'].queryset = OcrModel.objects.filter(job=OcrModel.MODEL_JOB_SEGMENT,
document=self.document)
self.fields['model'].queryset = self.document.ocr_models.filter(job=OcrModel.MODEL_JOB_SEGMENT)
self.fields['parts'].queryset = DocumentPart.objects.filter(document=self.document)
def validate_parts(self, data):
@@ -388,14 +388,18 @@ class SegTrainSerializer(ProcessSerializerMixin, serializers.Serializer):
if self.validated_data.get('model_name'):
file_ = model and model.file or None
model = OcrModel.objects.create(
document=self.document,
owner=self.user,
name=self.validated_data['model_name'],
job=OcrModel.MODEL_JOB_RECOGNIZE,
file=file_
)
OcrModelDocument.objects.create(
document=self.document,
ocr_model=model,
executed_on=timezone.now(),
)
segtrain.delay(model.pk if model else None,
segtrain.delay(model.pk if model else None, self.document.pk,
[part.pk for part in self.validated_data.get('parts')],
user_pk=self.user.pk)
@@ -411,8 +415,7 @@ class TrainSerializer(ProcessSerializerMixin, serializers.Serializer):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.fields['transcription'].queryset = Transcription.objects.filter(document=self.document)
self.fields['model'].queryset = OcrModel.objects.filter(job=OcrModel.MODEL_JOB_RECOGNIZE,
document=self.document)
self.fields['model'].queryset = self.document.ocr_models.filter(job=OcrModel.MODEL_JOB_RECOGNIZE)
self.fields['parts'].queryset = DocumentPart.objects.filter(document=self.document)
def validate(self, data):
@@ -427,11 +430,15 @@ class TrainSerializer(ProcessSerializerMixin, serializers.Serializer):
if self.validated_data.get('model_name'):
file_ = model and model.file or None
model = OcrModel.objects.create(
document=self.document,
owner=self.user,
name=self.validated_data['model_name'],
job=OcrModel.MODEL_JOB_RECOGNIZE,
file=file_)
OcrModelDocument.objects.create(
document=self.document,
ocr_model=model,
executed_on=timezone.now(),
)
train.delay([part.pk for part in self.validated_data.get('parts')],
self.validated_data['transcription'].pk,
@@ -448,8 +455,7 @@ class TranscribeSerializer(ProcessSerializerMixin, serializers.Serializer):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
# self.fields['transcription'].queryset = Transcription.objects.filter(document=self.document)
self.fields['model'].queryset = OcrModel.objects.filter(job=OcrModel.MODEL_JOB_RECOGNIZE,
document=self.document)
self.fields['model'].queryset = self.document.ocr_models.filter(job=OcrModel.MODEL_JOB_RECOGNIZE)
self.fields['parts'].queryset = DocumentPart.objects.filter(document=self.document)
def process(self):
Loading