From 00059c9a3a88ae8ce1cb6a69168eb58e7fa1020d Mon Sep 17 00:00:00 2001
From: Martin Genet <martin.genet@polytechnique.edu>
Date: Tue, 25 Jun 2019 16:55:06 +0200
Subject: [PATCH] Better mechanism to overwrite/keep temporary images in
 generate_images

---
 downsample_images.py | 24 +++++++++++++++++-------
 generate_images.py   | 35 ++++++++++++++++++++---------------
 normalize_images.py  |  3 ++-
 3 files changed, 39 insertions(+), 23 deletions(-)

diff --git a/downsample_images.py b/downsample_images.py
index 4031a70..41f0ea2 100644
--- a/downsample_images.py
+++ b/downsample_images.py
@@ -31,6 +31,7 @@ def downsample_images(
         downsampling_factors,
         images_ext="vti",
         keep_resolution=0,
+        overwrite_orig_images=1,
         write_temp_images=0,
         verbose=0):
 
@@ -185,7 +186,11 @@ def downsample_images(
     else:
         rfft.SetInputData(image_downsampled) # MG20190520: Not sure why this does not work.
 
-    writer.SetInputConnection(rfft.GetOutputPort())
+    extract = vtk.vtkImageExtractComponents()
+    extract.SetInputConnection(rfft.GetOutputPort())
+    extract.SetComponents(0)
+
+    writer.SetInputConnection(extract.GetOutputPort())
 
     if (keep_resolution):
         for k_frame in range(images_nframes):
@@ -201,7 +206,7 @@ def downsample_images(
                 writer_mul.SetFileName(images_folder+"/"+images_basename+"_mul"+"_"+str(k_frame).zfill(images_zfill)+"."+images_ext)
                 writer_mul.Write()
 
-            writer.SetFileName(images_folder+"/"+images_basename+"_downsampled"+"_"+str(k_frame).zfill(images_zfill)+"."+images_ext)
+            writer.SetFileName(images_folder+"/"+images_basename+("_downsampled")*(not overwrite_orig_images)+"_"+str(k_frame).zfill(images_zfill)+"."+images_ext)
             writer.Write()
     else:
         for k_frame in range(images_nframes):
@@ -255,11 +260,16 @@ def downsample_images(
                 writer_sel.SetFileName(images_folder+"/"+images_basename+"_sel"+"_"+str(k_frame).zfill(images_zfill)+"."+images_ext)
                 writer_sel.Write()
 
-            rfft = vtk.vtkImageRFFT()             # MG20190520: Not sure why this is needed.
-            rfft.SetDimensionality(images_ndim)   # MG20190520: Not sure why this is needed.
-            rfft.SetInputData(image_downsampled)  # MG20190520: Not sure why this is needed.
+            rfft = vtk.vtkImageRFFT()                 # MG20190520: Not sure why this is needed.
+            rfft.SetDimensionality(images_ndim)       # MG20190520: Not sure why this is needed.
+            rfft.SetInputData(image_downsampled)      # MG20190520: Not sure why this is needed.
             rfft.Update()
 
-            writer.SetInputData(rfft.GetOutput()) # MG20190520: Not sure why this is needed.
-            writer.SetFileName(images_folder+"/"+images_basename+"_downsampled"+"_"+str(k_frame).zfill(images_zfill)+"."+images_ext)
+            extract = vtk.vtkImageExtractComponents() # MG20190520: Not sure why this is needed.
+            extract.SetInputData(rfft.GetOutput())    # MG20190520: Not sure why this is needed.
+            extract.SetComponents(0)                  # MG20190520: Not sure why this is needed.
+            extract.Update()                          # MG20190520: Not sure why this is needed.
+
+            writer.SetInputData(extract.GetOutput())  # MG20190520: Not sure why this is needed.
+            writer.SetFileName(images_folder+"/"+images_basename+("_downsampled")*(not overwrite_orig_images)+"_"+str(k_frame).zfill(images_zfill)+"."+images_ext)
             writer.Write()
diff --git a/generate_images.py b/generate_images.py
index 2cf9491..cbc4260 100644
--- a/generate_images.py
+++ b/generate_images.py
@@ -62,7 +62,8 @@ def generate_images(
         noise,
         deformation,
         evolution,
-        generate_image_gradient=False,
+        generate_image_gradient=0,
+        keep_temp_images=0,
         verbose=0):
 
     mypy.my_print(verbose, "*** generate_images ***")
@@ -188,16 +189,18 @@ def generate_images(
             images_basename=images["basename"],
             downsampling_factors=images["upsampling_factors"],
             keep_resolution=0,
+            overwrite_orig_images=(not keep_temp_images),
             write_temp_images=0,
             verbose=verbose)
 
-        for k_frame in range(images["n_frames"]):
-            os.rename(
-                images["folder"]+"/"+images["basename"]+"_"+str(k_frame).zfill(images["zfill"])+"."+images["ext"],
-                images["folder"]+"/"+images["basename"]+"_upsampled"+"_"+str(k_frame).zfill(images["zfill"])+"."+images["ext"])
-            os.rename(
-                images["folder"]+"/"+images["basename"]+"_downsampled"+"_"+str(k_frame).zfill(images["zfill"])+"."+images["ext"],
-                images["folder"]+"/"+images["basename"]+"_"+str(k_frame).zfill(images["zfill"])+"."+images["ext"])
+        if (keep_temp_images):
+            for k_frame in range(images["n_frames"]):
+                os.rename(
+                    images["folder"]+"/"+images["basename"]+"_"+str(k_frame).zfill(images["zfill"])+"."+images["ext"],
+                    images["folder"]+"/"+images["basename"]+"_upsampled"+"_"+str(k_frame).zfill(images["zfill"])+"."+images["ext"])
+                os.rename(
+                    images["folder"]+"/"+images["basename"]+"_downsampled"+"_"+str(k_frame).zfill(images["zfill"])+"."+images["ext"],
+                    images["folder"]+"/"+images["basename"]+"_"+str(k_frame).zfill(images["zfill"])+"."+images["ext"])
 
     if (images["data_type"] in ("float")):
         normalizing = False
@@ -208,12 +211,14 @@ def generate_images(
             images_folder=images["folder"],
             images_basename=images["basename"],
             images_datatype=images["data_type"],
+            overwrite_orig_images=(not keep_temp_images),
             verbose=verbose)
 
-        for k_frame in range(images["n_frames"]):
-            os.rename(
-                src=images["folder"]+"/"+images["basename"]+"_"+str(k_frame).zfill(images["zfill"])+"."+images["ext"],
-                dst=images["folder"]+"/"+images["basename"]+"_prenormalized"+"_"+str(k_frame).zfill(images["zfill"])+"."+images["ext"])
-            os.rename(
-                src=images["folder"]+"/"+images["basename"]+"_normalized"+"_"+str(k_frame).zfill(images["zfill"])+"."+images["ext"],
-                dst=images["folder"]+"/"+images["basename"]+"_"+str(k_frame).zfill(images["zfill"])+"."+images["ext"])
+        if (keep_temp_images):
+            for k_frame in range(images["n_frames"]):
+                os.rename(
+                    src=images["folder"]+"/"+images["basename"]+"_"+str(k_frame).zfill(images["zfill"])+"."+images["ext"],
+                    dst=images["folder"]+"/"+images["basename"]+"_prenormalized"+"_"+str(k_frame).zfill(images["zfill"])+"."+images["ext"])
+                os.rename(
+                    src=images["folder"]+"/"+images["basename"]+"_normalized"+"_"+str(k_frame).zfill(images["zfill"])+"."+images["ext"],
+                    dst=images["folder"]+"/"+images["basename"]+"_"+str(k_frame).zfill(images["zfill"])+"."+images["ext"])
diff --git a/normalize_images.py b/normalize_images.py
index e292458..d2b1c82 100644
--- a/normalize_images.py
+++ b/normalize_images.py
@@ -30,6 +30,7 @@ def normalize_images(
         images_basename,
         images_datatype,
         images_ext="vti",
+        overwrite_orig_images=1,
         verbose=0):
 
     mypy.my_print(verbose, "*** normalize_images ***")
@@ -94,5 +95,5 @@ def normalize_images(
         mypy.my_print(verbose, "k_frame = "+str(k_frame))
 
         reader.SetFileName(images_folder+"/"+images_basename              +"_"+str(k_frame).zfill(images_zfill)+"."+images_ext)
-        writer.SetFileName(images_folder+"/"+images_basename+"_normalized"+"_"+str(k_frame).zfill(images_zfill)+"."+images_ext)
+        writer.SetFileName(images_folder+"/"+images_basename+("_normalized")*(not overwrite_orig_images)+"_"+str(k_frame).zfill(images_zfill)+"."+images_ext)
         writer.Write()
-- 
GitLab