Introduce and deploy `Aggregate` API
This MR introduces major (and, to some extent, API-breaking) revisions to the aggregation occurs in DecLearn.
It paves the way towards:
- enabling the use of Secure Aggregation on top of most computations (it notably aims at enabling to bridge DecLearn APIs with Fed-BioMed's SecAgg tools in the backend of the latter framework).
- implementing decentralized learning processes where peers' data may be aggregated across a chain of communicating peers rather than be centralized by an aggregating server.
Major changes
Aggregate
API
New This MR introduces the Aggregate
API, which is an abstract base dataclass
acting as a template for data structures that require sharing across peers and
aggregation.
The declearn.utils.Aggregate
ABC acts as a shared ancestor providing with
a base API and shared backend code to define data structures that:
- are serializable to and deserializable from JSON, and may therefore be preserved across network communications
- are aggregatable into an instance of the same structure
- use summation as the default aggregation rule for fields, which is
overridable by redefining the
default_aggregate
method - can implement custom
aggregate_<field.name>
methods to override the default summation rule - implement a
prepare_for_secagg
method that- enables defining which fields merely require sum-aggregation and need encryption when using SecAgg, and which fields are to be preserved in cleartext (and therefore go through the usual default or custom aggregation methods)
- can be made to raise a
NotImplementedError
when SecAgg cannot be achieved on a data structure
This new ABC currently has three main children:
-
AuxVar
: replaces plain dict forOptimizer
auxiliary variables -
MetricState
: replaces plain dict forMetric
intermediate states -
ModelUpdates
: replaces sharing of updates asVector
andn_steps
Each of this is defined jointly with another (pre-existing, revised) API for
components that (a) produce Aggregate
data structures based on some input
data and/or computations; (b) produce some output results based on a received
Aggregate
structure, meant to result from the aggregation of multiple peers'
produced data.
Aggregator
API
Revised The Aggregator
API was revised to make use of the new ModelUpdates
data
structure (inheriting Aggregate
).
-
Aggregator.prepare_for_sharing
pre-processes an inputVector
containing raw model updates and an integer indicating the number of local SGD steps into aModelUpdates
structure. -
Aggregator.finalize_updates
receives aModelUpdates
resulting from the aggregation of peers' instances, and performs final computations to produce aVector
of aggregated model updates. - The legacy
Aggregator.aggregate
method is deprecated (but still works).
Optimodule
auxiliary variables API
Revised The OptiModule
API (and, consequently, Optimizer
) was revised as to the
design and signature of auxiliary variables related methods, to make use of
the new AuxVar
data structure (inheriting Aggregate
).
-
OptiModule.collect_aux_var
now emits eitherNone
or anAuxVar
instance (the precise type of which is module-dependent), instead of a mere dict. -
OptiModule.process_aux_var
now expects a proper-typeAuxVar
instance that already aggregates clients' data, externalizing the aggregation rules to theAuxVar
subtypes, while keeping the finalization logic part of theOptiModule
subclasses. -
Optimizer.collect_aux_var
therefore emits a{name: aux_var}
dict. -
Optimizer.process_aux_var
therefore expects a{name: aux_var}
dict, rather than having distinct signatures on the client and server sides. - It is now expected that server-side components will send the same data to all clients, rather than allow sending client-wise values.
The backend code of ScaffoldClientModule
and ScaffoldServerModule
was
heavily revised to alter the distribution of information and computations:
- Client-side modules are now the sole owners of their local state, and send sum-aggregatable updates to the server, that are therefore SecAgg-compatible.
- The server consequently shares the same information with all clients, namely the current global state.
- To keep track of the (possibly growing with time) number of unique client, clients generate a random uuid that is sent with their state updates and preserved in cleartext when SecAgg is used.
- As a consequence, the server component knows which clients contributed to a given round, but receives an aggregate of local updates rather than the client-wise state values.
Metric
API
Revised The Metric
API was revised to make use of the new MetricState
data
structure (inheriting Aggregate
).
-
Metric.build_initial_states
generates a "zero-state"MetricState
instance (it replaces the previously-private_build_states
method that returned a dict). -
Metric.get_states
returns a (Metric-type-dependent)MetricState
instance, instead of a mere dict. -
Metric.set_states
assigns an incomingMetricState
into the instance, that may be finalized into results using the unchangedget_result
method. - The legacy
Metric.agg_states
is deprecated, in favor ofset_states
(but it still works).
Vector
(un)flattening API
The Vector
API was extended to enable transforming any instance to and from a
list of float values and a VectorSpec
data structure that provides with metadata
(some of which may be framework-specific). This was done in order to enable using
encryption and secura aggregation of Vector
-wrapped data.
-
VectorSpec
is a dataclass that stores key metadata about aVector
: its coefs names, shapes and dtypes; its type registration info; and optionally a dict with framework-specific information (e.g. indications abouttensorflow.IndexedSlices
). -
Vector.get_vector_specs
returns aVectorSpec
instance matching theVector
. -
Vector.flatten
returns a list of float values and aVectorSpec
from aVector
. -
Vector.unflatten
is a classmethod to rebuild a givenVector
subtype from its flattened values and specs. -
Vector.build_from_specs
is a generic builder to unflatten aVector
without specifying its exact subtype (which is retrieved from the inputVectorSpecs
).
New version policy
This MR introduces some breaking changes as to some components' APIs, that are not very high-level.
- High-level orchestration classes, and components' instantiation or parametrization by end-users, are left unchanged.
- Some key component API methods were deprecated in favor of new ones (e.g.
Aggregator.aggregate
); in that case, whenever possible, the deprecated method was kept functional (and is still being tested), with a warning, as is standard for deprecated features. - The type of information emitted and/or ingested by some methods were changed,
namely to use
Aggregate
structure to wrap up, better specify and enable aggregation of partial quantities being shared across peers. This is notably the case for metric states and optimizer auxiliary variables. However, they remain compatible with the deprecated-but-conserved methods of components. - In conclusion, the main point where issue might arise is for custom components, that would need revision to be compatible with the upcoming new Declearn version. Given our tight-and-close current user-base, this is deemed uncritical, and a tolerable affront to SemVer in an effort to progressively design and deploy changes that pave the way towards a future major version.
From now on, the server and clients are expected and verified to use the same
major.minor
version of DecLearn.
- This restrictive policy will enable conducting further framework revisions without having to worry about (and test for) potential incompatibilities within low-level components (e.g. due to renaming some quantities or changing some computation details).
- This policy should not come in the way of trying to abide by SemVer as to major APIs: users (including developers) should not have to worry about their custom components, high-level or application code breaking unexpectedly with minor versions being introduced. As a demonstrator, the current MR does not break functional tests nor repository examples, in spite of their not having been revised at all.
- This policy may be updated in the future, e.g. to specify that clients may have a newer minor version than the server (and most probably not the other way around).