Mentions légales du service
Skip to content
GitLab
Explore
Sign in
Primary navigation
Search or go to…
Project
D
declearn2
Manage
Activity
Members
Labels
Plan
Issues
Issue boards
Milestones
Wiki
Code
Merge requests
Repository
Branches
Commits
Tags
Repository graph
Compare revisions
Snippets
Build
Pipelines
Jobs
Pipeline schedules
Artifacts
Deploy
Releases
Package registry
Container Registry
Model registry
Operate
Terraform modules
Monitor
Service Desk
Analyze
Value stream analytics
Contributor analytics
CI/CD analytics
Repository analytics
Model experiments
Help
Help
Support
GitLab documentation
Compare GitLab plans
Community forum
Contribute to GitLab
Provide feedback
Keyboard shortcuts
?
Snippets
Groups
Projects
Show more breadcrumbs
Magnet
DecLearn
declearn2
Commits
580ce463
Verified
Commit
580ce463
authored
9 months ago
by
ANDREY Paul
Browse files
Options
Downloads
Patches
Plain Diff
Extend unit tests for 'FederatedServer'.
parent
25c7b773
No related branches found
Branches containing commit
No related tags found
Tags containing commit
1 merge request
!70
Finalize version 2.6.0
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
test/main/test_main_server.py
+282
-15
282 additions, 15 deletions
test/main/test_main_server.py
with
282 additions
and
15 deletions
test/main/test_main_server.py
+
282
−
15
View file @
580ce463
...
...
@@ -18,9 +18,10 @@
"""
Unit tests for
'
FederatedServer
'
.
"""
import
logging
import
math
import
os
from
unittest
import
mock
from
typing
import
Optional
,
Type
from
typing
import
Dict
,
List
,
Optional
,
Type
import
pytest
# type: ignore
...
...
@@ -31,8 +32,10 @@ from declearn.fairness.api import FairnessControllerServer
from
declearn.main
import
FederatedServer
from
declearn.main.config
import
(
FLOptimConfig
,
FLRunConfig
,
EvaluateConfig
,
FairnessConfig
,
RegisterConfig
,
TrainingConfig
,
)
from
declearn.main.utils
import
Checkpointer
...
...
@@ -41,8 +44,15 @@ from declearn.messaging import (
EvaluationReply
,
EvaluationRequest
,
FairnessQuery
,
InitReply
,
InitRequest
,
Message
,
MetadataQuery
,
MetadataReply
,
PrivacyReply
,
PrivacyRequest
,
SerializedMessage
,
StopTraining
,
TrainRequest
,
TrainReply
,
)
...
...
@@ -372,8 +382,8 @@ class TestFederatedServerInit: # pylint: disable=too-many-public-methods
class
TestFederatedServerRoutines
:
"""
Unit tests for
'
FederatedServer
'
main unitary routines.
"""
async
def
setup_server
(
self
,
@staticmethod
async
def
setup_test_server
(
use_secagg
:
bool
=
False
,
use_fairness
:
bool
=
False
,
)
->
FederatedServer
:
...
...
@@ -391,11 +401,10 @@ class TestFederatedServerRoutines:
else
None
),
)
secagg
=
(
mock
.
create_autospec
(
SecaggConfigServer
,
instance
=
True
)
if
use_secagg
else
None
)
secagg
=
None
# type: Optional[SecaggConfigServer]
if
use_secagg
:
secagg
=
mock
.
create_autospec
(
SecaggConfigServer
,
instance
=
True
)
secagg
.
secagg_type
=
"
mock_secagg
"
# type: ignore
return
FederatedServer
(
model
=
mock
.
create_autospec
(
Model
,
instance
=
True
),
netwk
=
netwk
,
...
...
@@ -409,7 +418,7 @@ class TestFederatedServerRoutines:
def
setup_mock_serialized_message
(
msg_cls
:
Type
[
Message
],
wrapped
:
Optional
[
Message
]
=
None
,
)
->
mock
.
Base
:
)
->
mock
.
NonCallableMagicMock
:
"""
Set up a mock SerializedMessage with given wrapped message type.
"""
message
=
mock
.
create_autospec
(
SerializedMessage
,
instance
=
True
)
message
.
message_cls
=
msg_cls
...
...
@@ -418,6 +427,120 @@ class TestFederatedServerRoutines:
message
.
deserialize
.
return_value
=
wrapped
return
message
@pytest.mark.parametrize
(
"
metadata
"
,
[
False
,
True
],
ids
=
[
"
nometa
"
,
"
metadata
"
]
)
@pytest.mark.parametrize
(
"
privacy
"
,
[
False
,
True
],
ids
=
[
"
nodp
"
,
"
dpsgd
"
])
@pytest.mark.parametrize
(
"
fairness
"
,
[
False
,
True
],
ids
=
[
"
unfair
"
,
"
fairness
"
]
)
@pytest.mark.parametrize
(
"
secagg
"
,
[
False
,
True
],
ids
=
[
"
clrtxt
"
,
"
secagg
"
])
@pytest.mark.asyncio
async
def
test_initialization
(
self
,
secagg
:
bool
,
fairness
:
bool
,
privacy
:
bool
,
metadata
:
bool
,
)
->
None
:
"""
Test that the
'
initialization
'
routine triggers expected calls.
"""
# Set up a server with mocked attributes.
server
=
await
self
.
setup_test_server
(
use_secagg
=
secagg
,
use_fairness
=
fairness
)
assert
isinstance
(
server
.
netwk
,
mock
.
NonCallableMagicMock
)
assert
isinstance
(
server
.
model
,
mock
.
NonCallableMagicMock
)
server
.
model
.
required_data_info
=
{
"
n_samples
"
}
if
metadata
else
{}
aggrg
=
server
.
aggrg
# Run the initialization routine.
config
=
FLRunConfig
.
from_params
(
rounds
=
10
,
register
=
RegisterConfig
(
0
,
2
,
120
),
training
=
{
"
batch_size
"
:
8
},
privacy
=
(
{
"
budget
"
:
(
1e-3
,
0.0
),
"
sclip_norm
"
:
1.0
}
if
privacy
else
None
),
)
server
.
netwk
.
wait_for_messages
.
side_effect
=
self
.
_setup_init_replies
(
metadata
,
privacy
)
await
server
.
initialization
(
config
)
# Verify that the clients-registration routine was called.
server
.
netwk
.
wait_for_clients
.
assert_awaited_once_with
(
0
,
2
,
120
)
# Verify that the expected number of message exchanges occured.
assert
server
.
netwk
.
broadcast_message
.
await_count
==
(
1
+
metadata
+
privacy
)
queries
=
server
.
netwk
.
broadcast_message
.
await_args_list
.
copy
()
# When configured, verify that metadata were queried and used.
if
metadata
:
query
=
queries
.
pop
(
0
)[
0
][
0
]
assert
isinstance
(
query
,
MetadataQuery
)
assert
query
.
fields
==
[
"
n_samples
"
]
server
.
model
.
initialize
.
assert_called_once_with
({
"
n_samples
"
:
200
})
# Verify that an InitRequest was sent with expected parameters.
query
=
queries
.
pop
(
0
)[
0
][
0
]
assert
isinstance
(
query
,
InitRequest
)
assert
query
.
dpsgd
is
privacy
if
secagg
:
assert
query
.
secagg
is
not
None
else
:
assert
query
.
secagg
is
None
assert
query
.
fairness
is
fairness
# Verify that DP-SGD setup occurred when expected.
if
privacy
:
query
=
queries
.
pop
(
0
)[
0
][
0
]
assert
isinstance
(
query
,
PrivacyRequest
)
assert
query
.
budget
==
(
1e-3
,
0.0
)
assert
query
.
sclip_norm
==
1.0
assert
query
.
rounds
==
10
# Verify that SecAgg setup occurred when expected.
decrypter
=
None
# type: Optional[Decrypter]
if
secagg
:
assert
isinstance
(
server
.
secagg
,
mock
.
NonCallableMagicMock
)
if
fairness
:
server
.
secagg
.
setup_decrypter
.
assert_awaited_once
()
decrypter
=
server
.
secagg
.
setup_decrypter
.
return_value
else
:
server
.
secagg
.
setup_decrypter
.
assert_not_called
()
# Verify that fairness setup occurred when expected.
if
fairness
:
assert
isinstance
(
server
.
fairness
,
mock
.
NonCallableMagicMock
)
server
.
fairness
.
setup_fairness
.
assert_awaited_once_with
(
netwk
=
server
.
netwk
,
aggregator
=
aggrg
,
secagg
=
decrypter
)
assert
server
.
aggrg
is
server
.
fairness
.
setup_fairness
.
return_value
def
_setup_init_replies
(
self
,
metadata
:
bool
,
privacy
:
bool
,
)
->
List
[
Dict
[
str
,
mock
.
NonCallableMagicMock
]]:
clients
=
(
"
client_a
"
,
"
client_b
"
)
messages
=
[]
# type: List[Dict[str, mock.NonCallableMagicMock]]
if
metadata
:
msg
=
MetadataReply
({
"
n_samples
"
:
100
})
messages
.
append
(
{
key
:
self
.
setup_mock_serialized_message
(
MetadataReply
,
msg
)
for
key
in
clients
}
)
messages
.
append
(
{
key
:
self
.
setup_mock_serialized_message
(
InitReply
)
for
key
in
clients
}
)
if
privacy
:
messages
.
append
(
{
key
:
self
.
setup_mock_serialized_message
(
PrivacyReply
)
for
key
in
clients
}
)
return
messages
@pytest.mark.parametrize
(
"
secagg
"
,
[
False
,
True
],
ids
=
[
"
clrtxt
"
,
"
secagg
"
])
@pytest.mark.asyncio
async
def
test_training_round
(
...
...
@@ -426,7 +549,7 @@ class TestFederatedServerRoutines:
)
->
None
:
"""
Test that the
'
training_round
'
routine triggers expected calls.
"""
# Set up a server with mocked attributes.
server
=
await
self
.
setup_server
(
use_secagg
=
secagg
)
server
=
await
self
.
setup_
test_
server
(
use_secagg
=
secagg
)
assert
isinstance
(
server
.
netwk
,
mock
.
NonCallableMagicMock
)
assert
isinstance
(
server
.
model
,
mock
.
NonCallableMagicMock
)
assert
isinstance
(
server
.
optim
,
mock
.
NonCallableMagicMock
)
...
...
@@ -484,7 +607,7 @@ class TestFederatedServerRoutines:
)
->
None
:
"""
Test that the
'
evaluation_round
'
routine triggers expected calls.
"""
# Set up a server with mocked attributes.
server
=
await
self
.
setup_server
(
use_secagg
=
secagg
)
server
=
await
self
.
setup_
test_
server
(
use_secagg
=
secagg
)
assert
isinstance
(
server
.
netwk
,
mock
.
NonCallableMagicMock
)
assert
isinstance
(
server
.
model
,
mock
.
NonCallableMagicMock
)
assert
isinstance
(
server
.
metrics
,
mock
.
NonCallableMagicMock
)
...
...
@@ -546,7 +669,7 @@ class TestFederatedServerRoutines:
)
->
None
:
"""
Test that
'
evaluation_round
'
skips rounds when configured.
"""
# Set up a server with mocked attributes.
server
=
await
self
.
setup_server
()
server
=
await
self
.
setup_
test_
server
()
assert
isinstance
(
server
.
netwk
,
mock
.
NonCallableMagicMock
)
# Mock a call that should result in skipping the round.
await
server
.
evaluation_round
(
...
...
@@ -566,7 +689,9 @@ class TestFederatedServerRoutines:
)
->
None
:
"""
Test that the
'
fairness_round
'
routine triggers expected calls.
"""
# Set up a server with mocked attributes.
server
=
await
self
.
setup_server
(
use_secagg
=
secagg
,
use_fairness
=
True
)
server
=
await
self
.
setup_test_server
(
use_secagg
=
secagg
,
use_fairness
=
True
)
assert
isinstance
(
server
.
netwk
,
mock
.
NonCallableMagicMock
)
assert
isinstance
(
server
.
model
,
mock
.
NonCallableMagicMock
)
assert
isinstance
(
server
.
fairness
,
mock
.
NonCallableMagicMock
)
...
...
@@ -609,7 +734,7 @@ class TestFederatedServerRoutines:
)
->
None
:
"""
Test that
'
fairness_round
'
early-exits when fairness is not set.
"""
# Set up a server with mocked attributes and no fairness controller.
server
=
await
self
.
setup_server
(
use_fairness
=
False
)
server
=
await
self
.
setup_
test_
server
(
use_fairness
=
False
)
assert
isinstance
(
server
.
netwk
,
mock
.
NonCallableMagicMock
)
assert
server
.
fairness
is
None
# Call the fairness round routine.1
...
...
@@ -628,7 +753,7 @@ class TestFederatedServerRoutines:
)
->
None
:
"""
Test that
'
fairness_round
'
skips rounds when configured.
"""
# Set up a server with a mocked fairness controller.
server
=
await
self
.
setup_server
(
use_fairness
=
True
)
server
=
await
self
.
setup_
test_
server
(
use_fairness
=
True
)
assert
isinstance
(
server
.
fairness
,
mock
.
NonCallableMagicMock
)
# Mock a call that should result in skipping the round.
await
server
.
fairness_round
(
...
...
@@ -637,3 +762,145 @@ class TestFederatedServerRoutines:
)
# Assert that the round was skipped.
server
.
fairness
.
run_fairness_round
.
assert_not_called
()
@pytest.mark.asyncio
async
def
test_stop_training
(
self
,
)
->
None
:
"""
Test that
'
stop_training
'
triggers expected actions.
"""
# Set up a server with mocked attributes.
server
=
await
self
.
setup_test_server
()
assert
isinstance
(
server
.
netwk
,
mock
.
NonCallableMagicMock
)
assert
isinstance
(
server
.
model
,
mock
.
NonCallableMagicMock
)
assert
isinstance
(
server
.
ckptr
,
mock
.
NonCallableMagicMock
)
server
.
ckptr
.
folder
=
"
mock_folder
"
# Call the 'stop_training' routine.
await
server
.
stop_training
(
rounds
=
5
)
# Verify that the expected message was broadcasted.
server
.
netwk
.
broadcast_message
.
assert_awaited_once
()
message
=
server
.
netwk
.
broadcast_message
.
await_args
[
0
][
0
]
assert
isinstance
(
message
,
StopTraining
)
assert
message
.
weights
is
server
.
model
.
get_weights
.
return_value
assert
math
.
isnan
(
message
.
loss
)
assert
message
.
rounds
==
5
# Verify that the expected checkpointing occured.
server
.
ckptr
.
save_model
.
assert_called_once_with
(
server
.
model
,
timestamp
=
"
best
"
)
class
TestFederatedServerRun
:
"""
Unit tests for
'
FederatedServer.run
'
and
'
async_run
'
routines.
"""
# Unit tests for FLRunConfig parsing via synchronous 'run' method.
def
test_run_from_dict
(
self
,
)
->
None
:
"""
Test that
'
run
'
properly parses input dict config.
Mock the actual underlying routine.
"""
server
=
FederatedServer
(
model
=
MOCK_MODEL
,
netwk
=
MOCK_NETWK
,
optim
=
MOCK_OPTIM
)
config
=
mock
.
create_autospec
(
dict
,
instance
=
True
)
with
mock
.
patch
.
object
(
FLRunConfig
,
"
from_params
"
,
return_value
=
mock
.
create_autospec
(
FLRunConfig
,
instance
=
True
),
)
as
patch_flrunconfig_from_params
:
with
mock
.
patch
.
object
(
server
,
"
async_run
"
)
as
patch_async_run
:
server
.
run
(
config
)
patch_flrunconfig_from_params
.
assert_called_once_with
(
**
config
)
patch_async_run
.
assert_called_once_with
(
patch_flrunconfig_from_params
.
return_value
)
def
test_run_from_toml
(
self
,
)
->
None
:
"""
Test that
'
run
'
properly parses input TOML file.
Mock the actual underlying routine.
"""
server
=
FederatedServer
(
model
=
MOCK_MODEL
,
netwk
=
MOCK_NETWK
,
optim
=
MOCK_OPTIM
)
config
=
"
mock_path.toml
"
with
mock
.
patch
.
object
(
FLRunConfig
,
"
from_toml
"
,
return_value
=
mock
.
create_autospec
(
FLRunConfig
,
instance
=
True
),
)
as
patch_flrunconfig_from_toml
:
with
mock
.
patch
.
object
(
server
,
"
async_run
"
)
as
patch_async_run
:
server
.
run
(
config
)
patch_flrunconfig_from_toml
.
assert_called_once_with
(
config
)
patch_async_run
.
assert_called_once_with
(
patch_flrunconfig_from_toml
.
return_value
)
def
test_run_from_config
(
self
,
)
->
None
:
"""
Test that
'
run
'
properly uses input FLRunConfig.
Mock the actual underlying routine.
"""
server
=
FederatedServer
(
model
=
MOCK_MODEL
,
netwk
=
MOCK_NETWK
,
optim
=
MOCK_OPTIM
)
config
=
mock
.
create_autospec
(
FLRunConfig
,
instance
=
True
)
with
mock
.
patch
.
object
(
server
,
"
async_run
"
)
as
patch_async_run
:
server
.
run
(
config
)
patch_async_run
.
assert_called_once_with
(
config
)
# Unit tests for overall actions sequence in 'async_run'.
@pytest.mark.asyncio
async
def
test_async_run_actions_sequence
(
self
)
->
None
:
"""
Test that
'
async_run
'
triggers expected routines.
"""
# Setup a server and a run config with mock attributes.
server
=
FederatedServer
(
model
=
MOCK_MODEL
,
netwk
=
MOCK_NETWK
,
optim
=
MOCK_OPTIM
,
checkpoint
=
mock
.
create_autospec
(
Checkpointer
,
instance
=
True
),
)
config
=
FLRunConfig
(
rounds
=
10
,
register
=
mock
.
create_autospec
(
RegisterConfig
,
instance
=
True
),
training
=
mock
.
create_autospec
(
TrainingConfig
,
instance
=
True
),
evaluate
=
mock
.
create_autospec
(
EvaluateConfig
,
instance
=
True
),
fairness
=
mock
.
create_autospec
(
FairnessConfig
,
instance
=
True
),
privacy
=
None
,
early_stop
=
None
,
)
# Call 'async_run', mocking all underlying routines.
with
mock
.
patch
.
object
(
server
,
"
initialization
"
)
as
patch_initialization
:
with
mock
.
patch
.
object
(
server
,
"
training_round
"
)
as
patch_training
:
with
mock
.
patch
.
object
(
server
,
"
evaluation_round
"
)
as
patch_evaluation
:
with
mock
.
patch
.
object
(
server
,
"
fairness_round
"
)
as
patch_fairness
:
with
mock
.
patch
.
object
(
server
,
"
stop_training
"
)
as
patch_stop_training
:
await
server
.
async_run
(
config
)
# Verify that expected calls occured.
patch_initialization
.
assert_called_once_with
(
config
)
patch_training
.
assert_has_calls
(
[
mock
.
call
(
idx
,
config
.
training
)
for
idx
in
range
(
1
,
11
)]
)
patch_evaluation
.
assert_has_calls
(
[
mock
.
call
(
idx
,
config
.
evaluate
)
for
idx
in
range
(
1
,
11
)]
)
patch_fairness
.
assert_has_calls
(
[
mock
.
call
(
idx
,
config
.
fairness
)
for
idx
in
range
(
0
,
10
)]
+
[
mock
.
call
(
10
,
config
.
fairness
,
force_run
=
True
)]
)
patch_stop_training
.
assert_called_once_with
(
10
)
This diff is collapsed.
Click to expand it.
Preview
0%
Loading
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Save comment
Cancel
Please
register
or
sign in
to comment