diff --git a/._README.md b/._README.md new file mode 100644 index 0000000000000000000000000000000000000000..97192c5d8defd5f096e8bb9f8aaf963384dcb4c2 Binary files /dev/null and b/._README.md differ diff --git a/._core_utils.py b/._core_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..4c7381b9ee4e5e4d1d3b9404b607a170638952e9 Binary files /dev/null and b/._core_utils.py differ diff --git a/._deepfind.py b/._deepfind.py new file mode 100644 index 0000000000000000000000000000000000000000..88716b326631ebe94e2c41209c99c07ba5816901 Binary files /dev/null and b/._deepfind.py differ diff --git a/._losses.py b/._losses.py new file mode 100644 index 0000000000000000000000000000000000000000..de5c5f6974bd6dc623366c4daba724ee97acfdf2 Binary files /dev/null and b/._losses.py differ diff --git a/._models.py b/._models.py new file mode 100644 index 0000000000000000000000000000000000000000..c8328815dd3b899007feb4f9cc33667374da8cd9 Binary files /dev/null and b/._models.py differ diff --git a/._utils.py b/._utils.py new file mode 100644 index 0000000000000000000000000000000000000000..790fa5ca839d74647a268d13cae16cc4126f9448 Binary files /dev/null and b/._utils.py differ diff --git a/README.md b/README.md index a232442cc4ab067e330a788ac84a29518805fad4..d553cc8fca6ab2e5ccdc2ce841c77d62f0edea5f 100644 --- a/README.md +++ b/README.md @@ -16,16 +16,26 @@ keras (2.1.6) numpy (1.14.3) h5py (2.7.1) lxml (4.3.2) -sklearn (0.0) +scikit-learn (0.19.1) +scikit-image (0.14.2) +matplotlib (2.2.3) ``` ## Installation guide First install the packages: ``` -pip install numpy tensorflow-gpu keras sklearn h5py lxml +pip install numpy tensorflow-gpu keras sklearn h5py lxml scikit-learn scikit-image matplotlib ``` For more details about installing Keras, please see [Keras installation instructions](https://keras.io/#installation). Once the dependencies are installed, the user should be able to run Deep Finder. -## Instructions for use \ No newline at end of file +## Instructions for use +Instructions for using Deep Finder are contained in folder examples/. The scripts contain comments on how the toolbox should be used. To run a script, first launch ipython: +``` +ipython +``` +The launch the script: +``` +%run script_file.py +``` \ No newline at end of file diff --git a/core_utils.py b/core_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..95bd64454ab993c583654c5f79e33ee8997eca6c --- /dev/null +++ b/core_utils.py @@ -0,0 +1,127 @@ +import numpy as np +import utils +import h5py + +import matplotlib +matplotlib.use('agg') # necessary else: AttributeError: 'NoneType' object has no attribute 'is_interactive' +import matplotlib.pyplot as plt + +def load_dataset(path_data, path_target): + data_list = [] + target_list = [] + for idx in range(0,len(path_data)): + data_list.append( utils.load_h5array(path_data[idx] )) + target_list.append(utils.load_h5array(path_target[idx])) + return data_list, target_list + + +def get_bootstrap_idx(objlist,Nbs): + # Get a vector containing the object class labels (from objlist): + Nobj = len(objlist) + label_list = np.zeros((Nobj,)) + for oo in range(0,Nobj): + label_list[oo] = float( objlist[oo].get('class_label') ) + + lblTAB = np.unique(label_list) # vector containing unique class labels + + # Bootstrap data so that we have equal frequencies (1/Nbs) for all classes: + # ->from label_list, sample Nbs objects from each class + bs_idx = [] + for l in lblTAB: + bs_idx.append( np.random.choice(np.squeeze(np.asarray(np.nonzero(label_list==l))), Nbs) ) + bs_idx = np.concatenate(bs_idx) + return bs_idx + +def get_patch_position(tomodim, p_in, obj, Lrnd): + # sample at coordinates specified in obj=objlist[idx] + x = int( obj.get('x') ) + y = int( obj.get('y') ) + z = int( obj.get('z') ) + + # Add random shift to coordinates: + x = x + np.random.choice(range(-Lrnd,Lrnd+1)) + y = y + np.random.choice(range(-Lrnd,Lrnd+1)) + z = z + np.random.choice(range(-Lrnd,Lrnd+1)) + + # Shift position if too close to border: + if (x<p_in) : x = p_in + if (y<p_in) : y = p_in + if (z<p_in) : z = p_in + if (x>tomodim[0]-p_in): x = tomodim[0]-p_in + if (y>tomodim[1]-p_in): y = tomodim[1]-p_in + if (z>tomodim[2]-p_in): z = tomodim[2]-p_in + + #else: # sample random position in tomogram + # x = np.int32( np.random.choice(range(p_in,tomodim[0]-p_in)) ) + # y = np.int32( np.random.choice(range(p_in,tomodim[0]-p_in)) ) + # z = np.int32( np.random.choice(range(p_in,tomodim[0]-p_in)) ) + + return x,y,z + +def save_history(history): + h5file = h5py.File('params_train_history.h5', 'w') + + # train and val loss & accuracy: + dset = h5file.create_dataset('acc', (len(history['acc']),)) + dset[:] = history['acc'] + dset = h5file.create_dataset('loss', (len(history['loss']),)) + dset[:] = history['loss'] + dset = h5file.create_dataset('val_acc', (len(history['val_acc']),)) + dset[:] = history['val_acc'] + dset = h5file.create_dataset('val_loss', (len(history['val_loss']),)) + dset[:] = history['val_loss'] + + # val precision, recall, F1: + dset = h5file.create_dataset('val_f1', np.shape(history['val_f1'])) + dset[:] = history['val_f1'] + dset = h5file.create_dataset('val_precision', np.shape(history['val_precision'])) + dset[:] = history['val_precision'] + dset = h5file.create_dataset('val_recall', np.shape(history['val_recall'])) + dset[:] = history['val_recall'] + + h5file.close() + return + +def plot_history(history): + Ncl = len(history['val_f1'][0]) + legend_names = [] + for lbl in range(0,Ncl): + legend_names.append('class '+str(lbl)) + + fig = plt.figure(figsize=(15,12)) + plt.subplot(321) + plt.plot(history['loss'] , label='train') + plt.plot(history['val_loss'], label='valid') + plt.ylabel('loss') + plt.xlabel('epochs') + plt.legend() + plt.grid() + + plt.subplot(323) + plt.plot(history['acc'] , label='train') + plt.plot(history['val_acc'], label='valid') + plt.ylabel('accuracy') + plt.xlabel('epochs') + plt.legend() + plt.grid() + + plt.subplot(322) + plt.plot(history['val_f1']) + plt.ylabel('F1-score') + plt.xlabel('epochs') + plt.legend(legend_names) + plt.grid() + + plt.subplot(324) + plt.plot(history['val_precision']) + plt.ylabel('precision') + plt.xlabel('epochs') + plt.grid() + + plt.subplot(326) + plt.plot(history['val_recall']) + plt.ylabel('recall') + plt.xlabel('epochs') + plt.grid() + + fig.savefig('history_train.png') \ No newline at end of file diff --git a/data/tomo10_data.h5 b/data/tomo10_data.h5 new file mode 100644 index 0000000000000000000000000000000000000000..c2a2cd29fcb711ecf5e29a38f8ac8ddf0c7fd8f3 Binary files /dev/null and b/data/tomo10_data.h5 differ diff --git a/data/tomo10_objlist.xml b/data/tomo10_objlist.xml new file mode 100644 index 0000000000000000000000000000000000000000..2989796109f52d9081a04d32d02ae68b1e482ab2 --- /dev/null +++ b/data/tomo10_objlist.xml @@ -0,0 +1,29 @@ +<objlist> + <object tomo_idx="0" class_label="1" x="16" y="83" z="81"/> + <object tomo_idx="0" class_label="1" x="73" y="31" z="75"/> + <object tomo_idx="0" class_label="1" x="41" y="62" z="78"/> + <object tomo_idx="0" class_label="2" x="83" y="54" z="31"/> + <object tomo_idx="0" class_label="2" x="39" y="62" z="34"/> + <object tomo_idx="0" class_label="2" x="71" y="32" z="20"/> + <object tomo_idx="0" class_label="3" x="66" y="70" z="51"/> + <object tomo_idx="0" class_label="3" x="78" y="24" z="55"/> + <object tomo_idx="0" class_label="3" x="78" y="58" z="50"/> + <object tomo_idx="0" class_label="4" x="23" y="51" z="18"/> + <object tomo_idx="0" class_label="4" x="32" y="41" z="64"/> + <object tomo_idx="0" class_label="4" x="62" y="18" z="68"/> + <object tomo_idx="0" class_label="5" x="41" y="83" z="67"/> + <object tomo_idx="0" class_label="5" x="70" y="50" z="65"/> + <object tomo_idx="0" class_label="5" x="21" y="80" z="62"/> + <object tomo_idx="0" class_label="6" x="36" y="57" z="50"/> + <object tomo_idx="0" class_label="6" x="40" y="24" z="74"/> + <object tomo_idx="0" class_label="6" x="84" y="50" z="84"/> + <object tomo_idx="0" class_label="7" x="19" y="63" z="42"/> + <object tomo_idx="0" class_label="7" x="66" y="21" z="43"/> + <object tomo_idx="0" class_label="7" x="47" y="77" z="24"/> + <object tomo_idx="0" class_label="8" x="56" y="33" z="34"/> + <object tomo_idx="0" class_label="8" x="62" y="54" z="53"/> + <object tomo_idx="0" class_label="8" x="76" y="42" z="36"/> + <object tomo_idx="0" class_label="9" x="40" y="15" z="34"/> + <object tomo_idx="0" class_label="9" x="31" y="35" z="36"/> + <object tomo_idx="0" class_label="9" x="17" y="23" z="75"/> +</objlist> diff --git a/data/tomo10_target.h5 b/data/tomo10_target.h5 new file mode 100644 index 0000000000000000000000000000000000000000..3d88a3e96e9b0ad4648c62763cc3da319ae44887 Binary files /dev/null and b/data/tomo10_target.h5 differ diff --git a/data/tomo1_data.h5 b/data/tomo1_data.h5 new file mode 100644 index 0000000000000000000000000000000000000000..3893566a08964ae34675504a09c2406ec46459f0 Binary files /dev/null and b/data/tomo1_data.h5 differ diff --git a/data/tomo1_objlist.xml b/data/tomo1_objlist.xml new file mode 100644 index 0000000000000000000000000000000000000000..025bbfa3edc60b9f075569995e175467e2466e12 --- /dev/null +++ b/data/tomo1_objlist.xml @@ -0,0 +1,29 @@ +<objlist> + <object tomo_idx="0" class_label="1" x="34" y="16" z="64"/> + <object tomo_idx="0" class_label="1" x="63" y="72" z="21"/> + <object tomo_idx="0" class_label="1" x="45" y="17" z="81"/> + <object tomo_idx="0" class_label="2" x="27" y="70" z="68"/> + <object tomo_idx="0" class_label="2" x="64" y="59" z="45"/> + <object tomo_idx="0" class_label="2" x="55" y="38" z="81"/> + <object tomo_idx="0" class_label="3" x="81" y="77" z="63"/> + <object tomo_idx="0" class_label="3" x="32" y="25" z="24"/> + <object tomo_idx="0" class_label="3" x="31" y="71" z="32"/> + <object tomo_idx="0" class_label="4" x="32" y="28" z="39"/> + <object tomo_idx="0" class_label="4" x="72" y="39" z="47"/> + <object tomo_idx="0" class_label="4" x="34" y="78" z="53"/> + <object tomo_idx="0" class_label="5" x="15" y="47" z="54"/> + <object tomo_idx="0" class_label="5" x="56" y="26" z="51"/> + <object tomo_idx="0" class_label="5" x="20" y="84" z="52"/> + <object tomo_idx="0" class_label="6" x="15" y="82" z="22"/> + <object tomo_idx="0" class_label="6" x="53" y="55" z="75"/> + <object tomo_idx="0" class_label="6" x="18" y="43" z="31"/> + <object tomo_idx="0" class_label="7" x="81" y="80" z="23"/> + <object tomo_idx="0" class_label="7" x="17" y="15" z="72"/> + <object tomo_idx="0" class_label="7" x="60" y="65" z="60"/> + <object tomo_idx="0" class_label="8" x="46" y="45" z="48"/> + <object tomo_idx="0" class_label="8" x="53" y="76" z="80"/> + <object tomo_idx="0" class_label="8" x="36" y="45" z="30"/> + <object tomo_idx="0" class_label="9" x="22" y="45" z="83"/> + <object tomo_idx="0" class_label="9" x="58" y="48" z="30"/> + <object tomo_idx="0" class_label="9" x="76" y="31" z="27"/> +</objlist> diff --git a/data/tomo1_target.h5 b/data/tomo1_target.h5 new file mode 100644 index 0000000000000000000000000000000000000000..fa97345ed41ad90d741b010ca8ed7db5e0ba8ab5 Binary files /dev/null and b/data/tomo1_target.h5 differ diff --git a/data/tomo2_data.h5 b/data/tomo2_data.h5 new file mode 100644 index 0000000000000000000000000000000000000000..6c073bcef2b745cdd327afbd0bb9a037305b17b1 Binary files /dev/null and b/data/tomo2_data.h5 differ diff --git a/data/tomo2_objlist.xml b/data/tomo2_objlist.xml new file mode 100644 index 0000000000000000000000000000000000000000..f6c9cd585d9dbb70ab1463881f1f94b63c498cd0 --- /dev/null +++ b/data/tomo2_objlist.xml @@ -0,0 +1,29 @@ +<objlist> + <object tomo_idx="0" class_label="1" x="45" y="46" z="70"/> + <object tomo_idx="0" class_label="1" x="26" y="24" z="20"/> + <object tomo_idx="0" class_label="1" x="18" y="70" z="72"/> + <object tomo_idx="0" class_label="2" x="60" y="43" z="51"/> + <object tomo_idx="0" class_label="2" x="15" y="44" z="35"/> + <object tomo_idx="0" class_label="2" x="38" y="49" z="28"/> + <object tomo_idx="0" class_label="3" x="66" y="18" z="79"/> + <object tomo_idx="0" class_label="3" x="35" y="83" z="43"/> + <object tomo_idx="0" class_label="3" x="23" y="27" z="61"/> + <object tomo_idx="0" class_label="4" x="83" y="47" z="40"/> + <object tomo_idx="0" class_label="4" x="32" y="55" z="41"/> + <object tomo_idx="0" class_label="4" x="38" y="65" z="83"/> + <object tomo_idx="0" class_label="5" x="76" y="78" z="22"/> + <object tomo_idx="0" class_label="5" x="26" y="36" z="44"/> + <object tomo_idx="0" class_label="5" x="56" y="21" z="44"/> + <object tomo_idx="0" class_label="6" x="44" y="69" z="57"/> + <object tomo_idx="0" class_label="6" x="46" y="51" z="45"/> + <object tomo_idx="0" class_label="6" x="59" y="80" z="51"/> + <object tomo_idx="0" class_label="7" x="61" y="30" z="32"/> + <object tomo_idx="0" class_label="7" x="62" y="69" z="38"/> + <object tomo_idx="0" class_label="7" x="44" y="47" z="14"/> + <object tomo_idx="0" class_label="8" x="27" y="17" z="47"/> + <object tomo_idx="0" class_label="8" x="66" y="28" z="57"/> + <object tomo_idx="0" class_label="8" x="31" y="64" z="29"/> + <object tomo_idx="0" class_label="9" x="22" y="30" z="84"/> + <object tomo_idx="0" class_label="9" x="76" y="50" z="84"/> + <object tomo_idx="0" class_label="9" x="74" y="18" z="25"/> +</objlist> diff --git a/data/tomo2_target.h5 b/data/tomo2_target.h5 new file mode 100644 index 0000000000000000000000000000000000000000..5ae01bdff70bf8d4bac31f715e160a2564228c8e Binary files /dev/null and b/data/tomo2_target.h5 differ diff --git a/data/tomo3_data.h5 b/data/tomo3_data.h5 new file mode 100644 index 0000000000000000000000000000000000000000..628b86ba802b3b58e7436f9b7910225e324be651 Binary files /dev/null and b/data/tomo3_data.h5 differ diff --git a/data/tomo3_objlist.xml b/data/tomo3_objlist.xml new file mode 100644 index 0000000000000000000000000000000000000000..60e531fafdfec178a631fb74dc4c42955424f787 --- /dev/null +++ b/data/tomo3_objlist.xml @@ -0,0 +1,29 @@ +<objlist> + <object tomo_idx="0" class_label="1" x="20" y="61" z="32"/> + <object tomo_idx="0" class_label="1" x="15" y="55" z="79"/> + <object tomo_idx="0" class_label="1" x="73" y="48" z="75"/> + <object tomo_idx="0" class_label="2" x="16" y="58" z="53"/> + <object tomo_idx="0" class_label="2" x="49" y="18" z="40"/> + <object tomo_idx="0" class_label="2" x="25" y="29" z="23"/> + <object tomo_idx="0" class_label="3" x="34" y="59" z="17"/> + <object tomo_idx="0" class_label="3" x="33" y="75" z="74"/> + <object tomo_idx="0" class_label="3" x="43" y="59" z="54"/> + <object tomo_idx="0" class_label="4" x="22" y="20" z="81"/> + <object tomo_idx="0" class_label="4" x="54" y="58" z="26"/> + <object tomo_idx="0" class_label="4" x="80" y="74" z="19"/> + <object tomo_idx="0" class_label="5" x="24" y="42" z="27"/> + <object tomo_idx="0" class_label="5" x="60" y="47" z="38"/> + <object tomo_idx="0" class_label="5" x="18" y="45" z="39"/> + <object tomo_idx="0" class_label="6" x="52" y="83" z="23"/> + <object tomo_idx="0" class_label="6" x="39" y="27" z="21"/> + <object tomo_idx="0" class_label="6" x="27" y="38" z="51"/> + <object tomo_idx="0" class_label="7" x="66" y="22" z="78"/> + <object tomo_idx="0" class_label="7" x="56" y="27" z="54"/> + <object tomo_idx="0" class_label="7" x="77" y="29" z="24"/> + <object tomo_idx="0" class_label="8" x="28" y="77" z="15"/> + <object tomo_idx="0" class_label="8" x="35" y="37" z="84"/> + <object tomo_idx="0" class_label="8" x="72" y="22" z="52"/> + <object tomo_idx="0" class_label="9" x="77" y="48" z="53"/> + <object tomo_idx="0" class_label="9" x="59" y="22" z="22"/> + <object tomo_idx="0" class_label="9" x="23" y="23" z="63"/> +</objlist> diff --git a/data/tomo3_target.h5 b/data/tomo3_target.h5 new file mode 100644 index 0000000000000000000000000000000000000000..5bd9e139e3b5bf457b6343b1ecd3441e45982470 Binary files /dev/null and b/data/tomo3_target.h5 differ diff --git a/data/tomo4_data.h5 b/data/tomo4_data.h5 new file mode 100644 index 0000000000000000000000000000000000000000..c7da220cedbf80a45db154d77de0dc200f5cd20f Binary files /dev/null and b/data/tomo4_data.h5 differ diff --git a/data/tomo4_objlist.xml b/data/tomo4_objlist.xml new file mode 100644 index 0000000000000000000000000000000000000000..1cd3b43b26369fd408291baeda6779c57fe0bfb1 --- /dev/null +++ b/data/tomo4_objlist.xml @@ -0,0 +1,29 @@ +<objlist> + <object tomo_idx="0" class_label="1" x="53" y="64" z="78"/> + <object tomo_idx="0" class_label="1" x="83" y="58" z="26"/> + <object tomo_idx="0" class_label="1" x="83" y="42" z="62"/> + <object tomo_idx="0" class_label="2" x="41" y="25" z="58"/> + <object tomo_idx="0" class_label="2" x="39" y="75" z="67"/> + <object tomo_idx="0" class_label="2" x="35" y="38" z="56"/> + <object tomo_idx="0" class_label="3" x="53" y="39" z="44"/> + <object tomo_idx="0" class_label="3" x="23" y="44" z="44"/> + <object tomo_idx="0" class_label="3" x="22" y="66" z="66"/> + <object tomo_idx="0" class_label="4" x="21" y="29" z="47"/> + <object tomo_idx="0" class_label="4" x="61" y="26" z="27"/> + <object tomo_idx="0" class_label="4" x="25" y="63" z="50"/> + <object tomo_idx="0" class_label="5" x="79" y="16" z="32"/> + <object tomo_idx="0" class_label="5" x="79" y="26" z="80"/> + <object tomo_idx="0" class_label="5" x="19" y="30" z="67"/> + <object tomo_idx="0" class_label="6" x="71" y="42" z="44"/> + <object tomo_idx="0" class_label="6" x="74" y="69" z="71"/> + <object tomo_idx="0" class_label="6" x="30" y="84" z="39"/> + <object tomo_idx="0" class_label="7" x="27" y="44" z="16"/> + <object tomo_idx="0" class_label="7" x="81" y="27" z="54"/> + <object tomo_idx="0" class_label="7" x="40" y="30" z="79"/> + <object tomo_idx="0" class_label="8" x="60" y="39" z="65"/> + <object tomo_idx="0" class_label="8" x="55" y="57" z="41"/> + <object tomo_idx="0" class_label="8" x="19" y="83" z="70"/> + <object tomo_idx="0" class_label="9" x="77" y="76" z="52"/> + <object tomo_idx="0" class_label="9" x="62" y="80" z="83"/> + <object tomo_idx="0" class_label="9" x="21" y="77" z="28"/> +</objlist> diff --git a/data/tomo4_target.h5 b/data/tomo4_target.h5 new file mode 100644 index 0000000000000000000000000000000000000000..0d667ed0c38f60438cbf34928dddd552fe99d6da Binary files /dev/null and b/data/tomo4_target.h5 differ diff --git a/data/tomo5_data.h5 b/data/tomo5_data.h5 new file mode 100644 index 0000000000000000000000000000000000000000..bc40b003298a9b9e3b33986a5a94026119182fd9 Binary files /dev/null and b/data/tomo5_data.h5 differ diff --git a/data/tomo5_objlist.xml b/data/tomo5_objlist.xml new file mode 100644 index 0000000000000000000000000000000000000000..f28c232cd999e70ae975e5cb8b6e0c7a43f3529e --- /dev/null +++ b/data/tomo5_objlist.xml @@ -0,0 +1,29 @@ +<objlist> + <object tomo_idx="0" class_label="1" x="60" y="65" z="23"/> + <object tomo_idx="0" class_label="1" x="58" y="37" z="53"/> + <object tomo_idx="0" class_label="1" x="81" y="66" z="73"/> + <object tomo_idx="0" class_label="2" x="18" y="60" z="56"/> + <object tomo_idx="0" class_label="2" x="65" y="34" z="74"/> + <object tomo_idx="0" class_label="2" x="55" y="24" z="73"/> + <object tomo_idx="0" class_label="3" x="49" y="50" z="71"/> + <object tomo_idx="0" class_label="3" x="27" y="82" z="17"/> + <object tomo_idx="0" class_label="3" x="60" y="82" z="20"/> + <object tomo_idx="0" class_label="4" x="33" y="23" z="42"/> + <object tomo_idx="0" class_label="4" x="18" y="42" z="79"/> + <object tomo_idx="0" class_label="4" x="52" y="70" z="66"/> + <object tomo_idx="0" class_label="5" x="36" y="18" z="77"/> + <object tomo_idx="0" class_label="5" x="19" y="55" z="76"/> + <object tomo_idx="0" class_label="5" x="24" y="36" z="60"/> + <object tomo_idx="0" class_label="6" x="16" y="24" z="61"/> + <object tomo_idx="0" class_label="6" x="55" y="65" z="53"/> + <object tomo_idx="0" class_label="6" x="19" y="60" z="29"/> + <object tomo_idx="0" class_label="7" x="71" y="80" z="61"/> + <object tomo_idx="0" class_label="7" x="82" y="43" z="67"/> + <object tomo_idx="0" class_label="7" x="46" y="77" z="25"/> + <object tomo_idx="0" class_label="8" x="34" y="43" z="23"/> + <object tomo_idx="0" class_label="8" x="63" y="49" z="42"/> + <object tomo_idx="0" class_label="8" x="22" y="45" z="45"/> + <object tomo_idx="0" class_label="9" x="26" y="20" z="18"/> + <object tomo_idx="0" class_label="9" x="58" y="25" z="21"/> + <object tomo_idx="0" class_label="9" x="73" y="45" z="16"/> +</objlist> diff --git a/data/tomo5_target.h5 b/data/tomo5_target.h5 new file mode 100644 index 0000000000000000000000000000000000000000..51aa5285d51668c5187709738479b52945d4e469 Binary files /dev/null and b/data/tomo5_target.h5 differ diff --git a/data/tomo6_data.h5 b/data/tomo6_data.h5 new file mode 100644 index 0000000000000000000000000000000000000000..a70fbc2f417535c5c00f60261a93652e142ed653 Binary files /dev/null and b/data/tomo6_data.h5 differ diff --git a/data/tomo6_objlist.xml b/data/tomo6_objlist.xml new file mode 100644 index 0000000000000000000000000000000000000000..6389680ee3600fec74a61bc331d7b8bf128d6c6e --- /dev/null +++ b/data/tomo6_objlist.xml @@ -0,0 +1,29 @@ +<objlist> + <object tomo_idx="0" class_label="1" x="26" y="50" z="61"/> + <object tomo_idx="0" class_label="1" x="55" y="31" z="40"/> + <object tomo_idx="0" class_label="1" x="20" y="32" z="24"/> + <object tomo_idx="0" class_label="2" x="64" y="49" z="23"/> + <object tomo_idx="0" class_label="2" x="42" y="19" z="69"/> + <object tomo_idx="0" class_label="2" x="27" y="14" z="30"/> + <object tomo_idx="0" class_label="3" x="80" y="77" z="56"/> + <object tomo_idx="0" class_label="3" x="81" y="16" z="26"/> + <object tomo_idx="0" class_label="3" x="15" y="68" z="82"/> + <object tomo_idx="0" class_label="4" x="53" y="59" z="64"/> + <object tomo_idx="0" class_label="4" x="42" y="74" z="77"/> + <object tomo_idx="0" class_label="4" x="74" y="38" z="56"/> + <object tomo_idx="0" class_label="5" x="52" y="17" z="78"/> + <object tomo_idx="0" class_label="5" x="72" y="80" z="81"/> + <object tomo_idx="0" class_label="5" x="63" y="36" z="14"/> + <object tomo_idx="0" class_label="6" x="39" y="50" z="23"/> + <object tomo_idx="0" class_label="6" x="40" y="65" z="60"/> + <object tomo_idx="0" class_label="6" x="82" y="56" z="34"/> + <object tomo_idx="0" class_label="7" x="80" y="38" z="28"/> + <object tomo_idx="0" class_label="7" x="58" y="81" z="31"/> + <object tomo_idx="0" class_label="7" x="55" y="30" z="65"/> + <object tomo_idx="0" class_label="8" x="18" y="50" z="37"/> + <object tomo_idx="0" class_label="8" x="78" y="74" z="35"/> + <object tomo_idx="0" class_label="8" x="25" y="26" z="54"/> + <object tomo_idx="0" class_label="9" x="46" y="71" z="18"/> + <object tomo_idx="0" class_label="9" x="80" y="17" z="67"/> + <object tomo_idx="0" class_label="9" x="24" y="70" z="15"/> +</objlist> diff --git a/data/tomo6_target.h5 b/data/tomo6_target.h5 new file mode 100644 index 0000000000000000000000000000000000000000..ffafda0f417e18bd1fd8c4aee2d78f846ffc025b Binary files /dev/null and b/data/tomo6_target.h5 differ diff --git a/data/tomo7_data.h5 b/data/tomo7_data.h5 new file mode 100644 index 0000000000000000000000000000000000000000..eca49e3088eb1ea47953719fb83eb135220243f5 Binary files /dev/null and b/data/tomo7_data.h5 differ diff --git a/data/tomo7_objlist.xml b/data/tomo7_objlist.xml new file mode 100644 index 0000000000000000000000000000000000000000..f3cb6c8fcbb57ffd156f9884744d41d08c9a0648 --- /dev/null +++ b/data/tomo7_objlist.xml @@ -0,0 +1,29 @@ +<objlist> + <object tomo_idx="0" class_label="1" x="44" y="77" z="71"/> + <object tomo_idx="0" class_label="1" x="36" y="52" z="52"/> + <object tomo_idx="0" class_label="1" x="47" y="21" z="27"/> + <object tomo_idx="0" class_label="2" x="14" y="59" z="78"/> + <object tomo_idx="0" class_label="2" x="23" y="46" z="29"/> + <object tomo_idx="0" class_label="2" x="18" y="45" z="45"/> + <object tomo_idx="0" class_label="3" x="82" y="74" z="73"/> + <object tomo_idx="0" class_label="3" x="26" y="78" z="59"/> + <object tomo_idx="0" class_label="3" x="76" y="45" z="55"/> + <object tomo_idx="0" class_label="4" x="84" y="70" z="37"/> + <object tomo_idx="0" class_label="4" x="16" y="31" z="23"/> + <object tomo_idx="0" class_label="4" x="23" y="81" z="45"/> + <object tomo_idx="0" class_label="5" x="63" y="17" z="74"/> + <object tomo_idx="0" class_label="5" x="77" y="57" z="78"/> + <object tomo_idx="0" class_label="5" x="58" y="72" z="25"/> + <object tomo_idx="0" class_label="6" x="81" y="19" z="71"/> + <object tomo_idx="0" class_label="6" x="43" y="76" z="17"/> + <object tomo_idx="0" class_label="6" x="24" y="25" z="66"/> + <object tomo_idx="0" class_label="7" x="31" y="68" z="37"/> + <object tomo_idx="0" class_label="7" x="62" y="33" z="53"/> + <object tomo_idx="0" class_label="7" x="52" y="41" z="46"/> + <object tomo_idx="0" class_label="8" x="67" y="18" z="51"/> + <object tomo_idx="0" class_label="8" x="52" y="14" z="43"/> + <object tomo_idx="0" class_label="8" x="23" y="66" z="70"/> + <object tomo_idx="0" class_label="9" x="59" y="82" z="82"/> + <object tomo_idx="0" class_label="9" x="69" y="29" z="19"/> + <object tomo_idx="0" class_label="9" x="37" y="17" z="83"/> +</objlist> diff --git a/data/tomo7_target.h5 b/data/tomo7_target.h5 new file mode 100644 index 0000000000000000000000000000000000000000..69fcd41e59ac652fe9ac77c0498ae150b7fdadc9 Binary files /dev/null and b/data/tomo7_target.h5 differ diff --git a/data/tomo8_data.h5 b/data/tomo8_data.h5 new file mode 100644 index 0000000000000000000000000000000000000000..ccea7817e5114b68cea9f313510a4c24ea7a3aaf Binary files /dev/null and b/data/tomo8_data.h5 differ diff --git a/data/tomo8_objlist.xml b/data/tomo8_objlist.xml new file mode 100644 index 0000000000000000000000000000000000000000..07b5b388010a15a2f98a5b96c266f5c76ba80a45 --- /dev/null +++ b/data/tomo8_objlist.xml @@ -0,0 +1,29 @@ +<objlist> + <object tomo_idx="0" class_label="1" x="16" y="18" z="18"/> + <object tomo_idx="0" class_label="1" x="19" y="19" z="57"/> + <object tomo_idx="0" class_label="1" x="74" y="22" z="31"/> + <object tomo_idx="0" class_label="2" x="50" y="60" z="66"/> + <object tomo_idx="0" class_label="2" x="82" y="82" z="16"/> + <object tomo_idx="0" class_label="2" x="35" y="60" z="47"/> + <object tomo_idx="0" class_label="3" x="65" y="77" z="67"/> + <object tomo_idx="0" class_label="3" x="32" y="33" z="81"/> + <object tomo_idx="0" class_label="3" x="56" y="35" z="19"/> + <object tomo_idx="0" class_label="4" x="31" y="70" z="60"/> + <object tomo_idx="0" class_label="4" x="34" y="49" z="79"/> + <object tomo_idx="0" class_label="4" x="82" y="50" z="78"/> + <object tomo_idx="0" class_label="5" x="63" y="36" z="44"/> + <object tomo_idx="0" class_label="5" x="36" y="35" z="42"/> + <object tomo_idx="0" class_label="5" x="78" y="17" z="75"/> + <object tomo_idx="0" class_label="6" x="43" y="50" z="24"/> + <object tomo_idx="0" class_label="6" x="49" y="83" z="51"/> + <object tomo_idx="0" class_label="6" x="35" y="17" z="43"/> + <object tomo_idx="0" class_label="7" x="50" y="20" z="39"/> + <object tomo_idx="0" class_label="7" x="23" y="28" z="71"/> + <object tomo_idx="0" class_label="7" x="41" y="74" z="67"/> + <object tomo_idx="0" class_label="8" x="73" y="37" z="66"/> + <object tomo_idx="0" class_label="8" x="41" y="45" z="66"/> + <object tomo_idx="0" class_label="8" x="16" y="47" z="20"/> + <object tomo_idx="0" class_label="9" x="19" y="66" z="45"/> + <object tomo_idx="0" class_label="9" x="63" y="66" z="27"/> + <object tomo_idx="0" class_label="9" x="79" y="52" z="38"/> +</objlist> diff --git a/data/tomo8_target.h5 b/data/tomo8_target.h5 new file mode 100644 index 0000000000000000000000000000000000000000..a8f33b5b8f00a0c9ca1f5e3f314f4a5c90638ad7 Binary files /dev/null and b/data/tomo8_target.h5 differ diff --git a/data/tomo9_data.h5 b/data/tomo9_data.h5 new file mode 100644 index 0000000000000000000000000000000000000000..4dae03be4bd57ba6c3060928e3798a8ece6a664e Binary files /dev/null and b/data/tomo9_data.h5 differ diff --git a/data/tomo9_objlist.xml b/data/tomo9_objlist.xml new file mode 100644 index 0000000000000000000000000000000000000000..0f13054191a23c74938f4aac69685a295e7482f7 --- /dev/null +++ b/data/tomo9_objlist.xml @@ -0,0 +1,29 @@ +<objlist> + <object tomo_idx="0" class_label="1" x="24" y="80" z="75"/> + <object tomo_idx="0" class_label="1" x="77" y="59" z="83"/> + <object tomo_idx="0" class_label="1" x="30" y="58" z="15"/> + <object tomo_idx="0" class_label="2" x="44" y="31" z="45"/> + <object tomo_idx="0" class_label="2" x="32" y="76" z="17"/> + <object tomo_idx="0" class_label="2" x="56" y="39" z="21"/> + <object tomo_idx="0" class_label="3" x="52" y="18" z="73"/> + <object tomo_idx="0" class_label="3" x="49" y="60" z="34"/> + <object tomo_idx="0" class_label="3" x="77" y="77" z="16"/> + <object tomo_idx="0" class_label="4" x="64" y="27" z="23"/> + <object tomo_idx="0" class_label="4" x="65" y="69" z="29"/> + <object tomo_idx="0" class_label="4" x="80" y="53" z="33"/> + <object tomo_idx="0" class_label="5" x="69" y="81" z="62"/> + <object tomo_idx="0" class_label="5" x="48" y="68" z="83"/> + <object tomo_idx="0" class_label="5" x="47" y="16" z="41"/> + <object tomo_idx="0" class_label="6" x="72" y="47" z="43"/> + <object tomo_idx="0" class_label="6" x="66" y="79" z="51"/> + <object tomo_idx="0" class_label="6" x="48" y="41" z="59"/> + <object tomo_idx="0" class_label="7" x="47" y="25" z="15"/> + <object tomo_idx="0" class_label="7" x="77" y="60" z="18"/> + <object tomo_idx="0" class_label="7" x="83" y="34" z="45"/> + <object tomo_idx="0" class_label="8" x="73" y="73" z="72"/> + <object tomo_idx="0" class_label="8" x="75" y="33" z="56"/> + <object tomo_idx="0" class_label="8" x="15" y="38" z="19"/> + <object tomo_idx="0" class_label="9" x="17" y="39" z="76"/> + <object tomo_idx="0" class_label="9" x="79" y="40" z="15"/> + <object tomo_idx="0" class_label="9" x="44" y="46" z="84"/> +</objlist> diff --git a/data/tomo9_target.h5 b/data/tomo9_target.h5 new file mode 100644 index 0000000000000000000000000000000000000000..895a0bae5febd2df4e2a8486adb2671283d1735c Binary files /dev/null and b/data/tomo9_target.h5 differ diff --git a/deepfind.py b/deepfind.py index a9a6ef774a2c8f1247f37f8c7b8a44717446c4bd..9475ee96d901bea88d1ecb0cbfaae81800dc4244 100644 --- a/deepfind.py +++ b/deepfind.py @@ -14,6 +14,8 @@ from lxml import etree import models import losses +import core_utils +import utils class deepfind: def __init__(self, Ncl): @@ -28,22 +30,35 @@ class deepfind: self.Nvalid = 100 # number of samples for validation self.optimizer = Adam(lr=0.0001, beta_1=0.9, beta_2=0.999, epsilon=1e-08, decay=0.0) + self.flag_direct_read = 1 + self.flag_batch_bootstrap = 1 self.Lrnd = 13 # random shifts applied when sampling data- and target-patches (in voxels) - - self.net = models.my_model(self.dim_in, self.Ncl) - self.net.compile(optimizer=self.optimizer, loss=losses.tversky_loss, metrics=['accuracy']) - + # Segmentation, parameters for dividing data in patches: self.P = 192 # patch length (in pixels) /!\ has to a multiple of 4 (because of 2 pooling layers), so that dim_in=dim_out self.poverlap = 70 # patch overlap (in pixels) self.pcrop = 25 # how many pixels to crop from border - # Parameters: - # path_data : path to data (i.e. tomograms) - # path_target : path to targets (i.e. annotated volumes) - # objlist_train : object list containing coordinates of macromolecules - # objlist_valid : another object list for monitoring the training (in particular detect overfitting) + # This function launches the training procedure. For each epoch, an image is plotted, displaying the progression with different metrics: loss, accuracy, f1-score, recall, precision. Every 10 epochs, the current network weights are saved. + # The network is trained on small 3D patches (i.e. sub-volumes), sampled from the larger tomograms (due to memory limitation). The patch sampling is not realized randomly, but is guided by the macromolecule coordinates contained in so-called object lists (objlist). + # Concerning the loading of the dataset, two options are possible: + # flag_direct_read=0: + # flag_direct_read=1: + # Arguments: + # path_data : a list containing the paths to data files (i.e. tomograms) + # path_target : a list containing the paths to target files (i.e. annotated volumes) + # objlist_train : an xml structure containing information about annotated objects: origin tomogram (should correspond to the index of 'path_data' argument), coordinates, class. During training, these coordinates are used for guiding the patch sampling procedure. + # objlist_valid : same as 'objlist_train', but objects contained in this xml structure are not used for training, but for validation. It allows to monitor the training and check for over/under-fitting. Ideally, the validation objects should originate from different tomograms than training objects. def train(self, path_data, path_target, objlist_train, objlist_valid): + # Build network: + net = models.my_model(self.dim_in, self.Ncl) + net.compile(optimizer=self.optimizer, loss='categorical_crossentropy', metrics=['accuracy']) + + # Load whole dataset: + if self.flag_direct_read == False: + print('Loading dataset ...') + data_list, target_list = core_utils.load_dataset(path_data, path_target) + print('Launch training ...') # Declare lists for storing training statistics: @@ -61,18 +76,24 @@ class deepfind: # TRAINING: start = time.time() for it in range(self.steps_per_epoch): - batch_data, batch_target = self.generate_batch_direct_read(path_data, path_target, objlist_train, self.batch_size) - loss_train = self.net.train_on_batch(batch_data, batch_target) + if self.flag_direct_read: + batch_data, batch_target = self.generate_batch_direct_read(path_data, path_target, self.batch_size, objlist_train) + else: + batch_data, batch_target = self.generate_batch_from_array(data_list, target_list, self.batch_size, objlist_train) + loss_train = net.train_on_batch(batch_data, batch_target) print('epoch %d/%d - it %d/%d - loss: %0.3f - acc: %0.3f' % (e+1, self.epochs, it+1, self.steps_per_epoch, loss_train[0], loss_train[1])) hist_loss_train.append(loss_train[0]) hist_acc_train.append( loss_train[1]) # VALIDATION (compute statistics to monitor training): - batch_data_valid, batch_target_valid = self.generate_batch_direct_read(path_data, path_target, objlist_valid, self.Nvalid) - loss_val = self.net.evaluate(batch_data_valid, batch_target_valid, verbose=0) + if self.flag_direct_read: + batch_data_valid, batch_target_valid = self.generate_batch_direct_read(path_data, path_target, self.batch_size, objlist_valid) + else: + batch_data_valid, batch_target_valid = self.generate_batch_from_array(data_list, target_list, self.batch_size, objlist_valid) + loss_val = net.evaluate(batch_data_valid, batch_target_valid, verbose=0) - batch_pred = self.net.predict(batch_data_valid) + batch_pred = net.predict(batch_data_valid) scores = precision_recall_fscore_support(batch_target_valid.argmax(axis=-1).flatten(), batch_pred.argmax(axis=-1).flatten(), average=None) hist_loss_valid.append(loss_val[0]) @@ -87,16 +108,16 @@ class deepfind: print('EPOCH %d/%d - valid loss: %0.3f - valid acc: %0.3f - %0.2fsec' % (e+1, self.epochs, loss_val[0], loss_val[1], end-start)) print('=============================================================') - # Save training history: + # Save and plot training history: history = {'loss':hist_loss_train, 'acc':hist_acc_train, 'val_loss':hist_loss_valid, 'val_acc':hist_acc_valid, 'val_f1':hist_f1, 'val_recall':hist_recall, 'val_precision':hist_precision} - self.save_history(history) + core_utils.save_history(history) + core_utils.plot_history(history) if (e+1)%10 == 0: # save weights every 10 epochs - self.net.save('params_model_epoch'+str(e+1)+'.h5') + net.save('params_model_epoch'+str(e+1)+'.h5') print "Model took %0.2f seconds to train"%np.sum(process_time) - self.net.save('params_model_FINAL.h5') - + net.save('params_model_FINAL.h5') # This function generates training batches: # - Data and target patches are sampled, in order to avoid loading whole tomograms. @@ -108,47 +129,47 @@ class deepfind: # This is usefull when a class is under-represented. # /!\ TO DO: add a test to ignore objlist coordinate too close to border (using tomodim=h5data['dataset'].shape) - def generate_batch_direct_read(self, path_data, path_target, objlist, batch_size): - Nobj = objlist.shape[0] + def generate_batch_direct_read(self, path_data, path_target, batch_size, objlist=None): p_in = np.int( np.floor(self.dim_in /2) ) batch_data = np.zeros((batch_size, self.dim_in, self.dim_in, self.dim_in, 1)) batch_target = np.zeros((batch_size, self.dim_in, self.dim_in, self.dim_in, self.Ncl)) - - lblTAB = np.unique(objlist[:,0]) - Nbs = 100 - - # Bootstrap data so that we have equal frequencies (1/Nbs) for all classes: - bs_idx = [] - for l in lblTAB: - bs_idx.append( np.random.choice(np.squeeze(np.asarray(np.nonzero(objlist[:,0]==l))), Nbs) ) - bs_idx = np.concatenate(bs_idx) - for i in range(batch_size): - # Choose random sample in training set: - index = np.random.choice(bs_idx) - tomoID = objlist[index,1] + # The batch is generated by randomly sampling data patches. + if self.flag_batch_bootstrap: # choose from bootstrapped objlist + pool = core_utils.get_bootstrap_idx(objlist,Nbs=batch_size) + else: # choose from whole objlist + pool = range(0,len(objlist)) - # Add random shift to coordinates: - x = objlist[index,4] + np.random.choice(range(-self.Lrnd,self.Lrnd+1)) - y = objlist[index,3] + np.random.choice(range(-self.Lrnd,self.Lrnd+1)) - z = objlist[index,2] + np.random.choice(range(-self.Lrnd,self.Lrnd+1)) + + for i in range(batch_size): + # Choose random object in training set: + index = np.random.choice(pool) + + tomoID = int( objlist[index].get('tomo_idx') ) + + h5file = h5py.File(path_data[tomoID], 'r') + tomodim = h5file['dataset'].shape # get tomo dimensions without loading the array + h5file.close() + + x,y,z = core_utils.get_patch_position(tomodim, p_in, objlist[index], self.Lrnd) # Load data and target patches: - h5data = h5py.File(path_data[tomoID], 'r') - patch_data = h5data['dataset'][x-p_in:x+p_in, y-p_in:y+p_in, z-p_in:z+p_in] - patch_data = (patch_data - np.mean(patch_data)) / np.std(patch_data) # normalize - h5data.close() + h5file = h5py.File(path_data[tomoID], 'r') + patch_data = h5file['dataset'][x-p_in:x+p_in, y-p_in:y+p_in, z-p_in:z+p_in] + h5file.close() - h5target = h5py.File(path_target[tomoID], 'r') - patch_batch = h5target['dataset'][x-p_in:x+p_in, y-p_in:y+p_in, z-p_in:z+p_in] - #patch_batch[patch_batch==-1] = 0 # /!\ -1 labels from 'ignore mask' could generate trouble - patch_batch_onehot = to_categorical(patch_batch, self.Ncl) - h5target.close() + h5file = h5py.File(path_target[tomoID], 'r') + patch_target = h5file['dataset'][x-p_in:x+p_in, y-p_in:y+p_in, z-p_in:z+p_in] + h5file.close() + + # Process the patches in order to be used by network: + patch_data = (patch_data - np.mean(patch_data)) / np.std(patch_data) # normalize + patch_target_onehot = to_categorical(patch_target, self.Ncl) # Store into batch array: batch_data[i,:,:,:,0] = patch_data - batch_target[i] = patch_batch_onehot + batch_target[i] = patch_target_onehot # Data augmentation (180degree rotation around tilt axis): if np.random.uniform()<0.5: @@ -157,38 +178,59 @@ class deepfind: return batch_data, batch_target - def save_history(self, history): - h5trainhist = h5py.File('params_train_history.h5', 'w') - - # train and val loss & accuracy: - dset = h5trainhist.create_dataset('acc', (len(history['acc']),)) - dset[:] = history['acc'] - dset = h5trainhist.create_dataset('loss', (len(history['loss']),)) - dset[:] = history['loss'] - dset = h5trainhist.create_dataset('val_acc', (len(history['val_acc']),)) - dset[:] = history['val_acc'] - dset = h5trainhist.create_dataset('val_loss', (len(history['val_loss']),)) - dset[:] = history['val_loss'] - - # val precision, recall, F1: - dset = h5trainhist.create_dataset('val_f1', np.shape(history['val_f1'])) - dset[:] = history['val_f1'] - dset = h5trainhist.create_dataset('val_precision', np.shape(history['val_precision'])) - dset[:] = history['val_precision'] - dset = h5trainhist.create_dataset('val_recall', np.shape(history['val_recall'])) - dset[:] = history['val_recall'] - - h5trainhist.close() - return + def generate_batch_from_array(self, data, target, batch_size, objlist=None): + p_in = np.int( np.floor(self.dim_in /2) ) + + batch_data = np.zeros((batch_size, self.dim_in, self.dim_in, self.dim_in, 1)) + batch_target = np.zeros((batch_size, self.dim_in, self.dim_in, self.dim_in, self.Ncl)) + + # The batch is generated by randomly sampling data patches. + if self.flag_batch_bootstrap: # choose from bootstrapped objlist + pool = core_utils.get_bootstrap_idx(objlist,Nbs=batch_size) + else: # choose from whole objlist + pool = range(0,len(objlist)) + + for i in range(batch_size): + # choose random sample in training set: + index = np.random.choice(pool) + + tomoID = int( objlist[index].get('tomo_idx') ) + + tomodim = data[tomoID].shape + + sample_data = data[tomoID] + sample_target = target[tomoID] + + dim = sample_data.shape + + # Get patch position: + x,y,z = core_utils.get_patch_position(tomodim, p_in, objlist[index], self.Lrnd) + + # Get patch: + patch_data = sample_data[x-p_in:x+p_in, y-p_in:y+p_in, z-p_in:z+p_in] + patch_batch = sample_target[x-p_in:x+p_in, y-p_in:y+p_in, z-p_in:z+p_in] + + # Process the patches in order to be used by network: + patch_data = (patch_data - np.mean(patch_data)) / np.std(patch_data) # normalize + patch_batch_onehot = to_categorical(patch_batch, self.Ncl) + + # Store into batch array: + batch_data[i,:,:,:,0] = patch_data + batch_target[i] = patch_batch_onehot + + # Data augmentation (180degree rotation around tilt axis): + if np.random.uniform()<0.5: + batch_data[i] = np.rot90(batch_data[i] , k=2, axes=(0,2)) + batch_target[i] = np.rot90(batch_target[i], k=2, axes=(0,2)) + + return batch_data, batch_target def segment(self, dataArray, weights_path): - self.net = models.my_model(self.P, self.Ncl) - self.net.load_weights(weights_path) - # Load data: - #h5data = h5py.File(path_data, 'r') - #dataArray = h5data['dataset'][:] - #h5data.close() + # Build network: + net = models.my_model(self.P, self.Ncl) + net.load_weights(weights_path) + dataArray = (dataArray[:] - np.mean(dataArray[:])) / np.std(dataArray[:]) # normalize dim = dataArray.shape @@ -225,7 +267,7 @@ class deepfind: print('Segmenting patch ' + str(patchCount) + ' / ' + str(Npatch) + ' ...' ) patch = dataArray[x-l:x+l, y-l:y+l, z-l:z+l] patch = np.reshape(patch, (1,self.P,self.P,self.P,1)) # reshape for keras [batch,x,y,z,channel] - pred = self.net.predict(patch, batch_size=1) + pred = net.predict(patch, batch_size=1) predArray[x-lcrop:x+lcrop, y-lcrop:y+lcrop, z-lcrop:z+lcrop, :] = predArray[x-lcrop:x+lcrop, y-lcrop:y+lcrop, z-lcrop:z+lcrop, :] + pred[0, l-lcrop:l+lcrop,l-lcrop:l+lcrop,l-lcrop:l+lcrop, :] normArray[x-lcrop:x+lcrop, y-lcrop:y+lcrop, z-lcrop:z+lcrop] = normArray[x-lcrop:x+lcrop, y-lcrop:y+lcrop, z-lcrop:z+lcrop] + np.ones((self.P-2*self.pcrop,self.P-2*self.pcrop,self.P-2*self.pcrop)) @@ -240,17 +282,21 @@ class deepfind: print "Model took %0.2f seconds to predict"%(end - start) return predArray # predArray is the array containing the scoremaps - - # Save scoremaps: - #path, filename = os.path.split(path_data) - #scoremap_file = filename[:-3]+'_scoremaps.h5' - #h5scoremap = h5py.File(scoremap_file, 'w') - #for cl in range(0,self.Ncl): - # dset = h5scoremap.create_dataset('class'+str(cl), (dim[0], dim[1], dim[2]), dtype='float16' ) - # dset[:] = np.float16(predArray[:,:,:,cl]) - #h5scoremap.close() - # For binning: skimage.measure.block_reduce(mat, (2,2), np.mean) + def segment_single_block(self, dataArray, weights_path): + # Build network: + net = models.my_model(self.P, self.Ncl) + net.load_weights(weights_path) + + dim = dataArray.shape + dataArray = (dataArray[:] - np.mean(dataArray[:])) / np.std(dataArray[:]) # normalize + dataArray = np.reshape(dataArray, (1,dim[0],dim[1],dim[2],1)) # reshape for keras [batch,x,y,z,channel] + + pred = net.predict(dataArray, batch_size=1) + predArray = pred[0,:,:,:,:] + + return predArray + def cluster(self, labelmap, sizeThr, clustRadius): Nclass = len(np.unique(labelmap)) - 1 # object classes only (background class not considered) @@ -287,13 +333,7 @@ class deepfind: labelcount[l] = np.size(np.nonzero( np.array(clustMember)==l+1 )) winninglabel = np.argmax(labelcount)+1 - # Store cluster infos in array: - #objlist[c,0] = clustSize - #objlist[c,1] = centroid[0] - #objlist[c,2] = centroid[1] - #objlist[c,3] = centroid[2] - #objlist[c,4] = winninglabel - + # Store cluster infos in xml structure: obj = etree.SubElement(objlist, 'object') obj.set('cluster_size', str(clustSize)) obj.set('class_label' , str(winninglabel)) diff --git a/examples b/examples new file mode 160000 index 0000000000000000000000000000000000000000..c7220091707ec653bece687be8d61e114269f0b3 --- /dev/null +++ b/examples @@ -0,0 +1 @@ +Subproject commit c7220091707ec653bece687be8d61e114269f0b3 diff --git a/utils.py b/utils.py index aece2163df2bfa0dc345fe22bfc7bcdbfd5b9244..9f564705751fb5c223e30daa22417af1cbe43e61 100644 --- a/utils.py +++ b/utils.py @@ -2,9 +2,45 @@ import numpy as np import h5py from skimage.measure import block_reduce from lxml import etree +from copy import deepcopy +from sklearn.metrics import pairwise_distances + +import matplotlib +matplotlib.use('agg') # necessary else: AttributeError: 'NoneType' object has no attribute 'is_interactive' +import matplotlib.pyplot as plt #def bin_data: +# Realizes quick visualization of a volume, by plotting its orthoslices, in the same fashion as the matlab function 'tom_volxyz' (TOM toolbox) +# If volume type is int8, the function assumes that the volume is a labelmap, and hence plots in color scale. +# Else, it assumes that the volume is tomographic data, and plots in gray scale. +def plot_volume_orthoslices(vol, filename): + # Get central slices along each dimension: + dim = vol.shape + idx0 = np.round(dim[0]/2) + idx1 = np.round(dim[1]/2) + idx2 = np.round(dim[2]/2) + + slice0 = vol[idx0,:,:] + slice1 = vol[:,idx1,:] + slice2 = vol[:,:,idx2] + + # Build image containing orthoslices: + img_array = np.zeros((slice0.shape[0]+slice1.shape[0], slice0.shape[1]+slice1.shape[0])) + img_array[0:slice0.shape[0], 0:slice0.shape[1]] = slice0 + img_array[slice0.shape[0]-1:-1, 0:slice0.shape[1]] = slice1 + img_array[0:slice0.shape[0], slice0.shape[1]-1:-1] = np.flipud(np.rot90(slice2)) + + # Drop the plot: + fig = plt.figure(figsize=(10,10)) + if vol.dtype==np.int8: + plt.imshow(img_array, cmap='CMRmap', vmin=np.min(vol), vmax=np.max(vol)) + else: + mu = np.mean(vol) # Get mean and std of data for plot range: + sig = np.std(vol) + plt.imshow(img_array, cmap='gray', vmin=mu-5*sig, vmax=mu+5*sig) + fig.savefig(filename) + def write_objlist(objlist, filename): tree = etree.ElementTree(objlist) tree.write(filename, pretty_print=True) @@ -13,6 +49,63 @@ def read_objlist(filename): tree = etree.parse(filename) objlist = tree.getroot() return objlist + +def print_objlist(objlist): + print(etree.tostring(objlist)) + +# /!\ for now this function does not know how to handle empty objlists +def get_Ntp_from_objlist(objl_gt, objl_df, tol_pos_err): + # tolerated position error (in voxel) + Ngt = len(objl_gt) + Ndf = len(objl_df) + coords_gt = np.zeros((Ngt,3)) + coords_df = np.zeros((Ndf,3)) + + for idx in range(0,Ngt): + coords_gt[idx,0] = objl_gt[idx].get('x') + coords_gt[idx,1] = objl_gt[idx].get('y') + coords_gt[idx,2] = objl_gt[idx].get('z') + for idx in range(0,Ndf): + coords_df[idx,0] = objl_df[idx].get('x') + coords_df[idx,1] = objl_df[idx].get('y') + coords_df[idx,2] = objl_df[idx].get('z') + + # Get pairwise distance matrix: + D = pairwise_distances(coords_gt, coords_df, metric='euclidean') + + # Get pairs that are closer than tol_pos_err: + D = D<=tol_pos_err + + # A detected object is considered a true positive (TP) if it is closer than tol_pos_err to a ground truth object. + match_vector = np.sum(D,axis=1) + Ntp = np.sum(match_vector==1) + return Ntp + +def objlist_get_class(objlistIN, label): + N = len(objlistIN) + label_list = np.zeros((N,)) + for idx in range(0,N): + label_list[idx] = objlistIN[idx].get('class_label') + idx_class = np.nonzero(label_list==label) + idx_class = idx_class[0] + + objlistOUT = etree.Element('objlist') + for idx in range(0,len(idx_class)): + objlistOUT.append( deepcopy(objlistIN[idx_class[idx]]) ) # deepcopy is necessary, else the object is removed from objlIN when appended to objlOUT + return objlistOUT + +def objlist_above_thr(objlistIN, thr): + N = len(objlistIN) + clust_size_list = np.zeros((N,)) + for idx in range(0,N): + clust_size_list[idx] = objlistIN[idx].get('cluster_size') + idx_thr = np.nonzero(clust_size_list>=thr) + idx_thr = idx_thr[0] + + objlistOUT = etree.Element('objlist') + for idx in range(0,len(idx_thr)): + objlistOUT.append( deepcopy(objlistIN[idx_thr[idx]]) ) # deepcopy is necessary, else the object is removed from objlIN when appended to objlOUT + return objlistOUT def bin_scoremaps(scoremaps): dim = scoremaps.shape @@ -28,34 +121,40 @@ def scoremaps2labelmap(scoremaps): return labelmap def load_h5array(filename): # rename to read_ - h5data = h5py.File(filename, 'r') - dataArray = h5data['dataset'][:] - h5data.close() + h5file = h5py.File(filename, 'r') + dataArray = h5file['dataset'][:] + h5file.close() return dataArray +def write_h5array(array, filename): + h5file = h5py.File(filename, 'w') + dset = h5file.create_dataset('dataset', array.shape, dtype='float16' ) + dset[:] = np.float16(array) + h5file.close() + def load_scoremaps(filename): # rename to read_ - h5data = h5py.File(filename, 'r') - datasetnames = h5data.keys() + h5file = h5py.File(filename, 'r') + datasetnames = h5file.keys() Ncl = len(datasetnames) - dim = h5data['class0'].shape + dim = h5file['class0'].shape scoremaps = np.zeros((dim[0],dim[1],dim[2],Ncl)) for cl in range(0,Ncl): - scoremaps[:,:,:,cl] = h5data['class'+str(cl)][:] - h5data.close() + scoremaps[:,:,:,cl] = h5file['class'+str(cl)][:] + h5file.close() return scoremaps def write_scoremaps(scoremaps, filename): - h5scoremap = h5py.File(filename, 'w') + h5file = h5py.File(filename, 'w') dim = scoremaps.shape Ncl = dim[3] for cl in range(0,Ncl): - dset = h5scoremap.create_dataset('class'+str(cl), (dim[0], dim[1], dim[2]), dtype='float16' ) + dset = h5file.create_dataset('class'+str(cl), (dim[0], dim[1], dim[2]), dtype='float16' ) dset[:] = np.float16(scoremaps[:,:,:,cl]) - h5scoremap.close() + h5file.close() def write_labelmap(labelmap, filename): dim = labelmap.shape - h5lblmap = h5py.File(filename, 'w') - dset = h5lblmap.create_dataset('dataset', (dim[0],dim[1],dim[2]), dtype='int8' ) + h5file = h5py.File(filename, 'w') + dset = h5file.create_dataset('dataset', (dim[0],dim[1],dim[2]), dtype='int8' ) dset[:] = np.int8(labelmap) - h5lblmap.close() \ No newline at end of file + h5file.close() \ No newline at end of file