diff --git a/compute_strains.py b/compute_strains.py
index 6c69d3c810d36730bcdaad3a2fa9e074a26eea66..a6c2a27749a7dcffec21da78f3a21d63e1a97f92 100644
--- a/compute_strains.py
+++ b/compute_strains.py
@@ -46,7 +46,7 @@ def compute_strains(
             verbose=verbose)
         ref_mesh_n_cells = ref_mesh.GetNumberOfCells()
         if (verbose): print "ref_mesh_n_cells = " + str(ref_mesh_n_cells)
-        
+
         if (ref_mesh.GetCellData().HasArray("sector_id")):
             iarray_sector_id = ref_mesh.GetCellData().GetArray("sector_id")
             n_sector_ids = 0
@@ -57,7 +57,14 @@ def compute_strains(
                     n_sector_ids = sector_id+1
             if (verbose): print "n_sector_ids = " + str(n_sector_ids)
         else:
+            iarray_sector_id = None
             n_sector_ids = 0
+
+        if (ref_mesh.GetCellData().HasArray("part_id")):
+            part_id_array = ref_mesh.GetCellData().GetArray("part_id")
+        else:
+            part_id_array = None
+
     else:
         ref_mesh = None
         n_sector_ids = 0
@@ -104,10 +111,10 @@ def compute_strains(
         n_cells = mesh.GetNumberOfCells()
         if (ref_mesh is not None):
             assert (n_cells == ref_mesh_n_cells), "ref_mesh_n_cells ("+str(ref_mesh_n_cells)+") ≠ n_cells ("+str(n_cells)+"). Aborting."
-            if (ref_mesh.GetCellData().HasArray("part_id")):
-                mesh.GetCellData().AddArray(ref_mesh.GetCellData().GetArray("part_id"))
-            if (ref_mesh.GetCellData().HasArray("sector_id")):
-                mesh.GetCellData().AddArray(ref_mesh.GetCellData().GetArray("sector_id"))
+            if (part_id_array is not None):
+                mesh.GetCellData().AddArray(part_id_array)
+            if (iarray_sector_id is not None):
+                mesh.GetCellData().AddArray(iarray_sector_id)
         myvtk.addDeformationGradients(
             mesh=mesh,
             disp_array_name=disp_array_name,