蒋学刚 發表於 2026-1-29 16:01:00

stable_baseline3 快速入门(一): 训练第一个强化学习模型

<h1 id="简介">简介</h1>
<p><strong>stable_baseline3</strong> 是一个基于 PyTorch 的强化学习算法开源库,里面集成了多种强化学习算法,使用这个开源库能够让我们不需要过度关注强化学习算法细节,专注于AI业务的开发。</p>
<h1 id="环境配置">环境配置</h1>
<pre><code class="language-bash">pip install stable-baselines3
pip install gymnasium
</code></pre>
<p>这里<code>stable-baselines3</code>会默认安装<code>pytroch</code>框架,但是是不带<code>cuda</code>版本的,这就意味着我们无法利用我们的显卡对模型进行训练。<br>
下载<code>cuda</code>版本的<code>pytroch</code>步骤如下:</p>
<ol>
<li>卸载原来版本的<code>pytroch</code>框架</li>
</ol>
<pre><code class="language-bash">pip uninstall torch torchvision torchaudio -y
#这个是针对RTX 30/40/50显卡的。
pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu126
</code></pre>
<p>如果其他版本请参考官网: https://pytorch.org/get-started/locally/</p>
<h1 id="认识stable_baseline3">认识stable_baseline3</h1>
<p><code>stable_baseline3</code>提供了许多模型,如下列表:</p>
<table>
<thead>
<tr>
<th>名称</th>
<th>动作空间</th>
<th>建议应用场景</th>
<th>核心优势</th>
</tr>
</thead>
<tbody>
<tr>
<td><strong>PPO</strong></td>
<td>连续 &amp; 离散</td>
<td><strong>全能选手</strong>,如机器人走动、金融交易、游戏 AI</td>
<td>极其稳定,对超参数不敏感,支持大规模并行训练。</td>
</tr>
<tr>
<td><strong>DQN</strong></td>
<td>仅离散</td>
<td>经典游戏(Atari)、开关控制、迷宫寻路</td>
<td>理解简单,在离散控制领域非常经典且有效。</td>
</tr>
<tr>
<td><strong>SAC</strong></td>
<td>仅连续</td>
<td>复杂物理模拟、机械臂抓取、自动驾驶</td>
<td>探索效率极高,能自动寻找最优路径且不轻易陷入局部最优。</td>
</tr>
<tr>
<td><strong>TD3</strong></td>
<td>仅连续</td>
<td>工业控制、无人机飞行、精密动作</td>
<td>针对 DDPG 的缺陷做了改进,训练过程比 SAC 更平滑。</td>
</tr>
<tr>
<td><strong>A2C</strong></td>
<td>连续 &amp; 离散</td>
<td>简单逻辑测试、快速原型验证</td>
<td>结构简单,虽然不如 PPO 稳定,但在特定并行环境下速度极快。</td>
</tr>
</tbody>
</table>
<p>在<strong>声明模型</strong>中,可以设置多种参数,这里列出常用的:<br>
<mark>目前不需要搞懂都有什么作用,后面有文章会详细讲解</mark></p>
<ol>
<li>训练参数</li>
</ol>
<ul>
<li><code>learning_rate</code>:学习率</li>
<li><code>gamma</code>:折扣因子</li>
<li><code>batch_size</code>:更新模型使用数据量</li>
<li><code>verbose</code>:打印信息模式。0-静默模式,1-信息模式,2-调试模式</li>
<li><code>device</code>:指定训练设备<code>cuda</code>使用显卡,<code>cpu</code>使用cpu</li>
</ul>
<ol start="2">
<li>模型规则</li>
</ol>
<ul>
<li><code>MlpPolicy</code>:多层感知机。适用于状态是数值场景(传感器等)</li>
<li><code>CnnPolicy</code>:卷积神经网络。适用于状态是图像场景(游戏等)</li>
</ul>
<h1 id="训练第一个强化学习模型">训练第一个强化学习模型</h1>
<h3 id="案例">案例</h3>
<p>案例描述:训练一个<code>gymnasium</code>默认提供的游戏环境,平衡杆游戏。</p>
<pre><code class="language-python">import gymnasium as gym
from stable_baselines3 import PPO

env = gym.make("CartPole-v1")

model = PPO("MlpPolicy", env, verbose=1, device="cuda")

print("开始训练...")
model.learn(total_timesteps=10000)

print("正在保存模型...")
model.save("ppo_cartpole")

print("正在读取模型...")
env = gym.make("CartPole-v1", render_mode="human")
loaded_model = PPO.load("ppo_cartpole", env=env)

print("训练结束,开始演示...")
obs, _ = env.reset()
for i in range(1000):
    action, _states = loaded_model.predict(obs, deterministic=True)

    obs, reward, terminated, truncated, info = env.step(action)
   
    if terminated or truncated:
      obs, _ = env.reset()

env.close()
</code></pre>
<h3 id="代码解释">代码解释</h3>
<p>代码流程如下:<br>
<strong>初始化环境模型-&gt;训练模型-&gt;保存模型-&gt;加载模型-&gt;模型预测</strong></p>
<h5 id="初始化环境模型">初始化环境模型</h5>
<p>初始化模型以及游戏的环境</p>
<pre><code class="language-python">env = gym.make("CartPole-v1")
model = PPO("MlpPolicy", env, verbose=1, device="cuda")

env = gym.make("CartPole-v1", render_mode="human")
</code></pre>
<ul>
<li><code>gym</code>中的<code>make</code>方法利用默认的游戏环境,<code>CartPole-v1</code>是游戏名,下面有一个<code>render_mode="human"</code>参数,用于标识是否展示画面。<mark>训练时展示画面会降低训练的速度,一般在预测时才使用</mark></li>
</ul>
<h5 id="训练模型">训练模型</h5>
<pre><code class="language-python">model.learn(total_timesteps=10000)
</code></pre>
<ul>
<li><code>total_timesteps</code>:训练10000次</li>
</ul>
<h5 id="保存模型">保存模型</h5>
<pre><code class="language-python">model.save("ppo_cartpole")
</code></pre>
<ul>
<li><code>"ppo_cartpole"</code> 为保存模型的名字,这里是保存在当前文件夹中。</li>
</ul>
<h5 id="加载模型">加载模型</h5>
<pre><code class="language-python">loaded_model = PPO.load("ppo_cartpole", env=env)
</code></pre>
<ul>
<li>第一个参数:刚刚保存的模型路径</li>
<li>第二个参数:训练的环境</li>
</ul>
<h5 id="模型预测">模型预测</h5>
<pre><code class="language-python">obs, _ = env.reset()
for i in range(1000):
    action, _states = loaded_model.predict(obs, deterministic=True)

    obs, reward, terminated, truncated, info = env.step(action)
   
    if terminated or truncated:
      obs, _ = env.reset()
</code></pre>
<ul>
<li><code>env.reset()</code>重置环境,返回初始观测值<code>obs</code>和<code>info</code>(这里没用到)</li>
<li>模型的<code>predict</code>方法用于根据观测值<code>obs</code>预测下一步行动。<mark><strong>注意:deterministic参数要为True,不然会报错</strong></mark></li>
<li>模型的<code>step</code>方法根据行动值返回结果。(这些都是什么后面文章会讲)</li>
</ul>
<p><strong>如果❤喜欢❤本系列教程,就点个关注吧,后续不定期更新~</strong></p>


</div>
<div id="MySignature" role="contentinfo">
    <p>本文来自博客园,作者:ClownLMe,转载请注明原文链接:https://www.cnblogs.com/ClownLMe/p/19549111</p><br><br>
来源:https://www.cnblogs.com/ClownLMe/p/19549111
頁: [1]
查看完整版本: stable_baseline3 快速入门(一): 训练第一个强化学习模型