Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Tusser156/yzmgan

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

1 Commit
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

CycleGAN 低光域到高光域图像转换

基于CycleGAN架构的深度学习模型,用于将低光域图像转换为高光域图像。该模型集成了以下特性:

  • CycleGAN架构:无配对的图像到图像转换
  • SSIM损失函数:增强图片的对比度和亮度一致性
  • VGG19感知损失:利用预训练VGG19模型提取特征,提升感知质量
  • STN空间变换网络:增加输入图像的旋转不变性

项目结构

yzmgan/
├── models/              # 网络模型
│   ├── __init__.py
│   ├── generator.py    # CycleGAN生成器(包含STN)
│   ├── discriminator.py # CycleGAN判别器
│   └── stn.py          # 空间变换网络
├── losses/             # 损失函数
│   ├── __init__.py
│   ├── gan_loss.py     # GAN损失
│   ├── ssim_loss.py    # SSIM损失
│   ├── perceptual_loss.py # VGG19感知损失
│   └── cycle_loss.py   # 循环一致性损失
├── utils/              # 工具函数
│   ├── __init__.py
│   ├── dataset.py      # 数据加载器
│   └── device.py       # 设备工具(支持MPS/CUDA/CPU)
├── datasets/          # 数据集目录
│   └── dark2light/   # 低光到高光数据集
│       ├── trainA/  # 训练集低光图像
│       ├── trainB/  # 训练集高光图像
│       ├── testA/   # 测试集低光图像
│       └── testB/   # 测试集高光图像
├── train.py          # 训练脚本
├── train_quick.py    # 快速启动训练脚本
├── test.py           # 测试/推理脚本
├── evaluate.py       # 测试集评估脚本
├── requirements.txt  # 依赖包
└── README.md         # 说明文档

安装

  1. 克隆仓库并进入项目目录

  2. 安装依赖:

pip install -r requirements.txt

Mac M4 芯片(Apple Silicon)用户注意事项

对于Mac M4系列芯片用户,项目已自动支持MPS(Metal Performance Shaders)GPU加速:

  • 自动检测:训练和测试脚本会自动检测并使用MPS设备(如果可用)
  • PyTorch版本:确保使用PyTorch 2.0+版本,该版本原生支持MPS
  • 安装PyTorch:如果通过pip安装,可以使用:
    pip install torch torchvision
  • 设备选择:可以通过--device mps参数强制使用MPS,或--device cpu使用CPU
  • 内存管理:如果遇到显存不足问题,可以减小--batch_size(建议从2开始)

数据准备

项目已经包含了数据集,位于 datasets/dark2light/ 目录下:

  • 训练集
    • trainA/:低光域训练图像(814张)
    • trainB/:高光域训练图像(902张)
  • 测试集
    • testA/:低光域测试图像(200张)
    • testB/:高光域测试图像(200张)

注意:两个域的图像不需要配对,CycleGAN可以学习无配对的转换。

快速开始

方式一:使用默认配置直接训练(推荐)

python train.py

这将使用默认的数据集路径 datasets/dark2light/trainAdatasets/dark2light/trainB

方式二:快速启动脚本

python train_quick.py

使用预设的参数配置快速开始训练。

方式三:自定义参数训练

python train.py \
    --batch_size 2 \
    --n_epochs 200 \
    --lambda_cycle 10.0 \
    --lambda_ssim 1.0 \
    --lambda_perceptual 0.1 \
    --use_stn \
    --checkpoint_dir ./checkpoints \
    --log_dir ./logs

提示:如果使用自定义数据集路径,可以通过以下参数指定:

python train.py \
    --dataset_dir_A /path/to/your/low_light/images \
    --dataset_dir_B /path/to/your/normal_light/images

