federated-learning-dqn
Compare original and translation side by side
🇺🇸
Original
English🇨🇳
Translation
ChineseFederated Learning + DQN
Federated Learning + DQN
Privacy-preserving distributed reinforcement learning for healthcare scheduling.
面向医疗调度的隐私保护分布式强化学习方案。
When to Use
适用场景
- Multi-institution ML without sharing raw data
- Healthcare applications with privacy requirements
- Distributed optimization across organizations
- 多机构机器学习,无需共享原始数据
- 有隐私要求的医疗应用
- 跨组织的分布式优化
Architecture Overview
架构概述
┌─────────────┐ ┌─────────────┐ ┌─────────────┐
│ Hospital A │ │ Hospital B │ │ Hospital C │
│ Local DQN │ │ Local DQN │ │ Local DQN │
└──────┬──────┘ └──────┬──────┘ └──────┬──────┘
│ │ │
└───────────────────┼───────────────────┘
│
┌──────▼──────┐
│ Aggregator │
│ (Server) │
└─────────────┘┌─────────────┐ ┌─────────────┐ ┌─────────────┐
│ Hospital A │ │ Hospital B │ │ Hospital C │
│ Local DQN │ │ Local DQN │ │ Local DQN │
└──────┬──────┘ └──────┬──────┘ └──────┬──────┘
│ │ │
└───────────────────┼───────────────────┘
│
┌──────▼──────┐
│ Aggregator │
│ (Server) │
└─────────────┘Components
组件
Federated Learning
Federated Learning
FedAvg Algorithm:
python
undefinedFedAvg Algorithm:
python
undefinedServer
Server
def federated_averaging(models, weights):
total = sum(weights)
averaged = {}
for key in models[0].state_dict():
averaged[key] = sum(
w * model.state_dict()[key]
for model, w in zip(models, weights)
) / total
return averaged
def federated_averaging(models, weights):
total = sum(weights)
averaged = {}
for key in models[0].state_dict():
averaged[key] = sum(
w * model.state_dict()[key]
for model, w in zip(models, weights)
) / total
return averaged
Round
Round
for round in range(num_rounds):
clients = select_clients()
models, weights = [], []
for client in clients:
model, weight = client.train(local_epochs)
models.append(model)
weights.append(weight)
global_model.load_state_dict(federated_averaging(models, weights))
undefinedfor round in range(num_rounds):
clients = select_clients()
models, weights = [], []
for client in clients:
model, weight = client.train(local_epochs)
models.append(model)
weights.append(weight)
global_model.load_state_dict(federated_averaging(models, weights))
undefinedDeep Q-Network (DQN)
Deep Q-Network (DQN)
Network Architecture:
python
import torch.nn as nn
class DQN(nn.Module):
def __init__(self, state_dim, action_dim):
super().__init__()
self.net = nn.Sequential(
nn.Linear(state_dim, 256),
nn.ReLU(),
nn.Linear(256, 256),
nn.ReLU(),
nn.Linear(256, action_dim)
)
def forward(self, x):
return self.net(x)Training Loop:
python
def train_dqn(agent, replay_buffer, target_net):
for step in range(num_steps):
state = env.reset()
done = False
while not done:
# Epsilon-greedy action
action = agent.select_action(state, epsilon)
next_state, reward, done, _ = env.step(action)
# Store transition
replay_buffer.push(state, action, reward, next_state, done)
# Sample batch
batch = replay_buffer.sample(batch_size)
# Compute loss
q_values = agent(batch.state)
next_q_values = target_net(batch.next_state)
target = batch.reward + gamma * next_q_values.max(1)[0] * (1 - batch.done)
loss = nn.MSELoss()(q_values.gather(1, batch.action), target)
# Update
optimizer.zero_grad()
loss.backward()
optimizer.step()
state = next_state
# Update target network
if step % target_update == 0:
target_net.load_state_dict(agent.state_dict())Network Architecture:
python
import torch.nn as nn
class DQN(nn.Module):
def __init__(self, state_dim, action_dim):
super().__init__()
self.net = nn.Sequential(
nn.Linear(state_dim, 256),
nn.ReLU(),
nn.Linear(256, 256),
nn.ReLU(),
nn.Linear(256, action_dim)
)
def forward(self, x):
return self.net(x)训练循环:
python
def train_dqn(agent, replay_buffer, target_net):
for step in range(num_steps):
state = env.reset()
done = False
while not done:
# Epsilon-greedy action
action = agent.select_action(state, epsilon)
next_state, reward, done, _ = env.step(action)
# Store transition
replay_buffer.push(state, action, reward, next_state, done)
# Sample batch
batch = replay_buffer.sample(batch_size)
# Compute loss
q_values = agent(batch.state)
next_q_values = target_net(batch.next_state)
target = batch.reward + gamma * next_q_values.max(1)[0] * (1 - batch.done)
loss = nn.MSELoss()(q_values.gather(1, batch.action), target)
# Update
optimizer.zero_grad()
loss.backward()
optimizer.step()
state = next_state
# Update target network
if step % target_update == 0:
target_net.load_state_dict(agent.state_dict())Multi-Level Feedback Queue (MLFQ)
Multi-Level Feedback Queue (MLFQ)
Integration with DQN:
python
class MLFQScheduler:
def __init__(self, num_queues=3):
self.queues = [[] for _ in range(num_queues)]
self.priority_boost = 10
def add_patient(self, patient, priority):
queue_idx = min(priority, len(self.queues) - 1)
self.queues[queue_idx].append(patient)
def get_next_patient(self):
# DQN selects which queue to serve
queue_state = self.get_queue_state()
action = dqn_agent.select_action(queue_state)
# Boost priority of waiting patients
self.boost_priorities()
return self.queues[action].pop(0) if self.queues[action] else None
def boost_priorities(self):
for i in range(len(self.queues) - 1, 0, -1):
for patient in self.queues[i]:
if patient.wait_time > self.priority_boost:
self.queues[i-1].append(patient)
self.queues[i].remove(patient)与DQN的集成:
python
class MLFQScheduler:
def __init__(self, num_queues=3):
self.queues = [[] for _ in range(num_queues)]
self.priority_boost = 10
def add_patient(self, patient, priority):
queue_idx = min(priority, len(self.queues) - 1)
self.queues[queue_idx].append(patient)
def get_next_patient(self):
# DQN selects which queue to serve
queue_state = self.get_queue_state()
action = dqn_agent.select_action(queue_state)
# Boost priority of waiting patients
self.boost_priorities()
return self.queues[action].pop(0) if self.queues[action] else None
def boost_priorities(self):
for i in range(len(self.queues) - 1, 0, -1):
for patient in self.queues[i]:
if patient.wait_time > self.priority_boost:
self.queues[i-1].append(patient)
self.queues[i].remove(patient)Privacy Guarantees
隐私保障
Differential Privacy
Differential Privacy
python
def add_dp_noise(gradients, epsilon, delta, sensitivity):
"""Add Gaussian noise for (ε,δ)-differential privacy"""
sigma = sensitivity * np.sqrt(2 * np.log(1.25 / delta)) / epsilon
noise = torch.randn_like(gradients) * sigma
return gradients + noisepython
def add_dp_noise(gradients, epsilon, delta, sensitivity):
"""Add Gaussian noise for (ε,δ)-differential privacy"""
sigma = sensitivity * np.sqrt(2 * np.log(1.25 / delta)) / epsilon
noise = torch.randn_like(gradients) * sigma
return gradients + noiseSecure Aggregation
安全聚合
- Clients encrypt model updates
- Server aggregates without seeing individual updates
- Only decrypted aggregate is visible
- 客户端加密模型更新
- 服务器在不查看单个更新的情况下进行聚合
- 仅可见解密后的聚合结果
Healthcare Scheduling Use Case
医疗调度用例
State Representation
状态表示
python
state = {
'queue_lengths': [len(q) for q in queues], # Shape: (num_queues,)
'patient_acuity': average_acuity_per_queue, # Shape: (num_queues,)
'resource_availability': [beds, staff, equipment],
'time_features': [hour_of_day, day_of_week],
'predicted_arrivals': next_hour_forecast,
}python
state = {
'queue_lengths': [len(q) for q in queues], # Shape: (num_queues,)
'patient_acuity': average_acuity_per_queue, # Shape: (num_queues,)
'resource_availability': [beds, staff, equipment],
'time_features': [hour_of_day, day_of_week],
'predicted_arrivals': next_hour_forecast,
}Action Space
动作空间
python
actions = {
0: 'Schedule from high-priority queue',
1: 'Schedule from medium-priority queue',
2: 'Schedule from low-priority queue',
3: 'Allocate additional resource',
4: 'Request transfer from other hospital',
}python
actions = {
0: 'Schedule from high-priority queue',
1: 'Schedule from medium-priority queue',
2: 'Schedule from low-priority queue',
3: 'Allocate additional resource',
4: 'Request transfer from other hospital',
}Reward Function
奖励函数
python
def calculate_reward(state, action, next_state):
reward = 0
# Minimize wait time (weighted by acuity)
reward -= sum(
patient.wait_time * patient.acuity
for patient in all_patients
)
# Penalize queue imbalance
reward -= variance(queue_lengths) * 10
# Reward completing high-acuity cases
reward += completed_high_acuity * 50
# Penalize resource overutilization
if resource_utilization > threshold:
reward -= overutilization_penalty
return rewardpython
def calculate_reward(state, action, next_state):
reward = 0
# Minimize wait time (weighted by acuity)
reward -= sum(
patient.wait_time * patient.acuity
for patient in all_patients
)
# Penalize queue imbalance
reward -= variance(queue_lengths) * 10
# Reward completing high-acuity cases
reward += completed_high_acuity * 50
# Penalize resource overutilization
if resource_utilization > threshold:
reward -= overutilization_penalty
return rewardImplementation Considerations
实现注意事项
Communication Efficiency
通信效率
- Compression: Quantize model updates
- Federated Dropout: Train smaller subnetworks
- Asynchronous Updates: No synchronization barrier
- 压缩: 量化模型更新
- Federated Dropout: 训练更小的子网络
- 异步更新: 无需同步屏障
Handling Non-IID Data
非IID数据处理
- Personalization: Fine-tune global model locally
- Clustered FL: Group similar hospitals
- Multi-task Learning: Shared representation + task-specific heads
- 个性化: 在本地微调全局模型
- Clustered FL: 将相似医院分组
- 多任务学习: 共享表示+任务特定头部
System Heterogeneity
系统异构性
- Straggler Handling: Async aggregation or timeout
- Variable Resources: Adaptive local epochs
- Device Selection: Probabilistic client sampling
- 掉队节点处理: 异步聚合或超时机制
- 资源可变: 自适应本地训练轮次
- 设备选择: 概率性客户端采样
Evaluation Metrics
评估指标
| Metric | Description |
|---|---|
| Privacy Budget (ε) | Differential privacy guarantee |
| Model Accuracy | Comparison to centralized training |
| Communication Rounds | Convergence speed |
| Patient Wait Time | Scheduling effectiveness |
| Resource Utilization | System efficiency |
| 指标 | 描述 |
|---|---|
| 隐私预算(ε) | 差分隐私保障 |
| 模型准确率 | 与集中式训练的对比 |
| 通信轮次 | 收敛速度 |
| 患者等待时间 | 调度有效性 |
| 资源利用率 | 系统效率 |