In [1]:
import torch
from torchvision.models import resnet50
import cv2
from IPython.display import Image, display
In [23]:
origimg = cv2.imread("dog_and_cat.jpg")
img = cv2.resize(origimg, (img.shape[1]*2, img.shape[0]*2))
image = torch.from_numpy(img[::-1].copy()).permute(2,0,1).unsqueeze(0)
model = resnet50()
In [24]:
displayed_img_num = 0
def display_image(out):
global displayed_img_num
in_dims = out.shape[1]
out_dims = 3
projection = torch.normal(mean=torch.zeros(out_dims,in_dims,1,1), std=torch.ones(out_dims,in_dims,1,1))
img3d = torch.conv2d(out, projection)
min_v = torch.min(img3d.view(out_dims, -1),dim=1)[0].view(1,out_dims,1,1)
max_v = torch.max(img3d.view(out_dims, -1),dim=1)[0].view(1,out_dims,1,1)
img3d = (img3d - min_v) / (max_v - min_v)
img3d = (img3d * 255).clamp(0, 255).to(dtype=torch.uint8).squeeze(0).permute(1,2,0).numpy()
img3d = img3d[::-1]
img3d = cv2.resize(img3d, (image.shape[3],image.shape[2]),interpolation=cv2.INTER_NEAREST)
cv2.imwrite(f"fname{displayed_img_num}.png",img3d)
display(Image(filename=f"fname{displayed_img_num}.png",width=origimg.shape[1], height=origimg.shape[0]))
displayed_img_num += 1
In [25]:
display(Image(filename="dog_and_cat.jpg"))
In [26]:
conv1out = model.conv1(image.float())
for i in range(5):
display_image(conv1out)