Mentions légales du service

Skip to content
Snippets Groups Projects
Commit 9347725a authored by CARRIVAIN Pascal's avatar CARRIVAIN Pascal
Browse files

comment unused functions

parent 23f40c47
Branches
Tags
1 merge request!1review of fdb function
Pipeline #987769 passed
......@@ -15,6 +15,7 @@ except ImportError:
found_einops = False
warn("Did not find einops, therefore use NumPy.")
from pyfaust.fdb.GB_param_generate import *
from pyfaust.fdb.utils import param_mul_param
import os
......@@ -24,52 +25,41 @@ if "GB_DISABLE_EINOPS" in dict(os.environ).keys():
print("Disable einops.")
def param_mul_param(param1, param2):
return [
param1[0],
param1[1] * param2[1],
param1[2] * param2[2],
param2[3],
param1[4],
param2[5],
]
def random_generate(param, seed: int = None, backend: str = 'numpy'):
"""Random a set of butterfly factors.
Args:
param:
A list of params.
seed: ``int``, optional
backend: ``str``, optional
Use numpy (default) or pytorch to compute
SVD and QR decompositions.
"""
if backend == 'pytorch' and found_pytorch:
if seed:
gen = torch.Generator()
gen.manual_seed(seed)
else:
gen = None
if len(param) == 6:
return torch.rand(
param[0],
param[3],
param[1] * param[4],
param[2] * param[5],
generator=gen,
)
return torch.rand(
param[0], param[3], param[1], param[2], generator=gen
)
else:
np.random.seed(seed if seed else 1)
if len(param) == 6:
return np.random.rand(
param[0], param[3], param[1] * param[4], param[2] * param[5]
)
return np.random.rand(param[0], param[3], param[1], param[2])
# def random_generate(param, seed: int = None, backend: str = 'numpy'):
# """Random a set of butterfly factors.
# Args:
# param:
# A list of params.
# seed: ``int``, optional
# backend: ``str``, optional
# Use numpy (default) or pytorch to compute
# SVD and QR decompositions.
# """
# if backend == 'pytorch' and found_pytorch:
# if seed:
# gen = torch.Generator()
# gen.manual_seed(seed)
# else:
# gen = None
# if len(param) == 6:
# return torch.rand(
# param[0],
# param[3],
# param[1] * param[4],
# param[2] * param[5],
# generator=gen,
# )
# return torch.rand(
# param[0], param[3], param[1], param[2], generator=gen
# )
# else:
# np.random.seed(seed if seed else 1)
# if len(param) == 6:
# return np.random.rand(
# param[0], param[3], param[1] * param[4], param[2] * param[5]
# )
# return np.random.rand(param[0], param[3], param[1], param[2])
def twiddle_mul_twiddle(l_twiddle, r_twiddle, l_param,
......@@ -179,30 +169,30 @@ def twiddle_to_dense(twiddle, backend: str = 'numpy'):
return output.reshape(a, d, b, n).swapaxes(1, 2).reshape(a * d * b, n)
def densification(twiddle_list, param_list, backend: str = 'numpy'):
"""Compute product of twiddle.
Args:
twiddle_list: ``list``
List of twiddles.
param_list: ``list``
List of params.
backend: ``str``, optional
Use numpy (default) or pytorch to compute
SVD and QR decompositions.
Returns:
Product of twiddles (``numpy.ndarray`` or ``torch.tensor```).
"""
a, b, c, d, p, q = param_list[0]
n = a * b * d * p
if backend == 'pytorch' and found_pytorch:
output = torch.ones(1, n, 1, 1)
else:
output = np.full(n, 1.0).reshape(1, n, 1, 1)
current_param = [1, 1, 1, n, 1, 1]
for twiddle, param in zip(twiddle_list, param_list):
output = twiddle_mul_twiddle(output, twiddle,
current_param, param, backend)
current_param = param_mul_param(current_param, param)
return output
# def densification(twiddle_list, param_list, backend: str = 'numpy'):
# """Compute product of twiddle.
# Args:
# twiddle_list: ``list``
# List of twiddles.
# param_list: ``list``
# List of params.
# backend: ``str``, optional
# Use numpy (default) or pytorch to compute
# SVD and QR decompositions.
# Returns:
# Product of twiddles (``numpy.ndarray`` or ``torch.tensor```).
# """
# a, b, c, d, p, q = param_list[0]
# n = a * b * d * p
# if backend == 'pytorch' and found_pytorch:
# output = torch.ones(1, n, 1, 1)
# else:
# output = np.full(n, 1.0).reshape(1, n, 1, 1)
# current_param = [1, 1, 1, n, 1, 1]
# for twiddle, param in zip(twiddle_list, param_list):
# output = twiddle_mul_twiddle(output, twiddle,
# current_param, param, backend)
# current_param = param_mul_param(current_param, param)
# return output
......@@ -123,33 +123,33 @@ def format_conversion(m, n, chainbc, weight, format: str="abcd"):
return result
def factorize(n):
"""Return a dictionary storing all prime divisor
of n with their corresponding powers.
Args:
n: ``int``
Returns:
``dict``
"""
if found_sympy:
prime_ints = list(primerange(1, n + 1))
else:
prime_ints = _prime_range(1, n + 1)
print(n + 1)
print(prime_ints)
result = {}
index = 0
while n > 1:
if n % prime_ints[index] == 0:
k = 0
while n % prime_ints[index] == 0:
n = n // prime_ints[index]
k = k + 1
result[prime_ints[index]] = k
index = index + 1
return result
# def factorize(n):
# """Return a dictionary storing all prime divisor
# of n with their corresponding powers.
# Args:
# n: ``int``
# Returns:
# ``dict``
# """
# if found_sympy:
# prime_ints = list(primerange(1, n + 1))
# else:
# prime_ints = _prime_range(1, n + 1)
# print(n + 1)
# print(prime_ints)
# result = {}
# index = 0
# while n > 1:
# if n % prime_ints[index] == 0:
# k = 0
# while n % prime_ints[index] == 0:
# n = n // prime_ints[index]
# k = k + 1
# result[prime_ints[index]] = k
# index = index + 1
# return result
def random_Euler_sum(n, k):
......@@ -201,37 +201,37 @@ class DebflyGen:
self.dp_table = np.zeros((m + 1, n + 1))
self.dp_table_temp = np.zeros((m + 1, n + 1))
def random_debfly_chain(self, n_factors, format: str="abcd"):
"""Return an uniformly random deformable butterfly chain
whose product is of size m x n has ``n_factors`` factors.
Args:
n_factors: ``int``
The number of factors.
format: ``str``, optional
"abcd" is default.
Returns:
``list``
"""
decomp_m = factorize(self.m)
decomp_n = factorize(self.n)
b_chain = [1] * n_factors
c_chain = [1] * n_factors
weight = [1] + [self.rank] * (n_factors - 1) + [1]
for divisor, powers in decomp_m.items():
random_partition = random_Euler_sum(powers, n_factors)
for i in range(len(b_chain)):
b_chain[i] = b_chain[i] * (divisor ** random_partition[i])
for divisor, powers in decomp_n.items():
random_partition = random_Euler_sum(powers, n_factors)
for i in range(len(c_chain)):
c_chain[i] = c_chain[i] * (divisor ** random_partition[i])
chainbc = [(b_chain[i], c_chain[i]) for i in range(n_factors)]
return format_conversion(
self.m, self.n, chainbc, weight, format=format
)
# def random_debfly_chain(self, n_factors, format: str="abcd"):
# """Return an uniformly random deformable butterfly chain
# whose product is of size m x n has ``n_factors`` factors.
# Args:
# n_factors: ``int``
# The number of factors.
# format: ``str``, optional
# "abcd" is default.
# Returns:
# ``list``
# """
# decomp_m = factorize(self.m)
# decomp_n = factorize(self.n)
# b_chain = [1] * n_factors
# c_chain = [1] * n_factors
# weight = [1] + [self.rank] * (n_factors - 1) + [1]
# for divisor, powers in decomp_m.items():
# random_partition = random_Euler_sum(powers, n_factors)
# for i in range(len(b_chain)):
# b_chain[i] = b_chain[i] * (divisor ** random_partition[i])
# for divisor, powers in decomp_n.items():
# random_partition = random_Euler_sum(powers, n_factors)
# for i in range(len(c_chain)):
# c_chain[i] = c_chain[i] * (divisor ** random_partition[i])
# chainbc = [(b_chain[i], c_chain[i]) for i in range(n_factors)]
# return format_conversion(
# self.m, self.n, chainbc, weight, format=format
# )
@staticmethod
def enumeration_inner_chain(m, n_factors):
......@@ -248,23 +248,23 @@ class DebflyGen:
)
return results
def enumeration_debfly_chain(self, n_factors, format="abcd"):
results = []
weight = [1] + [self.rank] * (n_factors - 1) + [1]
chain_b = DebflyGen.enumeration_inner_chain(self.m, n_factors)
chain_c = DebflyGen.enumeration_inner_chain(self.n, n_factors)
for f1 in chain_b:
for f2 in chain_c:
results.append(
format_conversion(
self.m,
self.n,
list(zip(f1, f2)),
weight,
format=format,
)
)
return results
# def enumeration_debfly_chain(self, n_factors, format="abcd"):
# results = []
# weight = [1] + [self.rank] * (n_factors - 1) + [1]
# chain_b = DebflyGen.enumeration_inner_chain(self.m, n_factors)
# chain_c = DebflyGen.enumeration_inner_chain(self.n, n_factors)
# for f1 in chain_b:
# for f2 in chain_c:
# results.append(
# format_conversion(
# self.m,
# self.n,
# list(zip(f1, f2)),
# weight,
# format=format,
# )
# )
# return results
def smallest_monotone_debfly_chain(self, n_factors, format: str="abcd"):
"""Return a deformable butterfly chain whose product is of
......@@ -354,78 +354,78 @@ class DebflyGen:
)
def count_parameters(param_chain):
"""Return number of parameters.
Args:
param_chain: ``tuple``
A generalized butterfly chain.
Returns:
Number of parameters (``int``).
"""
assert len(param_chain) > 0
count = 0
for params in param_chain:
if len(params) == 4:
count += params[0] * params[1] * params[2] * params[3]
elif len(params) == 5:
count += params[0] * params[3]
else:
count += (
params[0]
* params[1]
* params[2]
* params[3]
* params[4]
* params[5]
)
return count
def check_monotone(param_chain, rank):
"""Decide if the chain is monotone
(defined as in the paper Deformable butterfly).
Args:
param_chain:
A generalized butterfly chain and the intended rank.
rank: ``int``
Expected rank.
Returns:
bool
"""
assert len(param_chain) > 0
weight = [1] + [rank] * (len(param_chain) - 1) + [1]
if len(param_chain[0]) == 4:
m = param_chain[0][0] * param_chain[0][1] * param_chain[0][3]
n = param_chain[-1][0] * param_chain[-1][2] * param_chain[-1][3]
else:
m = param_chain[0][0]
n = param_chain[-1][1]
if m == n:
type = "square"
elif m > n:
type = "shrinking"
else:
type = "expanding"
for i in range(len(param_chain)):
if len(param_chain[i]) == 4:
b = param_chain[i][1] // weight[i]
c = param_chain[i][2] // weight[i + 1]
if not check_compatibility(b, c, type):
return False
elif len(param_chain[i]) == 5:
b = param_chain[i][2] // weight[i]
c = param_chain[i][3] // weight[i + 1]
if not check_compatibility(b, c, type):
return False
else:
if not check_compatibility(
param_chain[i][1], param_chain[i][2], type
):
return False
return True
# def count_parameters(param_chain):
# """Return number of parameters.
# Args:
# param_chain: ``tuple``
# A generalized butterfly chain.
# Returns:
# Number of parameters (``int``).
# """
# assert len(param_chain) > 0
# count = 0
# for params in param_chain:
# if len(params) == 4:
# count += params[0] * params[1] * params[2] * params[3]
# elif len(params) == 5:
# count += params[0] * params[3]
# else:
# count += (
# params[0]
# * params[1]
# * params[2]
# * params[3]
# * params[4]
# * params[5]
# )
# return count
# def check_monotone(param_chain, rank):
# """Decide if the chain is monotone
# (defined as in the paper Deformable butterfly).
# Args:
# param_chain:
# A generalized butterfly chain and the intended rank.
# rank: ``int``
# Expected rank.
# Returns:
# bool
# """
# assert len(param_chain) > 0
# weight = [1] + [rank] * (len(param_chain) - 1) + [1]
# if len(param_chain[0]) == 4:
# m = param_chain[0][0] * param_chain[0][1] * param_chain[0][3]
# n = param_chain[-1][0] * param_chain[-1][2] * param_chain[-1][3]
# else:
# m = param_chain[0][0]
# n = param_chain[-1][1]
# if m == n:
# type = "square"
# elif m > n:
# type = "shrinking"
# else:
# type = "expanding"
# for i in range(len(param_chain)):
# if len(param_chain[i]) == 4:
# b = param_chain[i][1] // weight[i]
# c = param_chain[i][2] // weight[i + 1]
# if not check_compatibility(b, c, type):
# return False
# elif len(param_chain[i]) == 5:
# b = param_chain[i][2] // weight[i]
# c = param_chain[i][3] // weight[i + 1]
# if not check_compatibility(b, c, type):
# return False
# else:
# if not check_compatibility(
# param_chain[i][1], param_chain[i][2], type
# ):
# return False
# return True
......@@ -27,17 +27,22 @@ class Factor:
def partial_prod_deformable_butterfly_params(gb_params, low, high):
"""
Closed form expression of partial matrix_product
of butterfly supports.
We name S_L, ..., S_1 the butterfly supports of
size 2^L, represented as binary matrices.
Then, the method computes the partial matrix_product S_{high-1} ... S_low.
:param supscript: list of sizes of factors
:param subscript: list of sizes of blocks
:param low: int
:param high: int, excluded
:return: numpy array, binary matrix
r"""Return closed form expression of partial matrix_product
of butterfly supports. We name $S_L, \cdots, S_1$ the butterfly
supports of size $2^L$, represented as binary matrices.
Then, the method computes the partial matrix
product $S_{high-1} \cots S_{low}$.
Args:
gb_params: ``list``
List of sizes of factors.
low: ``int``
First factor.
high: ``int``
Last factor (not included).
Returns:
binary matrix (``np.ndarray``)
"""
params = gb_params[low: high + 1]
result = [1] * 6
......@@ -92,11 +97,11 @@ def compatible_chain_gb_params(gb_params):
return True
def redundant_chain_gb_params(gb_params):
for pm in gb_params:
if not redundant_gb_params(pm):
return False
return True
# def redundant_chain_gb_params(gb_params):
# for pm in gb_params:
# if not redundant_gb_params(pm):
# return False
# return True
def param_mul_param(param1, param2):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment