Mentions légales du service
Skip to content
GitLab
Explore
Sign in
Primary navigation
Search or go to…
Project
E
eScriptorium
Manage
Activity
Members
Labels
Plan
Issues
Issue boards
Milestones
Wiki
Code
Merge requests
Repository
Branches
Commits
Tags
Repository graph
Compare revisions
Snippets
Deploy
Releases
Container Registry
Model registry
Monitor
Incidents
Service Desk
Analyze
Value stream analytics
Contributor analytics
Repository analytics
Model experiments
Help
Help
Support
GitLab documentation
Compare GitLab plans
Community forum
Contribute to GitLab
Provide feedback
Keyboard shortcuts
?
Snippets
Groups
Projects
Show more breadcrumbs
scripta
eScriptorium
Commits
6642419a
Commit
6642419a
authored
4 years ago
by
Robin Tissot
Browse files
Options
Downloads
Plain Diff
Merge branch 'develop'
parents
f22c1ee5
08e4e0e2
No related branches found
No related tags found
No related merge requests found
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
app/apps/core/tasks.py
+9
-6
9 additions, 6 deletions
app/apps/core/tasks.py
with
9 additions
and
6 deletions
app/apps/core/tasks.py
+
9
−
6
View file @
6642419a
...
@@ -17,7 +17,6 @@ from celery import shared_task
...
@@ -17,7 +17,6 @@ from celery import shared_task
from
celery.signals
import
before_task_publish
,
task_prerun
,
task_success
,
task_failure
from
celery.signals
import
before_task_publish
,
task_prerun
,
task_success
,
task_failure
from
django_redis
import
get_redis_connection
from
django_redis
import
get_redis_connection
from
easy_thumbnails.files
import
get_thumbnailer
from
easy_thumbnails.files
import
get_thumbnailer
from
kraken.lib
import
default_specs
from
kraken.lib
import
train
as
kraken_train
from
kraken.lib
import
train
as
kraken_train
...
@@ -196,6 +195,7 @@ def segtrain(task, model_pk, part_pks, user_pk=None):
...
@@ -196,6 +195,7 @@ def segtrain(task, model_pk, part_pks, user_pk=None):
threads
=
LOAD_THREADS
,
threads
=
LOAD_THREADS
,
augment
=
True
,
augment
=
True
,
resize
=
'
both
'
,
resize
=
'
both
'
,
hyper_params
=
{
'
epochs
'
:
20
},
load_hyper_parameters
=
True
)
load_hyper_parameters
=
True
)
if
not
os
.
path
.
exists
(
os
.
path
.
split
(
modelpath
)[
0
]):
if
not
os
.
path
.
exists
(
os
.
path
.
split
(
modelpath
)[
0
]):
...
@@ -222,9 +222,15 @@ def segtrain(task, model_pk, part_pks, user_pk=None):
...
@@ -222,9 +222,15 @@ def segtrain(task, model_pk, part_pks, user_pk=None):
})
})
trainer
.
run
(
_print_eval
)
trainer
.
run
(
_print_eval
)
best_version
=
os
.
path
.
join
(
os
.
path
.
dirname
(
modelpath
),
best_version
=
os
.
path
.
join
(
os
.
path
.
dirname
(
modelpath
),
f
'
version_
{
trainer
.
stopper
.
best_epoch
}
.mlmodel
'
)
f
'
version_
{
trainer
.
stopper
.
best_epoch
}
.mlmodel
'
)
shutil
.
copy
(
best_version
,
modelpath
)
try
:
shutil
.
copy
(
best_version
,
modelpath
)
except
FileNotFoundError
:
user
.
notify
(
_
(
"
Training didn
'
t get better results than base model!
"
),
id
=
"
seg-no-gain-error
"
,
level
=
'
danger
'
)
except
Exception
as
e
:
except
Exception
as
e
:
send_event
(
'
document
'
,
document
.
pk
,
"
training:error
"
,
{
send_event
(
'
document
'
,
document
.
pk
,
"
training:error
"
,
{
...
@@ -240,7 +246,6 @@ def segtrain(task, model_pk, part_pks, user_pk=None):
...
@@ -240,7 +246,6 @@ def segtrain(task, model_pk, part_pks, user_pk=None):
user
.
notify
(
_
(
"
Training finished!
"
),
user
.
notify
(
_
(
"
Training finished!
"
),
id
=
"
training-success
"
,
id
=
"
training-success
"
,
level
=
'
success
'
)
level
=
'
success
'
)
user
.
notify
(
report
)
finally
:
finally
:
model
.
training
=
False
model
.
training
=
False
model
.
save
()
model
.
save
()
...
@@ -361,8 +366,6 @@ def train_(qs, document, transcription, model=None, user=None):
...
@@ -361,8 +366,6 @@ def train_(qs, document, transcription, model=None, user=None):
DEVICE
=
getattr
(
settings
,
'
KRAKEN_TRAINING_DEVICE
'
,
'
cpu
'
)
DEVICE
=
getattr
(
settings
,
'
KRAKEN_TRAINING_DEVICE
'
,
'
cpu
'
)
LOAD_THREADS
=
getattr
(
settings
,
'
KRAKEN_TRAINING_LOAD_THREADS
'
,
0
)
LOAD_THREADS
=
getattr
(
settings
,
'
KRAKEN_TRAINING_LOAD_THREADS
'
,
0
)
hyper_params
=
default_specs
.
RECOGNITION_HYPER_PARAMS
.
copy
()
hyper_params
[
'
batch_size
'
]
=
1
trainer
=
(
kraken_train
.
KrakenTrainer
trainer
=
(
kraken_train
.
KrakenTrainer
.
recognition_train_gen
(
device
=
DEVICE
,
.
recognition_train_gen
(
device
=
DEVICE
,
load
=
load
,
load
=
load
,
...
@@ -373,7 +376,7 @@ def train_(qs, document, transcription, model=None, user=None):
...
@@ -373,7 +376,7 @@ def train_(qs, document, transcription, model=None, user=None):
resize
=
'
add
'
,
resize
=
'
add
'
,
threads
=
LOAD_THREADS
,
threads
=
LOAD_THREADS
,
augment
=
True
,
augment
=
True
,
hyper_params
=
hyper_params
,
hyper_params
=
{
'
batch_size
'
:
1
}
,
load_hyper_parameters
=
True
))
load_hyper_parameters
=
True
))
def
_print_eval
(
epoch
=
0
,
accuracy
=
0
,
chars
=
0
,
error
=
0
,
val_metric
=
0
):
def
_print_eval
(
epoch
=
0
,
accuracy
=
0
,
chars
=
0
,
error
=
0
,
val_metric
=
0
):
...
...
This diff is collapsed.
Click to expand it.
Preview
0%
Loading
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Save comment
Cancel
Please
register
or
sign in
to comment