diff --git a/compute_warped_images.py b/compute_warped_images.py
index a3590b348e76bdf75c98c58a36a0f1513cd29904..a2e02ddb7db88544b67023c78005e4bfd584f412 100644
--- a/compute_warped_images.py
+++ b/compute_warped_images.py
@@ -29,7 +29,8 @@ def compute_warped_images(
         ref_frame=0,
         ref_image_model=None,
         working_ext="vtu",
-        field_name="displacement",
+        working_displacement_field_name="displacement",
+        print_warped_mesh=0,
         verbose=0):
 
     ref_image_zfill = len(glob.glob(ref_image_folder+"/"+ref_image_basename+"_*.vti")[0].rsplit("_")[-1].split(".")[0])
@@ -70,7 +71,8 @@ def compute_warped_images(
         mesh = myvtk.readUGrid(
             filename=working_folder+"/"+working_basename+"_"+str(k_frame).zfill(working_zfill)+"."+working_ext)
         # print mesh
-        mesh.GetPointData().SetActiveVectors(field_name)
+        assert (mesh.GetPointData().HasArray(working_displacement_field_name)), "no array '" + working_displacement_field_name + "' in mesh"
+        mesh.GetPointData().SetActiveVectors(working_displacement_field_name)
 
         warp = vtk.vtkWarpVector()
         if (vtk.vtkVersion.GetVTKMajorVersion() >= 6):
@@ -79,9 +81,10 @@ def compute_warped_images(
             warp.SetInput(mesh)
         warp.Update()
         warped_mesh = warp.GetOutput()
-        myvtk.writeUGrid(
-            ugrid=warped_mesh,
-            filename=working_folder+"/"+working_basename+"-warped_"+str(k_frame).zfill(working_zfill)+"."+working_ext)
+        if print_warped_mesh:
+            myvtk.writeUGrid(
+                ugrid=warped_mesh,
+                filename=working_folder+"/"+working_basename+"-warped_"+str(k_frame).zfill(working_zfill)+"."+working_ext)
 
         probe = vtk.vtkProbeFilter()
         if (vtk.vtkVersion.GetVTKMajorVersion() >= 6):
@@ -93,7 +96,7 @@ def compute_warped_images(
         probe.Update()
         probed_image = probe.GetOutput()
         scalars_mask = probed_image.GetPointData().GetArray("vtkValidPointMask")
-        scalars_U = probed_image.GetPointData().GetArray(field_name)
+        scalars_U = probed_image.GetPointData().GetArray(working_displacement_field_name)
         #myvtk.writeImage(
             #image=probed_image,
             #filename=working_folder+"/"+working_basename+"_"+str(k_frame).zfill(working_zfill)+".vti")