diff --git a/deepfinder/training.py b/deepfinder/training.py index 3bca9dd02a775014572547757eec817aa33b6766..e4daeef677476929c966b7ed1aef70ca6b3b9fa8 100644 --- a/deepfinder/training.py +++ b/deepfinder/training.py @@ -158,6 +158,7 @@ class Train(core.DeepFinder): self.steps_per_epoch = 100 self.steps_per_valid = 10 # number of samples for validation self.optimizer = Adam(lr=0.0001, beta_1=0.9, beta_2=0.999, epsilon=1e-08, decay=0.0) + self.loss = losses.tversky_loss self.flag_direct_read = 1 self.flag_batch_bootstrap = 0 @@ -231,7 +232,7 @@ class Train(core.DeepFinder): # Build network (not in constructor, else not possible to init model with weights from previous train round): - self.net.compile(optimizer=self.optimizer, loss=losses.tversky_loss, metrics=['accuracy']) + self.net.compile(optimizer=self.optimizer, loss=self.loss, metrics=['accuracy']) # Load whole dataset: if self.flag_direct_read == False: