Mentions légales du service
Skip to content
GitLab
Explore
Sign in
Primary navigation
Search or go to…
Project
faust
Manage
Activity
Members
Labels
Plan
Issues
Issue boards
Milestones
Wiki
Code
Merge requests
Repository
Branches
Commits
Tags
Repository graph
Compare revisions
Snippets
Build
Pipelines
Jobs
Pipeline schedules
Artifacts
Deploy
Releases
Package registry
Model registry
Operate
Environments
Terraform modules
Monitor
Incidents
Service Desk
Analyze
Value stream analytics
Contributor analytics
CI/CD analytics
Repository analytics
Model experiments
Help
Help
Support
GitLab documentation
Compare GitLab plans
Community forum
Contribute to GitLab
Provide feedback
Keyboard shortcuts
?
Snippets
Groups
Projects
Admin message
GitLab upgrade completed. Current version is 17.11.4.
Show more breadcrumbs
faust group
faust
Commits
1bbd9ee8
Commit
1bbd9ee8
authored
1 year ago
by
CARRIVAIN Pascal
Browse files
Options
Downloads
Patches
Plain Diff
format docstring + remove unused function
parent
ef667304
No related branches found
No related tags found
1 merge request
!1
review of fdb function
Pipeline
#985791
failed
1 year ago
Stage: test
Stage: pkg_purepy
Stage: pkg
Stage: pkg_test
Changes
1
Pipelines
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
wrapper/python/pyfaust/fdb/GB_param_generate.py
+58
-94
58 additions, 94 deletions
wrapper/python/pyfaust/fdb/GB_param_generate.py
with
58 additions
and
94 deletions
wrapper/python/pyfaust/fdb/GB_param_generate.py
+
58
−
94
View file @
1bbd9ee8
...
...
@@ -18,8 +18,6 @@ try:
except
ImportError
:
found_pytorch
=
False
warn
(
"
Did not find PyTorch, therefore use NumPy/SciPy.
"
)
import
os
import
sys
MAX
=
1e18
...
...
@@ -34,7 +32,7 @@ def _prime_range(a: int, b: int = None):
If b is ``None`` (default) consider interval [2, a).
Returns:
list
``
list
``
"""
if
b
is
None
:
start
,
end
=
2
,
a
...
...
@@ -68,27 +66,34 @@ def check_compatibility(b, c, type):
corresponding to three type of monotonicity.
Returns:
bool
``
bool
``
"""
if
type
==
"
square
"
:
return
b
==
c
if
type
==
"
expanding
"
:
el
if
type
==
"
expanding
"
:
return
b
<=
c
if
type
==
"
shrinking
"
:
el
if
type
==
"
shrinking
"
:
return
b
>=
c
else
:
raise
Exception
(
"
type must be either
'
square
'
,
"
+
"
'
expanding
'
or
'
shrinking
'
.
"
)
def
format_conversion
(
m
,
n
,
chainbc
,
weight
,
format
=
"
abcd
"
):
def
format_conversion
(
m
,
n
,
chainbc
,
weight
,
format
:
str
=
"
abcd
"
):
"""
Return a sequence of deformable butterfly factors
using the infomation of b and c
using the infomation of b and c
.
Args:
m, n: ``int``
Size of the matrix.
chainbc:
chainbc:
``list``
A sequence of pairs (b,c).
format: ``str``
Support 2 formats (a,b,c,d) (default) and (p,q,r,s,t)
format: ``str``, optional
Support 2 formats (a,b,c,d) (
"
abcd
"
is default)
and (p,q,r,s,t).
Returns:
``list``
"""
a
=
1
d
=
m
...
...
@@ -111,6 +116,9 @@ def format_conversion(m, n, chainbc, weight, format="abcd"):
)
elif
format
==
"
abcdpq
"
:
result
.
append
((
a
,
b
,
c
,
d
,
weight
[
i
],
weight
[
i
+
1
]))
else
:
raise
Exception
(
"
format must be either
'
abcd
'
,
"
+
"
'
pqrst
'
or
'
abcdpq
'
.
"
)
a
=
a
*
c
return
result
...
...
@@ -118,6 +126,12 @@ def format_conversion(m, n, chainbc, weight, format="abcd"):
def
factorize
(
n
):
"""
Return a dictionary storing all prime divisor
of n with their corresponding powers.
Args:
n: ``int``
Returns:
``dict``
"""
if
found_sympy
:
prime_ints
=
list
(
primerange
(
1
,
n
+
1
))
...
...
@@ -139,7 +153,17 @@ def factorize(n):
def
random_Euler_sum
(
n
,
k
):
# Return k nonnegative integers whose sum equals to n
"""
Return k nonnegative integers whose sum equals to n.
Args:
n: ``int``
Target sum.
k: ``int``
Number of nonnegative integers.
Returns:
``list``
"""
result
=
[
0
]
*
k
sample
=
np
.
random
.
randint
(
0
,
k
,
n
)
for
i
in
sample
:
...
...
@@ -177,9 +201,18 @@ class DebflyGen:
self
.
dp_table
=
np
.
zeros
((
m
+
1
,
n
+
1
))
self
.
dp_table_temp
=
np
.
zeros
((
m
+
1
,
n
+
1
))
def
random_debfly_chain
(
self
,
n_factors
,
format
=
"
abcd
"
):
def
random_debfly_chain
(
self
,
n_factors
,
format
:
str
=
"
abcd
"
):
"""
Return an uniformly random deformable butterfly chain
whose product is of size m x n has n_factors factors.
whose product is of size m x n has ``n_factors`` factors.
Args:
n_factors: ``int``
The number of factors.
format: ``str``, optional
"
abcd
"
is default.
Returns:
``list``
"""
decomp_m
=
factorize
(
self
.
m
)
decomp_n
=
factorize
(
self
.
n
)
...
...
@@ -233,13 +266,15 @@ class DebflyGen:
)
return
results
def
smallest_monotone_debfly_chain
(
self
,
n_factors
,
format
=
"
abcd
"
):
def
smallest_monotone_debfly_chain
(
self
,
n_factors
,
format
:
str
=
"
abcd
"
):
"""
Return a deformable butterfly chain whose product is of
size m x n has n_factors factors.
Args:
n_factors: ``int``
The number of factors.
format: ``str``, optional
"
abcd
"
is default.
"""
try
:
assert
n_factors
>
0
...
...
@@ -319,87 +354,15 @@ class DebflyGen:
)
def
optimized_deform_butterfly_mult_torch
(
input
,
num_mat
,
R_parameters
,
R_shapes
,
return_intermediates
:
bool
=
False
,
version
:
str
=
"
bmm
"
,
backend
:
str
=
'
numpy
'
,
):
"""
Less reshape than the original version.
Assume that input is 2D (n, in_size).
"""
n
=
input
.
shape
[
0
]
output
=
input
.
contiguous
()
intermediates
=
[
output
]
temp_p
=
0
for
m
in
range
(
num_mat
):
R_shape
=
R_shapes
[
m
]
output_size
,
input_size
,
row
,
col
,
diag
=
R_shape
[:]
num_p
=
col
*
output_size
nb_blocks
=
input_size
//
(
col
*
diag
)
if
version
==
"
pointwise
"
:
t
=
(
R_parameters
[
temp_p
:
temp_p
+
num_p
]
.
view
(
nb_blocks
,
diag
,
row
,
col
)
.
permute
(
0
,
2
,
3
,
1
)
)
if
found_pytorch
:
output
=
output
.
view
(
n
,
nb_blocks
,
1
,
col
,
diag
)
else
:
output
=
output
.
reshape
(
n
,
nb_blocks
,
1
,
col
,
diag
)
output
=
(
t
*
output
).
sum
(
dim
=-
2
)
elif
version
==
"
bmm
"
:
t
=
R_parameters
[
temp_p
:
temp_p
+
num_p
].
view
(
nb_blocks
*
diag
,
row
,
col
)
# (nb_blocks * diag, row, col)
output
=
(
output
.
reshape
(
n
,
nb_blocks
,
col
,
diag
)
.
transpose
(
-
1
,
-
2
)
.
reshape
(
n
,
-
1
,
col
)
)
if
found_pytorch
:
output
=
torch
.
bmm
(
output
.
transpose
(
0
,
1
),
t
.
transpose
(
2
,
1
))
else
:
output
=
np
.
einsum
(
"
ijk,ikl->ijl
"
,
output
.
transpose
(
0
,
1
),
t
.
transpose
(
2
,
1
)
)
output
=
output
.
reshape
(
nb_blocks
,
diag
,
n
,
row
).
permute
(
2
,
0
,
3
,
1
)
# (n, nb_blocks, row, diag)
elif
found_pytorch
and
version
==
"
conv1d
"
:
t
=
R_parameters
[
temp_p
:
temp_p
+
num_p
].
view
(
-
1
,
col
,
1
)
output
=
(
output
.
reshape
(
n
,
nb_blocks
,
col
,
diag
)
.
transpose
(
-
1
,
-
2
)
.
reshape
(
n
,
-
1
,
1
)
)
output
=
torch
.
nn
.
functional
.
conv1d
(
output
,
t
,
groups
=
nb_blocks
*
diag
)
output
=
output
.
view
(
n
,
nb_blocks
,
diag
,
row
).
transpose
(
-
1
,
-
2
)
else
:
raise
NotImplementedError
temp_p
+=
num_p
intermediates
.
append
(
output
)
return
(
output
.
reshape
(
n
,
output_size
)
if
not
return_intermediates
else
intermediates
)
# ----- Useful function to handle generalized butterfly chains -----
def
count_parameters
(
param_chain
):
"""
Return number of parameters.
Args:
param_chain: ``tuple``
A generalized butterfly chain.
def
count_parameters
(
param_chain
):
"""
Input: A generalized butterfly chain
Output: Number of parameters
Returns:
Number of parameters (``int``).
"""
assert
len
(
param_chain
)
>
0
count
=
0
...
...
@@ -428,6 +391,7 @@ def check_monotone(param_chain, rank):
param_chain:
A generalized butterfly chain and the intended rank.
rank: ``int``
Expected rank.
Returns:
bool
...
...
This diff is collapsed.
Click to expand it.
Preview
0%
Loading
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Save comment
Cancel
Please
register
or
sign in
to comment