一个倾听者 發表於 2026-1-30 17:46:00

stable_baseline3 快速入门(二): 训练自定义游戏,构建Gymnasium训练环境

<h1 id="简介">简介</h1>
<p><strong>Gymnasium</strong> 为强化学习提供了一个标准化的API,它定义了 Agent 应该如何观察世界、如何做出动作以及如何获得奖励,不管是游戏,还是工业设备,只需要满足<code>Gymnasium</code>标准都能使用同一套代码进行训练。</p>
<h1 id="认识gymnasium">认识Gymnasium</h1>
<p>使用<code>stable_baseline3</code>只需要定义好<code>Gymnasium</code>环境,关注训练的奖励机制,将重点放在业务的开发上而不是复杂的算法。</p>
<p><strong>Gymnasium</strong>提供了几个核心的api:</p>
<table>
<thead>
<tr>
<th><strong>方法</strong></th>
<th><strong>功能</strong></th>
<th><strong>返回值</strong></th>
</tr>
</thead>
<tbody>
<tr>
<td><strong><code>reset()</code></strong></td>
<td>将环境重置为初始状态,开始新回合。</td>
<td><code>obs, info</code></td>
</tr>
<tr>
<td><strong><code>step(action)</code></strong></td>
<td>环境向前推进一步,执行动作。</td>
<td><code>obs, reward, terminated, truncated, info</code></td>
</tr>
<tr>
<td><strong><code>render()</code></strong></td>
<td>可视化环境(根据 <code>render_mode</code> 渲染图像或弹出窗口)。</td>
<td>视配置而定(通常无或为 <code>np.array</code>)</td>
</tr>
<tr>
<td><strong><code>close()</code></strong></td>
<td>释放环境资源(关闭窗口、清理内存)。</td>
<td>无</td>
</tr>
</tbody>
</table>
<p>其中的各个返回值的含义:</p>
<ul>
<li><strong><code>observation</code> (Object)</strong>: 当前状态的描述。例如敌人,玩家的位置,玩家的状态等</li>
<li><strong><code>reward</code> (Float)</strong>: 上一步动作获得的奖励</li>
<li><strong><code>terminated</code> (Bool)</strong>: 是否由于<strong>任务逻辑</strong>结束。例如:到达终点、掉进岩浆等</li>
<li><strong><code>truncated</code> (Bool)</strong>: 是否由于<strong>外部限制</strong>结束。例如:达到最大步数 500 步</li>
<li><strong><code>info</code> (Dict)</strong>: 辅助诊断信息,模型训练通常不用,用于用户自定义调试或记录额外统计。</li>
</ul>
<h1 id="手动构建环境">手动构建环境</h1>
<h3 id="案例">案例</h3>
<p>案例描述:利用pygame构建一个简单的游戏,躲避掉落方块,利用构建的奖励机制,进行强化学习。</p>
<pre><code class="language-python">import gymnasium as gym
from gymnasium import spaces
import numpy as np
import pygame
import random
import cv2
import os
from stable_baselines3 import PPO
from stable_baselines3.common.callbacks import CheckpointCallback
from stable_baselines3.common.env_checker import check_env

class MyEnv(gym.Env):
    def __init__(self, render_mode=None):
      super(MyEnv, self).__init__()

      #初始化参数
      self.width = 400
      self.height = 300
      self.player_size = 30
      self.enemy_size = 30
      self.render_mode = render_mode

      self.action_space = spaces.Discrete(3)

      self.observation_space = spaces.Box(
            low=0, high=255, shape=(84, 84, 3), dtype=np.uint8
      )

      pygame.init()
      if self.render_mode == "human":
            self.screen = pygame.display.set_mode((self.width, self.height))
      
      self.canvas = pygame.Surface((self.width, self.height))
      self.font = pygame.font.SysFont("monospace", 15)

    def reset(self, seed=None, options=None):
      super().reset(seed=seed)

      self.player_x = self.width // 2 - self.player_size // 2
      self.player_y = self.height - self.player_size - 10
      self.enemies = []
      self.score = 0
      self.frame_count = 0

      self.current_speed = 5
      self.spawn_rate = 30

      return self._get_obs(), {}

    def step(self, action):
      reward = 0
      terminated = False
      truncated = False

      move_speed = 8
      if action == 1 and self.player_x &gt; 0: #
            self.player_x -= move_speed
            reward -= 0.05

      if action == 2 and self.player_x &lt; self.width - self.player_size:
            self.player_x += move_speed
            reward -= 0.05

      self.frame_count += 1

      level = self.score // 5
      self.current_speed = 5 + level
      self.spawn_rate = 30 - level * 2
      spawn_rate = max(10, 30 - level)

      if self.frame_count &gt;= spawn_rate:
            self.frame_count = 0
            enemy_x = random.randint(0, self.width - self.enemy_size)
            self.enemies.append() #

      for enemy in self.enemies:
            enemy += self.current_speed
            
            player_rect = pygame.Rect(self.player_x, self.player_y, self.player_size, self.player_size)
            enemy_rect = pygame.Rect(enemy, enemy, self.enemy_size, self.enemy_size)
            
            if player_rect.colliderect(enemy_rect):
                reward = -10
                terminated = True

            elif enemy &gt; self.height:
                self.enemies.remove(enemy)
                self.score += 1
                reward = 1
      
      if not terminated:
            if self.score &gt; 100:
                reward += 0.01
            reward += 0.01

      obs = self._get_obs()

      if self.render_mode == "human":
            self._render_window()

      return obs, reward, terminated, truncated, {}

    def _get_obs(self):
      self.canvas.fill((0, 0, 0))
      pygame.draw.rect(self.canvas, (50, 150, 255), (self.player_x, self.player_y, self.player_size, self.player_size))
      
      for enemy in self.enemies:
            pygame.draw.rect(self.canvas, (255, 50, 50), (enemy, enemy, self.enemy_size, self.enemy_size))

      img_array = pygame.surfarray.array3d(self.canvas)
      img_array = np.transpose(img_array, (1, 0, 2))
      obs = cv2.resize(img_array, (84, 84), interpolation=cv2.INTER_AREA)

      return obs.astype(np.uint8)

    def _render_window(self):
      self.screen.blit(self.canvas, (0, 0))
      text = self.font.render(f"Score: {self.score}", True, (255, 255, 255))
      self.screen.blit(text, (10, 10))
      pygame.display.flip()

      for event in pygame.event.get():
            if event.type == pygame.QUIT:
                pygame.quit()

def train():
    log_dir = "logs/DodgeGame"
    os.makedirs(log_dir, exist_ok=True)

    env = MyEnv()
    check_env(env)
    print("环境检查通过...")

    model_path = "models/dodge_ai.zip"
    if not os.path.exists(model_path):
      print("🆕 未发现旧模型,从头开始训练...")
      model = PPO(
            "CnnPolicy",
            env,
            verbose=1,
            tensorboard_log=log_dir,
            learning_rate=0.0001,
            n_steps=4096,
            batch_size=256,
            device="cuda")
      reset_timesteps = True
    else:
      print("发现旧模型,加载并继续训练...")
      model = PPO.load(
            model_path,
            env=env,      
            device="cuda",
            custom_objects={"learning_rate": 0.0001, "n_steps": 4096, "batch_size": 256}
      )
      reset_timesteps = False
   
    print("开始训练...")

    model.learn(
      total_timesteps=50000,
      reset_num_timesteps=reset_timesteps
    )

    model.save("models/dodge_ai")
    print("模型已保存!")
    env.close()

def prodict():
    env = MyEnv(render_mode="human")
    model = PPO.load("models/dodge_ai", env=env, device="cuda")
    obs, _ = env.reset()

    while True:
      action, _states = model.predict(obs, deterministic=True)

      obs, reward, terminated, truncated, info = env.step(action)

      if terminated or truncated:
            obs, _ = env.reset()
      
      pygame.time.Clock().tick(30)

if __name__ == "__main__":
    train()

    prodict()

</code></pre>
<h3 id="代码解析">代码解析</h3>
<p>代码流程如下:<br>
<strong>构建游戏环境-&gt;训练模型-&gt;模型预测</strong><br>
本篇重点讲构<strong>建游戏环境</strong>,其中的<code>pygame</code>相关代码简略,另外两个流程参考之前文章。</p>
<h4 id="构建游戏环境">构建游戏环境</h4>
<h5 id="初始化类">初始化类</h5>
<p>该类继承<code>gym.Env</code>类</p>
<pre><code class="language-python">class MyEnv(gym.Env):
</code></pre>
<h5 id="构造函数__init__">构造函数__init__</h5>
<pre><code class="language-python">def __init__(self, render_mode=None):
      super(MyEnv, self).__init__()

      #初始化参数
      self.width = 400
      self.height = 300
      self.player_size = 30
      self.enemy_size = 30
      self.render_mode = render_mode

      self.action_space = spaces.Discrete(3)

      self.observation_space = spaces.Box(
            low=0, high=255, shape=(84, 84, 3), dtype=np.uint8
      )

      pygame.init()
      if self.render_mode == "human":
            self.screen = pygame.display.set_mode((self.width, self.height))
      
      self.canvas = pygame.Surface((self.width, self.height))
      self.font = pygame.font.SysFont("monospace", 15)
</code></pre>
<p>在构造函数中,我们主要完成的是声明训练的维度,和输入:</p>
<ul>
<li>输入:<code>self.action_space = spaces.Discrete(3)</code>其中的<code>self.action_space</code>是<strong>固定名称的父类变量</strong>。<code>spaces.Discrete(3)</code>声明输入的数量,例如:<code>向左 向右 和 不动</code>3个输入。</li>
<li>观测维度:<code>self.observation_space</code>也是<strong>固定名称的父类变量</strong>。<code>spaces.Box</code>声明观测维度。</li>
</ul>
<pre><code class="language-python">self.observation_space = spaces.Box(
    low=0, high=255, shape=(84, 84, 3), dtype=np.uint8
)
</code></pre>
<ol>
<li><code>low</code>:观测参数的最小值</li>
<li><code>high</code>:观测参数的最大值</li>
<li><code>shape</code>:声明维度。例如:观测图片<code>shape(高,宽,RGB)</code>,观测一个平面,<code>shape(高,宽)</code></li>
<li><code>dtype</code>:每个变量类型,这里选<code>np.uint8</code>能够节省训练成本,默认是浮点型的。</li>
</ol>
<h5 id="任务重置-reset">任务重置 reset</h5>
<p>相当于初始化游戏状态,游戏的重新开始。返回的是<strong>观测值</strong>和<strong>状态信息(用于调试日志)</strong></p>
<pre><code class="language-python">def reset(self, seed=None, options=None):
      super().reset(seed=seed)

      self.player_x = self.width // 2 - self.player_size // 2
      self.player_y = self.height - self.player_size - 10
      self.enemies = []
      self.score = 0
      self.frame_count = 0

      self.current_speed = 5
      self.spawn_rate = 30

      return self._get_obs(), {}
</code></pre>
<p><strong>观测值 <code>_get_obs</code>:</strong><br>
通过<code>pygame</code>画出的画面,然后用<code>opencv</code>进行简单处理:</p>
<ol>
<li>转换坐标轴(由于<code>opencv</code>坐标xy轴跟<code>pygame</code>的xy是颠倒的)</li>
<li>将画面缩放到<code>84 * 84</code>(可以提高训练效率)</li>
</ol>
<pre><code class="language-python">def _get_obs(self):
      self.canvas.fill((0, 0, 0))
      pygame.draw.rect(self.canvas, (50, 150, 255), (self.player_x, self.player_y, self.player_size, self.player_size))
      
      for enemy in self.enemies:
            pygame.draw.rect(self.canvas, (255, 50, 50), (enemy, enemy, self.enemy_size, self.enemy_size))

      img_array = pygame.surfarray.array3d(self.canvas)
      img_array = np.transpose(img_array, (1, 0, 2))
      obs = cv2.resize(img_array, (84, 84), interpolation=cv2.INTER_AREA)

      return obs.astype(np.uint8)
</code></pre>
<h5 id="步-step重要">步 step(重要)</h5>
<p>这个函数是强化训练的<strong>核心</strong>,规定了在<code>一帧</code>或者<code>一步</code>,我们给AI的分数。<br>
<mark>分数的设置至关重要,这直接决定了训练出来AI的质量</mark><br>
根据下面代码(大部分都是游戏逻辑),主要讲设置<strong>奖励分数</strong>:</p>
<ol>
<li>在AI进行移动时 惩罚 0.05 分</li>
<li>在AI存活时 奖励 0.01分,游戏分数大于100时 存活奖励 0.02分</li>
<li>在障碍物完全下落时 奖励 1 分</li>
<li>在与障碍物碰撞时 惩罚 10 分</li>
</ol>
<pre><code class="language-python">def step(self, action):
      reward = 0
      terminated = False
      truncated = False

      move_speed = 8
      if action == 1 and self.player_x &gt; 0: #
            self.player_x -= move_speed
            reward -= 0.05

      if action == 2 and self.player_x &lt; self.width - self.player_size:
            self.player_x += move_speed
            reward -= 0.05

      self.frame_count += 1

      level = self.score // 5
      self.current_speed = 5 + level
      self.spawn_rate = 30 - level * 2
      spawn_rate = max(10, 30 - level)

      if self.frame_count &gt;= spawn_rate:
            self.frame_count = 0
            enemy_x = random.randint(0, self.width - self.enemy_size)
            self.enemies.append() #

      for enemy in self.enemies:
            enemy += self.current_speed
            
            player_rect = pygame.Rect(self.player_x, self.player_y, self.player_size, self.player_size)
            enemy_rect = pygame.Rect(enemy, enemy, self.enemy_size, self.enemy_size)
            
            if player_rect.colliderect(enemy_rect):
                reward = -10
                terminated = True

            elif enemy &gt; self.height:
                self.enemies.remove(enemy)
                self.score += 1
                reward = 1
      
      if not terminated:
            if self.score &gt; 100:
                reward += 0.01
            reward += 0.01

      obs = self._get_obs()

      if self.render_mode == "human":
            self._render_window()

      return obs, reward, terminated, truncated, {}
</code></pre>
<h5 id="展示游戏画面">展示游戏画面</h5>
<p>下面完全是<code>pygame</code>代码,用于显示游戏画面,这里就不解释了。</p>
<pre><code class="language-python">def _render_window(self):
      self.screen.blit(self.canvas, (0, 0))
      text = self.font.render(f"Score: {self.score}", True, (255, 255, 255))
      self.screen.blit(text, (10, 10))
      pygame.display.flip()

      for event in pygame.event.get():
            if event.type == pygame.QUIT:
                pygame.quit()
</code></pre>
<p><strong>你成功成为了一名调参侠了,快来试试吧!</strong></p>
<p><strong>如果❤喜欢❤本系列教程,就点个关注吧,后续不定期更新~</strong></p>


</div>
<div id="MySignature" role="contentinfo">
    <p>本文来自博客园,作者:ClownLMe,转载请注明原文链接:https://www.cnblogs.com/ClownLMe/p/19554865</p><br><br>
来源:https://www.cnblogs.com/ClownLMe/p/19554865
頁: [1]
查看完整版本: stable_baseline3 快速入门(二): 训练自定义游戏,构建Gymnasium训练环境