Mentions légales du service

Skip to content
Snippets Groups Projects
Commit e85d1184 authored by ANDREY Paul's avatar ANDREY Paul
Browse files

Simplify server-side hyperparameters' collection into messages.

parent f814d6c7
No related branches found
No related tags found
1 merge request!14Refactor `FederatedServer` and `FederatedClient` code.
...@@ -349,10 +349,7 @@ class FederatedServer: ...@@ -349,10 +349,7 @@ class FederatedServer:
params = { params = {
"round_i": round_i, "round_i": round_i,
"weights": self.model.get_weights(), "weights": self.model.get_weights(),
"batches": train_cfg.batch_cfg, **train_cfg.message_params,
"n_epoch": train_cfg.n_epoch,
"n_steps": train_cfg.n_steps,
"timeout": train_cfg.timeout,
} # type: Dict[str, Any] } # type: Dict[str, Any]
messages = {} # type: Dict[str, messaging.Message] messages = {} # type: Dict[str, messaging.Message]
# Dispatch auxiliary variables (which may be client-specific). # Dispatch auxiliary variables (which may be client-specific).
...@@ -526,9 +523,7 @@ class FederatedServer: ...@@ -526,9 +523,7 @@ class FederatedServer:
message = messaging.EvaluationRequest( message = messaging.EvaluationRequest(
round_i=round_i, round_i=round_i,
weights=self.model.get_weights(), weights=self.model.get_weights(),
batches=valid_cfg.batch_cfg, **valid_cfg.message_params,
n_steps=valid_cfg.n_steps,
timeout=valid_cfg.timeout,
) )
await self.netwk.broadcast_message(message, clients) await self.netwk.broadcast_message(message, clients)
......
...@@ -92,6 +92,16 @@ class TrainingConfig: ...@@ -92,6 +92,16 @@ class TrainingConfig:
"drop_remainder": self.drop_remainder, "drop_remainder": self.drop_remainder,
} }
@property
def message_params(self) -> Dict[str, Any]:
"""TrainRequest message parameters from this config."""
return {
"batches": self.batch_cfg,
"n_epoch": self.n_epoch,
"n_steps": self.n_steps,
"timeout": self.timeout,
}
@dataclasses.dataclass @dataclasses.dataclass
class EvaluateConfig(TrainingConfig): class EvaluateConfig(TrainingConfig):
...@@ -103,3 +113,10 @@ class EvaluateConfig(TrainingConfig): ...@@ -103,3 +113,10 @@ class EvaluateConfig(TrainingConfig):
""" """
drop_remainder: bool = False drop_remainder: bool = False
@property
def message_params(self) -> Dict[str, Any]:
"""ValidRequest message parameters from this config."""
params = super().message_params
params.pop("n_epoch")
return params
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment