基于CycleGAN架构的深度学习模型,用于将低光域图像转换为高光域图像。该模型集成了以下特性:
- CycleGAN架构:无配对的图像到图像转换
- SSIM损失函数:增强图片的对比度和亮度一致性
- VGG19感知损失:利用预训练VGG19模型提取特征,提升感知质量
- STN空间变换网络:增加输入图像的旋转不变性
yzmgan/
├── models/ # 网络模型
│ ├── __init__.py
│ ├── generator.py # CycleGAN生成器(包含STN)
│ ├── discriminator.py # CycleGAN判别器
│ └── stn.py # 空间变换网络
├── losses/ # 损失函数
│ ├── __init__.py
│ ├── gan_loss.py # GAN损失
│ ├── ssim_loss.py # SSIM损失
│ ├── perceptual_loss.py # VGG19感知损失
│ └── cycle_loss.py # 循环一致性损失
├── utils/ # 工具函数
│ ├── __init__.py
│ ├── dataset.py # 数据加载器
│ └── device.py # 设备工具(支持MPS/CUDA/CPU)
├── datasets/ # 数据集目录
│ └── dark2light/ # 低光到高光数据集
│ ├── trainA/ # 训练集低光图像
│ ├── trainB/ # 训练集高光图像
│ ├── testA/ # 测试集低光图像
│ └── testB/ # 测试集高光图像
├── train.py # 训练脚本
├── train_quick.py # 快速启动训练脚本
├── test.py # 测试/推理脚本
├── evaluate.py # 测试集评估脚本
├── requirements.txt # 依赖包
└── README.md # 说明文档
-
克隆仓库并进入项目目录
-
安装依赖:
pip install -r requirements.txt对于Mac M4系列芯片用户,项目已自动支持MPS(Metal Performance Shaders)GPU加速:
- 自动检测:训练和测试脚本会自动检测并使用MPS设备(如果可用)
- PyTorch版本:确保使用PyTorch 2.0+版本,该版本原生支持MPS
- 安装PyTorch:如果通过pip安装,可以使用:
pip install torch torchvision
- 设备选择:可以通过
--device mps参数强制使用MPS,或--device cpu使用CPU - 内存管理:如果遇到显存不足问题,可以减小
--batch_size(建议从2开始)
项目已经包含了数据集,位于 datasets/dark2light/ 目录下:
- 训练集:
trainA/:低光域训练图像(814张)trainB/:高光域训练图像(902张)
- 测试集:
testA/:低光域测试图像(200张)testB/:高光域测试图像(200张)
注意:两个域的图像不需要配对,CycleGAN可以学习无配对的转换。
python train.py这将使用默认的数据集路径 datasets/dark2light/trainA 和 datasets/dark2light/trainB。
python train_quick.py使用预设的参数配置快速开始训练。
python train.py \
--batch_size 2 \
--n_epochs 200 \
--lambda_cycle 10.0 \
--lambda_ssim 1.0 \
--lambda_perceptual 0.1 \
--use_stn \
--checkpoint_dir ./checkpoints \
--log_dir ./logs提示:如果使用自定义数据集路径,可以通过以下参数指定:
python train.py \
--dataset_dir_A /path/to/your/low_light/images \
--dataset_dir_B /path/to/your/normal_light/images--dataset_dir_A: 低光域图像目录路径(默认:datasets/dark2light/trainA)--dataset_dir_B: 高光域图像目录路径(默认:datasets/dark2light/trainB)--batch_size: 批次大小(默认:4)--n_epochs: 训练轮数(默认:200)--lr: 学习率(默认:0.0002)--lambda_cycle: 循环一致性损失权重(默认:10.0)--lambda_ssim: SSIM损失权重(默认:1.0)--lambda_perceptual: 感知损失权重(默认:0.1)--use_stn: 启用STN空间变换网络(默认:True)--checkpoint_dir: 模型保存目录(默认:./checkpoints)--log_dir: TensorBoard日志目录(默认:./logs)
训练过程会自动保存检查点,并在TensorBoard中记录损失曲线。
python test.py \
--checkpoint ./checkpoints/checkpoint_epoch_200.pth \
--input /path/to/input/image.jpg \
--output /path/to/output/image.jpgpython test.py \
--checkpoint ./checkpoints/checkpoint_epoch_200.pth \
--input /path/to/input/directory \
--output /path/to/output/directory使用内置测试集评估训练好的模型:
python evaluate.py \
--checkpoint ./checkpoints/checkpoint_epoch_200.pth \
--output_dir ./evaluation_results这将:
- 在测试集上运行模型
- 生成转换结果图像
- 保存到
evaluation_results/目录
- 编码器:使用卷积层和InstanceNorm进行下采样
- 残差块:9个残差块用于特征提取和转换
- 解码器:使用转置卷积进行上采样
- STN模块:在输入前应用空间变换,增强旋转不变性
- PatchGAN架构:对图像的局部patch进行判别
- 多尺度特征:逐步下采样提取多尺度特征
- GAN损失:对抗训练损失(LSGAN)
- 循环一致性损失:确保A→B→A和B→A→B的循环一致性
- SSIM损失:结构相似性指数,增强对比度和亮度一致性
- 感知损失:基于VGG19的特征损失,提升感知质量
- 定位网络:学习仿射变换参数
- 网格采样:对输入图像进行空间变换
- 旋转不变性:增强模型对图像旋转的鲁棒性
-
损失权重调整:
- 如果图像质量不佳,可以增加
--lambda_perceptual - 如果对比度不够,可以增加
--lambda_ssim - 如果循环一致性有问题,可以增加
--lambda_cycle
- 如果图像质量不佳,可以增加
-
学习率:
- 初始学习率为0.0002,会在训练后半段自动衰减
- 可以根据训练情况手动调整
-
批次大小:
- 如果显存不足,可以减小
--batch_size - 建议至少为2
- 如果显存不足,可以减小
-
训练时间:
- 200个epoch通常足够,但可以根据数据集大小调整
- 确保有足够的GPU显存(建议至少6GB),Mac M4用户建议至少16GB统一内存
- 图像会被调整到256x256进行训练
- STN模块会增加一定的计算开销,但能提升模型鲁棒性
- VGG19模型会在首次运行时自动下载
- MPS加速:项目已完全支持MPS,训练时会自动使用Apple Silicon GPU加速
- 批次大小:M4芯片建议从较小的batch_size开始(如2或4),根据内存情况调整
- 性能优化:MPS在某些操作上可能比CUDA稍慢,但整体性能仍然很好
- 内存管理:Mac的统一内存架构允许更灵活的内存分配,但建议监控内存使用情况
- 数据加载:如果遇到数据加载瓶颈,可以调整
--num_workers参数(Mac建议使用2-4个worker)
本项目仅供学习和研究使用。