diff --git a/app/apps/api/fields.py b/app/apps/api/fields.py
new file mode 100644
index 0000000000000000000000000000000000000000..a8d600443d9ed3cdab2a6642d0da08cbb669d0e2
--- /dev/null
+++ b/app/apps/api/fields.py
@@ -0,0 +1,18 @@
+from rest_framework import serializers
+
+
+class DisplayChoiceField(serializers.ChoiceField):
+    def to_representation(self, obj):
+        if obj == '' and self.allow_blank:
+            return obj
+        return self._choices[obj]
+
+    def to_internal_value(self, data):
+        # To support inserts with the value
+        if data == '' and self.allow_blank:
+            return ''
+
+        for key, val in self._choices.items():
+            if val == data:
+                return key
+        self.fail('invalid_choice', input=data)
diff --git a/app/apps/api/serializers.py b/app/apps/api/serializers.py
index a84389e0015fe1e983a303faefd8e3e032b7357e..61532a14ac5367a20821b0fc24bcf713866e73e5 100644
--- a/app/apps/api/serializers.py
+++ b/app/apps/api/serializers.py
@@ -4,10 +4,12 @@ import html
 
 from django.conf import settings
 from django.db.utils import IntegrityError
+from django.utils.translation import gettext_lazy as _
 
 from rest_framework import serializers
 from easy_thumbnails.files import get_thumbnailer
 
+from api.fields import DisplayChoiceField
 from users.models import User
 from core.models import (Document,
                          DocumentPart,
@@ -16,7 +18,9 @@ from core.models import (Document,
                          Transcription,
                          LineTranscription,
                          BlockType,
-                         LineType)
+                         LineType,
+                         OcrModel)
+from core.tasks import (segtrain, train, segment, transcribe)
 
 logger = logging.getLogger(__name__)
 
@@ -106,7 +110,7 @@ class DocumentSerializer(serializers.ModelSerializer):
 
 
 class PartSerializer(serializers.ModelSerializer):
-    image = ImageField(thumbnails=['card', 'large'])
+    image = ImageField(required=False, thumbnails=['card', 'large'])
     filename = serializers.CharField(read_only=True)
     bw_image = ImageField(thumbnails=['large'], required=False)
     workflow = serializers.JSONField(read_only=True)
@@ -259,3 +263,179 @@ class PartDetailSerializer(PartSerializer):
         nex = DocumentPart.objects.filter(
             document=instance.document, order__gt=instance.order).order_by('order').first()
         return nex and nex.pk or None
+
+
+class OcrModelSerializer(serializers.ModelSerializer):
+    owner = serializers.ReadOnlyField(source='owner.username')
+    job = DisplayChoiceField(choices=OcrModel.MODEL_JOB_CHOICES)
+    training = serializers.ReadOnlyField()
+
+    class Meta:
+        model = OcrModel
+        fields = ('pk', 'name', 'file', 'job',
+                  'owner', 'training', 'versions')
+
+    def create(self, data):
+        document = Document.objects.get(pk=self.context["view"].kwargs["document_pk"])
+        data['document'] = document
+        data['owner'] = self.context["view"].request.user
+        obj = super().create(data)
+        return obj
+
+
+class ProcessSerializerMixin():
+    def __init__(self, document, user, *args, **kwargs):
+        self.document = document
+        self.user = user
+        super().__init__(*args, **kwargs)
+
+
+class SegmentSerializer(ProcessSerializerMixin, serializers.Serializer):
+    STEPS_CHOICES = (
+        ('both', _('Lines and regions')),
+        ('lines', _('Lines Baselines and Masks')),
+        ('masks', _('Only lines Masks')),
+        ('regions', _('Regions')),
+    )
+    TEXT_DIRECTION_CHOICES = (
+        ('horizontal-lr', _("Horizontal l2r")),
+        ('horizontal-rl', _("Horizontal r2l")),
+        ('vertical-lr', _("Vertical l2r")),
+        ('vertical-rl', _("Vertical r2l"))
+    )
+
+    parts = serializers.PrimaryKeyRelatedField(many=True,
+                                               queryset=DocumentPart.objects.all())
+    steps = serializers.ChoiceField(choices=STEPS_CHOICES,
+                                    required=False,
+                                    default='both')
+    model = serializers.PrimaryKeyRelatedField(required=False,
+                                               allow_null=True,
+                                               queryset=OcrModel.objects.all())
+    override = serializers.BooleanField(required=False, default=True)
+    text_direction = serializers.ChoiceField(default='horizontal-lr',
+                                             required=False,
+                                             choices=TEXT_DIRECTION_CHOICES)
+
+    def __init__(self, *args, **kwargs):
+        super().__init__(*args, **kwargs)
+        self.fields['model'].queryset = OcrModel.objects.filter(job=OcrModel.MODEL_JOB_SEGMENT,
+                                                                document=self.document)
+        self.fields['parts'].queryset = DocumentPart.objects.filter(document=self.document)
+
+    def process(self):
+        model = self.validated_data.get('model')
+        parts = self.validated_data.get('parts')
+        for part in parts:
+            part.chain_tasks(
+                segment.si(part.pk,
+                           user_pk=self.user.pk,
+                           model_pk=model,
+                           steps=self.validated_data.get('steps'),
+                           text_direction=self.validated_data.get('text_direction'),
+                           override=self.validated_data.get('override'))
+            )
+
+
+class SegTrainSerializer(ProcessSerializerMixin, serializers.Serializer):
+    parts = serializers.PrimaryKeyRelatedField(many=True,
+                                               queryset=DocumentPart.objects.all())
+    model = serializers.PrimaryKeyRelatedField(required=False,
+                                               queryset=OcrModel.objects.all())
+    model_name = serializers.CharField(required=False)
+
+    def __init__(self, *args, **kwargs):
+        super().__init__(*args, **kwargs)
+        self.fields['model'].queryset = OcrModel.objects.filter(job=OcrModel.MODEL_JOB_SEGMENT,
+                                                                document=self.document)
+        self.fields['parts'].queryset = DocumentPart.objects.filter(document=self.document)
+
+    def validate_parts(self, data):
+        if len(data) < 2:
+            raise serializers.ValidationError("Segmentation training requires at least 2 images.")
+        return data
+
+    def validate(self, data):
+        data = super().validate(data)
+        if not data.get('model') and not data.get('model_name'):
+            raise serializers.ValidationError(
+                _("Either use model_name to create a new model, add a model pk to retrain an existing one, or both to create a new model from an existing one."))
+        return data
+
+    def process(self):
+        model = self.validated_data.get('model')
+        if self.validated_data.get('model_name'):
+            file_ = model and model.file or None
+            model = OcrModel.objects.create(
+                document=self.document,
+                owner=self.user,
+                name=self.validated_data['model_name'],
+                job=OcrModel.MODEL_JOB_RECOGNIZE,
+                file=file_
+            )
+
+        segtrain.delay(model.pk if model else None,
+                       [part.pk for part in self.validated_data.get('parts')],
+                       user_pk=self.user.pk)
+
+
+class TrainSerializer(ProcessSerializerMixin, serializers.Serializer):
+    parts = serializers.PrimaryKeyRelatedField(many=True,
+                                               queryset=DocumentPart.objects.all())
+    model = serializers.PrimaryKeyRelatedField(required=False,
+                                               queryset=OcrModel.objects.all())
+    model_name = serializers.CharField(required=False)
+    transcription = serializers.PrimaryKeyRelatedField(queryset=Transcription.objects.all())
+
+    def __init__(self, *args, **kwargs):
+        super().__init__(*args, **kwargs)
+        self.fields['transcription'].queryset = Transcription.objects.filter(document=self.document)
+        self.fields['model'].queryset = OcrModel.objects.filter(job=OcrModel.MODEL_JOB_RECOGNIZE,
+                                                                document=self.document)
+        self.fields['parts'].queryset = DocumentPart.objects.filter(document=self.document)
+
+    def validate(self, data):
+        data = super().validate(data)
+        if not data.get('model') and not data.get('model_name'):
+            raise serializers.ValidationError(
+                    _("Either use model_name to create a new model, or add a model pk to retrain an existing one."))
+        return data
+
+    def process(self):
+        model = self.validated_data.get('model')
+        if self.validated_data.get('model_name'):
+            file_ = model and model.file or None
+            model = OcrModel.objects.create(
+                document=self.document,
+                owner=self.user,
+                name=self.validated_data['model_name'],
+                job=OcrModel.MODEL_JOB_RECOGNIZE,
+                file=file_)
+
+        train.delay([part.pk for part in self.validated_data.get('parts')],
+                    self.validated_data['transcription'].pk,
+                    model.pk if model else None,
+                    self.user.pk)
+
+
+class TranscribeSerializer(ProcessSerializerMixin, serializers.Serializer):
+    parts = serializers.PrimaryKeyRelatedField(many=True,
+                                               queryset=DocumentPart.objects.all())
+    model = serializers.PrimaryKeyRelatedField(queryset=OcrModel.objects.all())
+    # transcription = serializers.PrimaryKeyRelatedField(queryset=Transcription.objects.all())
+
+    def __init__(self, *args, **kwargs):
+        super().__init__(*args, **kwargs)
+        # self.fields['transcription'].queryset = Transcription.objects.filter(document=self.document)
+        self.fields['model'].queryset = OcrModel.objects.filter(job=OcrModel.MODEL_JOB_RECOGNIZE,
+                                                                document=self.document)
+        self.fields['parts'].queryset = DocumentPart.objects.filter(document=self.document)
+
+    def process(self):
+        model = self.validated_data.get('model')
+        for part in self.validated_data.get('parts'):
+            part.chain_tasks(
+                transcribe.si(part.pk,
+                              model_pk=model.pk,
+                              user_pk=self.user.pk)
+            )
diff --git a/app/apps/api/tests.py b/app/apps/api/tests.py
index 7de43df764c53d7866b4daccbc8f45c40bf61945..fb1ee141f2dd08f64d9a3fc76ad0354f9974bcf0 100644
--- a/app/apps/api/tests.py
+++ b/app/apps/api/tests.py
@@ -8,9 +8,10 @@ from django.core.files.uploadedfile import SimpleUploadedFile
 from django.test import override_settings
 from django.urls import reverse
 
-from core.models import Block, Line, Transcription, LineTranscription
+from core.models import Block, Line, Transcription, LineTranscription, OcrModel
 from core.tests.factory import CoreFactoryTestCase
 
+
 class UserViewSetTestCase(CoreFactoryTestCase):
 
     def setUp(self):
@@ -21,7 +22,7 @@ class UserViewSetTestCase(CoreFactoryTestCase):
         self.client.force_login(user)
         uri = reverse('api:user-onboarding')
         resp = self.client.put(uri, {
-                'onboarding' : 'False',
+                'onboarding': 'False',
                 }, content_type='application/json')
 
         user.refresh_from_db()
@@ -29,7 +30,40 @@ class UserViewSetTestCase(CoreFactoryTestCase):
         self.assertEqual(user.onboarding, False)
 
 
+class OcrModelViewSetTestCase(CoreFactoryTestCase):
+    def setUp(self):
+        super().setUp()
+        self.part = self.factory.make_part()
+        self.user = self.part.document.owner
+        self.model = self.factory.make_model(document=self.part.document)
+
+    def test_list(self):
+        self.client.force_login(self.user)
+        uri = reverse('api:model-list', kwargs={'document_pk': self.part.document.pk})
+        with self.assertNumQueries(7):
+            resp = self.client.get(uri)
+        self.assertEqual(resp.status_code, 200)
+
+    def test_detail(self):
+        self.client.force_login(self.user)
+        uri = reverse('api:model-detail',
+                      kwargs={'document_pk': self.part.document.pk,
+                              'pk': self.model.pk})
+        with self.assertNumQueries(6):
+            resp = self.client.get(uri)
+        self.assertEqual(resp.status_code, 200)
 
+    def test_create(self):
+        self.client.force_login(self.user)
+        uri = reverse('api:model-list', kwargs={'document_pk': self.part.document.pk})
+        with self.assertNumQueries(4):
+            resp = self.client.post(uri, {
+                'name': 'test.mlmodel',
+                'file': self.factory.make_asset_file(name='test.mlmodel',
+                                                     asset_name='fake_seg.mlmodel'),
+                'job': 'Segment'
+            })
+        self.assertEqual(resp.status_code, 201, resp.content)
 
 
 class DocumentViewSetTestCase(CoreFactoryTestCase):
@@ -37,6 +71,31 @@ class DocumentViewSetTestCase(CoreFactoryTestCase):
         super().setUp()
         self.doc = self.factory.make_document()
         self.doc2 = self.factory.make_document(owner=self.doc.owner)
+        self.part = self.factory.make_part(document=self.doc)
+        self.part2 = self.factory.make_part(document=self.doc)
+
+        self.line = Line.objects.create(
+            baseline=[[10, 25], [50, 25]],
+            mask=[10, 10, 50, 50],
+            document_part=self.part)
+        self.line2 = Line.objects.create(
+            baseline=[[10, 80], [50, 80]],
+            mask=[10, 60, 50, 100],
+            document_part=self.part)
+        self.transcription = Transcription.objects.create(
+            document=self.part.document,
+            name='test')
+        self.transcription2 = Transcription.objects.create(
+            document=self.part.document,
+            name='tr2')
+        self.lt = LineTranscription.objects.create(
+            transcription=self.transcription,
+            line=self.line,
+            content='test')
+        self.lt2 = LineTranscription.objects.create(
+            transcription=self.transcription2,
+            line=self.line2,
+            content='test2')
 
     def test_list(self):
         self.client.force_login(self.doc.owner)
@@ -62,9 +121,81 @@ class DocumentViewSetTestCase(CoreFactoryTestCase):
         # Note: raises a 404 instead of 403 but its fine
         self.assertEqual(resp.status_code, 404)
 
-    # not used
-    # def test_update
-    # def test_create
+    def test_segtrain_less_two_parts(self):
+        self.client.force_login(self.doc.owner)
+        model = self.factory.make_model(job=OcrModel.MODEL_JOB_SEGMENT, document=self.doc)
+        uri = reverse('api:document-segtrain', kwargs={'pk': self.doc.pk})
+        resp = self.client.post(uri, data={
+                'parts': [self.part.pk],
+                'model': model.pk
+        })
+
+        self.assertEqual(resp.status_code, 400)
+        self.assertEqual(resp.json()['error'], {'parts': [
+            'Segmentation training requires at least 2 images.']})
+
+    def test_segtrain_new_model(self):
+        self.client.force_login(self.doc.owner)
+        uri = reverse('api:document-segtrain', kwargs={'pk': self.doc.pk})
+        resp = self.client.post(uri, data={
+                'parts': [self.part.pk, self.part2.pk],
+                'model_name': 'new model'
+        })
+        self.assertEqual(resp.status_code, 200, resp.content)
+        self.assertEqual(OcrModel.objects.count(), 1)
+        self.assertEqual(OcrModel.objects.first().name, "new model")
+
+    def test_segtrain_existing_model_rename(self):
+        self.client.force_login(self.doc.owner)
+        model = self.factory.make_model(job=OcrModel.MODEL_JOB_SEGMENT, document=self.doc)
+        uri = reverse('api:document-segtrain', kwargs={'pk': self.doc.pk})
+        resp = self.client.post(uri, data={
+            'parts': [self.part.pk, self.part2.pk],
+            'model': model.pk,
+            'model_name': 'test new model'
+        })
+        self.assertEqual(resp.status_code, 200, resp.content)
+        self.assertEqual(OcrModel.objects.count(), 2)
+
+    def test_segment(self):
+        uri = reverse('api:document-segment', kwargs={'pk': self.doc.pk})
+        self.client.force_login(self.doc.owner)
+        model = self.factory.make_model(job=OcrModel.MODEL_JOB_SEGMENT, document=self.doc)
+        resp = self.client.post(uri, data={
+            'parts': [self.part.pk, self.part2.pk],
+            'seg_steps': 'both',
+            'model': model.pk,
+        })
+        self.assertEqual(resp.status_code, 200)
+
+    def test_train_new_model(self):
+        self.client.force_login(self.doc.owner)
+        uri = reverse('api:document-train', kwargs={'pk': self.doc.pk})
+        resp = self.client.post(uri, data={
+            'parts': [self.part.pk, self.part2.pk],
+            'model_name': 'testing new model',
+            'transcription': self.transcription.pk
+        })
+        self.assertEqual(resp.status_code, 200)
+        self.assertEqual(OcrModel.objects.filter(
+            document=self.doc,
+            job=OcrModel.MODEL_JOB_RECOGNIZE).count(), 1)
+
+    def test_transcribe(self):
+        trans = Transcription.objects.create(document=self.part.document)
+
+        self.client.force_login(self.doc.owner)
+        model = self.factory.make_model(job=OcrModel.MODEL_JOB_RECOGNIZE, document=self.doc)
+        uri = reverse('api:document-transcribe', kwargs={'pk': self.doc.pk})
+        resp = self.client.post(uri, data={
+            'parts': [self.part.pk, self.part2.pk],
+            'model': model.pk,
+            'transcription': trans.pk
+        })
+        self.assertEqual(resp.status_code, 200)
+        self.assertEqual(resp.content, b'{"status":"ok"}')
+        # won't work with dummy model and image
+        # self.assertEqual(LineTranscription.objects.filter(transcription=trans).count(), 2)
 
 
 class PartViewSetTestCase(CoreFactoryTestCase):
@@ -189,7 +320,7 @@ class BlockViewSetTestCase(CoreFactoryTestCase):
             # 5 insert
             resp = self.client.post(uri, {
                 'document_part': self.part.pk,
-                'box': '[[10,10], [50,50]]'
+                'box': '[[10,10], [20,20], [50,50]]'
             })
         self.assertEqual(resp.status_code, 201, resp.content)
 
@@ -215,15 +346,15 @@ class LineViewSetTestCase(CoreFactoryTestCase):
                 box=[10, 10, 200, 200],
                 document_part=self.part)
         self.line = Line.objects.create(
-                box=[60, 10, 100, 50],
+                mask=[60, 10, 100, 50],
                 document_part=self.part,
                 block=self.block)
         self.line2 = Line.objects.create(
-                box=[90, 10, 70, 50],
+                mask=[90, 10, 70, 50],
                 document_part=self.part,
                 block=self.block)
         self.orphan = Line.objects.create(
-            box=[0, 0, 10, 10],
+            mask=[0, 0, 10, 10],
             document_part=self.part,
             block=None)
 
@@ -294,10 +425,10 @@ class LineTranscriptionViewSetTestCase(CoreFactoryTestCase):
         self.part = self.factory.make_part()
         self.user = self.part.document.owner
         self.line = Line.objects.create(
-            box=[10, 10, 50, 50],
+            mask=[10, 10, 50, 50],
             document_part=self.part)
         self.line2 = Line.objects.create(
-            box=[10, 60, 50, 100],
+            mask=[10, 60, 50, 100],
             document_part=self.part)
         self.transcription = Transcription.objects.create(
             document=self.part.document,
@@ -361,7 +492,7 @@ class LineTranscriptionViewSetTestCase(CoreFactoryTestCase):
         uri = reverse('api:linetranscription-bulk-create',
                       kwargs={'document_pk': self.part.document.pk, 'part_pk': self.part.pk})
         ll = Line.objects.create(
-            box=[10, 10, 50, 50],
+            mask=[10, 10, 50, 50],
             document_part=self.part)
         with self.assertNumQueries(10):
             resp = self.client.post(
diff --git a/app/apps/api/urls.py b/app/apps/api/urls.py
index 3c773a394e07c9b1ea1c7c874d82c0b67cb53188..5c68c75144c91691322606062c09a3df48437534 100644
--- a/app/apps/api/urls.py
+++ b/app/apps/api/urls.py
@@ -10,7 +10,8 @@ from api.views import (DocumentViewSet,
                        LineViewSet,
                        BlockTypeViewSet,
                        LineTypeViewSet,
-                       LineTranscriptionViewSet)
+                       LineTranscriptionViewSet,
+                       OcrModelViewSet)
 
 router = routers.DefaultRouter()
 router.register(r'documents', DocumentViewSet)
@@ -20,6 +21,7 @@ router.register(r'types/line', LineTypeViewSet)
 documents_router = routers.NestedSimpleRouter(router, r'documents', lookup='document')
 documents_router.register(r'parts', PartViewSet, basename='part')
 documents_router.register(r'transcriptions', DocumentTranscriptionViewSet, basename='transcription')
+documents_router.register(r'models', OcrModelViewSet, basename='model')
 parts_router = routers.NestedSimpleRouter(documents_router, r'parts', lookup='part')
 parts_router.register(r'blocks', BlockViewSet)
 parts_router.register(r'lines', LineViewSet)
diff --git a/app/apps/api/views.py b/app/apps/api/views.py
index cad44e8f1912bceeb03155464e8691a0ba3f8f4d..19fc37659b438eebc26ae297a4cfada6ccc9d976 100644
--- a/app/apps/api/views.py
+++ b/app/apps/api/views.py
@@ -25,7 +25,13 @@ from api.serializers import (UserOnboardingSerializer,
                              DetailedLineSerializer,
                              LineOrderSerializer,
                              TranscriptionSerializer,
-                             LineTranscriptionSerializer)
+                             LineTranscriptionSerializer,
+                             SegmentSerializer,
+                             TrainSerializer,
+                             SegTrainSerializer,
+                             TranscribeSerializer,
+                             OcrModelSerializer)
+
 from core.models import (Document,
                          DocumentPart,
                          Block,
@@ -33,7 +39,10 @@ from core.models import (Document,
                          BlockType,
                          LineType,
                          Transcription,
-                         LineTranscription)
+                         LineTranscription,
+                         OcrModel,
+                         AlreadyProcessingException)
+
 from core.tasks import recalculate_masks
 from users.models import User
 from imports.forms import ImportForm, ExportForm
@@ -50,7 +59,7 @@ class UserViewSet(ModelViewSet):
 
     @action(detail=False, methods=['put'])
     def onboarding(self, request):
-        serializer = UserOnboardingSerializer(self.request.user,data=request.data, partial=True)
+        serializer = UserOnboardingSerializer(self.request.user, data=request.data, partial=True)
         if serializer.is_valid(raise_exception=True):
             serializer.save()
             return Response(status=status.HTTP_200_OK)
@@ -120,6 +129,42 @@ class DocumentViewSet(ModelViewSet):
         else:
             return self.form_error(json.dumps(form.errors))
 
+    def get_process_response(self, request, serializer_class):
+        document = self.get_object()
+        serializer = serializer_class(document=document,
+                                      user=request.user,
+                                      data=request.data)
+        if serializer.is_valid():
+            try:
+                serializer.process()
+            except AlreadyProcessingException:
+                return Response(status=status.HTTP_400_BAD_REQUEST,
+                                data={'status': 'error',
+                                      'error': 'Already processing.'})
+
+            return Response(status=status.HTTP_200_OK,
+                            data={'status': 'ok'})
+        else:
+            return Response(status=status.HTTP_400_BAD_REQUEST,
+                            data={'status': 'error',
+                                  'error': serializer.errors})
+
+    @action(detail=True, methods=['post'])
+    def segment(self, request, pk=None):
+        return self.get_process_response(request, SegmentSerializer)
+
+    @action(detail=True, methods=['post'])
+    def train(self, request, pk=None):
+        return self.get_process_response(request, TrainSerializer)
+
+    @action(detail=True, methods=['post'])
+    def segtrain(self, request, pk=None):
+        return self.get_process_response(request, SegTrainSerializer)
+
+    @action(detail=True, methods=['post'])
+    def transcribe(self, request, pk=None):
+        return self.get_process_response(request, TranscribeSerializer)
+
 
 class DocumentPermissionMixin():
     def get_queryset(self):
@@ -390,3 +435,22 @@ class LineTranscriptionViewSet(DocumentPermissionMixin, ModelViewSet):
         qs = LineTranscription.objects.filter(pk__in=lines)
         qs.update(content='')
         return Response(status=status.HTTP_204_NO_CONTENT, )
+
+
+class OcrModelViewSet(DocumentPermissionMixin, ModelViewSet):
+    queryset = OcrModel.objects.all()
+    serializer_class = OcrModelSerializer
+
+    def get_queryset(self):
+        return (super().get_queryset()
+                .filter(document=self.kwargs['document_pk']))
+
+    @action(detail=True, methods=['post'])
+    def cancel_training(self, request, pk=None):
+        model = self.get_object()
+        try:
+            model.cancel_training()
+        except Exception as e:
+            logger.exception(e)
+            return Response({'status': 'failed'}, status=400)
+        return Response({'status': 'canceled'})
diff --git a/app/apps/core/forms.py b/app/apps/core/forms.py
index 6d94585bcdb5ea98a7140116a2c202f6c7aa482d..c591e7fdfd4bd494a5a2b3dcab03bd0825912f47 100644
--- a/app/apps/core/forms.py
+++ b/app/apps/core/forms.py
@@ -109,6 +109,250 @@ MetadataFormSet = inlineformset_factory(Document, DocumentMetadata,
                                         extra=1, can_delete=True)
 
 
+class DocumentProcessForm1(BootstrapFormMixin, forms.Form):
+    parts = forms.CharField()
+
+    @cached_property
+    def parts(self):
+        pks = json.loads(self.data.get('parts'))
+        parts = DocumentPart.objects.filter(
+            document=self.document, pk__in=pks)
+        return parts
+
+    def __init__(self, document, user, *args, **kwargs):
+        self.document = document
+        self.user = user
+        super().__init__(*args, **kwargs)
+        # self.fields['typology'].widget = forms.HiddenInput()  # for now
+        # self.fields['typology'].initial = Typology.objects.get(name="Page")
+        # self.fields['typology'].widget.attrs['title'] = _("Default Typology")
+        if self.document.read_direction == self.document.READ_DIRECTION_RTL:
+            self.initial['text_direction'] = 'horizontal-rl'
+        self.fields['binarizer'].widget.attrs['disabled'] = True
+        self.fields['train_model'].queryset &= OcrModel.objects.filter(document=self.document)
+        self.fields['segtrain_model'].queryset &= OcrModel.objects.filter(document=self.document)
+        self.fields['seg_model'].queryset &= OcrModel.objects.filter(document=self.document)
+        self.fields['ocr_model'].queryset &= OcrModel.objects.filter(
+            Q(document=None, script=document.main_script)
+            | Q(document=self.document))
+        self.fields['transcription'].queryset = Transcription.objects.filter(document=self.document)
+
+    def process(self):
+        model = self.cleaned_data.get('model')
+
+class DocumentSegmentForm(DocumentProcessForm1):
+    SEG_STEPS_CHOICES = (
+        ('both', _('Lines and regions')),
+        ('lines', _('Lines Baselines and Masks')),
+        ('masks', _('Only lines Masks')),
+        ('regions', _('Regions')),
+    )
+    segmentation_steps = forms.ChoiceField(choices=SEG_STEPS_CHOICES,
+                                           initial='both', required=False)
+    seg_model = forms.ModelChoiceField(queryset=OcrModel.objects.filter(job=OcrModel.MODEL_JOB_SEGMENT),
+                                       label=_("Model"), empty_label="default ({name})".format(
+            name=settings.KRAKEN_DEFAULT_SEGMENTATION_MODEL.rsplit('/')[-1]),
+                                       required=False)
+    override = forms.BooleanField(required=False, initial=True,
+                                  help_text=_(
+                                      "If checked, deletes existing segmentation <b>and bound transcriptions</b> first!"))
+    TEXT_DIRECTION_CHOICES = (('horizontal-lr', _("Horizontal l2r")),
+                              ('horizontal-rl', _("Horizontal r2l")),
+                              ('vertical-lr', _("Vertical l2r")),
+                              ('vertical-rl', _("Vertical r2l")))
+    text_direction = forms.ChoiceField(initial='horizontal-lr', required=False,
+                                       choices=TEXT_DIRECTION_CHOICES)
+    upload_model = forms.FileField(required=False,
+                                   validators=[FileExtensionValidator(
+                                       allowed_extensions=['mlmodel', 'pronn', 'clstm'])])
+
+    def clean(self):
+        data = super().clean()
+        model_job = OcrModel.MODEL_JOB_SEGMENT
+
+        if data.get('upload_model'):
+            model = OcrModel.objects.create(
+                document=self.parts[0].document,
+                owner=self.user,
+                name=data['upload_model'].name.rsplit('.', 1)[0],
+                job=model_job)
+            # Note: needs to save the file in a second step because the path needs the db PK
+            model.file = data['upload_model']
+            model.save()
+
+        elif data.get('seg_model'):
+            model = data.get('seg_model')
+        else:
+            model = None
+
+        data['model'] = model
+        return data
+
+    def process(self):
+        super().process()
+
+        for part in self.parts:
+            part.task('segment',
+                      user_pk=self.user.pk,
+                      steps=self.cleaned_data.get('segmentation_steps'),
+                      text_direction=self.cleaned_data.get('text_direction'),
+                      model_pk=model and model.pk or None,
+                      override=self.cleaned_data.get('override'))
+
+
+class DocumentTrainForm(DocumentProcessForm1):
+    new_model = forms.CharField(required=False, label=_('Model name'))
+    train_model = forms.ModelChoiceField(queryset=OcrModel.objects
+                                         .filter(job=OcrModel.MODEL_JOB_RECOGNIZE),
+                                         label=_("Model"), required=False)
+    upload_model = forms.FileField(required=False,
+                                   validators=[FileExtensionValidator(
+                                       allowed_extensions=['mlmodel', 'pronn', 'clstm'])])
+
+    transcription = forms.ModelChoiceField(queryset=Transcription.objects.all(), required=False)
+
+    def clean_train_model(self):
+        model = self.cleaned_data['train_model']
+        if model and model.training:
+            raise AlreadyProcessingException
+        return model
+
+    def clean(self):
+        data = super().clean()
+
+        model_job = OcrModel.MODEL_JOB_RECOGNIZE
+
+        if data.get('train_model'):
+            model = data.get('train_model')
+
+        elif data.get('upload_model'):
+            model = OcrModel.objects.create(
+                document=self.parts[0].document,
+                owner=self.user,
+                name=data['upload_model'].name.rsplit('.', 1)[0],
+                job=model_job)
+            # Note: needs to save the file in a second step because the path needs the db PK
+            model.file = data['upload_model']
+            model.save()
+
+        elif data.get('new_model'):
+            # file will be created by the training process
+            model = OcrModel.objects.create(
+                document=self.parts[0].document,
+                owner=self.user,
+                name=data['new_model'],
+                job=model_job)
+
+        else:
+            raise forms.ValidationError(
+                    _("Either select a name for your new model or an existing one."))
+
+        data['model'] = model
+        return data
+
+
+    def process(self):
+        super().process()
+
+        model.train(self.parts,
+                    self.cleaned_data['transcription'],
+                    user=self.user)
+
+
+class DocumentSegtrainForm(DocumentProcessForm1):
+    segtrain_model = forms.ModelChoiceField(queryset=OcrModel.objects
+                                            .filter(job=OcrModel.MODEL_JOB_SEGMENT),
+                                            label=_("Model"), required=False)
+    upload_model = forms.FileField(required=False,
+                                   validators=[FileExtensionValidator(
+                                       allowed_extensions=['mlmodel', 'pronn', 'clstm'])])
+
+    new_model = forms.CharField(required=False, label=_('Model name'))
+
+    def clean(self):
+        data = super().clean()
+
+
+        model_job = OcrModel.MODEL_JOB_SEGMENT
+        if len(self.parts) < 2:
+            raise forms.ValidationError("Segmentation training requires at least 2 images.")
+
+        if data.get('segtrain_model'):
+            model = data.get('segtrain_model')
+        elif data.get('upload_model'):
+            model = OcrModel.objects.create(
+                document=self.parts[0].document,
+                owner=self.user,
+                name=data['upload_model'].name.rsplit('.', 1)[0],
+                job=model_job)
+            # Note: needs to save the file in a second step because the path needs the db PK
+            model.file = data['upload_model']
+            model.save()
+
+        elif data.get('new_model'):
+            # file will be created by the training process
+            model = OcrModel.objects.create(
+                document=self.parts[0].document,
+                owner=self.user,
+                name=data['new_model'],
+                job=model_job)
+
+        else:
+
+            raise forms.ValidationError(
+                _("Either select a name for your new model or an existing one."))
+
+        data['model'] = model
+        return data
+
+    def process(self):
+        super().process()
+        model.segtrain(self.document,
+                       self.parts,
+                       user=self.user)
+
+
+class DocumentTranscribeForm(DocumentProcessForm1):
+
+    upload_model = forms.FileField(required=False,
+                                   validators=[FileExtensionValidator(
+                                       allowed_extensions=['mlmodel', 'pronn', 'clstm'])])
+    ocr_model = forms.ModelChoiceField(queryset=OcrModel.objects
+                                       .filter(job=OcrModel.MODEL_JOB_RECOGNIZE),
+                                       label=_("Model"), required=False)
+
+    def clean(self):
+        data = super().clean()
+
+        model_job = OcrModel.MODEL_JOB_RECOGNIZE
+
+        if data.get('upload_model'):
+            model = OcrModel.objects.create(
+                document=self.parts[0].document,
+                owner=self.user,
+                name=data['upload_model'].name.rsplit('.', 1)[0],
+                job=model_job)
+            # Note: needs to save the file in a second step because the path needs the db PK
+            model.file = data['upload_model']
+            model.save()
+
+        elif data.get('ocr_model'):
+            model = data.get('ocr_model')
+        else:
+            raise forms.ValidationError(
+                    _("Either select a name for your new model or an existing one."))
+
+        data['model'] = model
+        return data
+
+    def process(self):
+        super().process()
+        for part in self.parts:
+            part.task('transcribe',
+                  user_pk=self.user.pk,
+                  model_pk=model and model.pk or None)
+
+
 class DocumentProcessForm(BootstrapFormMixin, forms.Form):
     # TODO: split this form into one for each process?!
     TASK_BINARIZE = 'binarize'
diff --git a/app/apps/core/models.py b/app/apps/core/models.py
index 162a718152de1ee70bb08e8a5b8167dcefd991b1..aad2b42151b11b7c417bfdcdafa843a1ab8731c5 100644
--- a/app/apps/core/models.py
+++ b/app/apps/core/models.py
@@ -677,52 +677,45 @@ class DocumentPart(OrderedModel):
         self.save()
         self.recalculate_ordering(read_direction=read_direction)
 
-    def transcribe(self, model=None, text_direction=None):
-        if model:
-            trans, created = Transcription.objects.get_or_create(
-                name='kraken:' + model.name,
-                document=self.document)
-            model_ = kraken_models.load_any(model.file.path)
-            lines = self.lines.all()
-            text_direction = (text_direction
-                              or (self.document.main_script
-                                  and self.document.main_script.text_direction)
-                              or 'horizontal-lr')
+    def transcribe(self, model, text_direction=None):
+        trans, created = Transcription.objects.get_or_create(
+            name='kraken:' + model.name,
+            document=self.document)
+        model_ = kraken_models.load_any(model.file.path)
+        lines = self.lines.all()
+        text_direction = (text_direction
+                          or (self.document.main_script
+                              and self.document.main_script.text_direction)
+                          or 'horizontal-lr')
 
-            with Image.open(self.image.file.name) as im:
-                for line in lines:
-                    if not line.baseline:
-                        bounds = {
-                            'boxes': [line.box],
-                            'text_direction': text_direction,
-                            'type': 'baselines',
-                            # 'script_detection': True
-                        }
-                    else:
-                        bounds = {
-                            'lines': [{'baseline': line.baseline,
-                                       'boundary': line.mask,
-                                       'text_direction': text_direction,
-                                       'script': 'default'}],  # self.document.main_script.name
-                            'type': 'baselines',
-                            # 'selfcript_detection': True
-                        }
-                    it = rpred.rpred(
-                            model_, im,
-                            bounds=bounds,
-                            pad=16,  # TODO: % of the image?
-                            bidi_reordering=True)
-                    lt, created = LineTranscription.objects.get_or_create(
-                        line=line, transcription=trans)
-                    for pred in it:
-                        lt.content = pred.prediction
-                        lt.graphs = [(char, pred.cuts[i], float(pred.confidences[i]))
-                                     for i, char in enumerate(pred.prediction)]
-                    lt.save()
-        else:
-            Transcription.objects.get_or_create(
-                name='manual',
-                document=self.document)
+        with Image.open(self.image.file.name) as im:
+            for line in lines:
+                if not line.baseline:
+                    bounds = {
+                        'boxes': [line.box],
+                        'text_direction': text_direction,
+                        'type': 'baselines',
+                        # 'script_detection': True
+                    }
+                else:
+                    bounds = {
+                        'lines': [{'baseline': line.baseline,
+                                   'boundary': line.mask,
+                                   'text_direction': text_direction,
+                                   'script': 'default'}],  # self.document.main_script.name
+                        'type': 'baselines',
+                        # 'selfcript_detection': True
+                    }
+                it = rpred.rpred(
+                        model_, im,
+                        bounds=bounds,
+                        pad=16,  # TODO: % of the image?
+                        bidi_reordering=True)
+                lt, created = LineTranscription.objects.get_or_create(
+                    line=line, transcription=trans)
+                for pred in it:
+                    lt.content = pred.prediction
+                lt.save()
 
         self.workflow_state = self.WORKFLOW_STATE_TRANSCRIBING
         self.calculate_progress()
@@ -1082,7 +1075,7 @@ class LineTranscription(Versioned, models.Model):
 
 def models_path(instance, filename):
     fn, ext = os.path.splitext(filename)
-    return 'models/%d/%s%s' % (instance.pk, slugify(fn), ext)
+    return 'models/%d/%s%s' % (instance.document.pk, slugify(fn), ext)
 
 
 class OcrModel(Versioned, models.Model):
@@ -1104,9 +1097,9 @@ class OcrModel(Versioned, models.Model):
     training_accuracy = models.FloatField(default=0.0)
     training_total = models.IntegerField(default=0)
     training_errors = models.IntegerField(default=0)
-    document = models.ForeignKey(Document, blank=True, null=True,
+    document = models.ForeignKey(Document,
                                  related_name='ocr_models',
-                                 default=None, on_delete=models.SET_NULL)
+                                 default=None, on_delete=models.CASCADE)
     script = models.ForeignKey(Script, blank=True, null=True, on_delete=models.SET_NULL)
 
     version_ignore_fields = ('name', 'owner', 'document', 'script', 'training')
