Mentions légales du service

Skip to content
Snippets Groups Projects
Commit ef667304 authored by CARRIVAIN Pascal's avatar CARRIVAIN Pascal
Browse files

add benchmark of fdb backend

parent 5f09f81c
No related branches found
No related tags found
1 merge request!1review of fdb function
Pipeline #985703 failed
#!/usr/bin/python
# -*- coding: utf-8 -*-
import gc
import getopt
import json
import os
import sys
import time
import matplotlib
# matplotlib.use('Agg')
import matplotlib.pyplot as plt
import matplotlib.colors as clr
cmap = plt.get_cmap("rainbow")
plt.rcParams.update({"font.size": 14})
plt.rcParams.update({"lines.linewidth": 3})
plt.rcParams.update({"lines.markersize": 6})
from pyfaust import fact
import numpy as np
import torch
def argitem(x, v):
return np.where(x == v)[0][0]
# return np.argmin(np.absolute(x - v))
def main(argv):
# default parameters
run_fdb = False
seed = 1
# read parameters from command line
try:
opts, args = getopt.getopt(
argv,
"h",
[
"run_fdb",
"seed=",
],
)
except getopt.GetoptError as err:
print(err)
sys.exit(2)
for opt, arg in opts:
if opt == "-h":
print("example:\n")
print(
"python3 benchmark_fdb.py --run_fdb --seed 1\n"
)
sys.exit()
elif opt == "--run_fdb":
run_fdb = True
elif opt == "--seed":
seed = int(arg)
else:
pass
np.random.seed(seed)
# Compute 'nrepeats' of fdb
nrepeats = 30
if run_fdb:
# Save data in a dict
try:
with open("fdb.json", "r") as in_file:
data = json.load(in_file)
except IOError:
data = {}
M = np.power(2, np.arange(1, 11 + 1, 1))
N = np.power(2, np.arange(1, 11 + 1, 1))
n_factors = np.array([2, 4, 8, 16])
rank = 1
backend = ["numpy", "pytorch"]
elapsed_time = np.full((M.shape[0], N.shape[0], n_factors.shape[0], len(backend)), 1e9)
iqr = np.full((M.shape[0], N.shape[0], n_factors.shape[0], len(backend)), 1e9)
data = {}
for m in M:
# Save data to dict / json file
strr = "{0:d}rows".format(m,)
if strr not in data.keys():
data[strr] = {}
for n in N:
# Save data to dict / json file
strc = "{0:d}columns".format(n)
if strc not in data[strr].keys():
data[strr][strc] = {}
for f in n_factors:
strf = "{0:d}factors".format(f)
if strf not in data[strr][strc].keys():
data[strr][strc][strf] = {}
# Elapsed time
matrix0 = torch.randn(m, n)
matrix1 = matrix0.numpy()
for j in range(len(backend)):
if backend[j] not in data[strr][strc][strf].keys():
data[strr][strc][strf][backend[j]] = []
if backend[j] == 'numpy':
matrix = np.random.randn(m, n)
else:
matrix = torch.randn(m, n)
for i in range(nrepeats):
if backend[j] == 'pytorch':
start = time.time()
F = fact.fdb(matrix0, n_factors=f, rank=rank, backend=backend[j])
end = time.time()
if backend[j] == 'numpy':
start = time.time()
F = fact.fdb(matrix1, n_factors=f, rank=rank, backend=backend[j])
end = time.time()
print(
"{0:s}, repeat={1:d}, shape={2:d}x{3:d}, {4:d} factors, time={5:f}".format(
backend[j],
i,
m,
n,
f,
end - start,
)
)
data[strr][strc][strf][backend[j]].append(end - start)
tmp = np.array(
data[strr][strc][strf][backend[j]]
)
elapsed_time[argitem(M, m), argitem(N, n), argitem(n_factors, f), j] = np.median(tmp)
q75, q25 = np.percentile(tmp, [75, 25])
iqr[argitem(M, m), argitem(N, n), argitem(n_factors, f), j] = q75 - q25
del matrix, tmp
# Clean
gc.collect()
# Plot best backend
for f in n_factors:
fig, ax = plt.subplots(1, 1, constrained_layout=False)
for m in M:
for n in N:
times = np.array(elapsed_time[argitem(M, m), argitem(N, n), argitem(n_factors, f), :])
for t in range(times.shape[0]):
# Do not keep the median if it is greater than iqr / 10
if iqr[argitem(M, m), argitem(N, n), argitem(n_factors, f), t] > (times[t] / 10.0):
times[t] = 1e9 * np.max(times)
argsort = np.argsort(times)
for i in range(len(backend)):
ax.plot(
n,
m,
linestyle="",
marker="o",
markersize=12.5
* times[argsort[0]]
/ times[argsort[i]],
color=cmap(argsort[i] / (len(backend) - 1)),
)
for i in range(len(backend)):
ax.plot(
1e-1,
1e-1,
linestyle="",
marker="o",
markersize=8.0,
color=cmap(i / (len(backend) - 1)),
label=backend[i],
)
ax.set_xlim(1, 2 * np.max(N))
ax.set_xscale("log")
ax.set_xlabel("# of columns")
ax.set_ylim(1, 10 * np.max(M))
ax.set_yscale("log")
ax.set_ylabel("# of rows")
ax.set_title("min(duration) / duration(backend)")
ax.legend(title="backend", ncol=2, loc="upper left",
columnspacing=1*0.25, handletextpad=4*0.01)
plt.savefig("best_backend_fdb_{0:d}factors.png".format(f), dpi=600)
fig.clf(),
plt.close()
# Save data to json file
with open("fdb.json", "w") as out_file:
json.dump(data, out_file, indent=1)
if __name__ == "__main__":
main(sys.argv[1:])
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment