徐霖凯奶奶 發表於 2026-1-31 21:54:00

从0到1,快速训练并使用YOLO模型

<h1 id="简介">简介</h1>
<p><strong>YOLO</strong>是目前计算机视觉领域最前沿、应用最广泛的<strong>目标检测</strong>算法框架,他能快速识别区分目标,广泛应用于游戏,无人驾驶,工业等领域。</p>
<p>以识别躲避掉落滑块的游戏的物体图片作为例子。</p>
<h1 id="一环境配置">一,环境配置</h1>
<pre><code class="language-bash">pip install ultralytics
</code></pre>
<h1 id="二准备数据集">二,准备数据集</h1>
<p>这个格式目录如下:</p>
<pre><code class="language-bash">my_dataset/
├── data.yaml # 配置文件(定义路径和类别)
├── train/ #训练数据集
│   ├── images/ # 训练图片
│   └── labels/ # 标注文件 (.txt)
└── val/ #验证数据集
    ├── images/
    └── labels/
</code></pre>
<h3 id="datayml">data.yml</h3>
<pre><code class="language-yaml">path: D:\D_MyProject\Ai\game_ai\my_dataset #数据集路径
train: train/images #训练集图片路径
val: val/images #验证集图片路径

nc: 3 #标记个数
names:#每个标记的名称
&nbsp; 0: player
&nbsp; 1: enemy
&nbsp; 2: game_over
</code></pre>
<h3 id="下面是用ai生成了数据集的生成脚本">下面是用AI生成了数据集的生成脚本</h3>
<p><mark>数据量太多了,这里为了演示,或者学习,可以直接使用下面脚本</mark></p>
<pre><code class="language-python">import pygame
import random
import sys
import os
import shutil

# =================配置区域=================
# 数据集根目录名称
DATASET_ROOT = "my_dataset"
# 采集总数量
MAX_IMAGES = 1000
# 训练集占比 (0.8 = 80% 训练, 20% 验证)
TRAIN_RATIO = 0.8

# ================= 1. 环境清理与目录创建 =================
print(f"🚀 正在初始化数据集目录: {DATASET_ROOT} ...")

# 如果目录已存在,先删除(防止旧数据混入),确保数据纯净
if os.path.exists(DATASET_ROOT):
    shutil.rmtree(DATASET_ROOT)

# 创建 YOLO 标准目录结构
for split in ['train', 'val']:
    os.makedirs(os.path.join(DATASET_ROOT, split, 'images'), exist_ok=True)
    os.makedirs(os.path.join(DATASET_ROOT, split, 'labels'), exist_ok=True)

# ================= 2. 自动生成 data.yaml =================
yaml_content = f"""
path: {os.path.abspath(DATASET_ROOT)} # 使用绝对路径,防止报错
train: train/images
val: val/images

nc: 3
names:
0: player
1: enemy
2: game_over
"""
with open(os.path.join(DATASET_ROOT, "data.yaml"), "w", encoding='utf-8') as f:
    f.write(yaml_content)
print("✅ data.yaml 配置文件已生成。")

# ================= 3. 游戏与采集逻辑 =================
pygame.init()
WIDTH, HEIGHT = 800, 600
# 使用 hidden 模式或正常模式均可,这里用正常模式方便你看到进度
screen = pygame.display.set_mode((WIDTH, HEIGHT))
pygame.display.set_caption("Auto Data Generator")
clock = pygame.time.Clock()

font_big = pygame.font.SysFont("monospace", 50)
font_small = pygame.font.SysFont("monospace", 35)

# 游戏状态
player_size, enemy_size = 50, 50
player_x = WIDTH // 2
player_y = HEIGHT - player_size - 10
enemies = []

img_count = 0
GAMEPLAY_LIMIT = int(MAX_IMAGES * 0.9) # 90% 正常游戏,10% Game Over

print(f"📸 开始采集 {MAX_IMAGES} 张图片 (自动划分 Train/Val)...")

while img_count &lt; MAX_IMAGES:
    # 处理退出事件,防止卡死
    for event in pygame.event.get():
      if event.type == pygame.QUIT:
            pygame.quit()
            sys.exit()

    screen.fill((0, 0, 0))
    labels = [] # 存储当前帧的标签
    dw, dh = 1.0 / WIDTH, 1.0 / HEIGHT # 归一化系数

    # --- 逻辑分支:正常游戏 vs Game Over ---
    if img_count &lt; GAMEPLAY_LIMIT:
      # A. 正常游戏画面
      # 1. 玩家移动
      if random.random() &lt; 0.2: # 增加移动频率
            player_x += random.choice([-15, 15])
            player_x = max(0, min(WIDTH - player_size, player_x))

      # 2. 敌人生成与移动
      if random.randint(0, 20) == 0: # 增加敌人密度
            enemies.append()
      
      for enemy in enemies: enemy += 15 # 加快下落速度
      enemies = &lt; HEIGHT]

      # 3. 绘制玩家 (Class 0)
      pygame.draw.rect(screen, (50, 150, 255), (player_x, player_y, player_size, player_size))
      # 计算 YOLO 坐标 (class x_center y_center width height)
      px, py = (player_x + player_size/2) * dw, (player_y + player_size/2) * dh
      labels.append(f"0 {px:.6f} {py:.6f} {player_size*dw:.6f} {player_size*dh:.6f}")

      # 4. 绘制敌人 (Class 1)
      for e in enemies:
            pygame.draw.rect(screen, (255, 50, 50), (e, e, enemy_size, enemy_size))
            ex, ey = (e + enemy_size/2) * dw, (e + enemy_size/2) * dh
            labels.append(f"1 {ex:.6f} {ey:.6f} {enemy_size*dw:.6f} {enemy_size*dh:.6f}")

    else:
      # B. Game Over 画面 (Class 2)
      text_surf = font_big.render("GAME OVER", True, (255, 50, 50))
      
      # 随机抖动位置,防止过拟合
      off_x, off_y = random.randint(-50, 50), random.randint(-50, 50)
      text_x = WIDTH // 2 - text_surf.get_width() // 2 + off_x
      text_y = HEIGHT // 2 - 100 + off_y
      screen.blit(text_surf, (text_x, text_y))

      # 干扰项
      score_surf = font_small.render(f"Score: {random.randint(0,999)}", True, (255,255,255))
      screen.blit(score_surf, (WIDTH//2 - score_surf.get_width()//2, HEIGHT//2 + 50))

      # 记录标签
      tw, th = text_surf.get_width(), text_surf.get_height()
      tx, ty = (text_x + tw/2) * dw, (text_y + th/2) * dh
      labels.append(f"2 {tx:.6f} {ty:.6f} {tw*dw:.6f} {th*dh:.6f}")

    # ================= 4. 保存逻辑 (核心修改) =================
    pygame.display.flip()
   
    # 采样率:不是每一帧都保存,防止重复度过高 (这里设为 30% 概率保存)
    if random.random() &lt; 0.3:
      # A. 决定是去 Train 还是 Val
      split_folder = "train" if random.random() &lt; TRAIN_RATIO else "val"
      
      # B. 生成文件名
      filename = f"{img_count:06d}"
      img_save_path = os.path.join(DATASET_ROOT, split_folder, "images", f"{filename}.jpg")
      lbl_save_path = os.path.join(DATASET_ROOT, split_folder, "labels", f"{filename}.txt")

      # C. 保存图片
      pygame.image.save(screen, img_save_path)

      # D. 保存标签
      with open(lbl_save_path, "w") as f:
            f.write("\n".join(labels))

      img_count += 1
      
      # 打印进度条
      print(f"[{split_folder.upper()}] 进度: {img_count}/{MAX_IMAGES}", end="\r")

    # 加速模拟,不要垂直同步,越快越好
    clock.tick(0)

pygame.quit()
print(f"\n\n✨ 全部完成!数据集已就绪:{os.path.abspath(DATASET_ROOT)}")
print("💡 下一步:直接运行 model.train(data='my_dataset/data.yaml')")
</code></pre>
<h1 id="三训练yolo模型">三,训练YOLO模型</h1>
<p>可以看到,使用<code>ultralytics</code>框架训练YOLO的代码非常简单,只需要几行<br>
<mark><strong>注意:这里YOLO会自动下载模型并训练,下载时失败可能需要挂梯子</strong></mark></p>
<pre><code class="language-python">from ultralytics import YOLO

def train_model():
    #加载模型
    model = YOLO("yolo11n.pt")

    #开始训练
    print("开始训练...")
    results = model.train(
      data="my_dataset/data.yaml", #数据集配置文件
      epochs=30, #训练轮数
      imgsz=640, #图片输入尺寸
      batch=16, #显存够大可以改大,比如 32 或 64
      device=0, #强制使用第一张显卡 (需要CUDA)
      workers=0, #Windows下设为0防止多进程报错
      project="dodge_project",#保存路径
      name="ai_model" #训练运行名称
    )
    print(f"训练完成!最佳模型保存在: {results.save_dir}/weights/best.pt")

if __name__ == "__main__":
    train_model()
</code></pre>
<h1 id="四使用yolo模型">四,使用YOLO模型</h1>
<pre><code class="language-python">from ultralytics import YOLO

#配置路径
MODEL_PATH = "dodge_project/ai_model/weights/best.pt"
PIC_PATH = "my_dataset/val/images/000045.jpg"

#加载模型
model = YOLO(MODEL_PATH)
print(f"模型加载成功")

#识别图片
results = model.predict(PIC_PATH, verbose=False, conf=0.4, imgsz=640)
result = results

#获取信息并打印
for box in result.boxes:
    # 获取坐标 (x1, y1, x2, y2)、类别 ID 和置信度
    x1, y1, x2, y2 = map(int, box.xyxy)
    cls_id = int(box.cls)
    conf = float(box.conf)
   
    # 类别名称映射
    names = {0: "Player", 1: "Enemy", 2: "Game Over"}
    label = names.get(cls_id, f"Unknown({cls_id})")
    print(f"找到物体: [{label:10}] | 置信度: {conf:.2f} | 坐标: ({x1}, {y1}) -&gt; ({x2}, {y2})")

#展示图片
result.show()
</code></pre>
<p><strong>如果❤喜欢❤本系列教程,就点个关注吧,后续不定期更新~</strong></p>


</div>
<div id="MySignature" role="contentinfo">
    <p>本文来自博客园,作者:ClownLMe,转载请注明原文链接:https://www.cnblogs.com/ClownLMe/p/19559180</p><br><br>
来源:https://www.cnblogs.com/ClownLMe/p/19559180
頁: [1]
查看完整版本: 从0到1,快速训练并使用YOLO模型