autowzry-agent 模块功能规格说明¶
版本: v1.0 日期: 2025-01-12
目录¶
Config 模块¶
模块路径¶
config/config.py
职责¶
统一管理项目配置,支持 YAML 文件、命令行参数、默认配置生成
核心类¶
HyperParameters¶
超参数定义
属性:
- input_channels: int - 输入图像通道数 (默认 4,堆叠帧)
- learning_rate: float - 学习率 (默认 1e-4)
- gamma: float - 折扣因子 (默认 0.99)
- batch_size: int - 批次大小 (默认 32)
- buffer_size: int - replay buffer 容量 (默认 100000)
- target_update_freq: int - target 网络更新频率 (默认 1000 步)
- num_epochs: int - 训练轮数 (默认 100)
- epsilon_start: float - 初始探索率 (默认 1.0)
- epsilon_end: float - 最终探索率 (默认 0.1)
- epsilon_decay: int - 探索率衰减步数 (默认 10000)
- max_episode_steps: int - 单局最大步数 (默认 10000)
- num_episodes: int - 收集的局数 (默认 10)
- reward_kill: float - 击杀奖励 (默认 5.0)
- reward_death: float - 死亡惩罚 (默认 -2.0)
- reward_assist: float - 助攻奖励 (默认 1.0)
- reward_default: float - 每帧存活奖励 (默认 0.01)
PathConfig¶
路径配置
属性:
- data_dir: str - 数据根目录
- episodes_dir: str - episode 存储目录
- models_dir: str - 模型存储目录
- logs_dir: str - 日志目录
EnvironmentConfig¶
环境配置
属性:
- mode: str - 运行模式 (spectate/battle/offline)
- device_backend: str - 设备后端 (autowzry/autowzry-lite)
- image_width: int - 预处理图像宽度 (默认 84)
- image_height: int - 预处理图像高度 (默认 84)
- frame_stack: int - 堆叠帧数 (默认 4)
- enabled_detections: list - 启用的检测项 (默认 ['kill', 'death', 'assist'])
- enabled_rewards: list - 启用的奖励项 (默认 ['kill', 'death', 'assist'])
Config¶
配置管理器
方法:
- from_yaml(yaml_path): 从 YAML 文件加载配置
- from_args(args): 从命令行参数构建配置
- save_yaml(yaml_path): 保存配置到 YAML 文件
- create_argument_parser(): 创建命令行参数解析器
- generate_default_config(output_path): 生成默认配置文件
使用示例:
# 从 YAML 加载
config = Config.from_yaml('config/my_config.yaml')
# 从命令行加载
config, args = Config.from_args()
# 生成默认配置
Config.generate_default_config('config/default_config.yaml')
Core 模块¶
agent.py¶
职责¶
Agent 主控制器,协调各模块完成不同模式的任务
核心类: Agent¶
初始化参数:
- config: Config - 配置对象
- model: BaseModel - 神经网络模型
- trainer: DQNTrainer - 训练器
- collector: ExperienceCollector - 经验收集器
- replay_buffer: ReplayBuffer - 经验缓冲区
方法:
- run_collect_mode(): 运行数据收集模式
- run_train_mode(): 运行训练模式
- run_battle_mode(): 运行对战模式(在线学习)
- run_evaluate_mode(): 运行评估模式
- save_checkpoint(path): 保存模型检查点
- load_checkpoint(path): 加载模型检查点
工作流程: 1. 根据配置初始化所有依赖模块 2. 根据模式调用对应的运行方法 3. 管理模型的保存和加载
model.py¶
职责¶
定义神经网络模型
核心类¶
BaseModel¶
模型基类
方法:
- forward(x): 前向传播(抽象方法)
- save(path): 保存模型权重
- load(path): 加载模型权重
- predict(state, epsilon): 推理并选择动作(含探索)
SimpleConvNet¶
简单的卷积神经网络
初始化参数:
- input_channels: int - 输入通道数(帧堆叠数)
- action_dim: int - 动作空间维度
网络结构: - Conv2d(input_channels, 32, kernel=8, stride=4) + ReLU - Conv2d(32, 64, kernel=4, stride=2) + ReLU - Conv2d(64, 64, kernel=3, stride=1) + ReLU - Flatten - Linear(conv_out_size, 512) + ReLU - Linear(512, action_dim)
输入输出: - 输入: (batch, input_channels, 84, 84) - 输出: (batch, action_dim) - Q值向量
trainer.py¶
职责¶
DQN 训练逻辑
核心类: DQNTrainer¶
初始化参数:
- model: BaseModel - 主网络
- target_model: BaseModel - 目标网络
- learning_rate: float - 学习率
- gamma: float - 折扣因子
方法:
- train_step(batch): 单步训练
- 输入: batch = (states, actions, rewards, next_states, dones)
- 输出: loss 值
- 流程:
1. 计算当前 Q 值: Q(s, a)
2. 计算目标 Q 值: r + γ * max Q'(s', a')
3. 计算 MSE loss
4. 反向传播并优化
- update_target_network(): 同步 target 网络权重
- compute_loss(batch): 计算 DQN loss
DQN Loss 公式:
Loss = MSE(Q(s,a), r + γ * max Q_target(s', a') * (1 - done))
policy.py¶
职责¶
策略选择器,控制探索/利用行为
核心类¶
Policy¶
策略基类
方法:
- select_action(q_values): 根据 Q 值选择动作(抽象方法)
GreedyPolicy¶
贪心策略
方法:
- select_action(q_values): 返回 Q 值最大的动作
EpsilonGreedyPolicy¶
ε-贪心策略
初始化参数:
- epsilon_start: float - 初始探索率
- epsilon_end: float - 最终探索率
- epsilon_decay: int - 衰减步数
方法:
- get_epsilon(): 计算当前探索率(指数衰减)
- select_action(q_values): 以 ε 概率随机探索,(1-ε) 概率贪心选择
探索率衰减公式:
epsilon = epsilon_end + (epsilon_start - epsilon_end) * exp(-step / decay)
SoftmaxPolicy¶
Softmax 策略(Boltzmann 探索)
初始化参数:
- temperature: float - 温度参数
方法:
- select_action(q_values): 按 softmax 概率采样
ModelWithPolicy¶
模型与策略的组合
初始化参数:
- model: BaseModel - 神经网络模型
- policy: Policy - 策略选择器
方法:
- predict(state): 结合模型推理和策略选择
Environment 模块¶
compatibility.py¶
职责¶
封装游戏设备接口,支持 autowzry 和 autowzry-lite
核心类: CompatibilityModule¶
初始化参数:
- backend: str - 后端类型 (autowzry/autowzry-lite)
- image_width: int - 预处理后的图像宽度
- image_height: int - 预处理后的图像高度
方法:
- init_device(): 初始化设备连接
- 尝试导入 autowzry
- 失败则使用 autowzry-lite
- capture_screen(): 截图并预处理
- 返回: numpy array, shape (H, W, C) 或 (C, H, W)
- 流程: 截图 → resize → 归一化
- execute_action(action): 执行游戏动作
- 输入: action (int 或 dict)
- 调用对应的游戏接口(移动、攻击、技能等)
- is_battle_running(): 判断对战是否进行中
- 返回: bool
- enter_battle(): 进入对战(仅 autowzry 支持)
- 自动启动游戏、进入对局
设计要点: - 统一的接口,隐藏底层差异 - 异常处理(设备断开、游戏崩溃) - 支持扩展到其他设备(如模拟器、PC端)
action_space.py¶
职责¶
定义动作空间,提供动作编码/解码
核心类: ActionSpace¶
初始化参数:
- enabled_actions: list - 启用的动作类型 (默认 ['move', 'attack'])
方法:
- define_action_space(): 定义动作空间
- 返回动作字典
- encode_action(command): 动作命令 → 索引
- 输入: dict, 例如 {'type': 'move', 'direction': 0}
- 输出: int (action index)
- decode_action(index): 索引 → 动作命令
- 输入: int
- 输出: dict
- get_dim(): 返回动作空间维度
动作定义示例:
actions = {
0: {'type': 'move', 'direction': 0}, # 向上移动
1: {'type': 'move', 'direction': 45}, # 东北移动
2: {'type': 'move', 'direction': 90}, # 向右移动
...
8: {'type': 'attack', 'target': 'auto'}, # 攻击
9: {'type': 'skill', 'slot': 1}, # 释放技能1
}
设计要点: - 动作空间可配置 - 支持连续动作的离散化 - 便于扩展新动作类型
reward_evaluator.py¶
职责¶
状态识别 + 奖励计算的统一模块
核心类: RewardEvaluator¶
初始化参数:
- enabled_detections: list - 启用的检测项
- enabled_rewards: list - 启用的奖励项
- reward_dict: dict - 奖励值字典
- template_dir: str - 模板图像目录
方法:
- detect_events(frame): 检测游戏事件
- 输入: frame (numpy array)
- 输出: dict, 例如:
python
{
'kill_count': 3,
'death_count': 1,
'assist_count': 2,
'is_dead': False,
'hp': 80, # 后期扩展
'position': (120, 340) # 后期扩展
}
- compute_reward(prev_state, curr_state): 计算奖励
- 输入: 前一帧和当前帧
- 输出: float (奖励值)
- 流程:
1. 检测两帧的事件
2. 比较差异(击杀+1、死亡等)
3. 根据奖励字典计算总奖励
- _detect_death(frame): 检测死亡状态
- 使用灰度画面特征或死亡图标匹配
- _detect_kill_count(frame): 检测击杀数
- 使用 OCR 或模板匹配
- _detect_assist_count(frame): 检测助攻数
奖励字典示例:
reward_dict = {
'default': 0.01, # 每帧存活
'kill': 5.0, # 击杀
'death': -2.0, # 死亡
'assist': 1.0, # 助攻
'hp_loss': -0.01, # 掉血 (后期)
'gold_gain': 0.05, # 获得金币 (后期)
}
设计要点: - 感知和奖励统一管理,便于同步修改 - 可配置的启用项,支持渐进式开发 - 使用 OpenCV 和模板匹配,避免依赖复杂模型
Data 模块¶
experience_collector.py¶
职责¶
收集游戏经验并保存为 episode 文件
核心类: ExperienceCollector¶
初始化参数:
- compatibility_module: CompatibilityModule
- action_space: ActionSpace
- reward_evaluator: RewardEvaluator
- mode: str - 运行模式 (spectate/battle/offline)
- save_dir: str - episode 保存目录
方法:
- collect_one_step(): 收集单步经验
- 返回: (state, action, reward, next_state, done)
- 流程:
1. 获取当前状态(截图)
2. 根据模式决定动作
- 观战/离线: 从图像差分推断
- 对战: 模型推理
3. 执行动作(对战模式)
4. 获取下一状态
5. 计算奖励
6. 判断是否结束
- collect_episode(max_steps): 收集完整一局
- 返回: episode 列表
- 循环调用 collect_one_step 直到游戏结束
- save_episode(episode, filename): 保存为 HDF5 文件
- 分离 (s, a, r, s', done)
- 使用 gzip 压缩
- 保存元数据(模式、步数、总奖励)
- collect_and_save_multiple_episodes(num_episodes): 批量收集
- 返回: 保存的文件路径列表
HDF5 文件结构:
episode_xxx.h5
├── states: (N, C, H, W) - 状态图像
├── actions: (N,) - 动作索引
├── rewards: (N,) - 奖励值
├── next_states: (N, C, H, W) - 下一状态
├── dones: (N,) - 结束标志
└── attrs:
├── mode: 'spectate'
├── num_steps: 1234
└── total_reward: 15.6
replay_buffer.py¶
职责¶
经验回放缓冲区,支持多文件加载和随机采样
核心类: ReplayBuffer¶
初始化参数:
- max_size: int - 最大容量(超出则覆盖旧数据)
属性:
- buffer: deque - 存储经验的双端队列
方法:
- push(experience): 添加单条经验
- 输入: (state, action, reward, next_state, done)
- load_episode_file(filepath): 加载单个 HDF5 文件
- 读取文件中的所有经验
- 逐条 push 到 buffer
- load_from_directory(directory, pattern, max_files): 批量加载
- 扫描目录匹配文件
- 加载多个 episode
- 支持限制最大文件数
- sample(batch_size): 随机采样
- 返回: (states, actions, rewards, next_states, dones)
- numpy array 格式
- save_consolidated(filepath): 保存整合后的 buffer
- 用于分发数据集
- get_statistics(): 获取统计信息
- 返回: size、平均奖励、最大/最小奖励
- __len__(): 返回当前 buffer 大小
设计要点: - 使用 deque 自动管理容量 - 支持增量加载和随机采样 - 提供统计信息便于调试
dataset.py¶
职责¶
PyTorch Dataset 封装,配合 DataLoader 使用
核心类: ExperienceDataset¶
继承: torch.utils.data.Dataset
初始化参数:
- replay_buffer: ReplayBuffer - 数据来源
方法:
- __len__(): 返回数据集大小
- __getitem__(idx): 获取单条经验
- 返回: dict
python
{
'state': torch.FloatTensor, # (C, H, W)
'action': torch.LongTensor, # (1,)
'reward': torch.FloatTensor, # (1,)
'next_state': torch.FloatTensor, # (C, H, W)
'done': torch.FloatTensor # (1,)
}
使用示例:
dataset = ExperienceDataset(replay_buffer)
dataloader = DataLoader(dataset, batch_size=32, shuffle=True, num_workers=4)
for batch in dataloader:
states = batch['state'] # (32, C, H, W)
actions = batch['action'] # (32, 1)
# 训练...
设计要点: - 固定数据快照(不动态更新) - 返回 tensor 格式,直接用于训练 - 支持多进程加载
Utils 模块¶
image_processing.py¶
职责¶
图像预处理工具函数
函数列表:
resize_frame(frame, width, height): 调整图像大小- 输入: numpy array
-
输出: numpy array, shape (height, width, channels)
-
to_grayscale(frame): 转换为灰度图 - 输入: RGB 图像
-
输出: 灰度图像, shape (H, W)
-
normalize(frame): 归一化到 [0, 1] - 输入: numpy array, dtype uint8
-
输出: numpy array, dtype float32
-
stack_frames(frames): 堆叠多帧 - 输入: list of frames, 例如 [frame1, frame2, frame3, frame4]
-
输出: numpy array, shape (num_frames, H, W)
-
preprocess_frame(frame, width, height, grayscale): 完整预处理流程 - 流程: resize → (可选) 灰度化 → 归一化
action_inference.py¶
职责¶
从图像差分推断执行的动作
函数列表:
infer_action(prev_frame, curr_frame): 推断动作- 输入: 前一帧和当前帧
- 输出: action index
-
流程:
- 检测位置变化 → 推断移动方向
- 检测特效(攻击/技能) → 推断操作
- 默认返回"无操作"
-
detect_position_change(prev_frame, curr_frame): 检测位置变化 - 使用光流法或特征点匹配
-
返回: 移动方向角度
-
detect_attack_effect(frame): 检测攻击特效 - 返回: bool
设计要点: - 仅用于观战/离线模式 - 不需要 100% 准确,提供初始训练数据即可 - 后期可以用模型替代
logger.py¶
职责¶
日志记录和输出
核心类: Logger
初始化参数:
- log_dir: str - 日志目录
- use_tensorboard: bool - 是否使用 TensorBoard
方法:
- log_scalar(tag, value, step): 记录标量
- 例如: loss、reward
- log_text(message): 记录文本日志
- log_histogram(tag, values, step): 记录直方图
- 例如: Q 值分布
- close(): 关闭日志
设计要点: - 同时输出到文件和控制台 - 可选 TensorBoard 集成 - 支持多种数据类型
Scripts 模块¶
collect.py¶
职责¶
数据收集脚本
功能: - 解析命令行参数 - 初始化模块 - 根据模式收集数据 - 保存 episode 文件
命令行参数:
- --config: 配置文件路径
- --mode: 运行模式 (spectate/battle/offline)
- --num-episodes: 收集局数
- --model-path: 模型路径(对战模式需要)
输出: - episode 文件列表 - 收集统计信息
train.py¶
职责¶
训练脚本
功能: - 加载配置 - 加载 replay buffer - 初始化模型和训练器 - 训练循环 - 保存模型
命令行参数:
- --config: 配置文件路径
- --episodes-dir: episode 目录
- --epochs: 训练轮数
- --batch-size: 批次大小
- --device: 训练设备 (cpu/cuda)
输出: - 训练日志 - 模型检查点
evaluate.py¶
职责¶
评估脚本
功能: - 加载训练好的模型 - 运行多局游戏 - 统计性能指标
命令行参数:
- --model-path: 模型路径
- --num-episodes: 评估局数
- --mode: 评估模式(通常是 battle)
输出: - 平均奖励 - 胜率(如果能识别) - 平均局长
generate_config.py¶
职责¶
生成默认配置文件
功能: - 调用 Config.generate_default_config() - 输出到指定路径
命令行参数:
- --output: 输出路径
文档结束