Mentions légales du service

Skip to content
Snippets Groups Projects
Commit af3b4e2c authored by CORNILLET Remi's avatar CORNILLET Remi
Browse files

Add solver, start testing

parent 8abdbffc
No related branches found
No related tags found
No related merge requests found
# -*- coding: utf-8 -*-
"""
Created on Wed Mar 15 14:37:33 2023
@author: rcornill
"""
# -*- coding: utf-8 -*-
"""
Created on Wed Mar 15 14:37:33 2023
@author: rcornill
"""
"""
Created on Fri Jan 13 14:43:03 2023
@author: rcornill
"""
import torch as tn
from wdnmf_dual.methods.base_function import *
from wdnmf_dual.methods.grad_pb_dual import *
def optim_pb_dual_h(g, rho2, w, data, sigma, dictionary, ent, matK,
step_h, iterdual = 10):
"""
Parameters
----------
g : 2d tn.tensor
Dual variable, size n*m
rho2 : float64
Regulation rate for the nonnegativity constraint w.r.t H
h : 2d tn.tensor
Abundance matrix, size r*m
data : 2d tn.tensor
Data entry, of size n*m
sigma : 1d array-like
Array of size r, corresponding to the r selected template of the dictionary
dictionary : 2d tn.tensor
Dictionary matrix, each columns is a different template, size n*d with d>r
step_w : float
Step size for gradient descent for g1
iterdual : int, optional
Number of total loop to update g1 and g2 before updating W. The default is 10.
Returns
-------
g : 2d tn.tensor
Dual variable, size n*r, updated
"""
rep_g = tn.clone(g).detach()
for _ in range(iterdual):
g_grad = gradh_pb_dual(rep_g, rho2, w, data, ent, matK)
rep_g = rep_g + step_h*g_grad
return rep_g
def solH_primal(g, w, rho2):
h = -grad_e_dual(-tn.matmul(tn.transpose(w, 0, 1), g)/rho2)
return h
def optim_h(w, h_init, data, sigma, dictionary, reg, ent, rho1, rho2, cost,
step_h = 1e-3, iterdual = 500, method = 'sinkhorn',
numItermax = 1000, thr = 1e-9, verbose = False):
init_pb = pb(w = w, h = h_init, sigma = sigma, data = data, dictionary = dictionary, reg = reg,
ent = ent, rho1 = rho1, rho2=rho2, cost = cost, method = method,
numItermax= numItermax, thr = thr, verbose = verbose)
#Temporaire, pendant phase de test, à enlever
print("Coût initial : " + str(float(init_pb)))
n, r = w.shape
_, m = h_init.shape
""" assert n, m = data.shape """
_, d = dictionary.shape
matK = compute_matK(cost, ent)
g = tn.rand((n, m), dtype=tn.float64) # Peut se remplacer par un g_init
g = simplex_norm(g)
g = optim_pb_dual_h(g, rho2, w, data, sigma, dictionary, ent, matK,
step_h, iterdual)
h = solH_primal(g, w, rho2)
return h
\ No newline at end of file
......@@ -8,7 +8,7 @@ import torch as tn
from wdnmf_dual.methods.base_function import *
from wdnmf_dual.methods.grad_pb_dual import *
def optim_pb_dual(g1, g2, rho1, h, data, sigma, dictionary, reg, ent, matK,
def optim_pb_dual_w(g1, g2, rho1, h, data, sigma, dictionary, reg, ent, matK,
stepg1, stepg2, iterg1, iterg2, iterdual = 10):
"""
......@@ -79,8 +79,9 @@ def solW_primal(g1, g2, h, rho1):
return w
def optim_w(w_init, h, data, sigma, dictionary, reg, ent, rho1, rho2, cost,
method = 'sinkhorn', stepg1 = 2e-3, stepg2 = 2e-3, iterg1 = 1, iterg2 = 1, iterdual = 150, itermax = 1,
numItermax = 500, thr = 1e-9, verbose = False):
stepg1 = 2e-3, stepg2 = 2e-3, iterg1 = 1, iterg2 = 1, iterdual = 150,
method = 'sinkhorn', numItermax = 500,
thr = 1e-9, verbose = False):
init_pb = pb(w = w_init, h = h, sigma = sigma, data = data, dictionary = dictionary, reg = reg,
ent = ent, rho1 = rho1, rho2=rho2, cost = cost, method = method, numItermax= numItermax,
......@@ -101,9 +102,7 @@ def optim_w(w_init, h, data, sigma, dictionary, reg, ent, rho1, rho2, cost,
g2 = tn.rand((n, r), dtype=tn.float64)
g2 = simplex_norm(g2)
for i in range(itermax):
g1, g2 = optim_pb_dual(g1, g2, rho1, h, data, sigma, dictionary, reg, ent,
g1, g2 = optim_pb_dual_w(g1, g2, rho1, h, data, sigma, dictionary, reg, ent,
matK, stepg1, stepg2, iterg1, iterg2, iterdual)
w = solW_primal(g1, g2, h, rho1)
......
File added
File added
File added
No preview for this file type
......@@ -183,4 +183,9 @@ def grad2_pb_dual(g1, g2, rho1, h, sigma, dictionary, reg, ent, matK):
f_dico = dg.grad_f_dual(g2/reg, dico, ent, matK)
return -e-f_dico
\ No newline at end of file
def gradh_pb_dual(g, rho2, w, data, ent, matK):
wT = tn.transpose(w, 0, 1)
e = w@dg.grad_e_dual(-wT@g/rho2)
f = dg.grad_f_dual(g, data, ent, matK)
return -e-f
\ No newline at end of file
# -*- coding: utf-8 -*-
"""
Created on Mon Mar 20 09:54:05 2023
@author: rcornill
"""
import torch as tn
from wdnmf_dual.methods.base_function import *
from wdnmf_dual.methods.grad_pb_dual import *
from wdnmf_dual.methods.Optim_W import optim_w
from wdnmf_dual.methods.Optim_H import optim_h
def solver_wdnmf(data, dictionary,
reg, ent, rho1, rho2, cost, r,
w_init = None, h_init = None, sigma_init = None,
step_g1 = 1e-3, step_g2 = 1e-3, iter_g1 = 1, iter_g2 = 1,
iterdual_w = 1,
step_h = 1e-3, iter_h = 1, iterdual_h = 1,
itertotal = 1000,
method = 'sinkhorn', numItermax = 1000, thr = 1e-9,
verbose = False):
n, m = data.shape
if w_init is None:
w = simplex_norm(tn.rand((n, r), dtype=tn.float64))
else:
w = tn.clone(w_init)
if h_init is None:
h = simplex_norm(tn.rang((r, m), dtype=tn.float64))
else:
h = tn.clone(h_init)
if sigma_init is None:
sigma = [i for i in range(r)]
else:
sigma = sigma_init.copy()
for i in range(itertotal):
w = optim_w(w, h, data, sigma, dictionary, reg,
ent, rho1, rho2, cost, step_g1, step_g2, iter_g1, iter_g2,
iterdual_W, method, numItermax, thr, verbose)
h = optim_h(w, h, data, sigma, dictionary, reg,
ent, rho1, rho2, cost, step_h, iterdual_h, method,
numItermax, thr, verbose)
sigma
return (w, h, sigma)
\ No newline at end of file
# -*- coding: utf-8 -*-
"""
Created on Tue Mar 21 15:38:21 2023
@author: rcornill
"""
import torch as tn
import ot
from scipy.optimize import linear_sum_assignment
def assignment (loss_mat): #, optimal = False):
_, indice = linear_sum_assignment(loss_mat.detach().numpy())
return indice
def loss_matrix(w, dico, cost, ent, numItermax = 1000, method = 'sinkhorn',
thr = 1e-9, verbose = False, log = False, warn=False):
"""
Compute the distance between each rows of mat1 and rows of mat2.
Parameters
----------
mat1 : Float tn.tensor
2d tensor
mat2 : Float tn.tensor
2d tensor
epsilon : float, optional
A positive float. If it's 0, we are using the usual Lp Wass distance,
else we're using the entropic Wass distance and we then need a cost.
The default is 0.
cost : Float pt.tensor (array-like), optional
This is the cost matrice for entropic wass distance,
containing the distance from each point of supp(mat1) and supp(mat2).
Needed if epsilon != 0.
The default is None.
Returns
-------
loss_mat :
Matrice, each point (i, j) is the Wasserstein distance betwen mat1_i and mat2_j.
"""
m1,r1 = w.shape
m2,r2 = dico.shape
loss_mat=tn.ones(m1,m2)
for i in range(m1):
for j in range(m2):
loss_mat[i,j] = ot.bregman.sinkhorn2(a = w[:,j], b = dico[:,j],
M = cost, reg = ent, method = method,
numItermax = numItermax, stopThr = thr,
verbose = verbose,
log = log, warn = warn)
return loss_mat
def update_sigma(d, w, cost = None):
loss_mat = loss_matrix(w, d, cost)
sigma = assignment(loss_mat)
return sigma
......@@ -2,7 +2,8 @@ import wdnmf_dual as wd
import torch as tn
from wdnmf_dual.methods.grad_pb_dual import pb_dual, grad1_pb_dual, pb
from wdnmf_dual.methods.base_function import compute_matK, simplex_norm
from wdnmf_dual.methods.Optim_pb import optim_w
from wdnmf_dual.methods.Optim_W import optim_w
from wdnmf_dual.methods.Optim_H import optim_h
n = 3
m = 4
......@@ -20,7 +21,7 @@ for i in range(n):
cost[i, j] = (i-j)**2
matK = compute_matK(cost, ent)
"""
for i in range(5):
#aim = tn.ones(n, m) + tn.rand(n,m)
#aim = simplex_norm(aim)
......@@ -61,4 +62,23 @@ for i in range(5):
print("Coût après optim : " + str(float(out4)))
print("Coût optimal : " + str(float(pb(w = dictionary, h = h, sigma = sigma, data = data, dictionary = dictionary, reg = reg,
ent = ent, rho1 = rho1, rho2=rho2, cost = cost))))
print("~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~")
\ No newline at end of file
print("~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~")"""
#dictionary = tn.tensor([[0.05, 0.05],[0.9, 0.05], [0.05, 0.9]], dtype=tn.float64)
dictionary = tn.abs(tn.rand(n,2, dtype=tn.float64))
dictionary = simplex_norm(dictionary)
#w = tn.tensor([[0.9, 0.05],[0.05, 0.05], [0.05, 0.9]], dtype=tn.float64)
h_init = simplex_norm(tn.rand(r,m))
h = simplex_norm(tn.rand(r,m))
data = dictionary@h
out5 = optim_h(dictionary, h_init, data, sigma, dictionary, reg, ent, rho1, rho2, cost,
step_w = 1e-2, iterdual = 1000, itermax = 1,
numItermax = 1000, thr = 1e-9, verbose = False)
print(out5)
print(h)
print(h_init)
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment