diff --git a/compute_strains.py b/compute_strains.py
index 87fb9a6158e27b2df698f281a76be35514a0c2dd..b38fcd36a1ea83c03b7594f80ceea962fb5f7565 100644
--- a/compute_strains.py
+++ b/compute_strains.py
@@ -49,6 +49,18 @@ def compute_strains(
         ref_mesh_n_cells = ref_mesh.GetNumberOfCells()
         if (verbose): print "ref_mesh_n_cells = " + str(ref_mesh_n_cells)
 
+        if (ref_mesh.GetCellData().HasArray("part_id")):
+            iarray_part_id = ref_mesh.GetCellData().GetArray("part_id")
+            n_part_ids = 0
+            for k_cell in xrange(ref_mesh_n_cells):
+                part_id = int(iarray_part_id.GetTuple1(k_cell))
+                if (part_id > n_part_ids-1):
+                    n_part_ids = part_id+1
+            if (verbose): print "n_part_ids = " + str(n_part_ids)
+        else:
+            iarray_part_id = None
+            n_part_ids = 0
+
         if (ref_mesh.GetCellData().HasArray("sector_id")):
             iarray_sector_id = ref_mesh.GetCellData().GetArray("sector_id")
             n_sector_ids = 0
@@ -62,13 +74,9 @@ def compute_strains(
             iarray_sector_id = None
             n_sector_ids = 0
 
-        if (ref_mesh.GetCellData().HasArray("part_id")):
-            iarray_part_id = ref_mesh.GetCellData().GetArray("part_id")
-        else:
-            iarray_part_id = None
-
     else:
         ref_mesh = None
+        n_part_ids = 0
         n_sector_ids = 0
 
     working_filenames = glob.glob(working_folder+"/"+working_basename+"_[0-9]*."+working_ext)
@@ -142,6 +150,13 @@ def compute_strains(
                 field_name="part_id",
                 threshold_value=0.5,
                 threshold_by_upper_or_lower="lower")
+            n_points = mesh.GetNumberOfPoints()
+            n_cells = mesh.GetNumberOfCells()
+            if (iarray_part_id is not None):
+                iarray_part_id = mesh.GetCellData().GetArray("part_id")
+                n_part_ids = 0
+            if (iarray_sector_id is not None):
+                iarray_sector_id = mesh.GetCellData().GetArray("sector_id")
         myvtk.writeUGrid(
             ugrid=mesh,
             filename=working_folder+"/"+working_basename+"_"+str(k_frame).zfill(working_zfill)+"."+working_ext,
@@ -154,24 +169,27 @@ def compute_strains(
                 farray_strain = mesh.GetCellData().GetArray(strain_array_name)
 
         if (write_strains):
-            if (n_sector_ids == 0):
-                strains_all = []
-                for k_cell in xrange(n_cells):
-                    strains_all.append(farray_strain.GetTuple(k_cell))
-            elif (n_sector_ids == 1):
-                strains_all = []
-                for k_cell in xrange(n_cells):
-                    sector_id = int(iarray_sector_id.GetTuple(k_cell)[0])
-                    if (sector_id < 0): continue
-                    strains_all.append(farray_strain.GetTuple(k_cell))
+            if (n_sector_ids in (0,1)):
+                if (n_part_ids == 0):
+                    strains_all = [farray_strain.GetTuple(k_cell) for k_cell in xrange(n_cells)]
+                else:
+                    strains_all = [farray_strain.GetTuple(k_cell) for k_cell in xrange(n_cells) if iarray_part_id.GetTuple1(k_cell) > 0]
             elif (n_sector_ids > 1):
                 strains_all = []
                 strains_per_sector = [[] for sector_id in xrange(n_sector_ids)]
-                for k_cell in xrange(n_cells):
-                    sector_id = int(iarray_sector_id.GetTuple(k_cell)[0])
-                    if (sector_id < 0): continue
-                    strains_all.append(farray_strain.GetTuple(k_cell))
-                    strains_per_sector[sector_id].append(farray_strain.GetTuple(k_cell))
+                if (n_part_ids == 0):
+                    for k_cell in xrange(n_cells):
+                        strains_all.append(farray_strain.GetTuple(k_cell))
+                        sector_id = int(iarray_sector_id.GetTuple1(k_cell))
+                        strains_per_sector[sector_id].append(farray_strain.GetTuple(k_cell))
+                else:
+                    for k_cell in xrange(n_cells):
+                        part_id = int(iarray_part_id.GetTuple1(k_cell))
+                        if (part_id > 0): continue
+                        strains_all.append(farray_strain.GetTuple(k_cell))
+                        sector_id = int(iarray_sector_id.GetTuple1(k_cell))
+                        if (sector_id < 0): continue
+                        strains_per_sector[sector_id].append(farray_strain.GetTuple(k_cell))
 
             if (temporal_offset is not None) and (temporal_resolution is not None):
                 strain_file.write(str(temporal_offset + k_frame*temporal_resolution))