From 08e216b3fa705b76e8b255957888ffa8ea2b2b81 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?C=C3=A9cile=20Patte?= <cecile.patte@inria.fr>
Date: Wed, 28 Mar 2018 12:24:36 +0200
Subject: [PATCH] Allow for reference mesh with PartId but no local basis in
 compute_strains

---
 compute_strains.py | 49 +++++++++++++++++++++++-----------------------
 1 file changed, 25 insertions(+), 24 deletions(-)

diff --git a/compute_strains.py b/compute_strains.py
index 6346951..6c69d3c 100644
--- a/compute_strains.py
+++ b/compute_strains.py
@@ -27,9 +27,9 @@ def compute_strains(
         disp_array_name="displacement",
         defo_grad_array_name="DeformationGradient",
         strain_array_name="Strain",
-        mesh_w_local_basis_folder=None,
-        mesh_w_local_basis_basename=None,
-        mesh_w_local_basis_ext="vtk",
+        ref_mesh_folder=None,
+        ref_mesh_basename=None,
+        ref_mesh_ext="vtk",
         CYL_or_PPS="PPS",
         write_strains=1,
         temporal_offset=None,
@@ -40,17 +40,17 @@ def compute_strains(
         write_binned_strains_vs_radius=0,
         verbose=1):
 
-    if (mesh_w_local_basis_folder is not None) and (mesh_w_local_basis_basename is not None):
-        mesh_w_local_basis = myvtk.readUGrid(
-            filename=mesh_w_local_basis_folder+"/"+mesh_w_local_basis_basename+"."+mesh_w_local_basis_ext,
+    if (ref_mesh_folder is not None) and (ref_mesh_basename is not None):
+        ref_mesh = myvtk.readUGrid(
+            filename=ref_mesh_folder+"/"+ref_mesh_basename+"."+ref_mesh_ext,
             verbose=verbose)
-        mesh_w_local_basis_n_cells = mesh_w_local_basis.GetNumberOfCells()
-        if (verbose): print "mesh_w_local_basis_n_cells = " + str(mesh_w_local_basis_n_cells)
-
-        if (mesh_w_local_basis.GetCellData().HasArray("sector_id")):
-            iarray_sector_id = mesh_w_local_basis.GetCellData().GetArray("sector_id")
+        ref_mesh_n_cells = ref_mesh.GetNumberOfCells()
+        if (verbose): print "ref_mesh_n_cells = " + str(ref_mesh_n_cells)
+        
+        if (ref_mesh.GetCellData().HasArray("sector_id")):
+            iarray_sector_id = ref_mesh.GetCellData().GetArray("sector_id")
             n_sector_ids = 0
-            for k_cell in xrange(mesh_w_local_basis_n_cells):
+            for k_cell in xrange(ref_mesh_n_cells):
                 sector_id = int(iarray_sector_id.GetTuple1(k_cell))
                 if (sector_id < 0): continue
                 if (sector_id > n_sector_ids-1):
@@ -59,7 +59,7 @@ def compute_strains(
         else:
             n_sector_ids = 0
     else:
-        mesh_w_local_basis = None
+        ref_mesh = None
         n_sector_ids = 0
 
     working_filenames = glob.glob(working_folder+"/"+working_basename+"_[0-9]*."+working_ext)
@@ -102,10 +102,12 @@ def compute_strains(
             filename=mesh_filename,
             verbose=verbose)
         n_cells = mesh.GetNumberOfCells()
-        if (mesh_w_local_basis is not None):
-            assert (n_cells == mesh_w_local_basis_n_cells), "mesh_w_local_basis_n_cells ("+str(mesh_w_local_basis_n_cells)+") ≠ n_cells ("+str(n_cells)+"). Aborting."
-            mesh.GetCellData().AddArray(mesh_w_local_basis.GetCellData().GetArray("part_id"))
-            mesh.GetCellData().AddArray(mesh_w_local_basis.GetCellData().GetArray("sector_id"))
+        if (ref_mesh is not None):
+            assert (n_cells == ref_mesh_n_cells), "ref_mesh_n_cells ("+str(ref_mesh_n_cells)+") ≠ n_cells ("+str(n_cells)+"). Aborting."
+            if (ref_mesh.GetCellData().HasArray("part_id")):
+                mesh.GetCellData().AddArray(ref_mesh.GetCellData().GetArray("part_id"))
+            if (ref_mesh.GetCellData().HasArray("sector_id")):
+                mesh.GetCellData().AddArray(ref_mesh.GetCellData().GetArray("sector_id"))
         myvtk.addDeformationGradients(
             mesh=mesh,
             disp_array_name=disp_array_name,
@@ -122,7 +124,7 @@ def compute_strains(
             mesh=mesh,
             defo_grad_array_name=defo_grad_array_name,
             strain_array_name=strain_array_name,
-            mesh_w_local_basis=mesh_w_local_basis,
+            mesh_w_local_basis=ref_mesh,
             verbose=verbose)
         myvtk.writeUGrid(
             ugrid=mesh,
@@ -130,8 +132,7 @@ def compute_strains(
             verbose=verbose)
 
         if (write_strains) or (write_strains_vs_radius) or (write_binned_strains_vs_radius):
-            if (mesh_w_local_basis is not None):
-                assert (mesh.GetCellData().HasArray(strain_array_name+"_"+CYL_or_PPS))
+            if (ref_mesh is not None) and (mesh.GetCellData().HasArray(strain_array_name+"_"+CYL_or_PPS)):
                 farray_strain = mesh.GetCellData().GetArray(strain_array_name+"_"+CYL_or_PPS)
             else:
                 farray_strain = mesh.GetCellData().GetArray(strain_array_name)
@@ -171,16 +172,16 @@ def compute_strains(
             strain_file.write("\n")
 
         if (write_strains_vs_radius):
-            assert (mesh_w_local_basis.GetCellData().HasArray("rr"))
-            farray_rr = mesh_w_local_basis.GetCellData().GetArray("rr")
+            assert (ref_mesh.GetCellData().HasArray("rr"))
+            farray_rr = ref_mesh.GetCellData().GetArray("rr")
             for k_cell in xrange(n_cells):
                 strain_vs_radius_file.write(" ".join([str(val) for val in [k_frame, farray_rr.GetTuple1(k_cell)]+list(farray_strain.GetTuple(k_cell))]) + "\n")
             strain_vs_radius_file.write("\n")
             strain_vs_radius_file.write("\n")
 
         if (write_binned_strains_vs_radius):
-            assert (mesh_w_local_basis.GetCellData().HasArray("rr"))
-            farray_rr = mesh_w_local_basis.GetCellData().GetArray("rr")
+            assert (ref_mesh.GetCellData().HasArray("rr"))
+            farray_rr = ref_mesh.GetCellData().GetArray("rr")
             n_r = 10
             binned_strains = [[] for k_r in xrange(n_r)]
             for k_cell in xrange(n_cells):
-- 
GitLab