Mentions légales du service

Skip to content
Snippets Groups Projects
Commit b6980931 authored by MOEBEL Emmanuel's avatar MOEBEL Emmanuel
Browse files

train : added sample_weights option (to be tested)

parent 1d6bafb6
Branches
Tags
No related merge requests found
...@@ -166,6 +166,7 @@ class Train(core.DeepFinder): ...@@ -166,6 +166,7 @@ class Train(core.DeepFinder):
self.Lrnd = 13 # random shifts applied when sampling data- and target-patches (in voxels) self.Lrnd = 13 # random shifts applied when sampling data- and target-patches (in voxels)
self.class_weight = None self.class_weight = None
self.sample_weights = None # np array same lenght as objl_train
self.check_attributes() self.check_attributes()
...@@ -261,8 +262,16 @@ class Train(core.DeepFinder): ...@@ -261,8 +262,16 @@ class Train(core.DeepFinder):
if self.flag_direct_read: if self.flag_direct_read:
batch_data, batch_target = self.generate_batch_direct_read(path_data, path_target, self.batch_size, objlist_train) batch_data, batch_target = self.generate_batch_direct_read(path_data, path_target, self.batch_size, objlist_train)
else: else:
batch_data, batch_target = self.generate_batch_from_array(data_list, target_list, self.batch_size, objlist_train) batch_data, batch_target, idx_list = self.generate_batch_from_array(data_list, target_list, self.batch_size, objlist_train)
loss_train = self.net.train_on_batch(batch_data, batch_target, class_weight=self.class_weight)
if self.sample_weights is not None:
sample_weight = self.sample_weights[idx_list]
else:
sample_weight = None
loss_train = self.net.train_on_batch(batch_data, batch_target,
class_weight=self.class_weight,
sample_weight=sample_weight)
self.display('epoch %d/%d - it %d/%d - loss: %0.3f - acc: %0.3f' % (e + 1, self.epochs, it + 1, self.steps_per_epoch, loss_train[0], loss_train[1])) self.display('epoch %d/%d - it %d/%d - loss: %0.3f - acc: %0.3f' % (e + 1, self.epochs, it + 1, self.steps_per_epoch, loss_train[0], loss_train[1]))
list_loss_train.append(loss_train[0]) list_loss_train.append(loss_train[0])
...@@ -280,7 +289,7 @@ class Train(core.DeepFinder): ...@@ -280,7 +289,7 @@ class Train(core.DeepFinder):
if self.flag_direct_read: if self.flag_direct_read:
batch_data_valid, batch_target_valid = self.generate_batch_direct_read(path_data, path_target, self.batch_size, objlist_valid) batch_data_valid, batch_target_valid = self.generate_batch_direct_read(path_data, path_target, self.batch_size, objlist_valid)
else: else:
batch_data_valid, batch_target_valid = self.generate_batch_from_array(data_list, target_list, self.batch_size, objlist_valid) batch_data_valid, batch_target_valid, idx_list = self.generate_batch_from_array(data_list, target_list, self.batch_size, objlist_valid)
loss_val = self.net.evaluate(batch_data_valid, batch_target_valid, verbose=0) # TODO replace by loss() to reduce computation loss_val = self.net.evaluate(batch_data_valid, batch_target_valid, verbose=0) # TODO replace by loss() to reduce computation
batch_pred = self.net.predict(batch_data_valid) batch_pred = self.net.predict(batch_data_valid)
#loss_val = K.eval(losses.tversky_loss(K.constant(batch_target_valid), K.constant(batch_pred))) #loss_val = K.eval(losses.tversky_loss(K.constant(batch_target_valid), K.constant(batch_pred)))
...@@ -417,9 +426,11 @@ class Train(core.DeepFinder): ...@@ -417,9 +426,11 @@ class Train(core.DeepFinder):
else: # choose from whole objlist else: # choose from whole objlist
pool = range(0, len(objlist)) pool = range(0, len(objlist))
idx_list = []
for i in range(batch_size): for i in range(batch_size):
# choose random sample in training set: # choose random sample in training set:
index = np.random.choice(pool) index = np.random.choice(pool)
idx_list.append(index)
tomoID = int(objlist[index]['tomo_idx']) tomoID = int(objlist[index]['tomo_idx'])
...@@ -448,4 +459,5 @@ class Train(core.DeepFinder): ...@@ -448,4 +459,5 @@ class Train(core.DeepFinder):
batch_data[i] = np.rot90(batch_data[i], k=2, axes=(0, 2)) batch_data[i] = np.rot90(batch_data[i], k=2, axes=(0, 2))
batch_target[i] = np.rot90(batch_target[i], k=2, axes=(0, 2)) batch_target[i] = np.rot90(batch_target[i], k=2, axes=(0, 2))
return batch_data, batch_target return batch_data, batch_target, idx_list
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment