python通过特征图画热图_heatmap

对模型输出的特征图y进行上采样:

features=nn.Upsample(scale_factor=32, mode='bicubic', align_corners=None)(y.view(-1,12,12,1536).permute(0,3,1,2))
 #scale_factor是上采样倍数
 #y是输入,需要permute一下是转换一下通道的位置
##以下对特征图逐个画出来
for i in range(features.shape[0]):
    feature = features[i,:,:].cpu().detach().numpy()
    save_feature(im_resize, feature, i, detected_classes)
    
#保存图的函数
#im_resize为原始通入到模型的图像,detected_classes是模型输出的这个图像的类别,
def save_feature(im_resize, feature, i, detected_classes,):
    print('save {}'.format(i))
    fig = plt.figure()
    plt.imshow(im_resize)
    plt.imshow(feature, alpha=0.65)
    plt.axis('off')
    plt.axis('tight')
    plt.title("detected classes: {}".format(detected_classes))

    outpath = ""#设置一下你要保存的路径
    plt.savefig(outpath, format='png', transparent=True, dpi=100, pad_inches = 0)
    plt.show()
    print('done\n')

未经允许不得转载!python通过特征图画热图_heatmap

如遇到无法显示的问题,请先尝试刷新页面

客服联系邮箱:ai52learn@foxmail.com

本文地址:https://ai.52learn.online/11959