diff --git a/compute_warped_images.py b/compute_warped_images.py
index 085354bebd203d9592af86afaa96d44dbc5af4f2..a2e02ddb7db88544b67023c78005e4bfd584f412 100644
--- a/compute_warped_images.py
+++ b/compute_warped_images.py
@@ -29,6 +29,8 @@ def compute_warped_images(
         ref_frame=0,
         ref_image_model=None,
         working_ext="vtu",
+        working_displacement_field_name="displacement",
+        print_warped_mesh=0,
         verbose=0):
 
     ref_image_zfill = len(glob.glob(ref_image_folder+"/"+ref_image_basename+"_*.vti")[0].rsplit("_")[-1].split(".")[0])
@@ -56,7 +58,7 @@ def compute_warped_images(
 
     working_zfill = len(glob.glob(working_folder+"/"+working_basename+"_*."+working_ext)[0].rsplit("_")[-1].split(".")[0])
     n_frames = len(glob.glob(working_folder+"/"+working_basename+"_"+"[0-9]"*working_zfill+"."+working_ext))
-    #n_frames = 1
+    # n_frames = 1
 
     X = numpy.empty(3)
     U = numpy.empty(3)
@@ -68,7 +70,9 @@ def compute_warped_images(
 
         mesh = myvtk.readUGrid(
             filename=working_folder+"/"+working_basename+"_"+str(k_frame).zfill(working_zfill)+"."+working_ext)
-        #print mesh
+        # print mesh
+        assert (mesh.GetPointData().HasArray(working_displacement_field_name)), "no array '" + working_displacement_field_name + "' in mesh"
+        mesh.GetPointData().SetActiveVectors(working_displacement_field_name)
 
         warp = vtk.vtkWarpVector()
         if (vtk.vtkVersion.GetVTKMajorVersion() >= 6):
@@ -77,9 +81,10 @@ def compute_warped_images(
             warp.SetInput(mesh)
         warp.Update()
         warped_mesh = warp.GetOutput()
-        #myvtk.writeUGrid(
-            #ugrid=warped_mesh,
-            #filename=working_folder+"/"+working_basename+"-warped_"+str(k_frame).zfill(working_zfill)+"."+working_ext)
+        if print_warped_mesh:
+            myvtk.writeUGrid(
+                ugrid=warped_mesh,
+                filename=working_folder+"/"+working_basename+"-warped_"+str(k_frame).zfill(working_zfill)+"."+working_ext)
 
         probe = vtk.vtkProbeFilter()
         if (vtk.vtkVersion.GetVTKMajorVersion() >= 6):
@@ -91,7 +96,7 @@ def compute_warped_images(
         probe.Update()
         probed_image = probe.GetOutput()
         scalars_mask = probed_image.GetPointData().GetArray("vtkValidPointMask")
-        scalars_U = probed_image.GetPointData().GetArray("displacement")
+        scalars_U = probed_image.GetPointData().GetArray(working_displacement_field_name)
         #myvtk.writeImage(
             #image=probed_image,
             #filename=working_folder+"/"+working_basename+"_"+str(k_frame).zfill(working_zfill)+".vti")