本demo从pytorch官方的迁移学习示例修改而来,增加了以下功能:

  1. 根据AUC来迭代最优参数;
  2. 五折交叉验证;
  3. 输出验证集错误分类图片;
  4. 输出分类报告并保存AUC结果图片。
     import os
    import numpy as np
    import torch
    import torch.nn as nn
    from torch.optim import lr_scheduler
    import torchvision
    from torchvision import datasets, models, transforms
    from torch.utils.data import DataLoader
    from sklearn.metrics import roc_auc_score, classification_report
    from sklearn.model_selection import KFold
    from torch.autograd import Variable
    import torch.optim as optim
    import time
    import copy
    import shutil
    import sys
    import scikitplot as skplt
    import matplotlib.pyplot as plt
    import pandas as pd plt.switch_backend('agg')
    N_CLASSES = 2
    BATCH_SIZE = 8
    DATA_DIR = './data'
    LABEL_DICT = {0: 'class_1', 1: 'class_2'} def imshow(inp, title=None):
    """Imshow for Tensor."""
    inp = inp.numpy().transpose((1, 2, 0))
    mean = np.array([0.485, 0.456, 0.406])
    std = np.array([0.229, 0.224, 0.225])
    inp = std * inp + mean
    inp = np.clip(inp, 0, 1)
    plt.imshow(inp)
    if title is not None:
    plt.title(title)
    plt.pause(100) def train_model(model, criterion, optimizer, scheduler, fold, name, num_epochs=25):
    since = time.time()
    # 先深拷贝一份当前模型的参数,后面迭代过程中若遇到更优模型则替换
    best_model_wts = copy.deepcopy(model.state_dict())
    # best_acc = 0.0
    # 初始auc
    best_auc = 0.0
    best_desc = [0, 0, None]
    best_img_name = None
    plt_auc = [None, None] for epoch in range(num_epochs):
    print('Epoch {}/{}'.format(epoch, num_epochs - 1))
    print('- ' * 50) for phase in ['train', 'val']:
    if phase == 'train':
    # 训练的时候进行学习率规划,其定义在下面给出
    scheduler.step()
    model.train(True)
    else:
    model.train(False)
    phase_pred = np.array([])
    phase_label = np.array([])
    img_name = np.zeros((1, 2))
    prob_pred = np.zeros((1, 2))
    running_loss = 0.0
    running_corrects = 0
    # 这样迭代方便跟踪图片路径,输出错误图片名称
    for data, index in zip(dataloaders[phase], dataloaders[phase].batch_sampler):
    inputs, labels = data
    if use_gpu:
    inputs = Variable(inputs.cuda())
    labels = Variable(labels.cuda())
    else:
    inputs, labels = Variable(inputs), Variable(labels) # 梯度参数设为0
    optimizer.zero_grad() # forward
    outputs = model(inputs)
    _, preds = torch.max(outputs.data, 1)
    loss = criterion(outputs, labels) # backward + 训练阶段优化
    if phase == 'train':
    loss.backward()
    optimizer.step() if phase == 'val':
    img_name = np.append(img_name, np.array(dataloaders[phase].dataset.imgs)[index], axis=0)
    prob = outputs.data.cpu().numpy()
    prob_pred = np.append(prob_pred, prob, axis=0) phase_pred = np.append(phase_pred, preds.cpu().numpy())
    phase_label = np.append(phase_label, labels.data.cpu().numpy())
    running_loss += loss.item() * inputs.size(0)
    running_corrects += torch.sum(preds == labels.data).float()
    print()
    epoch_loss = running_loss / dataset_sizes[phase]
    epoch_acc = running_corrects / dataset_sizes[phase]
    epoch_auc = roc_auc_score(phase_label, phase_pred)
    print('{} Loss: {:.4f} Acc: {:.4f} Auc: {:.4f}'.format(
    phase, epoch_loss, epoch_acc, epoch_auc))
    report = classification_report(phase_label, phase_pred, target_names=class_names)
    print(report) img_name = zip(img_name[1:], phase_pred)
    # 当验证时遇到了更好的模型则予以保留
    if phase == 'val' and epoch_auc > best_auc:
    best_auc = epoch_auc
    best_desc = epoch_acc, epoch_auc, report
    best_img_name = img_name
    # 深拷贝模型参数
    best_model_wts = copy.deepcopy(model.state_dict())
    plt_auc = phase_label, prob_pred[1:] print()
    print(plt_auc[0].shape, plt_auc[1].shape)
    csv_file = pd.DataFrame(plt_auc[1], columns=['class_1', 'class_2'])
    csv_file['true_label'] = pd.DataFrame(plt_auc[0])
    csv_file['true_label'] = csv_file['true_label'].apply(lambda x: LABEL_DICT[x])
    csv_file.to_csv(f'./prob_result/{name}_fold_{fold}_porb.csv', index=False)
    skplt.metrics.plot_roc_curve(plt_auc[0], plt_auc[1], curves=['each_class'])
    plt.savefig(f'./roc_img/{name}_fold_{fold}_roc.png', dpi=600)
    time_elapsed = time.time() - since
    print('Training complete in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))
    reports = 'The Desc according to the Best val Auc: \nACC -> {:4f}\nAclass_2 -> {:4f}\n\n{}'.format(best_desc[0], best_desc[1],
    best_desc[2])
    report_file.write(reports)
    print(reports)
    print('List the wrong judgement img ...')
    count = 0
    for i in best_img_name:
    actual_label = int(i[0][1])
    pred_label = i[1]
    if actual_label != pred_label:
    tmp_word = f'{i[0][0].split("/")[-1]}, actual: {LABEL_DICT[actual_label]}, ' \
    f'pred: {LABEL_DICT[pred_label]}'
    print(tmp_word)
    label_file.write(tmp_word + '\n')
    count += 1
    print(f'This fold has {count} wrong records ...') # 载入最优模型参数
    model.load_state_dict(best_model_wts)
    return model def plot_img():
    for i, data in enumerate(dataloaders['train']):
    inputs, classes = data
    out = torchvision.utils.make_grid(inputs)
    imshow(out, title=[class_names[x] for x in classes]) # 此函数可以修改适用于自己项目的图片文件名
    def move_file(data, file_path, dir_path, root_path):
    label_0 = 'class_2'
    label_1 = 'class_1'
    print(f'start copy the {file_path} file ...')
    os.chdir(dir_path)
    if os.path.exists(file_path):
    print(f'Find exist {file_path} file, the file will be dropped.')
    shutil.rmtree(os.path.join(root_path, dir_path, file_path))
    print(f'Finish drop the {file_path} file.') os.mkdir(file_path)
    tmp_path = os.path.join(os.getcwd(), file_path)
    tmp_pre_path = os.getcwd()
    for d in data:
    pre_path = os.path.join(tmp_pre_path, d)
    os.chdir(tmp_path)
    if d[:2] == label_0:
    if not os.path.exists(label_0):
    os.mkdir(label_0)
    cur_path = os.path.join(tmp_path, label_0, d)
    shutil.copyfile(pre_path, cur_path)
    if d[:2] == label_1:
    if not os.path.exists(label_1):
    os.mkdir(label_1)
    cur_path = os.path.join(tmp_path, label_1, d)
    shutil.copyfile(pre_path, cur_path)
    print('finish this work ...') if __name__ == "__main__":
    if not os.path.exists('roc_img'):
    os.mkdir('roc_img')
    if not os.path.exists('prob_result'):
    os.mkdir('prob_result')
    if not os.path.exists('report'):
    os.mkdir('report')
    if not os.path.exists('error_record'):
    os.mkdir('error_record')
    if not os.path.exists('model'):
    os.mkdir('model')
    label_file = open(f'./error_record/{sys.argv[1]}_img_name_actual_pred.txt', 'w') kf = KFold(n_splits=5, shuffle=True, random_state=1)
    origin_path = '/home/project/'
    dd_list = np.array([o for o in os.listdir(DATA_DIR) if os.path.isfile(os.path.join(DATA_DIR, o))]) for m, n in enumerate(kf.split(dd_list), start=1):
    report_file = open(f'./report/{sys.argv[1]}_fold_{m}_report.txt', 'w')
    print(f'The {m} fold for copy file and training ...')
    move_file(dd_list[n[0]], 'train', DATA_DIR, origin_path)
    os.chdir(origin_path)
    move_file(dd_list[n[1]], 'val', DATA_DIR, origin_path)
    os.chdir(origin_path)
    data_transforms = {
    'train': transforms.Compose([
    # 裁剪到224,224
    transforms.RandomResizedCrop(224),
    # 随机水平翻转给定的PIL.Image,概率为0.5。即:一半的概率翻转,一半的概率不翻转。
    transforms.RandomHorizontalFlip(),
    # transforms.ColorJitter(0.05, 0.05, 0.05, 0.05), # HSV以及对比度变化
    transforms.ToTensor(),
    # 把一个取值范围是[0,255]的PIL.Image或者shape为(H,W,C)的numpy.ndarray,转换成形状为[C,H,W],取值范围是[0,1.0]的FloadTensor
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
    'val': transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
    } image_datasets = {x: datasets.ImageFolder(os.path.join(DATA_DIR, x),
    data_transforms[x])
    for x in ['train', 'val']}
    dataloaders = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=BATCH_SIZE,
    shuffle=True, num_workers=8, pin_memory=False)
    for x in ['train', 'val']} dataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'val']} class_names = image_datasets['train'].classes
    size = len(class_names)
    print('label mapping: ')
    print(image_datasets['train'].class_to_idx)
    use_gpu = torch.cuda.is_available()
    model_ft = None
    if sys.argv[1] == 'resnet':
    model_ft = models.resnet50(pretrained=True)
    num_ftrs = model_ft.fc.in_features
    model_ft.fc = nn.Sequential(
    nn.Linear(num_ftrs, N_CLASSES),
    nn.Sigmoid()
    ) # 这边可以自行把inception模型加进去
    if sys.argv[1] == 'inception':
    raise Exception("not provide inception model ...")
    # model_ft = models.inception_v3(pretrained=True) if sys.argv[1] == 'desnet':
    model_ft = models.densenet121(pretrained=True)
    num_ftrs = model_ft.classifier.in_features
    model_ft.classifier = nn.Sequential(
    nn.Linear(num_ftrs, N_CLASSES),
    nn.Sigmoid()
    )
    # use_gpu = False if use_gpu:
    model_ft = model_ft.cuda() criterion = nn.CrossEntropyLoss()
    optimizer_ft = optim.SGD(model_ft.parameters(), lr=0.001, momentum=0.9)
    # 每7个epoch衰减0.1倍
    exp_lr_scheduler = lr_scheduler.StepLR(optimizer_ft, step_size=7, gamma=0.1)
    model_ft = train_model(model_ft, criterion, optimizer_ft, exp_lr_scheduler, m, sys.argv[1], num_epochs=25)
    print('Start save the model ...')
    torch.save(model_ft.state_dict(), f'./model/fold_{m}_{sys.argv[1]}.pkl')
    print(f'The mission of the fold {m} finished.')
    print('# '*50)
    report_file.close()
    label_file.close()

最新文章

  1. 什么情况下可以不写PHP的结束标签“?>”
  2. Form_Form Builder Export导出为Excel(案例)
  3. JS的基础语法
  4. HW1.7
  5. SVN不同图标的不同意义
  6. Winform ComBox模糊查询
  7. PAT (Advanced Level) 1112. Stucked Keyboard (20)
  8. [iOS开发]WKWebView加载JS
  9. 2017年的golang、python、php、c++、c、java、Nodejs性能对比(golang python php c++ java Nodejs Performance)
  10. html查看器android
  11. centos7镜像文件
  12. DB2常见错误信息
  13. 2018牛客网暑期ACM多校训练营(第一场)J Different Integers(树状数组)
  14. Java 中的泛型
  15. C# 图片处理方法 整理汇总
  16. windows本地eclispe运行linux上hadoop的maperduce程序
  17. 【Spark调优】:尽量避免使用shuffle类算子
  18. JavaWeb学习 (十三)————JSP
  19. html网页采集
  20. Http请求帮助类

热门文章

  1. BZOJ3033: 太鼓达人(欧拉回路)
  2. HTML行内元素、块状元素和行内块状元素的区分
  3. java线程安全问题原理性分析
  4. 01_JMS概述
  5. Android自定义验证码输入框
  6. 关联关系的CRUD
  7. tr标签是什么
  8. solidity语言5
  9. SpringMvc-自定义视图
  10. tampermonkey利用@require调用本地脚本的方法