Về trang chủ
Nguyễn Dương Thế Vĩ
Nguyễn Dương Thế VĩTài khoản đã xác minh
30-12-2024

Tranning bằng mạng CNN

Mạng nơ-ron tích chập (Convolutional Neural Networks - CNN) là một công cụ mạnh mẽ trong học sâu, đặc biệt hữu ích cho các bài toán về xử lý ảnh và nhận dạng. Bài viết này sẽ hướng dẫn bạn các bước cơ bản để thiết kế và training một mạng CNN từ đầu.

#python#AI
Tranning bằng mạng CNN

Convolutional Neural Network (CNN) là một loại mạng nơ-ron sâu (Deep Neural Network) được thiết kế đặc biệt để xử lý dữ liệu có cấu trúc dạng lưới, chẳng hạn như hình ảnh. CNN đã đạt được nhiều thành công lớn trong các bài toán nhận diện và phân loại hình ảnh. Download Jupyter Notebook

import argparse
import random
import shutil
import time
import warnings
import torch
import torch.nn as nn
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.distributed as dist
import torch.optim
import torch.utils.data
import torch.utils.data.distributed
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import torch
import torch.nn as nn
import torch.nn.functional as F
from google.colab import drive
drive.mount('/content/drive')
import sys
sys.path.append('/content/drive/MyDrive/ML/')
import writeLogAcc as wA
import sys
sys.path.append('/content/drive/MyDrive/ML/models/')
from cross_entropy import LabelSmoothingCrossEntropy
import sys
sys.path.append('/content/drive/MyDrive/ML/models/')  # Add models directory to path
#from models.Nhom12Mang1 import *
from torchsummary import summary
# from models.Nhom12Mang1 import *
#Thay đổi mạng tại đây học import
#Mạng 1
class Net(nn.Module):
      def __init__(self, n_class=10):
          super(Net, self).__init__()
          self.conv0 = nn.Conv2d(3, 32, kernel_size=7, padding=3, stride=1)
          self.conv1 = nn.Conv2d(32, 64, kernel_size=3, padding=1, stride=2)
          self.conv2 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
          self.conv3 = nn.Conv2d(128, 128, kernel_size=3, padding=1)
          self.conv4 = nn.Conv2d(128, 128, kernel_size=3, padding=1)
          self.conv5 = nn.Conv2d(128, 256, kernel_size=3, padding=1)
          self.conv6 = nn.Conv2d(256, 256, kernel_size=3, padding=1)
          self.conv7 = nn.Conv2d(256, 512, kernel_size=3, padding=1)
          self.conv8 = nn.Conv2d(512, 256, kernel_size=3, padding=1)
          self.conv9 = nn.Conv2d(256, 512, kernel_size=3, padding=1)
          self.conv10 = nn.Conv2d(512, 256, kernel_size=3, padding=1)
          self.conv11 = nn.Conv2d(512, 64, kernel_size=3, padding=1)
          self.maxpool1 = nn.MaxPool2d(kernel_size=2, stride=2)
          self.maxpool2 = nn.MaxPool2d(kernel_size=4, stride=4)
          self.avgpool1 = nn.AvgPool2d(kernel_size=2, stride=2)
          self.avgpool2 = nn.AvgPool2d(kernel_size=2, stride=4)
          self.fc1 = nn.Linear(9408, 1204)
          self.fc2 = nn.Linear(1204, n_class)
          self.relu = nn.ReLU()
      def forward(self, x):
          x = self.conv0(x)

          x1 = self.relu(self.conv1(x))
          x2 = self.relu(self.conv2(x1))
          x3 = self.relu(self.conv3(x2))
          x4 = self.relu(self.conv4(x3))

          x_add1 = x2 + x4

          x5 = self.relu(self.conv5(x_add1))
          x6 = self.relu(self.conv6(x5))
          x7 = self.maxpool1(x6)
          x8 = self.relu(self.conv7(x7))
          x9 = self.relu(self.conv8(x8))

          x_add2 = x7 + x9

          x10 = self.relu(self.conv9(x_add2))
          x11 = self.relu(self.conv10(x10))

          x_cat1 = torch.cat((x_add2, x11), dim=1)

          x_avg1 = self.avgpool1(x_cat1)
          x_avg1 = self.conv11(x_avg1)

          x_add1 = self.maxpool2(x_add1)
          x_cat2 = torch.cat((x_add1, x_avg1), dim=1)
          x_avg2 = self.avgpool2(x_cat2)

          x_flat = x_avg2.view(x_avg2.size(0), -1)
          x_fc1 = self.relu(self.fc1(x_flat))
          x_out = self.fc2(x_fc1)
          return x_out

#Mang 2
# class Net(nn.Module):
#     def __init__(self, n_classes=10):
#         super(Net, self).__init__()
#         #Mang 2
#         self.conv0 = nn.Conv2d(3, 32, kernel_size=7, padding=3, stride=1)
#         self.conv1 = nn.Conv2d(32, 64, kernel_size=5, padding=2)  # Conv1
#         self.conv2 = nn.Conv2d(64, 64, kernel_size=5, padding=2)  # Conv2
#         self.conv3 = nn.Conv2d(64, 64, kernel_size=5, padding=2)  # Conv3
#         self.conv4 = nn.Conv2d(64, 128, kernel_size=5, padding=2)  # Conv4
#         self.conv5 = nn.Conv2d(128, 256, kernel_size=5, padding=2)  # Conv5
#         self.conv6 = nn.Conv2d(256, 256, kernel_size=5, padding=2)  # Conv6
#         self.conv7 = nn.Conv2d(256, 256, kernel_size=5, padding=2)  # Conv7
#         self.conv8 = nn.Conv2d(256, 512, kernel_size=5, padding=2)  # Conv8
#         self.conv9 = nn.Conv2d(512, 256, kernel_size=5, padding=2)  # Conv9
#         self.conv10 = nn.Conv2d(256, 128, kernel_size=5, padding=2)  # Conv10

#         # Pooling layers
#         self.maxpool1 = nn.MaxPool2d(kernel_size=2, stride=2)
#         self.maxpool2 = nn.MaxPool2d(kernel_size=2, stride=2)
#         self.avgpool = nn.AvgPool2d(kernel_size=2, stride=2)

#         # Fully connected layer
#         self.fc = nn.Linear(301056, n_classes)

#         # Activation
#         self.relu = nn.ReLU()

#     def forward(self, x):
#         # Stage 1
#         x = self.conv0(x)
#         x1 = self.relu(self.conv1(x))  # Conv1
#         x2 = self.relu(self.conv2(x1))  # Conv2
#         x3 = self.relu(self.conv3(x2))  # Conv3

#         # Tensor addition 1
#         x_add1 = x1 + x3

#         # Stage 2
#         x4 = self.relu(self.conv4(x_add1))  # Conv4
#         x5 = self.relu(self.conv5(x4))  # Conv5
#         x6 = self.maxpool1(x5)  # Maxpool1

#         # Stage 3
#         x7 = self.relu(self.conv6(x6))  # Conv6
#         x8 = self.relu(self.conv7(x7))  # Conv7

#         # Tensor addition 2
#         x_add2 = x6 + x8

#         # Stage 4
#         x9 = self.relu(self.conv8(x_add2))  # Conv8
#         x10 = self.relu(self.conv9(x9))  # Conv9
#         x11 = self.relu(self.conv10(x10))  # Conv10

#         # Concatenation
#         x_cat = torch.cat((x11, x10), dim=1)

#         # Average pooling
#         x_avg = self.avgpool(x_cat)

#         # Maxpool2
#         x_pool = self.maxpool2(x_avg)

#         x_flat = x_pool.view(x_pool.size(0), -1)

#         x_out = self.fc(x_flat)
#         return x_out

#Mang 3

# class Net(nn.Module):
#       def __init__(self, n_class=10):
#           super(Net, self).__init__()
#           self.conv0 = nn.Conv2d(in_channels = 3, out_channels = 32, kernel_size = 7, stride=1, padding=3)
#           self.conv1 = nn.Conv2d(in_channels = 32, out_channels = 32, kernel_size = 3, stride=1, padding=0)
#           self.conv2 = nn.Conv2d(in_channels = 32, out_channels = 32, kernel_size = 3, stride=1, padding=0)
#           self.conv3 = nn.Conv2d(in_channels = 32, out_channels = 64, kernel_size = 5, stride=2, padding=4)
#           self.conv4 = nn.Conv2d(in_channels = 32, out_channels = 32, kernel_size = 3, stride=2, padding=1)
#           self.conv5 = nn.Conv2d(in_channels = 32, out_channels = 32, kernel_size = 5, stride=2, padding=2)
#           self.conv6 = nn.Conv2d(in_channels = 64, out_channels = 64, kernel_size = 3, stride=2, padding=1)
#           self.conv7 = nn.Conv2d(in_channels = 32, out_channels = 64, kernel_size = 5, stride=2, padding=2)
#           self.conv8 = nn.Conv2d(in_channels = 32, out_channels = 64, kernel_size = 5, stride=2, padding=2)
#           self.conv9 = nn.Conv2d(in_channels = 64, out_channels = 64, kernel_size = 3, stride=1, padding=2)
#           self.conv10 = nn.Conv2d(in_channels = 64, out_channels = 64, kernel_size = 5, stride=1, padding=1)

#           self.conv11 = nn.Conv2d(in_channels = 64, out_channels = 64, kernel_size = 5, stride=1, padding=4)
#           self.conv12 = nn.Conv2d(in_channels = 64, out_channels = 64, kernel_size = 5, stride=1, padding=4)
#           self.conv13 = nn.Conv2d(in_channels = 128, out_channels = 128, kernel_size = 3, stride=1, padding=3)
#           self.conv14 = nn.Conv2d(in_channels = 128, out_channels = 128, kernel_size = 3, stride=1, padding=3)
#           self.conv15 = nn.Conv2d(in_channels = 192, out_channels = 192, kernel_size = 5, stride=1, padding=2)
#           self.conv16 = nn.Conv2d(in_channels = 192, out_channels = 192, kernel_size = 5, stride=1, padding=2)
#           self.conv17 = nn.Conv2d(in_channels = 192, out_channels = 192, kernel_size = 3, stride=1, padding=2)
#           self.conv18 = nn.Conv2d(in_channels = 192, out_channels = 192, kernel_size = 3, stride=1, padding=2)
#           self.conv19 = nn.Conv2d(in_channels = 192, out_channels = 80, kernel_size = 3, stride=2, padding=3)
#           self.conv20 = nn.Conv2d(in_channels = 192, out_channels = 80, kernel_size = 5, stride=2, padding=4)

#           self.avgpool1 = nn.AvgPool2d(kernel_size=2, stride=2)
#           self.avgpool2 = nn.AvgPool2d(kernel_size=2, stride=2)

#           self.maxpool1 = nn.MaxPool2d(kernel_size=2, stride=2)
#           self.maxpool2 = nn.MaxPool2d(kernel_size=2, stride=2)

#           self.fc1 = nn.Linear(9 * 9 * 160, 1024)
#           self.fc2 = nn.Linear(1024, 10)
#           self.relu = nn.ReLU()
#       def forward(self, x):
#           conv0_x = self.conv0(x)
#           conv1_x = F.relu(self.conv1(conv0_x))
#           conv2_x = F.relu(self.conv2(conv1_x))
#           conv3_x = F.relu(self.conv3(conv2_x))
#           conv4_x = F.relu(self.conv4(conv0_x))
#           conv7_x = F.relu(self.conv7(conv4_x))
#           conv5_x = F.relu(self.conv5(conv1_x))
#           conv6_x = F.relu(self.conv6(conv3_x))
#           conv8_x = F.relu(self.conv8(conv5_x))

#           top_left_sum_x = conv7_x + conv8_x
#           top_right_sum_x = conv6_x + conv8_x

#           conv9_x = F.relu(self.conv9(conv8_x))
#           conv10_x = F.relu(self.conv10(conv9_x))

#           bottom_left_sum_x = top_left_sum_x + conv10_x
#           bottom_right_sum_x = top_right_sum_x + conv10_x

#           conv11_x = F.relu(self.conv11(bottom_left_sum_x))
#           top_left_cat_x = torch.cat((bottom_left_sum_x, conv10_x), dim=1)
#           top_right_cat_x = torch.cat((bottom_right_sum_x, conv10_x), dim=1)

#           conv12_x = F.relu(self.conv12(bottom_right_sum_x))
#           conv13_x = F.relu(self.conv13(top_left_cat_x))
#           conv14_x = F.relu(self.conv14(top_right_cat_x))

#           bottom_left_cat_x = torch.cat((conv11_x, conv13_x), dim=1)
#           bottom_right_cat_x = torch.cat((conv12_x, conv14_x), dim=1)

#           conv15_x = F.relu(self.conv15(bottom_left_cat_x))
#           conv16_x = F.relu(self.conv16(bottom_right_cat_x))

#           avgpool1_x = self.avgpool1(conv15_x)
#           avgpool2_x = self.avgpool2(conv16_x)

#           conv17_x = F.relu(self.conv17(avgpool1_x))
#           conv18_x = F.relu(self.conv18(avgpool2_x))

#           conv19_x = F.relu(self.conv19(conv17_x))
#           conv20_x = F.relu(self.conv20(conv18_x))

#           maxpool1_x = self.maxpool1(conv19_x)
#           maxpool2_x = self.maxpool2(conv20_x)

#           final_cat_x = torch.cat((maxpool1_x, maxpool2_x), dim=1)

#           view_x = final_cat_x.view(-1, 9 * 9 * 160)

#           fc1_x = F.relu(self.fc1(view_x))
#           fc2_x = self.fc2(fc1_x)


#           return fc2_x

model = Net()
model = model.cuda()
print ("model")
print (model)

# get the number of model parameters
print('Number of model parameters: {}'.format(
    sum([p.data.nelement() for p in model.parameters()])))
#print(model)
#model.cuda()
summary(model, (3, 224, 224))

#Dùng CPU
# model = model.to('cpu')  # Hoặc model = model.cpu()

# print("Model")
# print(model)

# # get the number of model parameters
# print('Number of model parameters: {}'.format(
#     sum([p.data.nelement() for p in model.parameters()])))

# # Sử dụng device='cpu' khi gọi summary
# summary(model, (3, 224, 224), device='cpu')
import torch
from ptflops import get_model_complexity_info
from torchsummary import summary
#from models.Nhom12Mang1 import *
with torch.cuda.device(0):

  model = Net(3)
  #macs, params = get_model_complexity_info(model, (3, 32, 32), as_strings=True,
  macs, params = get_model_complexity_info(model, (3, 224, 224), as_strings=True,
                                           print_per_layer_stat=True, verbose=True,
                                           #flops_units='MMac')
                                           flops_units='GMac')
  print('{:<30}  {:<8}'.format('Computational complexity (MACs): ', macs))
  macs1 = macs.split()
  strmacs1=str(float(macs1[0])/2) + ' ' + macs1[1][0]
  print('{:<30}  {:<8}'.format('Floating-point operations (FLOPs): ', strmacs1))
  print('{:<30}  {:<8}'.format('Number of parameters: ', params))

  print('Number of model parameters (referred)): {}'.format(
      sum([p.data.nelement() for p in model.parameters()])))
  #summary(model, (3, 224, 224))
# Chuyển sang sử dụng CPU
# device = torch.device('cpu')

# # Khởi tạo model trên CPU
# model = Net(3).to(device)

# # Sử dụng get_model_complexity_info với input size phù hợp
# macs, params = get_model_complexity_info(model, (3, 224, 224), as_strings=True,
#                                           print_per_layer_stat=True, verbose=True,
#                                           flops_units='GMac')

# print('{:<30} {:<8}'.format('Computational complexity (MACs): ', macs))
# macs1 = macs.split()
# strmacs1 = str(float(macs1[0])/2) + ' ' + macs1[1][0]
# print('{:<30} {:<8}'.format('Floating-point operations (FLOPs): ', strmacs1))
# print('{:<30} {:<8}'.format('Number of parameters: ', params))

# print('Number of model parameters (referred)): {}'.format(
#     sum([p.data.nelement() for p in model.parameters()])))
import os
#os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
#os.environ["CUDA_VISIBLE_DEVICES"] = "1,2"  # specify which GPU(s) to be used
#os.environ["CUDA_VISIBLE_DEVICES"] = "0,1"  # specify which GPU(s) to be used
os.environ["CUDA_VISIBLE_DEVICES"] = "0,2"  # specify which GPU(s) to be used
parser = argparse.ArgumentParser(description='PyTorch ImageNet Training')
# Remove GPU-related arguments or set them to None
parser = argparse.ArgumentParser(description='PyTorch ImageNet Training')
parser.add_argument('-r', '--data', type=str, default='/content/drive/MyDrive/ML/dataset/', help='path to dataset')
parser.add_argument('-j', '--workers', default=16, type=int, metavar='N',
                    help='number of data loading workers (default: 4)')
parser.add_argument('--epochs', default=50, type=int, metavar='N',
                    help='number of total epochs to run')
parser.add_argument('--start-epoch', default=0, type=int, metavar='N',
                    help='manual epoch number (useful on restarts)')
#parser.add_argument('-b', '--batch-size', default=256, type=int,
parser.add_argument('-b', '--batch-size', default=16, type=int,
                    metavar='N', help='mini-batch size (default: 256)')
parser.add_argument('--lr', '--learning-rate', default=0.1, type=float,
                    metavar='LR', help='initial learning rate')
parser.add_argument('--momentum', default=0.9, type=float, metavar='M',
                    help='momentum')
parser.add_argument('--weight-decay', '--wd', default=1e-4, type=float,
                    metavar='W', help='weight decay (default: 1e-4)')
parser.add_argument('--print-freq', '-p', default=100, type=int,
                    metavar='N', help='print frequency (default: 10)')
parser.add_argument('--resume', default='', type=str, metavar='PATH',
                    help='path to latest checkpoint (default: none)')
parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true',
                    help='evaluate model on validation set')
parser.add_argument('--pretrained', dest='pretrained', action='store_true',
                    help='use pre-trained model')
parser.add_argument('--world-size', default=1, type=int,
                    help='number of distributed processes')
parser.add_argument('--dist-url', default='tcp://224.66.41.62:23456', type=str,
                    help='url used to set up distributed training')
parser.add_argument('--dist-backend', default='gloo', type=str,
                    help='distributed backend')
parser.add_argument('--seed', default=None, type=int, nargs='+',
                    help='seed for initializing training. ')
parser.add_argument('--gpu', default=None, type=int,
                    help='GPU id to use.')
parser.add_argument('--ksize', default=None, type=list,
                    help='Manually select the eca module kernel size')
parser.add_argument('--action', default='', type=str,
                    help='other information.')
def main():
    global args, best_prec1
    #args = parser.parse_args()
    args, _ = parser.parse_known_args()

    if args.seed is not None:
        random.seed(args.seed)
        torch.manual_seed(args.seed)
        cudnn.deterministic = True
        warnings.warn('You have chosen to seed training. '
                      'This will turn on the CUDNN deterministic setting, '
                      'which can slow down your training considerably! '
                      'You may see unexpected behavior when restarting '
                      'from checkpoints.')
    #args.gpu = 1
    if args.gpu is not None:
        warnings.warn('You have chosen a specific GPU. This will completely '
                      'disable data parallelism.')

    #args.distributed = args.world_size > 1
    args.distributed = False

    if args.distributed:
        dist.init_process_group(backend=args.dist_backend, init_method=args.dist_url,
                                world_size=args.world_size)
    torch.autograd.set_detect_anomaly(True)
    # create model

    args.arch = 'Apple_banana_orange'
    filenameLOG = "/content/drive/MyDrive/ML/checkpoints/%s/"%(args.arch + '_' + args.action) + '/' + args.arch + '.txt'
    #if not os.path.exists(pathout):
    #    os.makedirs(pathout)
    # get model
    #model = get_model_new(args=args)
    if args.pretrained:
        print("=> using pre-trained model '{}'".format(args.arch))
        model = models.__dict__[args.arch](k_size=args.ksize, pretrained=True)
    else:
        model = Net(4)
    if args.gpu is not None:
        model = model.cuda(args.gpu)
    elif args.distributed:
        model.cuda()
        model = torch.nn.parallel.DistributedDataParallel(model)
    else:
        model = torch.nn.DataParallel(model).cuda()

    print(model)

    # get the number of models parameters
    print('Number of models parameters: {}'.format(
        sum([p.data.nelement() for p in model.parameters()])))

    # define loss function (criterion) and optimizer
    train_loss_fn = LabelSmoothingCrossEntropy(smoothing=0.1).cuda(args.gpu)
    criterion = nn.CrossEntropyLoss().cuda(args.gpu)

    optimizer = torch.optim.SGD(model.parameters(), args.lr,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)

    # optionally resume from a checkpoint
    if args.evaluate:
        pathcheckpoint = "/content/drive/MyDrive/ML/checkpoints/%s/"%(args.arch + '_' + args.action) + "model_best.pth.tar"
        if os.path.isfile(pathcheckpoint):
            print("=> loading checkpoint '{}'".format(pathcheckpoint))
            checkpoint = torch.load(pathcheckpoint)
            model.load_state_dict(checkpoint['state_dict'])
            #optimizer.load_state_dict(checkpoint['optimizer'])
            del checkpoint
        else:
            print("=> no checkpoint found at '{}'".format(pathcheckpoint))
            return
    if args.resume:
        if os.path.isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            best_prec1 = checkpoint['best_prec1']
            model.load_state_dict(checkpoint['state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            print("=> loaded checkpoint '{}' (epoch {})"
                  .format(args.resume, checkpoint['epoch']))
            del checkpoint
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))

    cudnn.benchmark = True

    # Data loading code
    traindir = os.path.join(args.data, 'train')
    valdir = os.path.join(args.data, 'test')
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])

    train_dataset = datasets.ImageFolder(
        traindir,
        transforms.Compose([
            transforms.RandomResizedCrop(224),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            normalize,
        ])
        )
    # train_dataset = datasets.ImageFolder(
    #     traindir,
    #     transforms.Compose([
    #         transforms.Resize(size=(256, 256)),
    #         transforms.RandomCrop(224),
    #         transforms.RandomHorizontalFlip(),
    #         transforms.ColorJitter(0.4),
    #         transforms.ToTensor(),
    #         normalize
    #     ]))

    if args.distributed:
        train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset)
    else:
        train_sampler = None

    train_loader = torch.utils.data.DataLoader(
        train_dataset, batch_size=args.batch_size, shuffle=(train_sampler is None),
        num_workers=args.workers, pin_memory=True, sampler=train_sampler)

    val_loader = torch.utils.data.DataLoader(
        datasets.ImageFolder(valdir, transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            normalize,
        ])),
        batch_size=args.batch_size, shuffle=False,
        num_workers=args.workers, pin_memory=True)
    # val_loader = torch.utils.data.DataLoader(
    #     datasets.ImageFolder(valdir, transforms.Compose([
    #         transforms.Resize(size=(256, 256)),
    #         transforms.CenterCrop(224),
    #         transforms.ToTensor(),
    #         normalize
    #     ])),
    #     batch_size=args.batch_size, shuffle=False,
    #     num_workers=args.workers, pin_memory=True)
    if args.evaluate:
        m = time.time()
        _, _ =validate(val_loader, model, criterion)
        n = time.time()
        print((n-m)/3600)
        return

    directory = "/content/drive/MyDrive/ML/checkpoints/%s/"%(args.arch + '_' + args.action)
    if not os.path.exists(directory):
        os.makedirs(directory)

    Loss_plot = {}
    train_prec1_plot = {}
    train_prec5_plot = {}
    val_prec1_plot = {}
    val_prec5_plot = {}
    epoch_max = None
    best_prec1 = 0
    for epoch in range(args.start_epoch, args.epochs):
        start_time = time.time()
        if args.distributed:
            train_sampler.set_epoch(epoch)
        adjust_learning_rate(optimizer, epoch)

        # train for one epoch
        # train(train_loader, model, criterion, optimizer, epoch)
        #loss_temp, train_prec1_temp, train_prec5_temp = train(train_loader, model, criterion, optimizer, epoch)
        loss_temp, train_prec1_temp, train_prec5_temp = train(train_loader, model, train_loss_fn, optimizer, epoch)

        Loss_plot[epoch] = loss_temp
        train_prec1_plot[epoch] = train_prec1_temp
        train_prec5_plot[epoch] = train_prec5_temp

        # evaluate on validation set
        # prec1 = validate(val_loader, model, criterion)
        prec1, prec5 = validate(val_loader, model, criterion)
        val_prec1_plot[epoch] = prec1
        val_prec5_plot[epoch] = prec5

        # remember best prec@1 and save checkpoint
        is_best = prec1 > best_prec1
        best_prec1 = max(prec1, best_prec1)
        save_checkpoint({
            'epoch': epoch + 1,
            'arch': args.arch,
            'state_dict': model.state_dict(),
            'best_prec1': best_prec1,
            'optimizer' : optimizer.state_dict(),
        }, is_best)

        # 将Loss,train_prec1,train_prec5,val_prec1,val_prec5用.txt的文件存起来
        data_save(directory + 'Loss_plot.txt', Loss_plot)
        data_save(directory + 'train_prec1.txt', train_prec1_plot)
        data_save(directory + 'train_prec5.txt', train_prec5_plot)
        data_save(directory + 'val_prec1.txt', val_prec1_plot)
        data_save(directory + 'val_prec5.txt', val_prec5_plot)

        line = 'Epoch {}/{} summary: loss_train={:.5f}, acc_train={:.2f}%, loss_val={:.2f}, acc_val={:.2f}% (best: {:.2f}% @ epoch {})'.format(epoch, args.epochs, loss_temp, train_prec1_temp, 0, prec1, best_prec1, epoch_max)
        wA.writeLogAcc(filenameLOG,line)
        end_time = time.time()
        time_value = (end_time - start_time) / 3600
        print("-" * 80)
        print(time_value)
        print("-" * 80)
def train(train_loader, model, criterion, optimizer, epoch):
    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = AverageMeter()
    top1 = AverageMeter()
    top5 = AverageMeter()
    losses_batch = {}

    # switch to train mode
    model.train()

    end = time.time()
    for i, (input, target) in enumerate(train_loader):
        # measure data loading time
        data_time.update(time.time() - end)

        if args.gpu is not None:
            input = input.cuda(args.gpu, non_blocking=True)
        target = target.cuda(args.gpu, non_blocking=True)

        # compute output
        output = model(input)

        # Check for NaNs in output
        if torch.isnan(output).any():
            print("NaN detected in model output.")
            return None, None, None  # Returning None to stop further processing

        loss = criterion(output, target)

        # Check for NaNs in loss
        if torch.isnan(loss).any():
            print(f"NaN detected in loss at iteration {i}")
            return None, None, None

        # measure accuracy and record loss
        prec1, prec5 = accuracy(output, target, topk=(1, 2))
        losses.update(loss.item(), input.size(0))
        top1.update(prec1[0], input.size(0))
        top5.update(prec5[0], input.size(0))

        # compute gradient and do SGD step
        optimizer.zero_grad()

        loss.backward()

        # Apply gradient clipping
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=5.0)  # Adjust max_norm as needed

        optimizer.step()

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        if i % args.print_freq == 0:
            print('Epoch: [{0}][{1}/{2}]\t'
                  'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                  'Data {data_time.val:.3f} ({data_time.avg:.3f})\t'
                  'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
                  'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t'
                  'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format(
                   epoch, i, len(train_loader), batch_time=batch_time,
                   data_time=data_time, loss=losses, top1=top1, top5=top5))

    return losses.avg, top1.avg, top5.avg
def validate(val_loader, model, criterion):
    batch_time = AverageMeter()
    losses = AverageMeter()
    top1 = AverageMeter()
    top5 = AverageMeter()

    # switch to evaluate mode
    model.eval()

    with torch.no_grad():
        end = time.time()
        for i, (input, target) in enumerate(val_loader):
            if args.gpu is not None:
                input = input.cuda(args.gpu, non_blocking=True)
            target = target.cuda(args.gpu, non_blocking=True)

            # compute output
            output = model(input)
            loss = criterion(output, target)

            # measure accuracy and record loss
            prec1, prec5 = accuracy(output, target, topk=(1, 2))
            losses.update(loss.item(), input.size(0))
            top1.update(prec1[0], input.size(0))
            top5.update(prec5[0], input.size(0))

            # measure elapsed time
            batch_time.update(time.time() - end)
            end = time.time()

            if i % args.print_freq == 0:
                print('Test: [{0}/{1}]\t'
                      'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                      'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
                      'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t'
                      'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format(
                       i, len(val_loader), batch_time=batch_time, loss=losses,
                       top1=top1, top5=top5))

        print(' * Prec@1 {top1.avg:.3f} Prec@5 {top5.avg:.3f}'
              .format(top1=top1, top5=top5))

    return top1.avg, top5.avg
def save_checkpoint(state, is_best, filename='checkpoint.pth.tar'):
    directory = "/content/drive/MyDrive/ML/checkpoints/%s/"%(args.arch + '_' + args.action)

    filename = directory + filename
    torch.save(state, filename)
    if is_best:
        shutil.copyfile(filename, directory + 'model_best.pth.tar')
class AverageMeter(object):
    """Computes and stores the average and current value"""
    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count
def adjust_learning_rate(optimizer, epoch):
    """Sets the learning rate to the initial LR decayed by 10 every 30 epochs"""
    lr = args.lr * (0.1 ** (epoch // 30))
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr


def accuracy(output, target, topk=(1,)):
    """Computes the precision@k for the specified values of k"""
    with torch.no_grad():
        maxk = max(topk)
        batch_size = target.size(0)

        _, pred = output.topk(maxk, 1, True, True)
        pred = pred.t()
        correct = pred.eq(target.view(1, -1).expand_as(pred))

        res = []
        for k in topk:
            #correct_k = correct[:k].view(-1).float().sum(0, keepdim=True)
            correct_k = correct[:k].contiguous().view(-1).float().sum(0, keepdim=True)
            res.append(correct_k.mul_(100.0 / batch_size))
        return res


def data_save(root, file):
    # Check if the file exists, and if not, create it
    if not os.path.exists(root):
        with open(root, 'w'): pass  # Create an empty file if it doesn't exist

    # Open the file and read lines
    with open(root, 'r') as file_temp:
        lines = file_temp.readlines()

    # Initialize epoch value
    if not lines:
        epoch = -1
    else:
        epoch = lines[-1][:lines[-1].index(' ')]  # Get the epoch from the last line
    epoch = int(epoch)

    # Append new data to the file
    with open(root, 'a') as file_temp:
        for line in file:
            if line > epoch:
                file_temp.write(str(line) + " " + str(file[line]) + '\n')


if __name__ == '__main__':
    main()

Bài viết liên quan