#coding=utf8 ################################################################################ ### ### ### Created by Martin Genet, 2016-2018 ### ### ### ### École Polytechnique, Palaiseau, France ### ### ### ################################################################################ import dolfin import os import myPythonLibrary as mypy import myVTKPythonLibrary as myvtk import dolfin_dic as ddic from Problem import Problem ################################################################################ class ImageRegistrationProblem(Problem): def __init__(self, mesh=None, mesh_folder=None, mesh_basename=None, U_family="Lagrange", U_degree=1): self.printer = mypy.Printer() self.set_mesh( mesh=mesh, mesh_folder=mesh_folder, mesh_basename=mesh_basename) self.set_displacement( U_family=U_family, U_degree=U_degree) self.energies = [] def __del__(self): self.printer.close() def set_mesh(self, mesh=None, mesh_folder=None, mesh_basename=None): self.printer.print_str("Loading mesh…") self.printer.inc() assert ((mesh is not None) or ((mesh_folder is not None) and (mesh_basename is not None))),\ "Must provide a mesh (mesh = "+str(mesh)+") or a mesh file (mesh_folder = "+str(mesh_folder)+", mesh_basename = "+str(mesh_basename)+"). Aborting." if (mesh is None): self.mesh_folder = mesh_folder self.mesh_basename = mesh_basename self.mesh_filebasename = self.mesh_folder+"/"+self.mesh_basename self.mesh_filename = self.mesh_filebasename+"."+"xml" assert (os.path.exists(self.mesh_filename)),\ "No mesh in "+mesh_filename+". Aborting." self.mesh = dolfin.Mesh(self.mesh_filename) else: self.mesh = mesh self.mesh_dimension = self.mesh.ufl_domain().geometric_dimension() assert (self.mesh_dimension in (2,3)),\ "mesh_dimension ("+str(self.mesh_dimension)+") must be 2 or 3. Aborting." self.printer.print_var("mesh_dimension",self.mesh_dimension) self.printer.print_var("mesh_n_vertices",self.mesh.num_vertices()) self.printer.print_var("mesh_n_cells",self.mesh.num_cells()) self.dV = dolfin.Measure( "dx", domain=self.mesh) self.dS = dolfin.Measure( "ds", domain=self.mesh) self.dF = dolfin.Measure( "dS", domain=self.mesh) self.mesh_V0 = dolfin.assemble(dolfin.Constant(1) * self.dV) self.printer.print_sci("mesh_V0",self.mesh_V0) self.mesh_S0 = dolfin.assemble(dolfin.Constant(1) * self.dS) self.printer.print_sci("mesh_S0",self.mesh_S0) self.mesh_h0 = self.mesh_V0**(1./self.mesh_dimension) self.printer.print_sci("mesh_h0",self.mesh_h0) self.mesh_h0 = dolfin.Constant(self.mesh_h0) self.printer.dec() def set_displacement(self, U_family="Lagrange", U_degree=1): self.printer.print_str("Defining functions…") self.U_family = U_family self.U_degree = U_degree self.U_fe = dolfin.VectorElement( family=self.U_family, cell=self.mesh.ufl_cell(), degree=self.U_degree) self.U_fs = dolfin.FunctionSpace( self.mesh, self.U_fe) self.U = dolfin.Function( self.U_fs, name="displacement") self.U.vector().zero() self.U_norm = 0. self.Uold = dolfin.Function( self.U_fs, name="previous displacement") self.Uold.vector().zero() self.Uold_norm = 0. self.DUold = dolfin.Function( self.U_fs, name="previous displacement increment") self.dU = dolfin.Function( self.U_fs, name="displacement correction") self.dU_trial = dolfin.TrialFunction(self.U_fs) self.dU_test = dolfin.TestFunction(self.U_fs) # for mesh volume computation self.I = dolfin.Identity(self.mesh_dimension) self.F = self.I + dolfin.grad(self.U) self.J = dolfin.det(self.F) def reinit(self): self.U.vector().zero() self.U_norm = 0. self.Uold.vector().zero() self.Uold_norm = 0. self.DUold.vector().zero() for energy in self.energies: energy.reinit() def add_image_energy(self, energy): if (hasattr(self, "images_n_frames") and hasattr(self, "images_ref_frame")): assert (energy.image_series.n_frames == self.images_n_frames) assert (energy.ref_frame == self.images_ref_frame) else: self.images_n_frames = energy.image_series.n_frames self.images_ref_frame = energy.ref_frame self.energies += [energy] def add_regul_energy(self, energy): self.energies += [energy] def assemble_ener(self): ener_form = 0. for energy in self.energies: ener_form += dolfin.Constant(energy.w) * energy.ener_form ener = dolfin.assemble( ener_form) #self.printer.print_var("ener",ener) return ener def assemble_res(self, res_vec): res_form = 0. for energy in self.energies: res_form -= dolfin.Constant(energy.w) * energy.res_form res_vec = dolfin.assemble( res_form, tensor=res_vec) #self.printer.print_var("res_vec",res_vec.array()) def assemble_jac(self, jac_mat): jac_form = 0. for energy in self.energies: jac_form += dolfin.Constant(energy.w) * energy.jac_form jac_mat = dolfin.assemble( jac_form, tensor=jac_mat) #self.printer.print_var("jac_mat",jac_mat.array()) def call_before_solve(self, *kargs, **kwargs): for energy in self.energies: energy.call_before_solve( *kargs, **kwargs) def call_after_solve(self, *kargs, **kwargs): self.DUold.vector()[:] = self.U.vector()[:] - self.Uold.vector()[:] self.Uold.vector()[:] = self.U.vector()[:] self.Uold_norm = self.U_norm for energy in self.energies: energy.call_after_solve( *kargs, **kwargs) def get_qoi_names(self): names = ["mesh_V"] for energy in self.energies: names += energy.get_qoi_names() return names def get_qoi_values(self): self.compute_mesh_volume() values = [self.mesh_V] for energy in self.energies: values += energy.get_qoi_values() return values def compute_mesh_volume(self): self.mesh_V = dolfin.assemble(self.J * self.dV) self.printer.print_sci("mesh_V",self.mesh_V) return self.mesh_V