Mentions légales du service

Skip to content
Snippets Groups Projects
Commit 1ed5025c authored by Martin Genet's avatar Martin Genet
Browse files

Working w/ Katka on Generated Image Energy

parent e681a1d1
No related branches found
No related tags found
No related merge requests found
...@@ -29,6 +29,14 @@ class Energy(): ...@@ -29,6 +29,14 @@ class Energy():
def call_before_assembly(self,
*args,
**kwargs):
pass
def call_before_solve(self, def call_before_solve(self,
*args, *args,
**kwargs): **kwargs):
......
...@@ -29,8 +29,7 @@ class WarpedImageEnergy(Energy): ...@@ -29,8 +29,7 @@ class WarpedImageEnergy(Energy):
quadrature_degree, quadrature_degree,
name="im", name="im",
w=1., w=1.,
ref_frame=0, ref_frame=0):
dynamic_scaling=False):
self.problem = problem self.problem = problem
self.printer = self.problem.printer self.printer = self.problem.printer
...@@ -39,7 +38,6 @@ class WarpedImageEnergy(Energy): ...@@ -39,7 +38,6 @@ class WarpedImageEnergy(Energy):
self.name = name self.name = name
self.w = w self.w = w
self.ref_frame = ref_frame self.ref_frame = ref_frame
self.dynamic_scaling = dynamic_scaling
self.printer.print_str("Defining warped image correlation energy…") self.printer.print_str("Defining warped image correlation energy…")
self.printer.inc() self.printer.inc()
...@@ -78,6 +76,7 @@ class WarpedImageEnergy(Energy): ...@@ -78,6 +76,7 @@ class WarpedImageEnergy(Energy):
self.printer.print_str("Defining measure…") self.printer.print_str("Defining measure…")
# dV
self.form_compiler_parameters = { self.form_compiler_parameters = {
"quadrature_degree":self.quadrature_degree, "quadrature_degree":self.quadrature_degree,
"quadrature_scheme":"default"} "quadrature_scheme":"default"}
...@@ -86,50 +85,33 @@ class WarpedImageEnergy(Energy): ...@@ -86,50 +85,33 @@ class WarpedImageEnergy(Energy):
domain=self.problem.mesh, domain=self.problem.mesh,
metadata=self.form_compiler_parameters) metadata=self.form_compiler_parameters)
self.printer.print_str("Loading reference image…")
self.printer.inc()
# ref_frame # ref_frame
assert (abs(self.ref_frame) < self.image_series.n_frames),\ assert (abs(self.ref_frame) < self.image_series.n_frames),\
"abs(ref_frame) = "+str(abs(self.ref_frame))+" >= "+str(self.image_series.n_frames)+" = image_series.n_frames. Aborting." "abs(ref_frame) = "+str(abs(self.ref_frame))+" >= "+str(self.image_series.n_frames)+" = image_series.n_frames. Aborting."
self.ref_frame = self.ref_frame%self.image_series.n_frames self.ref_frame = self.ref_frame%self.image_series.n_frames
self.ref_image_filename = self.image_series.get_image_filename(self.ref_frame)
self.printer.print_var("ref_frame",self.ref_frame) self.printer.print_var("ref_frame",self.ref_frame)
# Iref self.printer.dec()
self.Iref = dolfin.Expression( self.printer.print_str("Defining deformed image…")
cppcode=ddic.get_ExprIm_cpp(
# Igen
self.Igen = dolfin.Expression(
cppcode=ddic.get_ExprGenIm_cpp(
im_dim=self.image_series.dimension, im_dim=self.image_series.dimension,
im_type="im", im_type="im"),
im_is_def=0),
element=self.fe) element=self.fe)
self.ref_image_filename = self.image_series.get_image_filename(self.ref_frame) self.Igen.init_image(self.ref_image_filename)
self.Iref.init_image(self.ref_image_filename) self.Igen.init_disp(self.problem.U)
self.Iref_int = dolfin.assemble(self.Iref * self.dV)/self.problem.mesh_V0
self.printer.print_var("Iref_int",self.Iref_int)
self.Iref_norm = (dolfin.assemble(self.Iref**2 * self.dV)/self.problem.mesh_V0)**(1./2) # DIgen
assert (self.Iref_norm > 0.),\ self.DIgen = dolfin.Expression(
"Iref_norm = "+str(self.Iref_norm)+" <= 0. Aborting." cppcode=ddic.get_ExprGenIm_cpp(
self.printer.print_var("Iref_norm",self.Iref_norm)
# DIref
self.DIref = dolfin.Expression(
cppcode=ddic.get_ExprIm_cpp(
im_dim=self.image_series.dimension, im_dim=self.image_series.dimension,
im_type="grad" if (self.image_series.grad_basename is None) else "grad_no_deriv", im_type="grad"),
im_is_def=0),
element=self.ve) element=self.ve)
self.ref_image_grad_filename = self.image_series.get_image_grad_filename(self.ref_frame) self.Igen.init_image(self.ref_image_filename)
self.DIref.init_image(self.ref_image_grad_filename) self.Igen.init_disp(self.problem.U)
self.printer.dec()
self.printer.print_str("Defining deformed image…")
self.scaling = numpy.array([1.,0.])
if (self.dynamic_scaling):
self.p = numpy.empty((2,2))
self.q = numpy.empty(2)
# Idef # Idef
self.Idef = dolfin.Expression( self.Idef = dolfin.Expression(
...@@ -140,7 +122,6 @@ class WarpedImageEnergy(Energy): ...@@ -140,7 +122,6 @@ class WarpedImageEnergy(Energy):
element=self.fe) element=self.fe)
self.Idef.init_image(self.ref_image_filename) self.Idef.init_image(self.ref_image_filename)
self.Idef.init_disp(self.problem.U) self.Idef.init_disp(self.problem.U)
self.Idef.init_dynamic_scaling(self.scaling)
# DIdef # DIdef
self.DIdef = dolfin.Expression( self.DIdef = dolfin.Expression(
...@@ -151,32 +132,9 @@ class WarpedImageEnergy(Energy): ...@@ -151,32 +132,9 @@ class WarpedImageEnergy(Energy):
element=self.ve) element=self.ve)
self.DIdef.init_image(self.ref_image_filename) self.DIdef.init_image(self.ref_image_filename)
self.DIdef.init_disp(self.problem.U) self.DIdef.init_disp(self.problem.U)
self.DIdef.init_dynamic_scaling(self.scaling)
self.printer.print_str("Defining previous image…") self.printer.print_str("Defining previous image…")
# Iold
self.Iold = dolfin.Expression(
cppcode=ddic.get_ExprIm_cpp(
im_dim=self.image_series.dimension,
im_type="im",
im_is_def=1),
element=self.fe)
self.Iold.init_image(self.ref_image_filename)
self.Iold.init_disp(self.problem.Uold)
self.Iold.init_dynamic_scaling(self.scaling) # 2016/07/25: ok, same scaling must apply to Idef & Iold…
# DIold
self.DIold = dolfin.Expression(
cppcode=ddic.get_ExprIm_cpp(
im_dim=self.image_series.dimension,
im_type="grad" if (self.image_series.grad_basename is None) else "grad_no_deriv",
im_is_def=1),
element=self.ve)
self.DIold.init_image(self.ref_image_filename)
self.DIold.init_disp(self.problem.Uold)
self.DIold.init_dynamic_scaling(self.scaling) # 2016/07/25: ok, same scaling must apply to Idef & Iold…
self.printer.print_str("Defining correlation energy…") self.printer.print_str("Defining correlation energy…")
# Phi_ref # Phi_ref
...@@ -187,15 +145,9 @@ class WarpedImageEnergy(Energy): ...@@ -187,15 +145,9 @@ class WarpedImageEnergy(Energy):
self.Phi_Iref.init_image(self.ref_image_filename) self.Phi_Iref.init_image(self.ref_image_filename)
# Psi_c # Psi_c
self.Psi_c = self.Phi_Iref * (self.Idef - self.Iref)**2/2 self.Psi_c = self.Phi_Iref * (self.Igen - self.Idef)**2/2
self.DPsi_c = self.Phi_Iref * (self.Idef - self.Iref) * dolfin.dot(self.DIdef, self.problem.dU_test) self.DPsi_c = self.Phi_Iref * (self.Igen - self.Idef) * dolfin.dot(self.DIgen - self.DIdef, self.problem.dU_test)
self.DDPsi_c = self.Phi_Iref * dolfin.dot(self.DIgen - self.DIdef, self.problem.dU_trial) * dolfin.dot(self.DIgen - self.DIdef, self.problem.dU_test)
self.DDPsi_c = self.Phi_Iref * dolfin.dot(self.DIdef, self.problem.dU_trial) * dolfin.dot(self.DIdef, self.problem.dU_test)
self.DDPsi_c_old = self.Phi_Iref * dolfin.dot(self.DIold, self.problem.dU_trial) * dolfin.dot(self.DIold, self.problem.dU_test)
self.DDPsi_c_ref = self.Phi_Iref * dolfin.dot(self.DIref, self.problem.dU_trial) * dolfin.dot(self.DIref, self.problem.dU_test)
self.Psi_c_old = self.Phi_Iref * (self.Idef - self.Iold)**2/2
self.DPsi_c_old = self.Phi_Iref * (self.Idef - self.Iold) * dolfin.dot(self.DIdef, self.problem.dU_test)
# forms # forms
self.ener_form = self.Psi_c * self.dV self.ener_form = self.Psi_c * self.dV
...@@ -208,7 +160,14 @@ class WarpedImageEnergy(Energy): ...@@ -208,7 +160,14 @@ class WarpedImageEnergy(Energy):
def reinit(self): def reinit(self):
self.scaling[:] = [1.,0.] pass
def call_before_assembly(self):
self.Igen.generate_image()
self.DIgen.generate_image()
...@@ -226,16 +185,6 @@ class WarpedImageEnergy(Energy): ...@@ -226,16 +185,6 @@ class WarpedImageEnergy(Energy):
self.def_grad_image_filename = self.image_series.get_image_grad_filename(k_frame) self.def_grad_image_filename = self.image_series.get_image_grad_filename(k_frame)
self.DIdef.init_image(self.def_grad_image_filename) self.DIdef.init_image(self.def_grad_image_filename)
self.printer.print_str("Loading previous image for correlation energy…")
# Iold
self.old_image_filename = self.image_series.get_image_filename(k_frame_old)
self.Iold.init_image(self.old_image_filename)
# DIold
self.old_grad_image_filename = self.image_series.get_image_grad_filename(k_frame_old)
self.DIold.init_image(self.old_grad_image_filename)
def call_after_solve(self): def call_after_solve(self):
...@@ -246,7 +195,7 @@ class WarpedImageEnergy(Energy): ...@@ -246,7 +195,7 @@ class WarpedImageEnergy(Energy):
def get_qoi_names(self): def get_qoi_names(self):
return [self.name+"_ener", self.name+"_err"] return [self.name+"_ener"]
...@@ -254,7 +203,5 @@ class WarpedImageEnergy(Energy): ...@@ -254,7 +203,5 @@ class WarpedImageEnergy(Energy):
self.ener = (dolfin.assemble(self.ener_form)/self.problem.mesh_V0)**(1./2) self.ener = (dolfin.assemble(self.ener_form)/self.problem.mesh_V0)**(1./2)
self.printer.print_sci(self.name+"_ener",self.ener) self.printer.print_sci(self.name+"_ener",self.ener)
self.err = self.ener/self.Iref_norm
self.printer.print_sci(self.name+"_err",self.err)
return [self.ener, self.err] return [self.ener]
...@@ -179,6 +179,8 @@ class NonlinearSolver(): ...@@ -179,6 +179,8 @@ class NonlinearSolver():
self.res_old_vec = self.res_vec.copy() self.res_old_vec = self.res_vec.copy()
self.res_old_norm = self.res_norm self.res_old_norm = self.res_norm
self.problem.call_before_assembly()
# linear system: residual assembly # linear system: residual assembly
self.printer.print_str("Residual assembly…",newline=False) self.printer.print_str("Residual assembly…",newline=False)
timer = time.time() timer = time.time()
...@@ -281,6 +283,7 @@ class NonlinearSolver(): ...@@ -281,6 +283,7 @@ class NonlinearSolver():
self.printer.print_var("relax_k",relax_k,-1) self.printer.print_var("relax_k",relax_k,-1)
# self.printer.print_sci("relax_a",relax_a) # self.printer.print_sci("relax_a",relax_a)
# self.printer.print_sci("relax_b",relax_b) # self.printer.print_sci("relax_b",relax_b)
self.problem.call_before_assembly()
if (need_update_c): if (need_update_c):
relax_c = relax_b - (relax_b - relax_a) / phi relax_c = relax_b - (relax_b - relax_a) / phi
relax_list.append(relax_c) relax_list.append(relax_c)
......
...@@ -179,6 +179,17 @@ class ImageRegistrationProblem(Problem): ...@@ -179,6 +179,17 @@ class ImageRegistrationProblem(Problem):
def call_before_assembly(self,
*kargs,
**kwargs):
for energy in self.energies:
energy.call_before_assembly(
*kargs,
**kwargs)
def assemble_ener(self): def assemble_ener(self):
ener_form = 0. ener_form = 0.
......
...@@ -14,6 +14,7 @@ def get_ExprIm_cpp( ...@@ -14,6 +14,7 @@ def get_ExprIm_cpp(
im_dim, im_dim,
im_type="im", im_type="im",
im_is_def=0, im_is_def=0,
disp_type="fenics", # "vtk"
verbose=0): verbose=0):
assert (im_dim in (2,3)) assert (im_dim in (2,3))
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment