-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathutils.py
More file actions
116 lines (89 loc) · 3.64 KB
/
utils.py
File metadata and controls
116 lines (89 loc) · 3.64 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
import os
import datetime
import torch.nn as nn
import matplotlib.pyplot as plt
from skimage.metrics import structural_similarity, peak_signal_noise_ratio
from parameters import *
def init_weights(m):
"""
Weight initialization to layers
:param m: layer
:return: None
"""
classname = m.__class__.__name__
if classname.find("Conv") != -1:
nn.init.normal_(m.weight.data, 0.0, 0.05)
def plot_image(test_batch, reconstructed_image, num_images):
"""
Plot images during validation
:param test_batch: Test images set
:param reconstructed_image: Reconstructed images set
:param num_images: Number of images to plot
:return: None
"""
f, ax = plt.subplots(num_images, 2)
for i in range(num_images):
test_image = (test_batch[i].cpu().detach().permute(1, 2, 0) * STD) + MEAN
rec_image = (reconstructed_image[i].cpu().detach().permute(1, 2, 0) * STD) + MEAN
if num_images == 1:
ax[0].imshow(test_image)
ax[1].imshow(rec_image)
else:
ax[i, 0].imshow(test_image)
ax[i, 1].imshow(rec_image)
f.set_figheight(20)
f.set_figwidth(20)
plt.show()
def plot_image_grid(test_batch, reconstructed_images, num_images):
"""
Plot image grid during validation
:param test_batch: Test images set
:param reconstructed_images: Reconstructed images set
:param num_images: Number of images to plot
:return: None
"""
f, ax = plt.subplots(num_images, len(reconstructed_images) + 1)
ax[0, 0].title.set_text("Original \n Image")
ax[0, 1].title.set_text("Reconstructed with \n 96% Compression")
ax[0, 2].title.set_text("Reconstructed with \n 92% Compression")
ax[0, 3].title.set_text("Reconstructed with \n 84% Compression")
ax[0, 4].title.set_text("Reconstructed with \n 68% Compression")
ax[0, 5].title.set_text("Reconstructed with \n 43% Compression")
for i in range(num_images):
test_image = (test_batch[i].cpu().detach().permute(1, 2, 0) * STD) + MEAN
ax[i, 0].imshow(test_image)
for ind, channel in enumerate(reconstructed_images.keys(), 1):
rec_image = (reconstructed_images[channel][i].cpu().detach().permute(1, 2, 0) * STD) + MEAN
ax[i, ind].imshow(rec_image)
f.set_figheight(20)
f.set_figwidth(20)
plt.savefig("../results/result.png")
plt.show()
def save_images(test_batch, reconstructed_images):
"""
Plot image grid during validation
:param test_batch: Test images set
:param reconstructed_images: Reconstructed images set
:return: None
"""
PATH = f"../results/{datetime.datetime.now().strftime('%Y_%m_%d_%H_%M_%S')}"
os.makedirs(PATH, exist_ok=True)
for i in range(len(test_batch)):
test_image = (test_batch[i].cpu().detach().permute(1, 2, 0) * STD) + MEAN
plt.imsave(f"{PATH}/original_image.png", test_image.numpy())
for channel in reconstructed_images.keys():
rec_image = (reconstructed_images[channel][i].cpu().detach().permute(1, 2, 0) * STD) + MEAN
plt.imsave(f"{PATH}/image_{channel}.png", rec_image.numpy())
def metrics(firstImage, secondImage):
"""
Calculate the evaluation metrics for the images
:param firstImage: First image
:param secondImage: Second image
:return: Metrics dictionary
"""
ssim = structural_similarity(
firstImage, secondImage, data_range=firstImage.max() - firstImage.min(), multichannel=True
)
psnr = peak_signal_noise_ratio(firstImage, secondImage, data_range=firstImage.max() - firstImage.min())
image_metrics = {"SSIM": ssim, "PSNR": psnr}
return image_metrics