主要参数说明

  • --dataset_dir_A: 低光域图像目录路径(默认:datasets/dark2light/trainA
  • --dataset_dir_B: 高光域图像目录路径(默认:datasets/dark2light/trainB
  • --batch_size: 批次大小(默认:4)
  • --n_epochs: 训练轮数(默认:200)
  • --lr: 学习率(默认:0.0002)
  • --lambda_cycle: 循环一致性损失权重(默认:10.0)
  • --lambda_ssim: SSIM损失权重(默认:1.0)
  • --lambda_perceptual: 感知损失权重(默认:0.1)
  • --use_stn: 启用STN空间变换网络(默认:True)
  • --checkpoint_dir: 模型保存目录(默认:./checkpoints)
  • --log_dir: TensorBoard日志目录(默认:./logs)

训练过程会自动保存检查点,并在TensorBoard中记录损失曲线。

测试/推理

单张图像测试

python test.py \
    --checkpoint ./checkpoints/checkpoint_epoch_200.pth \
    --input /path/to/input/image.jpg \
    --output /path/to/output/image.jpg

批量处理

python test.py \
    --checkpoint ./checkpoints/checkpoint_epoch_200.pth \
    --input /path/to/input/directory \
    --output /path/to/output/directory

测试集评估

使用内置测试集评估训练好的模型:

python evaluate.py \
    --checkpoint ./checkpoints/checkpoint_epoch_200.pth \
    --output_dir ./evaluation_results

这将:

  • 在测试集上运行模型
  • 生成转换结果图像
  • 保存到 evaluation_results/ 目录

模型架构说明

1. 生成器(Generator)

  • 编码器:使用卷积层和InstanceNorm进行下采样
  • 残差块:9个残差块用于特征提取和转换
  • 解码器:使用转置卷积进行上采样
  • STN模块:在输入前应用空间变换,增强旋转不变性

2. 判别器(Discriminator)

  • PatchGAN架构:对图像的局部patch进行判别
  • 多尺度特征:逐步下采样提取多尺度特征

3. 损失函数

  • GAN损失:对抗训练损失(LSGAN)
  • 循环一致性损失:确保A→B→A和B→A→B的循环一致性
  • SSIM损失:结构相似性指数,增强对比度和亮度一致性
  • 感知损失:基于VGG19的特征损失,提升感知质量

4. STN(空间变换网络)

  • 定位网络:学习仿射变换参数
  • 网格采样:对输入图像进行空间变换
  • 旋转不变性:增强模型对图像旋转的鲁棒性

训练技巧

  1. 损失权重调整

    • 如果图像质量不佳,可以增加--lambda_perceptual
    • 如果对比度不够,可以增加--lambda_ssim
    • 如果循环一致性有问题,可以增加--lambda_cycle
  2. 学习率

    • 初始学习率为0.0002,会在训练后半段自动衰减
    • 可以根据训练情况手动调整
  3. 批次大小

    • 如果显存不足,可以减小--batch_size
    • 建议至少为2
  4. 训练时间

    • 200个epoch通常足够,但可以根据数据集大小调整

注意事项

通用注意事项

  • 确保有足够的GPU显存(建议至少6GB),Mac M4用户建议至少16GB统一内存
  • 图像会被调整到256x256进行训练
  • STN模块会增加一定的计算开销,但能提升模型鲁棒性
  • VGG19模型会在首次运行时自动下载

Mac M4芯片特定注意事项

  • MPS加速:项目已完全支持MPS,训练时会自动使用Apple Silicon GPU加速
  • 批次大小:M4芯片建议从较小的batch_size开始(如2或4),根据内存情况调整
  • 性能优化:MPS在某些操作上可能比CUDA稍慢,但整体性能仍然很好
  • 内存管理:Mac的统一内存架构允许更灵活的内存分配,但建议监控内存使用情况
  • 数据加载:如果遇到数据加载瓶颈,可以调整--num_workers参数(Mac建议使用2-4个worker)

许可证

本项目仅供学习和研究使用。

About

No description, website, or topics provided.

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages