diff --git a/ImageIterator.py b/ImageIterator.py
index 3e6fae3e2a48cef8ab652587cdd62b1d1796d4be..e650c94a79d970e8f14af73361e1ed66a9e0ff24 100644
--- a/ImageIterator.py
+++ b/ImageIterator.py
@@ -84,6 +84,8 @@ class ImageIterator():
                 basename=self.initialize_U_basename,
                 ext=self.initialize_U_ext)
 
+            dof_to_vertex_map = dolfin.dof_to_vertex_map(self.problem.U_fs)
+
         n_iter_tot = 0
         global_success = True
         for forward_or_backward in ["forward","backward"]:
@@ -112,10 +114,13 @@ class ImageIterator():
                 if (self.initialize_U_from_file):
                     mesh = mesh_series.get_mesh(k_frame)
                     array_U = mesh.GetPointData().GetArray(self.initialize_U_array_name)
-                    array_U = vtk.util.numpy_support.vtk_to_numpy(array_U)
-                    array_U = array_U.astype(float)
+                    array_U = vtk.util.numpy_support.vtk_to_numpy(array_U)[:,:self.problem.mesh_dimension]
+                    # print array_U
+                    # array_U = array_U.astype(float)
+                    # print array_U
                     array_U = numpy.reshape(array_U, array_U.size)
-                    self.problem.U.vector()[:] = array_U[dolfin.dof_to_vertex_map(self.problem.U_fs)]
+                    # print array_U
+                    self.problem.U.vector()[:] = array_U[dof_to_vertex_map]
 
                 elif (self.initialize_DU_with_DUold):
                     self.problem.U.vector().axpy(1., self.problem.DUold.vector())