Mentions légales du service

Skip to content
Snippets Groups Projects
Commit c060da2f authored by hhakim's avatar hhakim
Browse files

Add an option to pick operator norm or Frobenius to compute relative error.

Calculate the norm directly with singular values instead of computing the approximate of M and then calculate the norm with it.
parent 7ffaccb5
No related branches found
No related tags found
No related merge requests found
...@@ -27,25 +27,48 @@ class TruncSvd(Thread): ...@@ -27,25 +27,48 @@ class TruncSvd(Thread):
self.rcs = rcs self.rcs = rcs
def run(s): def run(s):
global norm_ord
S = np.zeros([s.U.shape[1], s.V.shape[0]]) S = np.zeros([s.U.shape[1], s.V.shape[0]])
if(norm_ord == "fro"): nfroM = 0
for i in range(0, s.S.shape[0]): for i in range(0, s.S.shape[0]):
S[i, i] = s.S[i] S[i, i] = s.S[i]
if(norm_ord == "fro"): nfroM += S[i,i]**2
if(norm_ord == "fro"): nfroM = np.sqrt(nfroM)
# assert((abs(M - s.U.dot(S).dot(s.V))/abs(s.M) < .01).all()) # assert((abs(M - s.U.dot(S).dot(s.V))/abs(s.M) < .01).all())
while(s.r <= min(s.M.shape[0], s.M.shape[1])): while(s.r < min(s.M.shape[0], s.M.shape[1])):
s.Mr = (s.U[:, 0:s.r].dot(S[0:s.r, 0:s.r])).dot(s.V[0:s.r, :]) #s.Mr = (s.U[:, 0:s.r].dot(S[0:s.r, 0:s.r])).dot(s.V[0:s.r, :])
#s.errs[0, s.r-1] = S[s.r-1,s.r-1] / S[0,0] #s.errs[0, s.r-1] = norm(s.M-s.Mr, norm_ord)/norm(s.M, norm_ord)
s.errs[0, s.r-1] = norm(s.M-s.Mr, 'fro')/norm(s.M, 'fro') if(norm_ord == 2):
err = S[s.r, s.r]/S[0,0]
#assert(err == s.errs[0, s.r-1])
#print(s.errs[0, s.r-1], err)
s.errs[0, s.r-1] = err
elif(norm_ord == "fro"):
err = 0
for j in range(s.r, s.S.shape[0]):
err += s.S[j]**2
err = np.sqrt(err)/nfroM
#print(s.errs[0, s.r-1], err)
#assert(err == s.errs[0, s.r-1])
s.errs[0, s.r-1] = err
s.rcs[0, s.r-1] = (s.r*(s.M.shape[0]+s.M.shape[1])) s.rcs[0, s.r-1] = (s.r*(s.M.shape[0]+s.M.shape[1]))
s.rcs[0, s.r-1] /= s.M.shape[0] * s.M.shape[1] s.rcs[0, s.r-1] /= s.M.shape[0] * s.M.shape[1]
s.r += s.offset s.r += s.offset
if __name__ == '__main__': if __name__ == '__main__':
if(len(argv) < 3): if(len(argv) < 4):
print("USAGE: ", argv[0], "<num_lines> <num_cols>") print("USAGE: ", argv[0], "<num_lines> <num_cols> <norm_ord>")
print("norm_ord is fro or 2")
exit(1) exit(1)
m = int(argv[1]) m = int(argv[1])
n = int(argv[2]) n = int(argv[2])
global norm_ord
if(argv[3] == "2"):
norm_ord=2
elif(argv[3] == "fro"):
norm_ord="fro"
else:
raise("Error: norm must be 2 or fro.")
nthreads = 8 nthreads = 8
errs = np.zeros([1, min(m, n)]) errs = np.zeros([1, min(m, n)])
rcs = np.zeros([1, min(m, n)]) rcs = np.zeros([1, min(m, n)])
...@@ -65,9 +88,9 @@ if __name__ == '__main__': ...@@ -65,9 +88,9 @@ if __name__ == '__main__':
plt.ylabel('Relative Error') plt.ylabel('Relative Error')
plt.title('Relative Error over RC of Truncated SVDs for a dense matrix M (' plt.title('Relative Error over RC of Truncated SVDs for a dense matrix M ('
+ str(m) + 'x' + str(n)+')') + str(m) + 'x' + str(n)+')')
# f = open("svd_err_vs_rcg_output_"+str(nthreads), "w") f = open("svd_err_vs_rc_output_"+str(nthreads), "w")
# for i in range(0,errs.shape[1]): for i in range(0,errs.shape[1]):
# f.write(str(errs[0,i])+" "+str(rcs[0,i])+"\n") f.write(str(errs[0,i])+" "+str(rcs[0,i])+"\n")
# f.close() f.close()
#plt.plot(rcs, errs) #plt.plot(rcs, errs)
plt.show() plt.show()
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment