本文共 1022 字,大约阅读时间需要 3 分钟。
pytorch打印自定义网络的每层的名称
import torchfrom torchvision import modelsfrom torchsummary import summaryfrom resnext_MulTask_clothes import resnext50_elasticdata_class=[8, 7]device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')# vgg = models.vgg16().to(device)model = resnext50_elastic(num_classes=data_class) # 原模型model = torch.nn.DataParallel(model).cuda() # 并行处理# 已训练好的模型的pth文件checkpoint = torch.load('06-resnext50_elastic_checkpoint.pth.tar')model.load_state_dict(checkpoint['state_dict'], strict=False) # 参数加载summary(model, (3, 224, 224))
参考连接:https://www.jianshu.com/p/97c626d33924
另:
打印resnet152网络的每层的名称import torchfrom torchvision import modelsfrom torchsummary import summaryfrom resnet_pretrained import resnet152device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')model = resnet152() # 原模型model = torch.nn.DataParallel(model).cuda() # 并行处理# 已训练好的模型的pth文件checkpoint = torch.load('resnet152-b121ed2d.pth')model.load_state_dict(checkpoint, strict=False) # 参数加载summary(model, (3, 224, 224))
转载地址:http://cjfxf.baihongyu.com/