Commit f22c1ee5 authored by Robin Tissot's avatar Robin Tissot
Browse files

Merge branch 'develop'

parents af0172a1 025e4b35
env/
app/static/
app/media/
app/test_media/
app/logs/
ci/
......@@ -8,4 +8,7 @@ __pycache__
app/escriptorium/local_settings.py
app/static/
app/media/
*.log
\ No newline at end of file
*.log
front/node_modules
front/dist
variables.env
stages:
- test
- build
- deploy
build:
stage: build
image: node:12-alpine
artifacts:
paths:
- front/dist
expire_in: 2 weeks
before_script:
- cd front
- npm ci
script:
- npm run production
docker-build:
stage: build
image: docker:19.03.1
services:
- docker:dind
variables:
DOCKER_DRIVER: overlay2
DOCKER_HOST: tcp://docker:2375/
except:
- schedules
script:
- ci/build.sh
FROM node:12-alpine as frontend
WORKDIR /build
COPY ./front /build
RUN npm ci && npm run production
# pull official base image
FROM python:3.7.5-buster
......@@ -13,6 +19,7 @@ ENV PYTHONUNBUFFERED 1
ARG VERSION_DATE="passthistobuildcmd"
ENV VERSION_DATE=$VERSION_DATE
ENV FRONTEND_DIR=/usr/src/app/front
# update apk
RUN apt-get update
......@@ -23,16 +30,17 @@ RUN adduser --system --no-create-home --ingroup uwsgi uwsgi
RUN apt-get install netcat-traditional jpegoptim pngcrush
RUN apt-get --assume-yes install libvips
COPY ./requirements.txt /usr/src/app/requirements.txt
# set work directory
WORKDIR /usr/src/app
RUN pip install --upgrade pip
COPY ./app/requirements.txt /usr/src/app/requirements.txt
RUN pip install -U -r requirements.txt
COPY ./entrypoint.sh /usr/src/app/entrypoint.sh
COPY . /usr/src/app/
COPY ./app/entrypoint.sh /usr/src/app/entrypoint.sh
COPY ./app /usr/src/app/
COPY --from=frontend /build/dist /usr/src/app/front
# run entrypoint.sh
ENTRYPOINT ["/usr/src/app/entrypoint.sh"]
env/
static/
media/
test_media/
logs/
from rest_framework import serializers
class DisplayChoiceField(serializers.ChoiceField):
def to_representation(self, obj):
if obj == '' and self.allow_blank:
return obj
return self._choices[obj]
def to_internal_value(self, data):
# To support inserts with the value
if data == '' and self.allow_blank:
return ''
for key, val in self._choices.items():
if val == data:
return key
self.fail('invalid_choice', input=data)
......@@ -4,10 +4,12 @@ import html
from django.conf import settings
from django.db.utils import IntegrityError
from django.utils.translation import gettext_lazy as _
from rest_framework import serializers
from easy_thumbnails.files import get_thumbnailer
from api.fields import DisplayChoiceField
from users.models import User
from core.models import (Document,
DocumentPart,
......@@ -16,7 +18,9 @@ from core.models import (Document,
Transcription,
LineTranscription,
BlockType,
LineType)
LineType,
OcrModel)
from core.tasks import (segtrain, train, segment, transcribe)
logger = logging.getLogger(__name__)
......@@ -106,7 +110,7 @@ class DocumentSerializer(serializers.ModelSerializer):
class PartSerializer(serializers.ModelSerializer):
image = ImageField(thumbnails=['card', 'large'])
image = ImageField(required=False, thumbnails=['card', 'large'])
filename = serializers.CharField(read_only=True)
bw_image = ImageField(thumbnails=['large'], required=False)
workflow = serializers.JSONField(read_only=True)
......@@ -259,3 +263,179 @@ class PartDetailSerializer(PartSerializer):
nex = DocumentPart.objects.filter(
document=instance.document, order__gt=instance.order).order_by('order').first()
return nex and nex.pk or None
class OcrModelSerializer(serializers.ModelSerializer):
owner = serializers.ReadOnlyField(source='owner.username')
job = DisplayChoiceField(choices=OcrModel.MODEL_JOB_CHOICES)
training = serializers.ReadOnlyField()
class Meta:
model = OcrModel
fields = ('pk', 'name', 'file', 'job',
'owner', 'training', 'versions')
def create(self, data):
document = Document.objects.get(pk=self.context["view"].kwargs["document_pk"])
data['document'] = document
data['owner'] = self.context["view"].request.user
obj = super().create(data)
return obj
class ProcessSerializerMixin():
def __init__(self, document, user, *args, **kwargs):
self.document = document
self.user = user
super().__init__(*args, **kwargs)
class SegmentSerializer(ProcessSerializerMixin, serializers.Serializer):
STEPS_CHOICES = (
('both', _('Lines and regions')),
('lines', _('Lines Baselines and Masks')),
('masks', _('Only lines Masks')),
('regions', _('Regions')),
)
TEXT_DIRECTION_CHOICES = (
('horizontal-lr', _("Horizontal l2r")),
('horizontal-rl', _("Horizontal r2l")),
('vertical-lr', _("Vertical l2r")),
('vertical-rl', _("Vertical r2l"))
)
parts = serializers.PrimaryKeyRelatedField(many=True,
queryset=DocumentPart.objects.all())
steps = serializers.ChoiceField(choices=STEPS_CHOICES,
required=False,
default='both')
model = serializers.PrimaryKeyRelatedField(required=False,
allow_null=True,
queryset=OcrModel.objects.all())
override = serializers.BooleanField(required=False, default=True)
text_direction = serializers.ChoiceField(default='horizontal-lr',
required=False,
choices=TEXT_DIRECTION_CHOICES)
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.fields['model'].queryset = OcrModel.objects.filter(job=OcrModel.MODEL_JOB_SEGMENT,
document=self.document)
self.fields['parts'].queryset = DocumentPart.objects.filter(document=self.document)
def process(self):
model = self.validated_data.get('model')
parts = self.validated_data.get('parts')
for part in parts:
part.chain_tasks(
segment.si(part.pk,
user_pk=self.user.pk,
model_pk=model,
steps=self.validated_data.get('steps'),
text_direction=self.validated_data.get('text_direction'),
override=self.validated_data.get('override'))
)
class SegTrainSerializer(ProcessSerializerMixin, serializers.Serializer):
parts = serializers.PrimaryKeyRelatedField(many=True,
queryset=DocumentPart.objects.all())
model = serializers.PrimaryKeyRelatedField(required=False,
queryset=OcrModel.objects.all())
model_name = serializers.CharField(required=False)
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.fields['model'].queryset = OcrModel.objects.filter(job=OcrModel.MODEL_JOB_SEGMENT,
document=self.document)
self.fields['parts'].queryset = DocumentPart.objects.filter(document=self.document)
def validate_parts(self, data):
if len(data) < 2:
raise serializers.ValidationError("Segmentation training requires at least 2 images.")
return data
def validate(self, data):
data = super().validate(data)
if not data.get('model') and not data.get('model_name'):
raise serializers.ValidationError(
_("Either use model_name to create a new model, add a model pk to retrain an existing one, or both to create a new model from an existing one."))
return data
def process(self):
model = self.validated_data.get('model')
if self.validated_data.get('model_name'):
file_ = model and model.file or None
model = OcrModel.objects.create(
document=self.document,
owner=self.user,
name=self.validated_data['model_name'],
job=OcrModel.MODEL_JOB_RECOGNIZE,
file=file_
)
segtrain.delay(model.pk if model else None,
[part.pk for part in self.validated_data.get('parts')],
user_pk=self.user.pk)
class TrainSerializer(ProcessSerializerMixin, serializers.Serializer):
parts = serializers.PrimaryKeyRelatedField(many=True,
queryset=DocumentPart.objects.all())
model = serializers.PrimaryKeyRelatedField(required=False,
queryset=OcrModel.objects.all())
model_name = serializers.CharField(required=False)
transcription = serializers.PrimaryKeyRelatedField(queryset=Transcription.objects.all())
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.fields['transcription'].queryset = Transcription.objects.filter(document=self.document)
self.fields['model'].queryset = OcrModel.objects.filter(job=OcrModel.MODEL_JOB_RECOGNIZE,
document=self.document)
self.fields['parts'].queryset = DocumentPart.objects.filter(document=self.document)
def validate(self, data):
data = super().validate(data)
if not data.get('model') and not data.get('model_name'):
raise serializers.ValidationError(
_("Either use model_name to create a new model, or add a model pk to retrain an existing one."))
return data
def process(self):
model = self.validated_data.get('model')
if self.validated_data.get('model_name'):
file_ = model and model.file or None
model = OcrModel.objects.create(
document=self.document,
owner=self.user,
name=self.validated_data['model_name'],
job=OcrModel.MODEL_JOB_RECOGNIZE,
file=file_)
train.delay([part.pk for part in self.validated_data.get('parts')],
self.validated_data['transcription'].pk,
model.pk if model else None,
self.user.pk)
class TranscribeSerializer(ProcessSerializerMixin, serializers.Serializer):
parts = serializers.PrimaryKeyRelatedField(many=True,
queryset=DocumentPart.objects.all())
model = serializers.PrimaryKeyRelatedField(queryset=OcrModel.objects.all())
# transcription = serializers.PrimaryKeyRelatedField(queryset=Transcription.objects.all())
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
# self.fields['transcription'].queryset = Transcription.objects.filter(document=self.document)
self.fields['model'].queryset = OcrModel.objects.filter(job=OcrModel.MODEL_JOB_RECOGNIZE,
document=self.document)
self.fields['parts'].queryset = DocumentPart.objects.filter(document=self.document)
def process(self):
model = self.validated_data.get('model')
for part in self.validated_data.get('parts'):
part.chain_tasks(
transcribe.si(part.pk,
model_pk=model.pk,
user_pk=self.user.pk)
)
......@@ -8,9 +8,10 @@ from django.core.files.uploadedfile import SimpleUploadedFile
from django.test import override_settings
from django.urls import reverse
from core.models import Block, Line, Transcription, LineTranscription
from core.models import Block, Line, Transcription, LineTranscription, OcrModel
from core.tests.factory import CoreFactoryTestCase
class UserViewSetTestCase(CoreFactoryTestCase):
def setUp(self):
......@@ -21,7 +22,7 @@ class UserViewSetTestCase(CoreFactoryTestCase):
self.client.force_login(user)
uri = reverse('api:user-onboarding')
resp = self.client.put(uri, {
'onboarding' : 'False',
'onboarding': 'False',
}, content_type='application/json')
user.refresh_from_db()
......@@ -29,7 +30,40 @@ class UserViewSetTestCase(CoreFactoryTestCase):
self.assertEqual(user.onboarding, False)
class OcrModelViewSetTestCase(CoreFactoryTestCase):
def setUp(self):
super().setUp()
self.part = self.factory.make_part()
self.user = self.part.document.owner
self.model = self.factory.make_model(document=self.part.document)
def test_list(self):
self.client.force_login(self.user)
uri = reverse('api:model-list', kwargs={'document_pk': self.part.document.pk})
with self.assertNumQueries(7):
resp = self.client.get(uri)
self.assertEqual(resp.status_code, 200)
def test_detail(self):
self.client.force_login(self.user)
uri = reverse('api:model-detail',
kwargs={'document_pk': self.part.document.pk,
'pk': self.model.pk})
with self.assertNumQueries(6):
resp = self.client.get(uri)
self.assertEqual(resp.status_code, 200)
def test_create(self):
self.client.force_login(self.user)
uri = reverse('api:model-list', kwargs={'document_pk': self.part.document.pk})
with self.assertNumQueries(4):
resp = self.client.post(uri, {
'name': 'test.mlmodel',
'file': self.factory.make_asset_file(name='test.mlmodel',
asset_name='fake_seg.mlmodel'),
'job': 'Segment'
})
self.assertEqual(resp.status_code, 201, resp.content)
class DocumentViewSetTestCase(CoreFactoryTestCase):
......@@ -37,6 +71,31 @@ class DocumentViewSetTestCase(CoreFactoryTestCase):
super().setUp()
self.doc = self.factory.make_document()
self.doc2 = self.factory.make_document(owner=self.doc.owner)
self.part = self.factory.make_part(document=self.doc)
self.part2 = self.factory.make_part(document=self.doc)
self.line = Line.objects.create(
baseline=[[10, 25], [50, 25]],
mask=[10, 10, 50, 50],
document_part=self.part)
self.line2 = Line.objects.create(
baseline=[[10, 80], [50, 80]],
mask=[10, 60, 50, 100],
document_part=self.part)
self.transcription = Transcription.objects.create(
document=self.part.document,
name='test')
self.transcription2 = Transcription.objects.create(
document=self.part.document,
name='tr2')
self.lt = LineTranscription.objects.create(
transcription=self.transcription,
line=self.line,
content='test')
self.lt2 = LineTranscription.objects.create(
transcription=self.transcription2,
line=self.line2,
content='test2')
def test_list(self):
self.client.force_login(self.doc.owner)
......@@ -62,9 +121,81 @@ class DocumentViewSetTestCase(CoreFactoryTestCase):
# Note: raises a 404 instead of 403 but its fine
self.assertEqual(resp.status_code, 404)
# not used
# def test_update
# def test_create
def test_segtrain_less_two_parts(self):
self.client.force_login(self.doc.owner)
model = self.factory.make_model(job=OcrModel.MODEL_JOB_SEGMENT, document=self.doc)
uri = reverse('api:document-segtrain', kwargs={'pk': self.doc.pk})
resp = self.client.post(uri, data={
'parts': [self.part.pk],
'model': model.pk
})
self.assertEqual(resp.status_code, 400)
self.assertEqual(resp.json()['error'], {'parts': [
'Segmentation training requires at least 2 images.']})
def test_segtrain_new_model(self):
self.client.force_login(self.doc.owner)
uri = reverse('api:document-segtrain', kwargs={'pk': self.doc.pk})
resp = self.client.post(uri, data={
'parts': [self.part.pk, self.part2.pk],
'model_name': 'new model'
})
self.assertEqual(resp.status_code, 200, resp.content)
self.assertEqual(OcrModel.objects.count(), 1)
self.assertEqual(OcrModel.objects.first().name, "new model")
def test_segtrain_existing_model_rename(self):
self.client.force_login(self.doc.owner)
model = self.factory.make_model(job=OcrModel.MODEL_JOB_SEGMENT, document=self.doc)
uri = reverse('api:document-segtrain', kwargs={'pk': self.doc.pk})
resp = self.client.post(uri, data={
'parts': [self.part.pk, self.part2.pk],
'model': model.pk,
'model_name': 'test new model'
})
self.assertEqual(resp.status_code, 200, resp.content)
self.assertEqual(OcrModel.objects.count(), 2)
def test_segment(self):
uri = reverse('api:document-segment', kwargs={'pk': self.doc.pk})
self.client.force_login(self.doc.owner)
model = self.factory.make_model(job=OcrModel.MODEL_JOB_SEGMENT, document=self.doc)
resp = self.client.post(uri, data={
'parts': [self.part.pk, self.part2.pk],
'seg_steps': 'both',
'model': model.pk,
})
self.assertEqual(resp.status_code, 200)
def test_train_new_model(self):
self.client.force_login(self.doc.owner)
uri = reverse('api:document-train', kwargs={'pk': self.doc.pk})
resp = self.client.post(uri, data={
'parts': [self.part.pk, self.part2.pk],
'model_name': 'testing new model',
'transcription': self.transcription.pk
})
self.assertEqual(resp.status_code, 200)
self.assertEqual(OcrModel.objects.filter(
document=self.doc,
job=OcrModel.MODEL_JOB_RECOGNIZE).count(), 1)
def test_transcribe(self):
trans = Transcription.objects.create(document=self.part.document)
self.client.force_login(self.doc.owner)
model = self.factory.make_model(job=OcrModel.MODEL_JOB_RECOGNIZE, document=self.doc)
uri = reverse('api:document-transcribe', kwargs={'pk': self.doc.pk})
resp = self.client.post(uri, data={
'parts': [self.part.pk, self.part2.pk],
'model': model.pk,
'transcription': trans.pk
})
self.assertEqual(resp.status_code, 200)
self.assertEqual(resp.content, b'{"status":"ok"}')
# won't work with dummy model and image
# self.assertEqual(LineTranscription.objects.filter(transcription=trans).count(), 2)
class PartViewSetTestCase(CoreFactoryTestCase):
......@@ -189,7 +320,7 @@ class BlockViewSetTestCase(CoreFactoryTestCase):
# 5 insert
resp = self.client.post(uri, {
'document_part': self.part.pk,
'box': '[[10,10], [50,50]]'
'box': '[[10,10], [20,20], [50,50]]'
})
self.assertEqual(resp.status_code, 201, resp.content)
......@@ -215,15 +346,15 @@ class LineViewSetTestCase(CoreFactoryTestCase):
box=[10, 10, 200, 200],
document_part=self.part)
self.line = Line.objects.create(
box=[60, 10, 100, 50],
mask=[60, 10, 100, 50],
document_part=self.part,
block=self.block)
self.line2 = Line.objects.create(
box=[90, 10, 70, 50],
mask=[90, 10, 70, 50],
document_part=self.part,
block=self.block)
self.orphan = Line.objects.create(
box=[0, 0, 10, 10],
mask=[0, 0, 10, 10],
document_part=self.part,
block=None)
......@@ -294,10 +425,10 @@ class LineTranscriptionViewSetTestCase(CoreFactoryTestCase):
self.part = self.factory.make_part()
self.user = self.part.document.owner
self.line = Line.objects.create(
box=[10, 10, 50, 50],
mask=[10, 10, 50, 50],
document_part=self.part)
self.line2 = Line.objects.create(
box=[10, 60, 50, 100],
mask=[10, 60, 50, 100],
document_part=self.part)
self.transcription = Transcription.objects.create(
document=self.part.document,
......@@ -361,7 +492,7 @@ class LineTranscriptionViewSetTestCase(CoreFactoryTestCase):
uri = reverse('api:linetranscription-bulk-create',
kwargs={'document_pk': self.part.document.pk, 'part_pk': self.part.pk})
ll = Line.objects.create(
box=[10, 10, 50, 50],
mask=[10, 10, 50, 50],
document_part=self.part)
with self.assertNumQueries(10):
resp = self.client.post(
......
......@@ -10,7 +10,8 @@ from api.views import (DocumentViewSet,
LineViewSet,
BlockTypeViewSet,
LineTypeViewSet,
LineTranscriptionViewSet)
LineTranscriptionViewSet,
OcrModelViewSet)
router = routers.DefaultRouter()
router.register(r'documents', DocumentViewSet)
......@@ -20,6 +21,7 @@ router.register(r'types/line', LineTypeViewSet)
documents_router = routers.NestedSimpleRouter(router, r'documents', lookup='document')
documents_router.register(r'parts', PartViewSet, basename='part')
documents_router.register(r'transcriptions', DocumentTranscriptionViewSet, basename='transcription')
documents_router.register(r'models', OcrModelViewSet, basename='model')
parts_router = routers.NestedSimpleRouter(documents_router, r'parts', lookup='part')
parts_router.register(r'blocks', BlockViewSet)
parts_router.register(r'lines', LineViewSet)
......
......@@ -25,7 +25,13 @@ from api.serializers import (UserOnboardingSerializer,
DetailedLineSerializer,
LineOrderSerializer,
TranscriptionSerializer,
LineTranscriptionSerializer)
LineTranscriptionSerializer,
SegmentSerializer,
TrainSerializer,
SegTrainSerializer,
TranscribeSerializer,
OcrModelSerializer)
from core.models import (Document,
DocumentPart,
Block,
......@@ -33,7 +39,10 @@ from core.models import (Document,
BlockType,
LineType,
Transcription,
LineTranscription)
LineTranscription,
OcrModel,
AlreadyProcessingException)