Mentions légales du service

Skip to content
Snippets Groups Projects
Commit 1bbd9ee8 authored by CARRIVAIN Pascal's avatar CARRIVAIN Pascal
Browse files

format docstring + remove unused function

parent ef667304
No related branches found
No related tags found
1 merge request!1review of fdb function
Pipeline #985791 failed
......@@ -18,8 +18,6 @@ try:
except ImportError:
found_pytorch = False
warn("Did not find PyTorch, therefore use NumPy/SciPy.")
import os
import sys
MAX = 1e18
......@@ -34,7 +32,7 @@ def _prime_range(a: int, b: int = None):
If b is ``None`` (default) consider interval [2, a).
Returns:
list
``list``
"""
if b is None:
start, end = 2, a
......@@ -68,27 +66,34 @@ def check_compatibility(b, c, type):
corresponding to three type of monotonicity.
Returns:
bool
``bool``
"""
if type == "square":
return b == c
if type == "expanding":
elif type == "expanding":
return b <= c
if type == "shrinking":
elif type == "shrinking":
return b >= c
else:
raise Exception("type must be either 'square'," +
" 'expanding' or 'shrinking'.")
def format_conversion(m, n, chainbc, weight, format="abcd"):
def format_conversion(m, n, chainbc, weight, format: str="abcd"):
"""Return a sequence of deformable butterfly factors
using the infomation of b and c
using the infomation of b and c.
Args:
m, n: ``int``
Size of the matrix.
chainbc:
chainbc: ``list``
A sequence of pairs (b,c).
format: ``str``
Support 2 formats (a,b,c,d) (default) and (p,q,r,s,t)
format: ``str``, optional
Support 2 formats (a,b,c,d) ("abcd" is default)
and (p,q,r,s,t).
Returns:
``list``
"""
a = 1
d = m
......@@ -111,6 +116,9 @@ def format_conversion(m, n, chainbc, weight, format="abcd"):
)
elif format == "abcdpq":
result.append((a, b, c, d, weight[i], weight[i + 1]))
else:
raise Exception("format must be either 'abcd'," +
" 'pqrst' or 'abcdpq'.")
a = a * c
return result
......@@ -118,6 +126,12 @@ def format_conversion(m, n, chainbc, weight, format="abcd"):
def factorize(n):
"""Return a dictionary storing all prime divisor
of n with their corresponding powers.
Args:
n: ``int``
Returns:
``dict``
"""
if found_sympy:
prime_ints = list(primerange(1, n + 1))
......@@ -139,7 +153,17 @@ def factorize(n):
def random_Euler_sum(n, k):
# Return k nonnegative integers whose sum equals to n
"""Return k nonnegative integers whose sum equals to n.
Args:
n: ``int``
Target sum.
k: ``int``
Number of nonnegative integers.
Returns:
``list``
"""
result = [0] * k
sample = np.random.randint(0, k, n)
for i in sample:
......@@ -177,9 +201,18 @@ class DebflyGen:
self.dp_table = np.zeros((m + 1, n + 1))
self.dp_table_temp = np.zeros((m + 1, n + 1))
def random_debfly_chain(self, n_factors, format="abcd"):
def random_debfly_chain(self, n_factors, format: str="abcd"):
"""Return an uniformly random deformable butterfly chain
whose product is of size m x n has n_factors factors.
whose product is of size m x n has ``n_factors`` factors.
Args:
n_factors: ``int``
The number of factors.
format: ``str``, optional
"abcd" is default.
Returns:
``list``
"""
decomp_m = factorize(self.m)
decomp_n = factorize(self.n)
......@@ -233,13 +266,15 @@ class DebflyGen:
)
return results
def smallest_monotone_debfly_chain(self, n_factors, format="abcd"):
def smallest_monotone_debfly_chain(self, n_factors, format: str="abcd"):
"""Return a deformable butterfly chain whose product is of
size m x n has n_factors factors.
Args:
n_factors: ``int``
The number of factors.
format: ``str``, optional
"abcd" is default.
"""
try:
assert n_factors > 0
......@@ -319,87 +354,15 @@ class DebflyGen:
)
def optimized_deform_butterfly_mult_torch(
input,
num_mat,
R_parameters,
R_shapes,
return_intermediates: bool = False,
version: str = "bmm",
backend: str = 'numpy',
):
"""
Less reshape than the original version.
Assume that input is 2D (n, in_size).
"""
n = input.shape[0]
output = input.contiguous()
intermediates = [output]
temp_p = 0
for m in range(num_mat):
R_shape = R_shapes[m]
output_size, input_size, row, col, diag = R_shape[:]
num_p = col * output_size
nb_blocks = input_size // (col * diag)
if version == "pointwise":
t = (
R_parameters[temp_p: temp_p + num_p]
.view(nb_blocks, diag, row, col)
.permute(0, 2, 3, 1)
)
if found_pytorch:
output = output.view(n, nb_blocks, 1, col, diag)
else:
output = output.reshape(n, nb_blocks, 1, col, diag)
output = (t * output).sum(dim=-2)
elif version == "bmm":
t = R_parameters[temp_p: temp_p + num_p].view(
nb_blocks * diag, row, col
) # (nb_blocks * diag, row, col)
output = (
output.reshape(n, nb_blocks, col, diag)
.transpose(-1, -2)
.reshape(n, -1, col)
)
if found_pytorch:
output = torch.bmm(output.transpose(0, 1), t.transpose(2, 1))
else:
output = np.einsum(
"ijk,ikl->ijl", output.transpose(0, 1), t.transpose(2, 1)
)
output = output.reshape(nb_blocks, diag, n, row).permute(
2, 0, 3, 1
) # (n, nb_blocks, row, diag)
elif found_pytorch and version == "conv1d":
t = R_parameters[temp_p: temp_p + num_p].view(-1, col, 1)
output = (
output.reshape(n, nb_blocks, col, diag)
.transpose(-1, -2)
.reshape(n, -1, 1)
)
output = torch.nn.functional.conv1d(
output, t, groups=nb_blocks * diag
)
output = output.view(n, nb_blocks, diag, row).transpose(-1, -2)
else:
raise NotImplementedError
temp_p += num_p
intermediates.append(output)
return (
output.reshape(n, output_size)
if not return_intermediates
else intermediates
)
# ----- Useful function to handle generalized butterfly chains -----
def count_parameters(param_chain):
"""Return number of parameters.
Args:
param_chain: ``tuple``
A generalized butterfly chain.
def count_parameters(param_chain):
"""
Input: A generalized butterfly chain
Output: Number of parameters
Returns:
Number of parameters (``int``).
"""
assert len(param_chain) > 0
count = 0
......@@ -428,6 +391,7 @@ def check_monotone(param_chain, rank):
param_chain:
A generalized butterfly chain and the intended rank.
rank: ``int``
Expected rank.
Returns:
bool
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment