可见光与红外图像融合

写了一个批量计算图像融合指标的脚本,主要用到了scikit-image模块,点击可查看官方文档

安装scikit-image

建议建立一个虚拟Python环境运行,虚拟环境的意义,就如同虚拟机一样,它可以实现不同环境中Python依赖包相互独立,互不干扰。

创建虚拟环境的步骤:

  • 第一步:安装Virtualenv

    1
    
    pip3 install virtualenv -i https://pypi.python.org/simple/
    
  • 第二步:cd 到存放虚拟环境的目录地址,执行以下代码创建虚拟环境

    1
    2
    3
    4
    
    virtualenv python3_env 
    # 你也可以指定版本
    virtualenv -p /usr/bin/python2.7 python2_env
    virtualenv -p /usr/bin/python3.8 python3_env
    
  • 第三步:激活虚拟环境

    ubuntu执行以下命令:

    1
    
    source python3_env/bin/activate
    

    Windows执行以下命令:

    1
    
    .\Scripts\activate.bat
    
  • 第四步:退出虚拟环境

    ubuntu执行:

    1
    
    deactivate
    

    Windows执行:

    1
    
    .\Scripts\deactivate.bat
    

接下来进入正文,安装scikit-image,激活虚拟环境后,pip安装

1
2
3
4
# 更新pip
python -m pip install -U pip
# 安装 scikit-image
python -m pip install -U scikit-image

批量计算图像融合指标SSIM、RMSE、PSNR、EN、AG

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
import cv2
import math
import os
import numpy as np
import matplotlib.pyplot as plt
from skimage.metrics import mean_squared_error as compare_mse
from skimage.metrics import peak_signal_noise_ratio as compare_psnr
from skimage.metrics import structural_similarity as compare_ssim
from skimage.measure import shannon_entropy as compare_entropy

def compare_img(img_raw_path, img_fusion_path):
    ssim_list = []
    rmse_list = []
    psnr_list = []
    en_list = []
    ag_list = []
    for filename_raw ,filename_fusion in zip(os.listdir(img_raw_path),os.listdir(img_fusion_path)):
        img_raw = cv2.imread(img_raw_path+'/'+filename_raw)
        img_fusion = cv2.imread(img_fusion_path+'/'+filename_fusion)
        ssim = compare_ssim(img_raw, img_fusion, channel_axis=2) # 对于多通道图像(RGB、HSV等)关键词multichannel要设置为True
        rmse = math.sqrt(compare_mse(img_raw, img_fusion))
        psnr = compare_psnr(img_raw, img_fusion)
        en = compare_entropy(img_fusion)
        ag = compare_ag(img_fusion)
        ssim_list.append(ssim)
        rmse_list.append(rmse)
        psnr_list.append(psnr)
        en_list.append(en)
        ag_list.append(ag)        
    return ssim_list, rmse_list, psnr_list,en_list,ag_list

def compare_ag(image):
    image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
    image = np.array(image)
    image = image.astype(np.float32)
    width = image.shape[1]
    width = width - 1
    height = image.shape[0]
    height = height - 1
    tmp = 0
    for i in range(1,height):
        for j in range(1,width):
            dx = image[i][j + 1] - image[i][j]
            dy = image[i + 1][j] - image[i][j]
            ds = math.sqrt((dx * dx + dy * dy) / 2)
            tmp += ds
    return tmp / (height * width)

def cal_mean_metric(list1,list2):
    average_metric = [list1,list2]
    average_metric = list(np.mean(average_metric, axis=0))
    return average_metric

def plot_metric(list1,list2,name):
    plt.figure()
    img_nums = len(list1)
    x_axis_data = range(1, img_nums+1)
    plt.plot(x_axis_data, list1, 'b*--',label="Original method "+name ,linewidth=1)
    plt.plot(x_axis_data, list2, 'r.--',label="Improved method "+name,linewidth=1)
    plt.legend()
    plt.xticks(x_axis_data)
    plt.xlabel("Image pair", fontsize=13, fontweight='bold')
    plt.ylabel("Metrics Value", fontsize=13, fontweight='bold')


if __name__ == "__main__":
    img_rgb_path = './img/rgb'
    img_red_path = './img/red'
    img_fusion_path = './img/img_fusion/'
    img_improve_path = './img/img_improve_fusion'

    rgb_ssim_list,rgb_rmse_list,rgb_psnr_list,rgb_en_list,rgb_ag_list = compare_img(img_rgb_path, img_fusion_path)
    red_ssim_list,red_rmse_list,red_psnr_list,red_en_list,red_ag_list = compare_img(img_red_path, img_fusion_path)
    average_ssim = cal_mean_metric(rgb_ssim_list,red_ssim_list)
    average_rmse = cal_mean_metric(rgb_rmse_list,red_rmse_list) 
    average_psnr = cal_mean_metric(rgb_psnr_list,red_psnr_list)
    # for i in range(len(average_ssim)):
    #     print("改进前:Average SSIM: %.4f" %average_ssim[i],
    #           "Average RMSE: %.4f" %average_rmse[i],
    #           "Average PSNR: %.4f" %average_psnr[i],
    #           "EN: %.4f" %rgb_en_list[i],
    #           "AG: %.4f" %rgb_ag_list[i])

    print("改进前:Average SSIM: %.4f" %np.mean(rgb_ssim_list),
            "Average RMSE: %.4f" %np.mean(rgb_rmse_list),
            "Average PSNR: %.4f" %np.mean(rgb_psnr_list),
            "EN: %.4f" %np.mean(rgb_en_list),
            "AG: %.4f" %np.mean(rgb_ag_list))

    rgb_ssim_list1,rgb_rmse_list1,rgb_psnr_list1,rgb_en_list1,rgb_ag_list1 = compare_img(img_rgb_path, img_improve_path)
    red_ssim_list1,red_rmse_list1,red_psnr_list1,red_en_list1,red_ag_list1 = compare_img(img_red_path, img_improve_path)
    average_ssim1 = cal_mean_metric(rgb_ssim_list1,red_ssim_list1)
    average_rmse1 = cal_mean_metric(rgb_rmse_list1,red_rmse_list1) 
    average_psnr1 = cal_mean_metric(rgb_psnr_list1,red_psnr_list1)
    # for i in range(len(average_ssim1)):
    #     print("改进后:Average SSIM: %.4f" %average_ssim1[i],
    #           "Average RMSE: %.4f" %average_rmse1[i],
    #           "Average PSNR: %.4f" %average_psnr1[i],
    #           "EN: %.4f" %rgb_en_list1[i],
    #           "AG: %.4f" %rgb_ag_list1[i])

    print("改进后:Average SSIM: %.4f" %np.mean(rgb_ssim_list1),
            "Average RMSE: %.4f" %np.mean(rgb_rmse_list1),
            "Average PSNR: %.4f" %np.mean(rgb_psnr_list1),
            "EN: %.4f" %np.mean(rgb_en_list1),
            "AG: %.4f" %np.mean(rgb_ag_list1))

    plot_metric(average_ssim,average_ssim1,"SSIM")
    plot_metric(average_rmse,average_rmse1,"RMSE")
    plot_metric(average_psnr,average_psnr1,"PSNR")
    plot_metric(rgb_en_list,rgb_en_list1,"EN")
    plot_metric(rgb_ag_list,rgb_ag_list1,"AG")
    plt.show()

运行结果如下

https://github.com/DyedBamboo/DyedBamboo.github.io-blog-img/raw/main/img_fus/Figure_1.png

https://github.com/DyedBamboo/DyedBamboo.github.io-blog-img/raw/main/img_fus/Figure_2.png

https://github.com/DyedBamboo/DyedBamboo.github.io-blog-img/raw/main/img_fus/Figure_3.png

https://github.com/DyedBamboo/DyedBamboo.github.io-blog-img/raw/main/img_fus/Figure_4.png

https://github.com/DyedBamboo/DyedBamboo.github.io-blog-img/raw/main/img_fus/Figure_5.png

https://github.com/DyedBamboo/DyedBamboo.github.io-blog-img/raw/main/img_fus/捕获.JPG