Mentions légales du service

Skip to content
Snippets Groups Projects
Commit 7d386c57 authored by hhakim's avatar hhakim
Browse files

Add a script for showing truncated SVD accuracy versus the RCG.

It's a way somehow to show out the interest of a particuler FAµST.
parent a0acaa1e
No related branches found
No related tags found
No related merge requests found
# this script aims to show how SVD could be useful to get a better RCG than
# the full matrix while not totally losing the accuracy
# The SVD case here can be seen as a particular situation for a FAµST
from sys import argv
from matplotlib import pyplot as plt
import numpy as np
from numpy.linalg import norm
from numpy.random import rand
from threading import Thread
from scipy.linalg import svd
class TruncSvd(Thread):
def __init__(self, offset, M, U, S, V, r, errs, rcgs):
Thread.__init__(self)
self.offset = offset
self.M = M
self.U = U
self.V = V
self.S = S
self.r = r
self.errs = errs
self.rcgs = rcgs
def run(s):
S = np.zeros([s.U.shape[1], s.V.shape[0]])
for i in range(0, s.S.shape[0]):
S[i, i] = s.S[i]
# assert((abs(M - s.U.dot(S).dot(s.V))/abs(s.M) < .01).all())
while(s.r <= min(s.M.shape[0], s.M.shape[1])):
s.Mr = (s.U[:, 0:s.r].dot(S[0:s.r, 0:s.r])).dot(s.V[0:s.r, :])
s.errs[0, s.r-1] = norm(s.M-s.Mr, 'fro')/norm(s.M, 'fro')
s.rcgs[0, s.r-1] = s.M.shape[0] * s.M.shape[1]
s.rcgs[0, s.r-1] /= (s.r*(s.M.shape[0]+s.M.shape[1]))
s.r += s.offset
if __name__ == '__main__':
if(len(argv) < 3):
print("USAGE: ", argv[0], "<num_lines> <num_cols>")
exit(1)
m = int(argv[1])
n = int(argv[2])
nthreads = 8
errs = np.zeros([1, min(m, n)])
rcgs = np.zeros([1, min(m, n)])
M = rand(m, n)
U, S, V = svd(M)
ths = list()
for i in range(1, nthreads+1):
th = TruncSvd(nthreads, M, U, S, V, i, errs, rcgs)
ths.append(th)
th.start()
for th in ths:
th.join()
plt.scatter(rcgs[0, :], errs[0, :], s=1)
# plt.gca().set_xscale('log', basex=.5)
plt.gca().invert_xaxis()
plt.xlabel('Relative Complexity Gain (RCG)')
plt.ylabel('Relative Error')
plt.title('Relative Error over RCG of Truncated SVDs for a dense matrix M ('
+ str(m) + 'x' + str(n)+')')
# f = open("svd_err_vs_rcg_output_"+str(nthreads), "w")
# for i in range(0,errs.shape[1]):
# f.write(str(errs[0,i])+" "+str(rcgs[0,i])+"\n")
# f.close()
# plt.plot(rcgs, errs)
plt.show()
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment