diff --git a/compute_warped_images.py b/compute_warped_images.py
index 085354bebd203d9592af86afaa96d44dbc5af4f2..a3590b348e76bdf75c98c58a36a0f1513cd29904 100644
--- a/compute_warped_images.py
+++ b/compute_warped_images.py
@@ -29,6 +29,7 @@ def compute_warped_images(
         ref_frame=0,
         ref_image_model=None,
         working_ext="vtu",
+        field_name="displacement",
         verbose=0):
 
     ref_image_zfill = len(glob.glob(ref_image_folder+"/"+ref_image_basename+"_*.vti")[0].rsplit("_")[-1].split(".")[0])
@@ -56,7 +57,7 @@ def compute_warped_images(
 
     working_zfill = len(glob.glob(working_folder+"/"+working_basename+"_*."+working_ext)[0].rsplit("_")[-1].split(".")[0])
     n_frames = len(glob.glob(working_folder+"/"+working_basename+"_"+"[0-9]"*working_zfill+"."+working_ext))
-    #n_frames = 1
+    # n_frames = 1
 
     X = numpy.empty(3)
     U = numpy.empty(3)
@@ -68,7 +69,8 @@ def compute_warped_images(
 
         mesh = myvtk.readUGrid(
             filename=working_folder+"/"+working_basename+"_"+str(k_frame).zfill(working_zfill)+"."+working_ext)
-        #print mesh
+        # print mesh
+        mesh.GetPointData().SetActiveVectors(field_name)
 
         warp = vtk.vtkWarpVector()
         if (vtk.vtkVersion.GetVTKMajorVersion() >= 6):
@@ -77,9 +79,9 @@ def compute_warped_images(
             warp.SetInput(mesh)
         warp.Update()
         warped_mesh = warp.GetOutput()
-        #myvtk.writeUGrid(
-            #ugrid=warped_mesh,
-            #filename=working_folder+"/"+working_basename+"-warped_"+str(k_frame).zfill(working_zfill)+"."+working_ext)
+        myvtk.writeUGrid(
+            ugrid=warped_mesh,
+            filename=working_folder+"/"+working_basename+"-warped_"+str(k_frame).zfill(working_zfill)+"."+working_ext)
 
         probe = vtk.vtkProbeFilter()
         if (vtk.vtkVersion.GetVTKMajorVersion() >= 6):
@@ -91,7 +93,7 @@ def compute_warped_images(
         probe.Update()
         probed_image = probe.GetOutput()
         scalars_mask = probed_image.GetPointData().GetArray("vtkValidPointMask")
-        scalars_U = probed_image.GetPointData().GetArray("displacement")
+        scalars_U = probed_image.GetPointData().GetArray(field_name)
         #myvtk.writeImage(
             #image=probed_image,
             #filename=working_folder+"/"+working_basename+"_"+str(k_frame).zfill(working_zfill)+".vti")