Federated Learning
Train machine learning models across distributed data sources without centralizing sensitive data. Covers federated averaging, privacy-preserving computation, communication efficiency, heterogeneous data handling, and the patterns that make federated learning practical.
Federated learning trains models across multiple devices or organizations without sharing raw data. Instead of sending data to a central server, each participant trains locally and sends only model updates (gradients). This enables ML on data that cannot leave its source — healthcare records, financial data, mobile device data.
How Federated Learning Works
Traditional ML:
Device 1 data ─┐
Device 2 data ──┼──→ Central Server → Train Model
Device 3 data ─┘
Problem: Raw data must leave device (privacy risk)
Federated Learning:
Device 1: Train locally → Send gradients ─┐
Device 2: Train locally → Send gradients ──┼──→ Aggregate → Global Model
Device 3: Train locally → Send gradients ─┘
Raw data never leaves device
Only model updates transmitted
Federated Averaging (FedAvg)
# Server-side aggregation
class FederatedServer:
def __init__(self, model):
self.global_model = model
def aggregate_round(self, client_updates: list):
"""Average client model updates weighted by data size."""
total_samples = sum(u["num_samples"] for u in client_updates)
aggregated_weights = {}
for layer_name in self.global_model.state_dict():
aggregated_weights[layer_name] = sum(
u["weights"][layer_name] * (u["num_samples"] / total_samples)
for u in client_updates
)
self.global_model.load_state_dict(aggregated_weights)
return self.global_model
# Client-side training
class FederatedClient:
def __init__(self, local_data, model):
self.data = local_data
self.model = model
def local_train(self, global_weights, epochs=5, lr=0.01):
"""Train on local data, return updated weights."""
self.model.load_state_dict(global_weights)
optimizer = torch.optim.SGD(self.model.parameters(), lr=lr)
for epoch in range(epochs):
for batch in self.data:
loss = self.model.compute_loss(batch)
loss.backward()
optimizer.step()
optimizer.zero_grad()
return {
"weights": self.model.state_dict(),
"num_samples": len(self.data)
}
Privacy Enhancements
Differential Privacy:
Add calibrated noise to gradients before sending
Guarantees: Cannot determine if any individual's data was used
Trade-off: More noise = more privacy = less accuracy
Secure Aggregation:
Encrypt individual updates, server only sees aggregate
No single client's update visible to server
Homomorphic Encryption:
Compute on encrypted data
Server aggregates without decrypting individual updates
Use Cases
| Domain | Data Source | Why Federated |
|---|---|---|
| Healthcare | Hospital patient records | HIPAA, data cannot leave hospital |
| Finance | Bank transaction data | Regulatory, competitive sensitivity |
| Mobile | Device usage patterns | Privacy, bandwidth constraints |
| IoT | Sensor data | Volume, connectivity constraints |
| Cross-org | Competing companies sharing insights | Anti-trust, IP protection |
Anti-Patterns
| Anti-Pattern | Consequence | Fix |
|---|---|---|
| No differential privacy | Model memorizes individual data | Add DP noise to gradients |
| Ignoring data heterogeneity | Model biased toward largest participants | Weighted aggregation, FedProx |
| Too many communication rounds | Bandwidth cost, latency | Local epochs, compression |
| No Byzantine fault tolerance | Malicious client poisons model | Robust aggregation (median, trimmed mean) |
| Homogeneous model assumption | Devices have different capabilities | Model heterogeneity support |
Federated learning is not just an ML technique — it is a paradigm shift in how we think about data access. The data stays where it is. The model comes to the data.