@@ -1124,8 +1117,9 @@ class OcrModel(Versioned, models.Model):
         return self.training_accuracy * 100
 
     def segtrain(self, document, parts_qs, user=None):
-        segtrain.delay(self.pk, document.pk,
-                       list(parts_qs.values_list('pk', flat=True)))
+        segtrain.delay(self.pk,
+                       list(parts_qs.values_list('pk', flat=True)),
+                       user_pk=user and user.pk or None)
 
     def train(self, parts_qs, transcription, user=None):
         train.delay(list(parts_qs.values_list('pk', flat=True)),
@@ -1139,7 +1133,7 @@ class OcrModel(Versioned, models.Model):
                 task_id = json.loads(redis_.get('training-%d' % self.pk))['task_id']
             elif self.job == self.MODEL_JOB_SEGMENT:
                 task_id = json.loads(redis_.get('segtrain-%d' % self.pk))['task_id']
-        except (TypeError, KeyError) as e:
+        except (TypeError, KeyError):
             raise ProcessFailureException(_("Couldn't find the training task."))
         else:
             if task_id:
diff --git a/app/apps/core/tasks.py b/app/apps/core/tasks.py
index d102e803a1969d9a098086c75d14a2dc5156b10d..65d564ed5cdf7dff9e3414350725bda24781e185 100644
--- a/app/apps/core/tasks.py
+++ b/app/apps/core/tasks.py
@@ -40,6 +40,7 @@ def update_client_state(part_id, task, status, task_id=None, data=None):
         "data": data or {}
     })
 
+
 @shared_task(autoretry_for=(MemoryError,), default_retry_delay=60)
 def generate_part_thumbnails(instance_pk):
     if not getattr(settings, 'THUMBNAIL_ENABLE', True):
@@ -129,7 +130,7 @@ def make_segmentation_training_data(part):
 
 
 @shared_task(bind=True, autoretry_for=(MemoryError,), default_retry_delay=60 * 60)
-def segtrain(task, model_pk, document_pk, part_pks, user_pk=None):
+def segtrain(task, model_pk, part_pks, user_pk=None):
     # # Note hack to circumvent AssertionError: daemonic processes are not allowed to have children
     from multiprocessing import current_process
     current_process().daemon = False
@@ -151,6 +152,7 @@ def segtrain(task, model_pk, document_pk, part_pks, user_pk=None):
     OcrModel = apps.get_model('core', 'OcrModel')
 
     model = OcrModel.objects.get(pk=model_pk)
+
     try:
         load = model.file.path
         upload_to = model.file.field.upload_to(model, model.name + '.mlmodel')
@@ -412,6 +414,7 @@ def train(task, part_pks, transcription_pk, model_pk, user_pk=None):
     Transcription = apps.get_model('core', 'Transcription')
     LineTranscription = apps.get_model('core', 'LineTranscription')
     OcrModel = apps.get_model('core', 'OcrModel')
+
     try:
         model = OcrModel.objects.get(pk=model_pk)
         model.training = True
@@ -451,10 +454,12 @@ def train(task, part_pks, transcription_pk, model_pk, user_pk=None):
 
 @shared_task(autoretry_for=(MemoryError,), default_retry_delay=10 * 60)
 def transcribe(instance_pk, model_pk=None, user_pk=None, text_direction=None, **kwargs):
+
     try:
         DocumentPart = apps.get_model('core', 'DocumentPart')
         part = DocumentPart.objects.get(pk=instance_pk)
     except DocumentPart.DoesNotExist:
+
         logger.error('Trying to transcribe innexistant DocumentPart : %d', instance_pk)
         return
 
@@ -466,18 +471,10 @@ def transcribe(instance_pk, model_pk=None, user_pk=None, text_direction=None, **
     else:
         user = None
 
-    if model_pk:
-        try:
-            OcrModel = apps.get_model('core', 'OcrModel')
-            model = OcrModel.objects.get(pk=model_pk)
-        except OcrModel.DoesNotExist:
-            # Not sure how we should deal with this case
-            model = None
-    else:
-        model = None
-
     try:
-        part.transcribe(model=model)
+        OcrModel = apps.get_model('core', 'OcrModel')
+        model = OcrModel.objects.get(pk=model_pk)
+        part.transcribe(model)
     except Exception as e:
         if user:
             user.notify(_("Something went wrong during the transcription!"),
diff --git a/app/apps/core/tests/factory.py b/app/apps/core/tests/factory.py
index 44bb89ba1e90a9c7ffcc002e2d1adfe0c77e12a5..98ffc303894f59f7dbb7ab3f4de9dd54740a00d0 100644
--- a/app/apps/core/tests/factory.py
+++ b/app/apps/core/tests/factory.py
@@ -84,17 +84,12 @@ class CoreFactory():
 
     def make_asset_file(self, name='test.png', asset_name='segmentation/default.png'):
         fp = os.path.join(os.path.dirname(__file__), 'assets', asset_name)
-        with Image.open(fp, 'r') as image:
-            file = BytesIO()
-            file.name = name
-            image.save(file, 'png')
-            file.seek(0)
-        return file
+        return open(fp, 'rb')
 
     def make_model(self, job=OcrModel.MODEL_JOB_RECOGNIZE, document=None):
         spec = '[1,48,0,1 Lbx100 Do O1c10]'
         nn = vgsl.TorchVGSLModel(spec)
-        model_name = 'test-model'
+        model_name = 'test-model.mlmodel'
         model = OcrModel.objects.create(name=model_name,
                                         document=document,
                                         job=job)