Mentions légales du service

Skip to content
Snippets Groups Projects
Problem_ImageRegistration.py 7.62 KiB
Newer Older
#coding=utf8

################################################################################
###                                                                          ###
### Created by Martin Genet, 2016-2018                                       ###
###                                                                          ###
### École Polytechnique, Palaiseau, France                                   ###
###                                                                          ###
################################################################################

import dolfin
import os

import myPythonLibrary as mypy
import myVTKPythonLibrary as myvtk

import dolfin_dic as ddic
from Problem import Problem

################################################################################

class ImageRegistrationProblem(Problem):



    def __init__(self,
            mesh=None,
            mesh_folder=None,
            mesh_basename=None,
            U_family="Lagrange",
            U_degree=1):

        self.printer = mypy.Printer()

        self.set_mesh(
            mesh=mesh,
            mesh_folder=mesh_folder,
            mesh_basename=mesh_basename)

        self.set_displacement(
            U_family=U_family,
            U_degree=U_degree)

        self.energies = []



    def __del__(self):

        self.printer.close()



    def set_mesh(self,
            mesh=None,
            mesh_folder=None,
            mesh_basename=None):

        self.printer.print_str("Loading mesh…")
        self.printer.inc()

        assert ((mesh is not None) or ((mesh_folder is not None) and (mesh_basename is not None))),\
            "Must provide a mesh (mesh = "+str(mesh)+") or a mesh file (mesh_folder = "+str(mesh_folder)+", mesh_basename = "+str(mesh_basename)+"). Aborting."

        if (mesh is None):
            self.mesh_folder = mesh_folder
            self.mesh_basename = mesh_basename
            self.mesh_filebasename = self.mesh_folder+"/"+self.mesh_basename
            self.mesh_filename = self.mesh_filebasename+"."+"xml"
            assert (os.path.exists(self.mesh_filename)),\
                "No mesh in "+mesh_filename+". Aborting."
            self.mesh = dolfin.Mesh(self.mesh_filename)
        else:
            self.mesh = mesh

        self.mesh_dimension = self.mesh.ufl_domain().geometric_dimension()
        assert (self.mesh_dimension in (2,3)),\
            "mesh_dimension ("+str(self.mesh_dimension)+") must be 2 or 3. Aborting."
        self.printer.print_var("mesh_dimension",self.mesh_dimension)

        self.printer.print_var("mesh_n_vertices",self.mesh.num_vertices())
        self.printer.print_var("mesh_n_cells",self.mesh.num_cells())

        self.dV = dolfin.Measure(
            "dx",
            domain=self.mesh)
        self.dS = dolfin.Measure(
            "ds",
            domain=self.mesh)
        self.dF = dolfin.Measure(
            "dS",
            domain=self.mesh)

        self.mesh_V0 = dolfin.assemble(dolfin.Constant(1) * self.dV)
        self.printer.print_sci("mesh_V0",self.mesh_V0)
        self.mesh_S0 = dolfin.assemble(dolfin.Constant(1) * self.dS)
        self.printer.print_sci("mesh_S0",self.mesh_S0)
        self.mesh_h0 = self.mesh_V0**(1./self.mesh_dimension)
        self.printer.print_sci("mesh_h0",self.mesh_h0)
        self.mesh_h0 = dolfin.Constant(self.mesh_h0)

        self.printer.dec()



    def set_displacement(self,
            U_family="Lagrange",
            U_degree=1):

        self.printer.print_str("Defining functions…")

        self.U_family = U_family
        self.U_degree = U_degree
        self.U_fe = dolfin.VectorElement(
            family=self.U_family,
            cell=self.mesh.ufl_cell(),
            degree=self.U_degree)
        self.U_fs = dolfin.FunctionSpace(
            self.mesh,
            self.U_fe)
        self.U = dolfin.Function(
            self.U_fs,
            name="displacement")
        self.U.vector().zero()
        self.U_norm = 0.
        self.Uold = dolfin.Function(
            self.U_fs,
            name="previous displacement")
        self.Uold.vector().zero()
        self.Uold_norm = 0.
        self.DUold = dolfin.Function(
            self.U_fs,
            name="previous displacement increment")
        self.dU = dolfin.Function(
            self.U_fs,
            name="displacement correction")
        self.dU_trial = dolfin.TrialFunction(self.U_fs)
        self.dU_test = dolfin.TestFunction(self.U_fs)

        # for mesh volume computation
        self.I = dolfin.Identity(self.mesh_dimension)
        self.F = self.I + dolfin.grad(self.U)
        self.J = dolfin.det(self.F)



    def reinit(self):

        self.U.vector().zero()
        self.U_norm = 0.
        self.Uold.vector().zero()
        self.Uold_norm = 0.
        self.DUold.vector().zero()

        for energy in self.energies:
            energy.reinit()



    def add_image_energy(self,
            energy):

        if (hasattr(self, "images_n_frames")
        and hasattr(self, "images_ref_frame")):
            assert (energy.image_series.n_frames  == self.images_n_frames)
            assert (energy.ref_frame == self.images_ref_frame)
        else:
            self.images_n_frames = energy.image_series.n_frames
            self.images_ref_frame = energy.ref_frame

        self.energies += [energy]



    def add_regul_energy(self,
            energy):

        self.energies += [energy]



    def assemble_ener(self):

            ener_ = dolfin.assemble(
                energy.ener_form)
Martin Genet's avatar
Martin Genet committed
            self.printer.print_var("ener_"+energy.name,ener_)
            ener += energy.w * ener_
        # ener_form = 0.
        # for energy in self.energies:
        #     ener_form += dolfin.Constant(energy.w) * energy.ener_form
        #
        # ener = dolfin.assemble(
        #     ener_form)
        # #self.printer.print_var("ener",ener)

        return ener



    def assemble_res(self,
            res_vec):

        res_form = 0.
        for energy in self.energies:
            res_form -= dolfin.Constant(energy.w) * energy.res_form

        res_vec = dolfin.assemble(
            res_form,
            tensor=res_vec)
        #self.printer.print_var("res_vec",res_vec.array())



    def assemble_jac(self,
            jac_mat):

        jac_form = 0.
        for energy in self.energies:
            jac_form += dolfin.Constant(energy.w) * energy.jac_form

        jac_mat = dolfin.assemble(
            jac_form,
            tensor=jac_mat)
        #self.printer.print_var("jac_mat",jac_mat.array())



    def call_before_solve(self,
            *kargs,
            **kwargs):

        for energy in self.energies:
            energy.call_before_solve(
                *kargs,
                **kwargs)



    def call_after_solve(self,
            *kargs,
            **kwargs):

        self.DUold.vector()[:] = self.U.vector()[:] - self.Uold.vector()[:]
        self.Uold.vector()[:] = self.U.vector()[:]
        self.Uold_norm = self.U_norm

        for energy in self.energies:
            energy.call_after_solve(
                *kargs,
                **kwargs)



    def get_qoi_names(self):

        names = ["mesh_V"]

        for energy in self.energies:
            names += energy.get_qoi_names()

        return names



    def get_qoi_values(self):

        self.compute_mesh_volume()
        values = [self.mesh_V]

        for energy in self.energies:
            values += energy.get_qoi_values()

        return values



    def compute_mesh_volume(self):

        self.mesh_V = dolfin.assemble(self.J * self.dV)
        self.printer.print_sci("mesh_V",self.mesh_V)
        return self.mesh_V