Mentions légales du service

Skip to content
Snippets Groups Projects
Commit 9a077c05 authored by MOEBEL Emmanuel's avatar MOEBEL Emmanuel
Browse files

Train: added self.loss argument. This makes it more easy to load custom losses.

parent 4b20c6cd
No related branches found
No related tags found
No related merge requests found
...@@ -158,6 +158,7 @@ class Train(core.DeepFinder): ...@@ -158,6 +158,7 @@ class Train(core.DeepFinder):
self.steps_per_epoch = 100 self.steps_per_epoch = 100
self.steps_per_valid = 10 # number of samples for validation self.steps_per_valid = 10 # number of samples for validation
self.optimizer = Adam(lr=0.0001, beta_1=0.9, beta_2=0.999, epsilon=1e-08, decay=0.0) self.optimizer = Adam(lr=0.0001, beta_1=0.9, beta_2=0.999, epsilon=1e-08, decay=0.0)
self.loss = losses.tversky_loss
self.flag_direct_read = 1 self.flag_direct_read = 1
self.flag_batch_bootstrap = 0 self.flag_batch_bootstrap = 0
...@@ -231,7 +232,7 @@ class Train(core.DeepFinder): ...@@ -231,7 +232,7 @@ class Train(core.DeepFinder):
# Build network (not in constructor, else not possible to init model with weights from previous train round): # Build network (not in constructor, else not possible to init model with weights from previous train round):
self.net.compile(optimizer=self.optimizer, loss=losses.tversky_loss, metrics=['accuracy']) self.net.compile(optimizer=self.optimizer, loss=self.loss, metrics=['accuracy'])
# Load whole dataset: # Load whole dataset:
if self.flag_direct_read == False: if self.flag_direct_read == False:
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment