cpp-reinforcement-learning

Compare original and translation side by side

🇺🇸

Original

English
🇨🇳

Translation

Chinese

C++ Reinforcement Learning

C++强化学习

Overview

概述

This skill covers implementing reinforcement learning algorithms in C++ using LibTorch (PyTorch C++ frontend) and modern C++17/20 features. It provides patterns for building high-performance RL systems suitable for production deployment, robotics, game AI, and real-time applications.
本技能涵盖如何使用LibTorch(PyTorch C++前端)和现代C++17/20特性在C++中实现强化学习算法。它提供了构建适用于生产部署、机器人、游戏AI和实时应用的高性能RL系统的模式。

When to Use

适用场景

  • Implementing DQN, PPO, SAC, or other RL algorithms in C++
  • Building performance-critical RL training pipelines
  • Creating efficient replay buffers with proper memory management
  • Deploying trained models with ONNX Runtime
  • Parallelizing environment rollouts across threads
  • Integrating RL with existing C++ codebases (games, robotics, simulations)
  • 在C++中实现DQN、PPO、SAC或其他RL算法
  • 构建性能敏感型RL训练流水线
  • 创建具备合理内存管理的高效重放缓冲区
  • 使用ONNX Runtime部署训练完成的模型
  • 跨线程并行化环境rollout
  • 将RL与现有C++代码库(游戏、机器人、仿真系统)集成

Core Libraries

核心库

Primary: LibTorch (PyTorch C++ Frontend)

主库:LibTorch(PyTorch C++前端)

LibTorch provides the same tensor operations and autograd capabilities as PyTorch in C++.
Installation: Download from https://pytorch.org/get-started/locally (select C++/LibTorch)
CMake Integration:
cmake
cmake_minimum_required(VERSION 3.18)
project(rl_project)

set(CMAKE_CXX_STANDARD 17)
find_package(Torch REQUIRED)

add_executable(train_agent src/main.cpp)
target_link_libraries(train_agent "${TORCH_LIBRARIES}")
LibTorch在C++中提供了与PyTorch相同的张量操作和自动微分能力。
CMake集成:
cmake
cmake_minimum_required(VERSION 3.18)
project(rl_project)

set(CMAKE_CXX_STANDARD 17)
find_package(Torch REQUIRED)

add_executable(train_agent src/main.cpp)
target_link_libraries(train_agent "${TORCH_LIBRARIES}")

Secondary Libraries

辅助库

  • ONNX Runtime - Cross-platform inference deployment
  • cpprl (mhubii/cpprl) - Reference PPO implementation
  • Gymnasium C++ bindings - Environment interfaces
  • ONNX Runtime - 跨平台推理部署
  • cpprl (mhubii/cpprl) - 参考PPO实现
  • Gymnasium C++绑定 - 环境接口

Quick Start: DQN Agent

快速入门:DQN Agent

cpp
#include <torch/torch.h>

struct DQNNet : torch::nn::Module {
    torch::nn::Linear fc1{nullptr}, fc2{nullptr}, fc3{nullptr};

    DQNNet(int64_t state_dim, int64_t action_dim) {
        fc1 = register_module("fc1", torch::nn::Linear(state_dim, 128));
        fc2 = register_module("fc2", torch::nn::Linear(128, 128));
        fc3 = register_module("fc3", torch::nn::Linear(128, action_dim));
    }

    torch::Tensor forward(torch::Tensor x) {
        x = torch::relu(fc1->forward(x));
        x = torch::relu(fc2->forward(x));
        return fc3->forward(x);
    }
};

// Training loop
auto policy_net = std::make_shared<DQNNet>(state_dim, action_dim);
auto target_net = std::make_shared<DQNNet>(state_dim, action_dim);
torch::optim::Adam optimizer(policy_net->parameters(), lr);

// Compute loss
auto q_values = policy_net->forward(states).gather(1, actions);
auto next_q = target_net->forward(next_states).max(1).values.detach();
auto target = rewards + gamma * next_q * (1 - dones);
auto loss = torch::mse_loss(q_values.squeeze(), target);

// Backward pass
optimizer.zero_grad();
loss.backward();
optimizer.step();
cpp
#include <torch/torch.h>

struct DQNNet : torch::nn::Module {
    torch::nn::Linear fc1{nullptr}, fc2{nullptr}, fc3{nullptr};

    DQNNet(int64_t state_dim, int64_t action_dim) {
        fc1 = register_module("fc1", torch::nn::Linear(state_dim, 128));
        fc2 = register_module("fc2", torch::nn::Linear(128, 128));
        fc3 = register_module("fc3", torch::nn::Linear(128, action_dim));
    }

    torch::Tensor forward(torch::Tensor x) {
        x = torch::relu(fc1->forward(x));
        x = torch::relu(fc2->forward(x));
        return fc3->forward(x);
    }
};

// Training loop
auto policy_net = std::make_shared<DQNNet>(state_dim, action_dim);
auto target_net = std::make_shared<DQNNet>(state_dim, action_dim);
torch::optim::Adam optimizer(policy_net->parameters(), lr);

// Compute loss
auto q_values = policy_net->forward(states).gather(1, actions);
auto next_q = target_net->forward(next_states).max(1).values.detach();
auto target = rewards + gamma * next_q * (1 - dones);
auto loss = torch::mse_loss(q_values.squeeze(), target);

// Backward pass
optimizer.zero_grad();
loss.backward();
optimizer.step();

Essential Patterns

核心模式

Replay Buffer (Ring Buffer)

重放缓冲区(环形缓冲区)

cpp
class ReplayBuffer {
public:
    explicit ReplayBuffer(size_t capacity)
        : capacity_(capacity), position_(0), size_(0) {
        buffer_.reserve(capacity);
    }

    void push(Experience exp) {
        if (buffer_.size() < capacity_) {
            buffer_.push_back(std::move(exp));
        } else {
            buffer_[position_] = std::move(exp);
        }
        position_ = (position_ + 1) % capacity_;
        size_ = std::min(size_ + 1, capacity_);
    }

    std::vector<Experience> sample(size_t batch_size);

private:
    std::vector<Experience> buffer_;
    size_t capacity_, position_, size_;
    std::mt19937 rng_{std::random_device{}()};
};
cpp
class ReplayBuffer {
public:
    explicit ReplayBuffer(size_t capacity)
        : capacity_(capacity), position_(0), size_(0) {
        buffer_.reserve(capacity);
    }

    void push(Experience exp) {
        if (buffer_.size() < capacity_) {
            buffer_.push_back(std::move(exp));
        } else {
            buffer_[position_] = std::move(exp);
        }
        position_ = (position_ + 1) % capacity_;
        size_ = std::min(size_ + 1, capacity_);
    }

    std::vector<Experience> sample(size_t batch_size);

private:
    std::vector<Experience> buffer_;
    size_t capacity_, position_, size_;
    std::mt19937 rng_{std::random_device{}()};
};

GPU Device Management

GPU设备管理

cpp
torch::Device device = torch::cuda::is_available() ? torch::kCUDA : torch::kCPU;
model->to(device);

// Create tensors on device
auto tensor = torch::zeros({batch_size, state_dim},
    torch::TensorOptions().device(device).dtype(torch::kFloat32));
cpp
torch::Device device = torch::cuda::is_available() ? torch::kCUDA : torch::kCPU;
model->to(device);

// Create tensors on device
auto tensor = torch::zeros({batch_size, state_dim},
    torch::TensorOptions().device(device).dtype(torch::kFloat32));

Inference Mode

推理模式

cpp
{
    torch::NoGradGuard no_grad;
    auto action_values = model->forward(state);
    auto action = action_values.argmax(1);
}
cpp
{
    torch::NoGradGuard no_grad;
    auto action_values = model->forward(state);
    auto action = action_values.argmax(1);
}

Common Pitfalls

常见陷阱

  1. Forgetting train/eval mode - Call
    model->train()
    or
    model->eval()
  2. Missing NoGradGuard - Use for inference to save memory
  3. Tensor accumulation - Use
    .detach()
    for stored tensors
  4. Thread safety - Clone models for parallel threads
  5. Device mismatch - Verify all tensors on same device
  1. 忘记切换train/eval模式 - 调用
    model->train()
    model->eval()
  2. 缺少NoGradGuard - 推理时使用以节省内存
  3. 张量累积问题 - 对存储的张量使用
    .detach()
  4. 线程安全问题 - 为并行线程克隆模型
  5. 设备不匹配 - 验证所有张量在同一设备上

Reference Files

参考文件

  • references/libtorch.md - LibTorch setup and API guide
  • references/algorithms.md - DQN, PPO, SAC implementations
  • references/memory-management.md - Replay buffers, smart pointers, RAII
  • references/performance.md - Optimization, parallelization, GPU
  • references/testing.md - Testing and debugging strategies
  • references/libtorch.md - LibTorch设置与API指南
  • references/algorithms.md - DQN、PPO、SAC实现
  • references/memory-management.md - 重放缓冲区、智能指针、RAII
  • references/performance.md - 优化、并行化、GPU
  • references/testing.md - 测试与调试策略