From da4f04c6760d81cd23646133b05a547fef9d4736 Mon Sep 17 00:00:00 2001
From: cpatte <cecile.patte@inria.fr>
Date: Tue, 15 Jan 2019 08:40:59 +0100
Subject: [PATCH] add assertion error, add argument print_warped_mesh, change
 name of argument field_name into working_displacement_field_name

---
 compute_warped_images.py | 15 +++++++++------
 1 file changed, 9 insertions(+), 6 deletions(-)

diff --git a/compute_warped_images.py b/compute_warped_images.py
index a3590b3..a2e02dd 100644
--- a/compute_warped_images.py
+++ b/compute_warped_images.py
@@ -29,7 +29,8 @@ def compute_warped_images(
         ref_frame=0,
         ref_image_model=None,
         working_ext="vtu",
-        field_name="displacement",
+        working_displacement_field_name="displacement",
+        print_warped_mesh=0,
         verbose=0):
 
     ref_image_zfill = len(glob.glob(ref_image_folder+"/"+ref_image_basename+"_*.vti")[0].rsplit("_")[-1].split(".")[0])
@@ -70,7 +71,8 @@ def compute_warped_images(
         mesh = myvtk.readUGrid(
             filename=working_folder+"/"+working_basename+"_"+str(k_frame).zfill(working_zfill)+"."+working_ext)
         # print mesh
-        mesh.GetPointData().SetActiveVectors(field_name)
+        assert (mesh.GetPointData().HasArray(working_displacement_field_name)), "no array '" + working_displacement_field_name + "' in mesh"
+        mesh.GetPointData().SetActiveVectors(working_displacement_field_name)
 
         warp = vtk.vtkWarpVector()
         if (vtk.vtkVersion.GetVTKMajorVersion() >= 6):
@@ -79,9 +81,10 @@ def compute_warped_images(
             warp.SetInput(mesh)
         warp.Update()
         warped_mesh = warp.GetOutput()
-        myvtk.writeUGrid(
-            ugrid=warped_mesh,
-            filename=working_folder+"/"+working_basename+"-warped_"+str(k_frame).zfill(working_zfill)+"."+working_ext)
+        if print_warped_mesh:
+            myvtk.writeUGrid(
+                ugrid=warped_mesh,
+                filename=working_folder+"/"+working_basename+"-warped_"+str(k_frame).zfill(working_zfill)+"."+working_ext)
 
         probe = vtk.vtkProbeFilter()
         if (vtk.vtkVersion.GetVTKMajorVersion() >= 6):
@@ -93,7 +96,7 @@ def compute_warped_images(
         probe.Update()
         probed_image = probe.GetOutput()
         scalars_mask = probed_image.GetPointData().GetArray("vtkValidPointMask")
-        scalars_U = probed_image.GetPointData().GetArray(field_name)
+        scalars_U = probed_image.GetPointData().GetArray(working_displacement_field_name)
         #myvtk.writeImage(
             #image=probed_image,
             #filename=working_folder+"/"+working_basename+"_"+str(k_frame).zfill(working_zfill)+".vti")
-- 
GitLab