diff --git a/app/apps/api/fields.py b/app/apps/api/fields.py new file mode 100644 index 0000000000000000000000000000000000000000..a8d600443d9ed3cdab2a6642d0da08cbb669d0e2 --- /dev/null +++ b/app/apps/api/fields.py @@ -0,0 +1,18 @@ +from rest_framework import serializers + + +class DisplayChoiceField(serializers.ChoiceField): + def to_representation(self, obj): + if obj == '' and self.allow_blank: + return obj + return self._choices[obj] + + def to_internal_value(self, data): + # To support inserts with the value + if data == '' and self.allow_blank: + return '' + + for key, val in self._choices.items(): + if val == data: + return key + self.fail('invalid_choice', input=data) diff --git a/app/apps/api/serializers.py b/app/apps/api/serializers.py index a84389e0015fe1e983a303faefd8e3e032b7357e..61532a14ac5367a20821b0fc24bcf713866e73e5 100644 --- a/app/apps/api/serializers.py +++ b/app/apps/api/serializers.py @@ -4,10 +4,12 @@ import html from django.conf import settings from django.db.utils import IntegrityError +from django.utils.translation import gettext_lazy as _ from rest_framework import serializers from easy_thumbnails.files import get_thumbnailer +from api.fields import DisplayChoiceField from users.models import User from core.models import (Document, DocumentPart, @@ -16,7 +18,9 @@ from core.models import (Document, Transcription, LineTranscription, BlockType, - LineType) + LineType, + OcrModel) +from core.tasks import (segtrain, train, segment, transcribe) logger = logging.getLogger(__name__) @@ -106,7 +110,7 @@ class DocumentSerializer(serializers.ModelSerializer): class PartSerializer(serializers.ModelSerializer): - image = ImageField(thumbnails=['card', 'large']) + image = ImageField(required=False, thumbnails=['card', 'large']) filename = serializers.CharField(read_only=True) bw_image = ImageField(thumbnails=['large'], required=False) workflow = serializers.JSONField(read_only=True) @@ -259,3 +263,179 @@ class PartDetailSerializer(PartSerializer): nex = DocumentPart.objects.filter( document=instance.document, order__gt=instance.order).order_by('order').first() return nex and nex.pk or None + + +class OcrModelSerializer(serializers.ModelSerializer): + owner = serializers.ReadOnlyField(source='owner.username') + job = DisplayChoiceField(choices=OcrModel.MODEL_JOB_CHOICES) + training = serializers.ReadOnlyField() + + class Meta: + model = OcrModel + fields = ('pk', 'name', 'file', 'job', + 'owner', 'training', 'versions') + + def create(self, data): + document = Document.objects.get(pk=self.context["view"].kwargs["document_pk"]) + data['document'] = document + data['owner'] = self.context["view"].request.user + obj = super().create(data) + return obj + + +class ProcessSerializerMixin(): + def __init__(self, document, user, *args, **kwargs): + self.document = document + self.user = user + super().__init__(*args, **kwargs) + + +class SegmentSerializer(ProcessSerializerMixin, serializers.Serializer): + STEPS_CHOICES = ( + ('both', _('Lines and regions')), + ('lines', _('Lines Baselines and Masks')), + ('masks', _('Only lines Masks')), + ('regions', _('Regions')), + ) + TEXT_DIRECTION_CHOICES = ( + ('horizontal-lr', _("Horizontal l2r")), + ('horizontal-rl', _("Horizontal r2l")), + ('vertical-lr', _("Vertical l2r")), + ('vertical-rl', _("Vertical r2l")) + ) + + parts = serializers.PrimaryKeyRelatedField(many=True, + queryset=DocumentPart.objects.all()) + steps = serializers.ChoiceField(choices=STEPS_CHOICES, + required=False, + default='both') + model = serializers.PrimaryKeyRelatedField(required=False, + allow_null=True, + queryset=OcrModel.objects.all()) + override = serializers.BooleanField(required=False, default=True) + text_direction = serializers.ChoiceField(default='horizontal-lr', + required=False, + choices=TEXT_DIRECTION_CHOICES) + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.fields['model'].queryset = OcrModel.objects.filter(job=OcrModel.MODEL_JOB_SEGMENT, + document=self.document) + self.fields['parts'].queryset = DocumentPart.objects.filter(document=self.document) + + def process(self): + model = self.validated_data.get('model') + parts = self.validated_data.get('parts') + for part in parts: + part.chain_tasks( + segment.si(part.pk, + user_pk=self.user.pk, + model_pk=model, + steps=self.validated_data.get('steps'), + text_direction=self.validated_data.get('text_direction'), + override=self.validated_data.get('override')) + ) + + +class SegTrainSerializer(ProcessSerializerMixin, serializers.Serializer): + parts = serializers.PrimaryKeyRelatedField(many=True, + queryset=DocumentPart.objects.all()) + model = serializers.PrimaryKeyRelatedField(required=False, + queryset=OcrModel.objects.all()) + model_name = serializers.CharField(required=False) + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.fields['model'].queryset = OcrModel.objects.filter(job=OcrModel.MODEL_JOB_SEGMENT, + document=self.document) + self.fields['parts'].queryset = DocumentPart.objects.filter(document=self.document) + + def validate_parts(self, data): + if len(data) < 2: + raise serializers.ValidationError("Segmentation training requires at least 2 images.") + return data + + def validate(self, data): + data = super().validate(data) + if not data.get('model') and not data.get('model_name'): + raise serializers.ValidationError( + _("Either use model_name to create a new model, add a model pk to retrain an existing one, or both to create a new model from an existing one.")) + return data + + def process(self): + model = self.validated_data.get('model') + if self.validated_data.get('model_name'): + file_ = model and model.file or None + model = OcrModel.objects.create( + document=self.document, + owner=self.user, + name=self.validated_data['model_name'], + job=OcrModel.MODEL_JOB_RECOGNIZE, + file=file_ + ) + + segtrain.delay(model.pk if model else None, + [part.pk for part in self.validated_data.get('parts')], + user_pk=self.user.pk) + + +class TrainSerializer(ProcessSerializerMixin, serializers.Serializer): + parts = serializers.PrimaryKeyRelatedField(many=True, + queryset=DocumentPart.objects.all()) + model = serializers.PrimaryKeyRelatedField(required=False, + queryset=OcrModel.objects.all()) + model_name = serializers.CharField(required=False) + transcription = serializers.PrimaryKeyRelatedField(queryset=Transcription.objects.all()) + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.fields['transcription'].queryset = Transcription.objects.filter(document=self.document) + self.fields['model'].queryset = OcrModel.objects.filter(job=OcrModel.MODEL_JOB_RECOGNIZE, + document=self.document) + self.fields['parts'].queryset = DocumentPart.objects.filter(document=self.document) + + def validate(self, data): + data = super().validate(data) + if not data.get('model') and not data.get('model_name'): + raise serializers.ValidationError( + _("Either use model_name to create a new model, or add a model pk to retrain an existing one.")) + return data + + def process(self): + model = self.validated_data.get('model') + if self.validated_data.get('model_name'): + file_ = model and model.file or None + model = OcrModel.objects.create( + document=self.document, + owner=self.user, + name=self.validated_data['model_name'], + job=OcrModel.MODEL_JOB_RECOGNIZE, + file=file_) + + train.delay([part.pk for part in self.validated_data.get('parts')], + self.validated_data['transcription'].pk, + model.pk if model else None, + self.user.pk) + + +class TranscribeSerializer(ProcessSerializerMixin, serializers.Serializer): + parts = serializers.PrimaryKeyRelatedField(many=True, + queryset=DocumentPart.objects.all()) + model = serializers.PrimaryKeyRelatedField(queryset=OcrModel.objects.all()) + # transcription = serializers.PrimaryKeyRelatedField(queryset=Transcription.objects.all()) + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + # self.fields['transcription'].queryset = Transcription.objects.filter(document=self.document) + self.fields['model'].queryset = OcrModel.objects.filter(job=OcrModel.MODEL_JOB_RECOGNIZE, + document=self.document) + self.fields['parts'].queryset = DocumentPart.objects.filter(document=self.document) + + def process(self): + model = self.validated_data.get('model') + for part in self.validated_data.get('parts'): + part.chain_tasks( + transcribe.si(part.pk, + model_pk=model.pk, + user_pk=self.user.pk) + ) diff --git a/app/apps/api/tests.py b/app/apps/api/tests.py index 7de43df764c53d7866b4daccbc8f45c40bf61945..fb1ee141f2dd08f64d9a3fc76ad0354f9974bcf0 100644 --- a/app/apps/api/tests.py +++ b/app/apps/api/tests.py @@ -8,9 +8,10 @@ from django.core.files.uploadedfile import SimpleUploadedFile from django.test import override_settings from django.urls import reverse -from core.models import Block, Line, Transcription, LineTranscription +from core.models import Block, Line, Transcription, LineTranscription, OcrModel from core.tests.factory import CoreFactoryTestCase + class UserViewSetTestCase(CoreFactoryTestCase): def setUp(self): @@ -21,7 +22,7 @@ class UserViewSetTestCase(CoreFactoryTestCase): self.client.force_login(user) uri = reverse('api:user-onboarding') resp = self.client.put(uri, { - 'onboarding' : 'False', + 'onboarding': 'False', }, content_type='application/json') user.refresh_from_db() @@ -29,7 +30,40 @@ class UserViewSetTestCase(CoreFactoryTestCase): self.assertEqual(user.onboarding, False) +class OcrModelViewSetTestCase(CoreFactoryTestCase): + def setUp(self): + super().setUp() + self.part = self.factory.make_part() + self.user = self.part.document.owner + self.model = self.factory.make_model(document=self.part.document) + + def test_list(self): + self.client.force_login(self.user) + uri = reverse('api:model-list', kwargs={'document_pk': self.part.document.pk}) + with self.assertNumQueries(7): + resp = self.client.get(uri) + self.assertEqual(resp.status_code, 200) + + def test_detail(self): + self.client.force_login(self.user) + uri = reverse('api:model-detail', + kwargs={'document_pk': self.part.document.pk, + 'pk': self.model.pk}) + with self.assertNumQueries(6): + resp = self.client.get(uri) + self.assertEqual(resp.status_code, 200) + def test_create(self): + self.client.force_login(self.user) + uri = reverse('api:model-list', kwargs={'document_pk': self.part.document.pk}) + with self.assertNumQueries(4): + resp = self.client.post(uri, { + 'name': 'test.mlmodel', + 'file': self.factory.make_asset_file(name='test.mlmodel', + asset_name='fake_seg.mlmodel'), + 'job': 'Segment' + }) + self.assertEqual(resp.status_code, 201, resp.content) class DocumentViewSetTestCase(CoreFactoryTestCase): @@ -37,6 +71,31 @@ class DocumentViewSetTestCase(CoreFactoryTestCase): super().setUp() self.doc = self.factory.make_document() self.doc2 = self.factory.make_document(owner=self.doc.owner) + self.part = self.factory.make_part(document=self.doc) + self.part2 = self.factory.make_part(document=self.doc) + + self.line = Line.objects.create( + baseline=[[10, 25], [50, 25]], + mask=[10, 10, 50, 50], + document_part=self.part) + self.line2 = Line.objects.create( + baseline=[[10, 80], [50, 80]], + mask=[10, 60, 50, 100], + document_part=self.part) + self.transcription = Transcription.objects.create( + document=self.part.document, + name='test') + self.transcription2 = Transcription.objects.create( + document=self.part.document, + name='tr2') + self.lt = LineTranscription.objects.create( + transcription=self.transcription, + line=self.line, + content='test') + self.lt2 = LineTranscription.objects.create( + transcription=self.transcription2, + line=self.line2, + content='test2') def test_list(self): self.client.force_login(self.doc.owner) @@ -62,9 +121,81 @@ class DocumentViewSetTestCase(CoreFactoryTestCase): # Note: raises a 404 instead of 403 but its fine self.assertEqual(resp.status_code, 404) - # not used - # def test_update - # def test_create + def test_segtrain_less_two_parts(self): + self.client.force_login(self.doc.owner) + model = self.factory.make_model(job=OcrModel.MODEL_JOB_SEGMENT, document=self.doc) + uri = reverse('api:document-segtrain', kwargs={'pk': self.doc.pk}) + resp = self.client.post(uri, data={ + 'parts': [self.part.pk], + 'model': model.pk + }) + + self.assertEqual(resp.status_code, 400) + self.assertEqual(resp.json()['error'], {'parts': [ + 'Segmentation training requires at least 2 images.']}) + + def test_segtrain_new_model(self): + self.client.force_login(self.doc.owner) + uri = reverse('api:document-segtrain', kwargs={'pk': self.doc.pk}) + resp = self.client.post(uri, data={ + 'parts': [self.part.pk, self.part2.pk], + 'model_name': 'new model' + }) + self.assertEqual(resp.status_code, 200, resp.content) + self.assertEqual(OcrModel.objects.count(), 1) + self.assertEqual(OcrModel.objects.first().name, "new model") + + def test_segtrain_existing_model_rename(self): + self.client.force_login(self.doc.owner) + model = self.factory.make_model(job=OcrModel.MODEL_JOB_SEGMENT, document=self.doc) + uri = reverse('api:document-segtrain', kwargs={'pk': self.doc.pk}) + resp = self.client.post(uri, data={ + 'parts': [self.part.pk, self.part2.pk], + 'model': model.pk, + 'model_name': 'test new model' + }) + self.assertEqual(resp.status_code, 200, resp.content) + self.assertEqual(OcrModel.objects.count(), 2) + + def test_segment(self): + uri = reverse('api:document-segment', kwargs={'pk': self.doc.pk}) + self.client.force_login(self.doc.owner) + model = self.factory.make_model(job=OcrModel.MODEL_JOB_SEGMENT, document=self.doc) + resp = self.client.post(uri, data={ + 'parts': [self.part.pk, self.part2.pk], + 'seg_steps': 'both', + 'model': model.pk, + }) + self.assertEqual(resp.status_code, 200) + + def test_train_new_model(self): + self.client.force_login(self.doc.owner) + uri = reverse('api:document-train', kwargs={'pk': self.doc.pk}) + resp = self.client.post(uri, data={ + 'parts': [self.part.pk, self.part2.pk], + 'model_name': 'testing new model', + 'transcription': self.transcription.pk + }) + self.assertEqual(resp.status_code, 200) + self.assertEqual(OcrModel.objects.filter( + document=self.doc, + job=OcrModel.MODEL_JOB_RECOGNIZE).count(), 1) + + def test_transcribe(self): + trans = Transcription.objects.create(document=self.part.document) + + self.client.force_login(self.doc.owner) + model = self.factory.make_model(job=OcrModel.MODEL_JOB_RECOGNIZE, document=self.doc) + uri = reverse('api:document-transcribe', kwargs={'pk': self.doc.pk}) + resp = self.client.post(uri, data={ + 'parts': [self.part.pk, self.part2.pk], + 'model': model.pk, + 'transcription': trans.pk + }) + self.assertEqual(resp.status_code, 200) + self.assertEqual(resp.content, b'{"status":"ok"}') + # won't work with dummy model and image + # self.assertEqual(LineTranscription.objects.filter(transcription=trans).count(), 2) class PartViewSetTestCase(CoreFactoryTestCase): @@ -189,7 +320,7 @@ class BlockViewSetTestCase(CoreFactoryTestCase): # 5 insert resp = self.client.post(uri, { 'document_part': self.part.pk, - 'box': '[[10,10], [50,50]]' + 'box': '[[10,10], [20,20], [50,50]]' }) self.assertEqual(resp.status_code, 201, resp.content) @@ -215,15 +346,15 @@ class LineViewSetTestCase(CoreFactoryTestCase): box=[10, 10, 200, 200], document_part=self.part) self.line = Line.objects.create( - box=[60, 10, 100, 50], + mask=[60, 10, 100, 50], document_part=self.part, block=self.block) self.line2 = Line.objects.create( - box=[90, 10, 70, 50], + mask=[90, 10, 70, 50], document_part=self.part, block=self.block) self.orphan = Line.objects.create( - box=[0, 0, 10, 10], + mask=[0, 0, 10, 10], document_part=self.part, block=None) @@ -294,10 +425,10 @@ class LineTranscriptionViewSetTestCase(CoreFactoryTestCase): self.part = self.factory.make_part() self.user = self.part.document.owner self.line = Line.objects.create( - box=[10, 10, 50, 50], + mask=[10, 10, 50, 50], document_part=self.part) self.line2 = Line.objects.create( - box=[10, 60, 50, 100], + mask=[10, 60, 50, 100], document_part=self.part) self.transcription = Transcription.objects.create( document=self.part.document, @@ -361,7 +492,7 @@ class LineTranscriptionViewSetTestCase(CoreFactoryTestCase): uri = reverse('api:linetranscription-bulk-create', kwargs={'document_pk': self.part.document.pk, 'part_pk': self.part.pk}) ll = Line.objects.create( - box=[10, 10, 50, 50], + mask=[10, 10, 50, 50], document_part=self.part) with self.assertNumQueries(10): resp = self.client.post( diff --git a/app/apps/api/urls.py b/app/apps/api/urls.py index 3c773a394e07c9b1ea1c7c874d82c0b67cb53188..5c68c75144c91691322606062c09a3df48437534 100644 --- a/app/apps/api/urls.py +++ b/app/apps/api/urls.py @@ -10,7 +10,8 @@ from api.views import (DocumentViewSet, LineViewSet, BlockTypeViewSet, LineTypeViewSet, - LineTranscriptionViewSet) + LineTranscriptionViewSet, + OcrModelViewSet) router = routers.DefaultRouter() router.register(r'documents', DocumentViewSet) @@ -20,6 +21,7 @@ router.register(r'types/line', LineTypeViewSet) documents_router = routers.NestedSimpleRouter(router, r'documents', lookup='document') documents_router.register(r'parts', PartViewSet, basename='part') documents_router.register(r'transcriptions', DocumentTranscriptionViewSet, basename='transcription') +documents_router.register(r'models', OcrModelViewSet, basename='model') parts_router = routers.NestedSimpleRouter(documents_router, r'parts', lookup='part') parts_router.register(r'blocks', BlockViewSet) parts_router.register(r'lines', LineViewSet) diff --git a/app/apps/api/views.py b/app/apps/api/views.py index cad44e8f1912bceeb03155464e8691a0ba3f8f4d..19fc37659b438eebc26ae297a4cfada6ccc9d976 100644 --- a/app/apps/api/views.py +++ b/app/apps/api/views.py @@ -25,7 +25,13 @@ from api.serializers import (UserOnboardingSerializer, DetailedLineSerializer, LineOrderSerializer, TranscriptionSerializer, - LineTranscriptionSerializer) + LineTranscriptionSerializer, + SegmentSerializer, + TrainSerializer, + SegTrainSerializer, + TranscribeSerializer, + OcrModelSerializer) + from core.models import (Document, DocumentPart, Block, @@ -33,7 +39,10 @@ from core.models import (Document, BlockType, LineType, Transcription, - LineTranscription) + LineTranscription, + OcrModel, + AlreadyProcessingException) + from core.tasks import recalculate_masks from users.models import User from imports.forms import ImportForm, ExportForm @@ -50,7 +59,7 @@ class UserViewSet(ModelViewSet): @action(detail=False, methods=['put']) def onboarding(self, request): - serializer = UserOnboardingSerializer(self.request.user,data=request.data, partial=True) + serializer = UserOnboardingSerializer(self.request.user, data=request.data, partial=True) if serializer.is_valid(raise_exception=True): serializer.save() return Response(status=status.HTTP_200_OK) @@ -120,6 +129,42 @@ class DocumentViewSet(ModelViewSet): else: return self.form_error(json.dumps(form.errors)) + def get_process_response(self, request, serializer_class): + document = self.get_object() + serializer = serializer_class(document=document, + user=request.user, + data=request.data) + if serializer.is_valid(): + try: + serializer.process() + except AlreadyProcessingException: + return Response(status=status.HTTP_400_BAD_REQUEST, + data={'status': 'error', + 'error': 'Already processing.'}) + + return Response(status=status.HTTP_200_OK, + data={'status': 'ok'}) + else: + return Response(status=status.HTTP_400_BAD_REQUEST, + data={'status': 'error', + 'error': serializer.errors}) + + @action(detail=True, methods=['post']) + def segment(self, request, pk=None): + return self.get_process_response(request, SegmentSerializer) + + @action(detail=True, methods=['post']) + def train(self, request, pk=None): + return self.get_process_response(request, TrainSerializer) + + @action(detail=True, methods=['post']) + def segtrain(self, request, pk=None): + return self.get_process_response(request, SegTrainSerializer) + + @action(detail=True, methods=['post']) + def transcribe(self, request, pk=None): + return self.get_process_response(request, TranscribeSerializer) + class DocumentPermissionMixin(): def get_queryset(self): @@ -390,3 +435,22 @@ class LineTranscriptionViewSet(DocumentPermissionMixin, ModelViewSet): qs = LineTranscription.objects.filter(pk__in=lines) qs.update(content='') return Response(status=status.HTTP_204_NO_CONTENT, ) + + +class OcrModelViewSet(DocumentPermissionMixin, ModelViewSet): + queryset = OcrModel.objects.all() + serializer_class = OcrModelSerializer + + def get_queryset(self): + return (super().get_queryset() + .filter(document=self.kwargs['document_pk'])) + + @action(detail=True, methods=['post']) + def cancel_training(self, request, pk=None): + model = self.get_object() + try: + model.cancel_training() + except Exception as e: + logger.exception(e) + return Response({'status': 'failed'}, status=400) + return Response({'status': 'canceled'}) diff --git a/app/apps/core/forms.py b/app/apps/core/forms.py index 6d94585bcdb5ea98a7140116a2c202f6c7aa482d..c591e7fdfd4bd494a5a2b3dcab03bd0825912f47 100644 --- a/app/apps/core/forms.py +++ b/app/apps/core/forms.py @@ -109,6 +109,250 @@ MetadataFormSet = inlineformset_factory(Document, DocumentMetadata, extra=1, can_delete=True) +class DocumentProcessForm1(BootstrapFormMixin, forms.Form): + parts = forms.CharField() + + @cached_property + def parts(self): + pks = json.loads(self.data.get('parts')) + parts = DocumentPart.objects.filter( + document=self.document, pk__in=pks) + return parts + + def __init__(self, document, user, *args, **kwargs): + self.document = document + self.user = user + super().__init__(*args, **kwargs) + # self.fields['typology'].widget = forms.HiddenInput() # for now + # self.fields['typology'].initial = Typology.objects.get(name="Page") + # self.fields['typology'].widget.attrs['title'] = _("Default Typology") + if self.document.read_direction == self.document.READ_DIRECTION_RTL: + self.initial['text_direction'] = 'horizontal-rl' + self.fields['binarizer'].widget.attrs['disabled'] = True + self.fields['train_model'].queryset &= OcrModel.objects.filter(document=self.document) + self.fields['segtrain_model'].queryset &= OcrModel.objects.filter(document=self.document) + self.fields['seg_model'].queryset &= OcrModel.objects.filter(document=self.document) + self.fields['ocr_model'].queryset &= OcrModel.objects.filter( + Q(document=None, script=document.main_script) + | Q(document=self.document)) + self.fields['transcription'].queryset = Transcription.objects.filter(document=self.document) + + def process(self): + model = self.cleaned_data.get('model') + +class DocumentSegmentForm(DocumentProcessForm1): + SEG_STEPS_CHOICES = ( + ('both', _('Lines and regions')), + ('lines', _('Lines Baselines and Masks')), + ('masks', _('Only lines Masks')), + ('regions', _('Regions')), + ) + segmentation_steps = forms.ChoiceField(choices=SEG_STEPS_CHOICES, + initial='both', required=False) + seg_model = forms.ModelChoiceField(queryset=OcrModel.objects.filter(job=OcrModel.MODEL_JOB_SEGMENT), + label=_("Model"), empty_label="default ({name})".format( + name=settings.KRAKEN_DEFAULT_SEGMENTATION_MODEL.rsplit('/')[-1]), + required=False) + override = forms.BooleanField(required=False, initial=True, + help_text=_( + "If checked, deletes existing segmentation <b>and bound transcriptions</b> first!")) + TEXT_DIRECTION_CHOICES = (('horizontal-lr', _("Horizontal l2r")), + ('horizontal-rl', _("Horizontal r2l")), + ('vertical-lr', _("Vertical l2r")), + ('vertical-rl', _("Vertical r2l"))) + text_direction = forms.ChoiceField(initial='horizontal-lr', required=False, + choices=TEXT_DIRECTION_CHOICES) + upload_model = forms.FileField(required=False, + validators=[FileExtensionValidator( + allowed_extensions=['mlmodel', 'pronn', 'clstm'])]) + + def clean(self): + data = super().clean() + model_job = OcrModel.MODEL_JOB_SEGMENT + + if data.get('upload_model'): + model = OcrModel.objects.create( + document=self.parts[0].document, + owner=self.user, + name=data['upload_model'].name.rsplit('.', 1)[0], + job=model_job) + # Note: needs to save the file in a second step because the path needs the db PK + model.file = data['upload_model'] + model.save() + + elif data.get('seg_model'): + model = data.get('seg_model') + else: + model = None + + data['model'] = model + return data + + def process(self): + super().process() + + for part in self.parts: + part.task('segment', + user_pk=self.user.pk, + steps=self.cleaned_data.get('segmentation_steps'), + text_direction=self.cleaned_data.get('text_direction'), + model_pk=model and model.pk or None, + override=self.cleaned_data.get('override')) + + +class DocumentTrainForm(DocumentProcessForm1): + new_model = forms.CharField(required=False, label=_('Model name')) + train_model = forms.ModelChoiceField(queryset=OcrModel.objects + .filter(job=OcrModel.MODEL_JOB_RECOGNIZE), + label=_("Model"), required=False) + upload_model = forms.FileField(required=False, + validators=[FileExtensionValidator( + allowed_extensions=['mlmodel', 'pronn', 'clstm'])]) + + transcription = forms.ModelChoiceField(queryset=Transcription.objects.all(), required=False) + + def clean_train_model(self): + model = self.cleaned_data['train_model'] + if model and model.training: + raise AlreadyProcessingException + return model + + def clean(self): + data = super().clean() + + model_job = OcrModel.MODEL_JOB_RECOGNIZE + + if data.get('train_model'): + model = data.get('train_model') + + elif data.get('upload_model'): + model = OcrModel.objects.create( + document=self.parts[0].document, + owner=self.user, + name=data['upload_model'].name.rsplit('.', 1)[0], + job=model_job) + # Note: needs to save the file in a second step because the path needs the db PK + model.file = data['upload_model'] + model.save() + + elif data.get('new_model'): + # file will be created by the training process + model = OcrModel.objects.create( + document=self.parts[0].document, + owner=self.user, + name=data['new_model'], + job=model_job) + + else: + raise forms.ValidationError( + _("Either select a name for your new model or an existing one.")) + + data['model'] = model + return data + + + def process(self): + super().process() + + model.train(self.parts, + self.cleaned_data['transcription'], + user=self.user) + + +class DocumentSegtrainForm(DocumentProcessForm1): + segtrain_model = forms.ModelChoiceField(queryset=OcrModel.objects + .filter(job=OcrModel.MODEL_JOB_SEGMENT), + label=_("Model"), required=False) + upload_model = forms.FileField(required=False, + validators=[FileExtensionValidator( + allowed_extensions=['mlmodel', 'pronn', 'clstm'])]) + + new_model = forms.CharField(required=False, label=_('Model name')) + + def clean(self): + data = super().clean() + + + model_job = OcrModel.MODEL_JOB_SEGMENT + if len(self.parts) < 2: + raise forms.ValidationError("Segmentation training requires at least 2 images.") + + if data.get('segtrain_model'): + model = data.get('segtrain_model') + elif data.get('upload_model'): + model = OcrModel.objects.create( + document=self.parts[0].document, + owner=self.user, + name=data['upload_model'].name.rsplit('.', 1)[0], + job=model_job) + # Note: needs to save the file in a second step because the path needs the db PK + model.file = data['upload_model'] + model.save() + + elif data.get('new_model'): + # file will be created by the training process + model = OcrModel.objects.create( + document=self.parts[0].document, + owner=self.user, + name=data['new_model'], + job=model_job) + + else: + + raise forms.ValidationError( + _("Either select a name for your new model or an existing one.")) + + data['model'] = model + return data + + def process(self): + super().process() + model.segtrain(self.document, + self.parts, + user=self.user) + + +class DocumentTranscribeForm(DocumentProcessForm1): + + upload_model = forms.FileField(required=False, + validators=[FileExtensionValidator( + allowed_extensions=['mlmodel', 'pronn', 'clstm'])]) + ocr_model = forms.ModelChoiceField(queryset=OcrModel.objects + .filter(job=OcrModel.MODEL_JOB_RECOGNIZE), + label=_("Model"), required=False) + + def clean(self): + data = super().clean() + + model_job = OcrModel.MODEL_JOB_RECOGNIZE + + if data.get('upload_model'): + model = OcrModel.objects.create( + document=self.parts[0].document, + owner=self.user, + name=data['upload_model'].name.rsplit('.', 1)[0], + job=model_job) + # Note: needs to save the file in a second step because the path needs the db PK + model.file = data['upload_model'] + model.save() + + elif data.get('ocr_model'): + model = data.get('ocr_model') + else: + raise forms.ValidationError( + _("Either select a name for your new model or an existing one.")) + + data['model'] = model + return data + + def process(self): + super().process() + for part in self.parts: + part.task('transcribe', + user_pk=self.user.pk, + model_pk=model and model.pk or None) + + class DocumentProcessForm(BootstrapFormMixin, forms.Form): # TODO: split this form into one for each process?! TASK_BINARIZE = 'binarize' diff --git a/app/apps/core/models.py b/app/apps/core/models.py index 162a718152de1ee70bb08e8a5b8167dcefd991b1..aad2b42151b11b7c417bfdcdafa843a1ab8731c5 100644 --- a/app/apps/core/models.py +++ b/app/apps/core/models.py @@ -677,52 +677,45 @@ class DocumentPart(OrderedModel): self.save() self.recalculate_ordering(read_direction=read_direction) - def transcribe(self, model=None, text_direction=None): - if model: - trans, created = Transcription.objects.get_or_create( - name='kraken:' + model.name, - document=self.document) - model_ = kraken_models.load_any(model.file.path) - lines = self.lines.all() - text_direction = (text_direction - or (self.document.main_script - and self.document.main_script.text_direction) - or 'horizontal-lr') + def transcribe(self, model, text_direction=None): + trans, created = Transcription.objects.get_or_create( + name='kraken:' + model.name, + document=self.document) + model_ = kraken_models.load_any(model.file.path) + lines = self.lines.all() + text_direction = (text_direction + or (self.document.main_script + and self.document.main_script.text_direction) + or 'horizontal-lr') - with Image.open(self.image.file.name) as im: - for line in lines: - if not line.baseline: - bounds = { - 'boxes': [line.box], - 'text_direction': text_direction, - 'type': 'baselines', - # 'script_detection': True - } - else: - bounds = { - 'lines': [{'baseline': line.baseline, - 'boundary': line.mask, - 'text_direction': text_direction, - 'script': 'default'}], # self.document.main_script.name - 'type': 'baselines', - # 'selfcript_detection': True - } - it = rpred.rpred( - model_, im, - bounds=bounds, - pad=16, # TODO: % of the image? - bidi_reordering=True) - lt, created = LineTranscription.objects.get_or_create( - line=line, transcription=trans) - for pred in it: - lt.content = pred.prediction - lt.graphs = [(char, pred.cuts[i], float(pred.confidences[i])) - for i, char in enumerate(pred.prediction)] - lt.save() - else: - Transcription.objects.get_or_create( - name='manual', - document=self.document) + with Image.open(self.image.file.name) as im: + for line in lines: + if not line.baseline: + bounds = { + 'boxes': [line.box], + 'text_direction': text_direction, + 'type': 'baselines', + # 'script_detection': True + } + else: + bounds = { + 'lines': [{'baseline': line.baseline, + 'boundary': line.mask, + 'text_direction': text_direction, + 'script': 'default'}], # self.document.main_script.name + 'type': 'baselines', + # 'selfcript_detection': True + } + it = rpred.rpred( + model_, im, + bounds=bounds, + pad=16, # TODO: % of the image? + bidi_reordering=True) + lt, created = LineTranscription.objects.get_or_create( + line=line, transcription=trans) + for pred in it: + lt.content = pred.prediction + lt.save() self.workflow_state = self.WORKFLOW_STATE_TRANSCRIBING self.calculate_progress() @@ -1082,7 +1075,7 @@ class LineTranscription(Versioned, models.Model): def models_path(instance, filename): fn, ext = os.path.splitext(filename) - return 'models/%d/%s%s' % (instance.pk, slugify(fn), ext) + return 'models/%d/%s%s' % (instance.document.pk, slugify(fn), ext) class OcrModel(Versioned, models.Model): @@ -1104,9 +1097,9 @@ class OcrModel(Versioned, models.Model): training_accuracy = models.FloatField(default=0.0) training_total = models.IntegerField(default=0) training_errors = models.IntegerField(default=0) - document = models.ForeignKey(Document, blank=True, null=True, + document = models.ForeignKey(Document, related_name='ocr_models', - default=None, on_delete=models.SET_NULL) + default=None, on_delete=models.CASCADE) script = models.ForeignKey(Script, blank=True, null=True, on_delete=models.SET_NULL) version_ignore_fields = ('name', 'owner', 'document', 'script', 'training') @@ -1124,8 +1117,9 @@ class OcrModel(Versioned, models.Model): return self.training_accuracy * 100 def segtrain(self, document, parts_qs, user=None): - segtrain.delay(self.pk, document.pk, - list(parts_qs.values_list('pk', flat=True))) + segtrain.delay(self.pk, + list(parts_qs.values_list('pk', flat=True)), + user_pk=user and user.pk or None) def train(self, parts_qs, transcription, user=None): train.delay(list(parts_qs.values_list('pk', flat=True)), @@ -1139,7 +1133,7 @@ class OcrModel(Versioned, models.Model): task_id = json.loads(redis_.get('training-%d' % self.pk))['task_id'] elif self.job == self.MODEL_JOB_SEGMENT: task_id = json.loads(redis_.get('segtrain-%d' % self.pk))['task_id'] - except (TypeError, KeyError) as e: + except (TypeError, KeyError): raise ProcessFailureException(_("Couldn't find the training task.")) else: if task_id: diff --git a/app/apps/core/tasks.py b/app/apps/core/tasks.py index d102e803a1969d9a098086c75d14a2dc5156b10d..65d564ed5cdf7dff9e3414350725bda24781e185 100644 --- a/app/apps/core/tasks.py +++ b/app/apps/core/tasks.py @@ -40,6 +40,7 @@ def update_client_state(part_id, task, status, task_id=None, data=None): "data": data or {} }) + @shared_task(autoretry_for=(MemoryError,), default_retry_delay=60) def generate_part_thumbnails(instance_pk): if not getattr(settings, 'THUMBNAIL_ENABLE', True): @@ -129,7 +130,7 @@ def make_segmentation_training_data(part): @shared_task(bind=True, autoretry_for=(MemoryError,), default_retry_delay=60 * 60) -def segtrain(task, model_pk, document_pk, part_pks, user_pk=None): +def segtrain(task, model_pk, part_pks, user_pk=None): # # Note hack to circumvent AssertionError: daemonic processes are not allowed to have children from multiprocessing import current_process current_process().daemon = False @@ -151,6 +152,7 @@ def segtrain(task, model_pk, document_pk, part_pks, user_pk=None): OcrModel = apps.get_model('core', 'OcrModel') model = OcrModel.objects.get(pk=model_pk) + try: load = model.file.path upload_to = model.file.field.upload_to(model, model.name + '.mlmodel') @@ -412,6 +414,7 @@ def train(task, part_pks, transcription_pk, model_pk, user_pk=None): Transcription = apps.get_model('core', 'Transcription') LineTranscription = apps.get_model('core', 'LineTranscription') OcrModel = apps.get_model('core', 'OcrModel') + try: model = OcrModel.objects.get(pk=model_pk) model.training = True @@ -451,10 +454,12 @@ def train(task, part_pks, transcription_pk, model_pk, user_pk=None): @shared_task(autoretry_for=(MemoryError,), default_retry_delay=10 * 60) def transcribe(instance_pk, model_pk=None, user_pk=None, text_direction=None, **kwargs): + try: DocumentPart = apps.get_model('core', 'DocumentPart') part = DocumentPart.objects.get(pk=instance_pk) except DocumentPart.DoesNotExist: + logger.error('Trying to transcribe innexistant DocumentPart : %d', instance_pk) return @@ -466,18 +471,10 @@ def transcribe(instance_pk, model_pk=None, user_pk=None, text_direction=None, ** else: user = None - if model_pk: - try: - OcrModel = apps.get_model('core', 'OcrModel') - model = OcrModel.objects.get(pk=model_pk) - except OcrModel.DoesNotExist: - # Not sure how we should deal with this case - model = None - else: - model = None - try: - part.transcribe(model=model) + OcrModel = apps.get_model('core', 'OcrModel') + model = OcrModel.objects.get(pk=model_pk) + part.transcribe(model) except Exception as e: if user: user.notify(_("Something went wrong during the transcription!"), diff --git a/app/apps/core/tests/factory.py b/app/apps/core/tests/factory.py index 44bb89ba1e90a9c7ffcc002e2d1adfe0c77e12a5..98ffc303894f59f7dbb7ab3f4de9dd54740a00d0 100644 --- a/app/apps/core/tests/factory.py +++ b/app/apps/core/tests/factory.py @@ -84,17 +84,12 @@ class CoreFactory(): def make_asset_file(self, name='test.png', asset_name='segmentation/default.png'): fp = os.path.join(os.path.dirname(__file__), 'assets', asset_name) - with Image.open(fp, 'r') as image: - file = BytesIO() - file.name = name - image.save(file, 'png') - file.seek(0) - return file + return open(fp, 'rb') def make_model(self, job=OcrModel.MODEL_JOB_RECOGNIZE, document=None): spec = '[1,48,0,1 Lbx100 Do O1c10]' nn = vgsl.TorchVGSLModel(spec) - model_name = 'test-model' + model_name = 'test-model.mlmodel' model = OcrModel.objects.create(name=model_name, document=document, job=job)