diff --git a/compute_strains.py b/compute_strains.py
index b38fcd36a1ea83c03b7594f80ceea962fb5f7565..8425d7d0ed1017014ee622239b55633b8e9d06cf 100644
--- a/compute_strains.py
+++ b/compute_strains.py
@@ -125,6 +125,9 @@ def compute_strains(
                 mesh.GetCellData().AddArray(iarray_part_id)
             if (iarray_sector_id is not None):
                 mesh.GetCellData().AddArray(iarray_sector_id)
+            if (write_strains_vs_radius) or (write_binned_strains_vs_radius):
+                assert (ref_mesh.GetCellData().HasArray("rr"))
+                mesh.GetCellData().AddArray(ref_mesh.GetCellData().GetArray("rr"))
         myvtk.addDeformationGradients(
             mesh=mesh,
             disp_array_name=disp_array_name,
@@ -206,16 +209,14 @@ def compute_strains(
             strain_file.write("\n")
 
         if (write_strains_vs_radius):
-            assert (ref_mesh.GetCellData().HasArray("rr"))
-            farray_rr = ref_mesh.GetCellData().GetArray("rr")
+            farray_rr = mesh.GetCellData().GetArray("rr")
             for k_cell in xrange(n_cells):
-                strain_vs_radius_file.write(" ".join([str(val) for val in [k_frame, farray_rr.GetTuple1(k_cell)]+list(farray_strain.GetTuple(k_cell))]) + "\n")
-            strain_vs_radius_file.write("\n")
-            strain_vs_radius_file.write("\n")
+                strains_vs_radius_file.write(" ".join([str(val) for val in [k_frame, farray_rr.GetTuple1(k_cell)]+list(farray_strain.GetTuple(k_cell))]) + "\n")
+            strains_vs_radius_file.write("\n")
+            strains_vs_radius_file.write("\n")
 
         if (write_binned_strains_vs_radius):
-            assert (ref_mesh.GetCellData().HasArray("rr"))
-            farray_rr = ref_mesh.GetCellData().GetArray("rr")
+            farray_rr = mesh.GetCellData().GetArray("rr")
             n_r = 10
             binned_strains = [[] for k_r in xrange(n_r)]
             for k_cell in xrange(n_cells):