Mentions légales du service

Skip to content
Snippets Groups Projects
Commit c0b5d639 authored by CORNILLET Remi's avatar CORNILLET Remi
Browse files

Optim W

parent 01203114
No related branches found
No related tags found
No related merge requests found
......@@ -58,8 +58,8 @@ def optim_pb_dual(g1, g2, rho1, h, data, sigma, dictionary, reg, ent, matK,
n1, m1 = g1.shape
n2, r = g2.shape
rep_g1 = tn.clone(g1, dtype = tn.float64)
rep_g2 = tn.clone(g2, dtype = tn.float64)
rep_g1 = tn.clone(g1).detach()
rep_g2 = tn.clone(g2).detach()
for _ in range(iterdual):
......@@ -68,19 +68,19 @@ def optim_pb_dual(g1, g2, rho1, h, data, sigma, dictionary, reg, ent, matK,
rep_g1 = rep_g1 + stepg1*g1_grad
for _ in range(iterg2):
g2_grad = grad2_pb_dual(g1, g2, rho1, h, sigma, dictionary, reg, ent, matK)
g2_grad = grad2_pb_dual(rep_g1, rep_g2, rho1, h, sigma, dictionary, reg, ent, matK)
rep_g2 = rep_g2 + stepg2*g2_grad
return (g1, g2)
return (rep_g1, rep_g2)
def solW_primal(g1, g2, h, rho1):
w = grad_e_dual((tn.matmul(g1, tn.transpose(h, 0, 1)) + g2 )/rho1)
w = -grad_e_dual(-(tn.matmul(g1, tn.transpose(h, 0, 1)) + g2 )/rho1)
return w
def optim_w(w_init, h, data, sigma, dictionary, reg, ent, rho1, rho2, cost,
method = 'sinkhorn', stepg1 = 1e-5, stepg2=1e-5, iterg1 = 1, iterg2 = 1, iterdual = 10, itermax = 10,
numItermax = 1000, thr = 1e-9, verbose = False):
method = 'sinkhorn', stepg1 = 5e-2, stepg2 = 5e-2, iterg1 = 1, iterg2 = 1, iterdual = 150, itermax = 1,
numItermax = 1000, thr = 1e-15, verbose = False):
init_pb = pb(w = w_init, h = h, sigma = sigma, data = data, dictionary = dictionary, reg = reg,
ent = ent, rho1 = rho1, rho2=rho2, cost = cost, method = method, numItermax= numItermax,
......@@ -93,15 +93,18 @@ def optim_w(w_init, h, data, sigma, dictionary, reg, ent, rho1, rho2, cost,
""" assert n, m = data.shape """
_, d = dictionary.shape
matK = compute_matK(cost, ent)
g1 = tn.rand((n, m), dtype=tn.float64) # Peut se remplacer par un g_init
g1 = simplex_norm(g1)
g2 = tn.rand((n, r), dtype=tn.float64)
g2 = simplex_norm(g2)
matK = compute_matK(cost, ent)
for i in range(itermax):
for i in range(itermax):
g1, g2 = optim_pb_dual(g1, g2, rho1, h, data, sigma, dictionary, reg, ent,
matK, stepg1, stepg2, iterg1, iterg2, iterdual)
w = solW_primal(g1, g2, h, rho1)
w = solW_primal(g1, g2, h, rho1)
return w
\ No newline at end of file
File added
File added
File added
File added
......@@ -103,7 +103,7 @@ def e_dual(v):
def grad_e_dual(v):
"""
Return softmax(v), or \frac{exp(v[i])}{\Sigma_j(exp([j]))}
Return -softmax(v), or -\frac{exp(v[i])}{\Sigma_j(exp([j]))}
Parameters
----------
......
......@@ -57,11 +57,11 @@ def pb(w, h, sigma, data, dictionary, reg, ent, rho1, rho2, cost,
approx = tn.matmul(w, h)
temp = dg.f_mat(source = approx, aim = data, cost = cost, ent = ent,
method = method, numItermax= numItermax, thr = thr,
verbose = False, log = False, warn = False)
verbose = False)
temp -= rho1 * dg.e_mat(w) + rho2 * dg.e_mat(h)
temp += reg * dg.f_mat(source = approx, aim = dictionary[:,sigma], cost = cost, ent = ent,
temp += reg * dg.f_mat(source = w, aim = dictionary[:,sigma], cost = cost, ent = ent,
method = method, numItermax = numItermax, thr = thr,
verbose = False, log = False, warn = False)
verbose = False)
return temp
def pb_dual(g1, g2, rho1, h, data, sigma, dictionary, reg, ent, matK):
......@@ -111,6 +111,8 @@ def pb_dual(g1, g2, rho1, h, data, sigma, dictionary, reg, ent, matK):
def grad1_pb_dual(g1, g2, rho1, h, data, ent, matK):
"""
Compute the gradient w.r.t g1 of the dual pb
function is f(g1, g2)= \rho_1 E_*((g1@H^T + g2)/\rho_1) -\Sigma_i f^*_{Y_i}(g1_i) - \lambda \Sigma_j f^*_{D_j}(g2_j)
supposed to be h(g1, g2) = \nabla E_*((g1@H^T + g2)/\rho_1)@H -\Sigma_i f^*_{Y_i}(g1_i)
Parameters
----------
......@@ -141,7 +143,7 @@ def grad1_pb_dual(g1, g2, rho1, h, data, ent, matK):
e = dg.grad_e_dual((g1@hT + g2)/rho1)@h
f_data = dg.grad_f_dual(g1, data, ent, matK)
return e+f_data
return e-f_data
def grad2_pb_dual(g1, g2, rho1, h, sigma, dictionary, reg, ent, matK):
"""
......@@ -180,5 +182,5 @@ def grad2_pb_dual(g1, g2, rho1, h, sigma, dictionary, reg, ent, matK):
dico = dictionary[:,sigma]
f_dico = dg.grad_f_dual(g2/reg, dico, ent, matK)
return e+f_dico
return e-f_dico
\ No newline at end of file
......@@ -7,7 +7,7 @@ Created on Tue Jan 17 13:57:40 2023
import numpy as np
import torch as tn
from .base_function import *
from wdnmf_dual.methods.base_function import *
def test_simplex_prox_mat_projection():
x = tn.ones(2, 3, dtype=tn.float64)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment