YOLOV1_pytorch_(1)
---
[toc]
---
# 1 YOLOV1
1. YOLOV1的算法流程。YOLOV1把圖像從448x448x3下采樣6次得到特征圖7x7x30,特征圖上的每個(gè)網(wǎng)格生成兩個(gè)框預(yù)測(cè)物體和物體存在的概率以及20個(gè)物體的類別,30=Bx(loc+conf)+cls = 2x(4+1)+20。B的數(shù)量為2,是預(yù)測(cè)框的數(shù)量。loc的數(shù)量為4,是預(yù)測(cè)框中心店的偏移和寬高。conf的數(shù)量為2,代表每個(gè)框與物體的概率。cls代表預(yù)測(cè)物體的數(shù)量。如果真實(shí)框落在某個(gè)網(wǎng)格內(nèi),網(wǎng)格內(nèi)置信度高得一個(gè)框預(yù)測(cè)真實(shí)框。
2. YOLOV1的優(yōu)缺點(diǎn)。YOLOV1是end-to-end的模型,模型簡(jiǎn)單;速度快,預(yù)測(cè)精度遜色與Faster_RCNN,但是預(yù)測(cè)速度大幅提升;在不同的數(shù)據(jù)集上驗(yàn)證證明模型泛化能力強(qiáng)。不需要滑動(dòng)窗口得到區(qū)域框,預(yù)測(cè)精度相對(duì)于其他的模型有所提高。
由于YOLOV1是end沒有設(shè)置先驗(yàn)框,因此預(yù)測(cè)精度受的影響且遷移能力差;只有一個(gè)特征層,對(duì)小物體和群體的小物體預(yù)測(cè)能力差,這樣也意味著有更多的修改空間。
3. 計(jì)算機(jī)視覺模型主要分為三大部分:數(shù)據(jù)處理、訓(xùn)練、預(yù)測(cè)。每個(gè)部分又包含若干細(xì)節(jié),通過化整為零分步拆分、各個(gè)擊破的方法學(xué)習(xí)模型,不僅可以加深對(duì)模型的理解,還可以為以后學(xué)習(xí)其他模型打下堅(jiān)實(shí)基礎(chǔ),大多數(shù)模型都是用同樣的套路,不同的是細(xì)節(jié)方面的修改,比如說數(shù)據(jù)增強(qiáng)、主干模型、損失函數(shù)的修改等等。
4.參考 https://www.github/abeardear/pytorch-YOLO-v1
## 1 數(shù)據(jù)處理
數(shù)據(jù)處理主要包括數(shù)據(jù)集劃分、讀入xml文件、數(shù)據(jù)增強(qiáng)三大部分。
首先把數(shù)據(jù)劃分為訓(xùn)練集、驗(yàn)證集和測(cè)試集,每種類型的數(shù)據(jù)集里存儲(chǔ)的是圖片的名稱,比如2002001.jpg,2002002.jpg,2002003.jpg圖片,在數(shù)據(jù)集中是2002001,2002002、2002003。第二步根據(jù)數(shù)據(jù)集讀入圖片的地址和和真實(shí)框的信息。第三部就是根據(jù)圖片和真實(shí)框的信息進(jìn)行數(shù)據(jù)增強(qiáng)及編碼得到標(biāo)簽。
三部分的代碼依次存放在 data_split_11.py 、data_jpgxlm_12.py、 data_label_13.py,點(diǎn)擊運(yùn)行就會(huì)得到對(duì)應(yīng)結(jié)果。然后點(diǎn)擊 train.py 就可以訓(xùn)練模型,根據(jù)訓(xùn)練模型的參數(shù)就可以點(diǎn)擊 predict.py 預(yù)測(cè)了。
### 1.1 數(shù)據(jù)集劃分
數(shù)據(jù)集劃分在data_split_11.py文件中。
已知信息:在VOCdevkit/VOC2007/Annotations存儲(chǔ)的是每個(gè)圖片對(duì)應(yīng)的xml信息、訓(xùn)練集、驗(yàn)證集、測(cè)試集比例。
輸出的是各個(gè)數(shù)據(jù)集的圖片名稱。
數(shù)據(jù)集劃分流程:
1. 得到所有的以'.xml'結(jié)尾的的信息。
2. 根據(jù)訓(xùn)練集、驗(yàn)證集、測(cè)試集比例和上一步得到的樣本數(shù)量,抽樣確定各個(gè)數(shù)據(jù)集的樣本數(shù)、各個(gè)樣本集中的樣本下標(biāo)。
3. 根據(jù)下標(biāo)存儲(chǔ)各個(gè)數(shù)據(jù)集。
```python
'''
dataSplit
'''
import os
import random
xml_path = r'/Users/ls/PycharmProjects/YOLOV1_LS/VOCdevkit/VOC2007/Annotations'
base_path = r'/Users/ls/PycharmProjects/YOLOV1_LS/VOCdevkit/VOC2007/ImageSets/Main'
# 1 樣本名字
tmp = []
img_names = os.listdir(xml_path)
for i in img_names:
? ? if i.endswith('.xml'):
? ? ? ? tmp.append(i[:-4])
# 2 數(shù)據(jù)集劃分
trainval_ratio = 0.9
train_ratio = 0.9
N = len(tmp)
trainval_num = int(trainval_ratio*N)
train_num = int(train_ratio*trainval_num)
trainval_idx = random.sample(range(N),trainval_num)
train_idx = random.sample(trainval_idx,train_num)
ftrainval = open(os.path.join(base_path,'LS_trainval.txt'),'w')
ftrain = open(os.path.join(base_path,'LS_train.txt'),'w')
fval = open(os.path.join(base_path,'LS_val.txt'),'w')
ftest = open(os.path.join(base_path,'LS_test.txt'),'w')
# 3 寫入數(shù)據(jù)
for i in range(N):
? ? name = tmp[i]+'\n'
? ? if i in trainval_idx:
? ? ? ? ftrainval.write(name)
? ? ? ? if i in train_idx:
? ? ? ? ? ? ftrain.write(name)
? ? ? ? else:
? ? ? ? ? ? fval.write(name)
? ? else:
? ? ? ? ftest.write(name)
```
### 1.2 讀入xml文件
data_jpgxlm_12.py把圖片信息和xml信息放在一起。
1. 打開數(shù)據(jù)集,建立保存文件地址,遍歷每個(gè)圖片信息,讀入圖片名稱和xml信息并保存。
2. 對(duì)于每個(gè)xml文件,根據(jù)‘difficult’和‘name’判斷是否要保存文件。根據(jù)物體名稱,確定下標(biāo),讀取框的信息。
```python
import xml.etree.ElementTree as ET
sets =[('2007','train'),('2007','val'),('2007','test')] ? # 集合里的元祖為數(shù)據(jù)集名稱,用于之前劃分的打開數(shù)據(jù)集和建立解析后的數(shù)據(jù)集(存儲(chǔ)圖片名稱和框、類別)。
classes = ["aeroplane", "bicycle", "bird", "boat", "bottle",
? ? ? ? ? ?"bus", "car", "cat", "chair", "cow", "diningtable",
? ? ? ? ? ?"dog", "horse", "motorbike", "person", "pottedplant",
? ? ? ? ? ?"sheep", "sofa", "train", "tvmonitor"]
# 1 解析xml文件,讀取圖片地址、真實(shí)框和物體類別信息
def convert_annotation(year,img_id,list_file):
? ? list_file.write("/Users/ls/PycharmProjects/YOLOV1_LS/VOCdevkit/VOC%s/JPEGImages/%s.jpg"%(year,img_id)) ?# 存儲(chǔ)圖片名稱
? ? in_file = open('/Users/ls/PycharmProjects/YOLOV1_LS/VOCdevkit/VOC%s/Annotations/%s.xml'%(year,img_id)) ? ?# 打開對(duì)應(yīng)圖片的xml文件
? ? root = ET.parse(in_file).getroot()
? ? for obj in root.iter('object'):
? ? ? ? difficult = obj.find('difficult').text
? ? ? ? cls = obj.find('name').text
? ? ? ? if cls not in classes or int(difficult)==1:
? ? ? ? ? ? continue
? ? ? ? xml_box = obj.find('bndbox')
? ? ? ? b = (int(xml_box.find('xmin').text),
? ? ? ? ? ? ?int(xml_box.find('ymin').text),
? ? ? ? ? ? ?int(xml_box.find('xmax').text),
? ? ? ? ? ? ?int(xml_box.find('ymax').text)) ?# 獲取框
? ? ? ? cls_id = classes.index(cls) ?# 獲取類別
? ? ? ? list_file.write(' '+','.join([str(i) for i in b])+','+str(cls_id)) ?# 存儲(chǔ)框和類別信息
#?
for year,img_set in sets:
? ? img_ids = open('/Users/ls/PycharmProjects/YOLOV1_LS/VOCdevkit/VOC%s/ImageSets/Main/%s.txt'%(year,img_set)).read().strip().split()
? ? # 讀取數(shù)據(jù)集信息
? ? list_file = open('/Users/ls/PycharmProjects/YOLOV1_LS/%s_%sLS.txt'%(year,img_set),'w') # 建立對(duì)應(yīng)的存儲(chǔ)文件
? ? for img_id in img_ids: ?# 遍歷每個(gè)圖片
? ? ? ? convert_annotation(year,img_id,list_file)
? ? ? ? list_file.write('\n') ?# 換行
? ? list_file.close()
```
### 1.3 數(shù)據(jù)增強(qiáng)
data_label_13.py對(duì)圖片數(shù)據(jù)增強(qiáng),增加樣本,提高模型的泛化能力。
1. 首先把圖片名稱、框、類別信息分別存儲(chǔ)。
2. 生成迭代器,對(duì)每個(gè)圖片進(jìn)行數(shù)據(jù)增強(qiáng)。翻轉(zhuǎn)、縮放、模糊、隨機(jī)變換亮度、隨機(jī)變換色度、隨機(jī)變換飽和度、隨機(jī)平移、隨機(jī)剪切。
3. 編碼。對(duì)于增強(qiáng)后的圖片,根據(jù)圖片寬高,獲取框在圖片的相對(duì)位置、去均值、統(tǒng)一圖片尺寸、編碼。編碼的關(guān)鍵在于找到真實(shí)框在特征圖上的相對(duì)位置。先把真實(shí)框左上角和右下角坐標(biāo)轉(zhuǎn)換為中心點(diǎn)和寬高,中心點(diǎn)坐標(biāo)x特征圖寬高后,向下取整在-1就得到真實(shí)框在特征圖上的位置ij。中心點(diǎn)坐標(biāo)x特征圖寬高-ij得到真實(shí)偏移。根據(jù)ij輸入偏移、寬高、類別信息。
(1)數(shù)據(jù)增強(qiáng)代碼
```python
''' 數(shù)據(jù)增強(qiáng) '''
import os
import sys
import CV2
import torch
import random
import os.path
import numpy as np
import matplotlib.pyplot as plt
import torch.utils.data as data
import torchvision.transforms as transforms
class yoloDataset(data.Dataset):
? ? image_size = 448
? ? def __init__(self,root,list_file,train,transform):
? ? ? ? print('data init')
? ? ? ? self.root = root ? ?## img.jpg_path
? ? ? ? self.train = train ?## bool:True,如果是訓(xùn)練模式,就進(jìn)行數(shù)據(jù)增強(qiáng)。
? ? ? ? self.transform = transform ?# 轉(zhuǎn)置
? ? ? ? self.fnames = [] ? # 存儲(chǔ)圖片名 ?eg: 00014.jpg
? ? ? ? self.boxes = [] ? ?# 存放真實(shí)框信息
? ? ? ? self.labels = [] ? # 存放標(biāo)簽
? ? ? ? self.mean = (123,117,104) ? # 圖片各個(gè)通道均值,用于歸一化數(shù)據(jù),加速訓(xùn)練。
? ? ? ? with open(list_file) as f:
? ? ? ? ? ? lines = f.readlines() ? # 讀取數(shù)據(jù)?
? ? ? ? # line: /Users/ls/PycharmProjects/YOLOV1_LS/VOCdevkit/VOC2007/JPEGImages/000022.jpg 68,103,368,283,12 186,44,255,230,14
? ? ? ? for line in lines: ??
? ? ? ? ? ? splited = line.strip().split()
? ? ? ? ? ? self.fnames.append(splited[0]) ?# 保存圖片名
? ? ? ? ? ? num_boxes = (len(splited)-1) ? ?# 一張圖片中真實(shí)框的個(gè)數(shù)
? ? ? ? ? ? box = [] ? ? # 存儲(chǔ)框
? ? ? ? ? ? label = [] ? # 存儲(chǔ)標(biāo)簽
? ? ? ? ? ? for i in range(1,num_boxes+1): # 遍歷每個(gè)框
? ? ? ? ? ? ? ? tmp = [float(j) for j in splited[i].split(',')] # 把真實(shí)框油字符變?yōu)閒loat類型,并用‘,’隔開。
? ? ? ? ? ? ? ? box.append(tmp[:4])?
? ? ? ? ? ? ? ? label.append(int(tmp[4])+1)
? ? ? ? ? ? self.boxes.append(torch.Tensor(box))
? ? ? ? ? ? self.labels.append(torch.LongTensor(label))
? ? ? ? self.num_samples = len(self.boxes)
? ? def __getitem__(self,idx):
? ? ? ? fname = self.fnames[idx] # 用迭代器遍歷每張圖片
? ? ? ? img = CV2.imread(fname) ?# 讀取圖片 CV2.imread(os.path.join(self.root+fname))
? ? ? ? boxes = self.boxes[idx].clone()
? ? ? ? labels = self.labels[idx].clone()
? ? ? ? if self.train:
? ? ? ? ? ? img,boxes = self.random_flip(img,boxes)
? ? ? ? ? ? img,boxes = self.randomScale(img,boxes)
? ? ? ? ? ? img = self.randomBlur(img)
? ? ? ? ? ? img = self.RandomBrightness(img)
? ? ? ? ? ? img = self.RandomHue(img)
? ? ? ? ? ? img = self.RandomSaturation(img)
? ? ? ? ? ? img,boxes,labels = self.randomShift(img,boxes,labels)
? ? ? ? ? ? img,boxes,labels = self.randomCrop(img,boxes,labels)
? ? ? ? h,w,_ = img.shape
? ? ? ? boxes /= torch.Tensor([w,h,w,h]).expand_as(boxes)
? ? ? ? img = self.BGR2RGB(img)
? ? ? ? img = self.subMean(img,self.mean)
? ? ? ? img = CV2.resize(img,(self.image_size,self.image_size))
? ? ? ? target = self.encoder(boxes,labels)
? ? ? ? for t in self.transform:
? ? ? ? ? ? img = t(img)
? ? ? ? return ?img,target
? ? def __len__(self):
? ? ? ? return self.num_samples
? ? def encoder(self,boxes,labels):
? ? ? ? grid_num = 7
? ? ? ? target = torch.zeros((grid_num,grid_num,30))
? ? ? ? wh = boxes[:,2:] - boxes[:,:2]
? ? ? ? cxcy = (boxes[:,2:] + boxes[:,:2])/2
? ? ? ? for i in range(cxcy.size()[0]):
? ? ? ? ? ? cxcy_sample = cxcy[i]
? ? ? ? ? ? ij = (cxcy_sample*grid_num).ceil()-1
? ? ? ? ? ? dxy = cxcy_sample*grid_num-ij
? ? ? ? ? ? target[int(ij[1]),int(ij[0]),:2] = target[int(ij[1]),int(ij[0]),5:7] = dxy
? ? ? ? ? ? target[int(ij[1]),int(ij[0]),2:4] = target[int(ij[1]),int(ij[0]),7:9] = wh[i]
? ? ? ? ? ? target[int(ij[1]),int(ij[0]),4] = target[int(ij[1]),int(ij[0]),9] = 1
? ? ? ? ? ? target[int(ij[1]),int(ij[0]),int(labels[i])+9] = 1
? ? ? ? return target
? ? def BGR2RGB(self,img):
? ? ? ? return CV2.cvtColor(img,CV2.COLOR_BGR2RGB)
? ? def BGR2HSV(self,img):
? ? ? ? return CV2.cvtColor(img,CV2.COLOR_BGR2HSV)
? ? def HVS2BGR(self,img):
? ? ? ? return CV2.cvtColor(img,CV2.COLOR_HSV2BGR)
? ? def RandomBrightness(self,bgr):
? ? ? ? if random.random()<0.5:
? ? ? ? ? ? hsv = self.BGR2HSV(bgr)
? ? ? ? ? ? h,s,v = CV2.split(hsv)
? ? ? ? ? ? v = v.astype(float)
? ? ? ? ? ? v *= random.choice([0.5,1.5])
? ? ? ? ? ? v = np.clip(v,0,255).astype(hsv.dtype)
? ? ? ? ? ? hsv = CV2.merge((h,s,v))
? ? ? ? ? ? bgr = self.HVS2BGR(hsv)
? ? ? ? return ?bgr
? ? def RandomSaturation(self,bgr):
? ? ? ? if random.random()<0.5:
? ? ? ? ? ? hsv = self.BGR2HSV(bgr)
? ? ? ? ? ? h,s,v = CV2.split(hsv)
? ? ? ? ? ? s = s.astype(float)
? ? ? ? ? ? s *= random.choice([0.5,1.5])
? ? ? ? ? ? s = np.clip(s,0,255).astype(hsv.dtype)
? ? ? ? ? ? hsv = CV2.merge((h,s,v))
? ? ? ? ? ? bgr = self.HVS2BGR(hsv)
? ? ? ? return bgr
? ? def RandomHue(self,bgr):
? ? ? ? if random.random() < 0.5:
? ? ? ? ? ? hsv = self.BGR2HSV(bgr)
? ? ? ? ? ? h,s,v = CV2.split(hsv)
? ? ? ? ? ? h = h.astype(float)
? ? ? ? ? ? h *= random.choice([0.5,1.5])
? ? ? ? ? ? h = np.clip(h,0,255).astype(hsv.dtype)
? ? ? ? ? ? hsv=CV2.merge((h,s,v))
? ? ? ? ? ? bgr = self.HVS2BGR(hsv)
? ? ? ? return bgr
? ? def randomBlur(self,bgr):
? ? ? ? if random.random() < 0.5:
? ? ? ? ? ? bgr = CV2.blur(bgr,(5,5))
? ? ? ? return bgr
? ? def randomShift(self,bgr,boxes,labels):
? ? ? ? center = (boxes[:,2:]+boxes[:,:2])/2
? ? ? ? if random.random()<0.5:
? ? ? ? ? ? height,width,c = bgr.shape
? ? ? ? ? ? after_shift_imge = np.zeros((height,width,c),dtype=bgr.dtype)
? ? ? ? ? ? after_shift_imge[:,:,:] = (104,117,123)
? ? ? ? ? ? shift_x = random.uniform(-width*0.2,width*0.2)
? ? ? ? ? ? shift_y = random.uniform(-height*0.2,height*0.2)
? ? ? ? ? ? if shift_x>=0 and shift_y>=0:
? ? ? ? ? ? ? ? after_shift_imge[int(shift_y):,int(shift_x):,:] = bgr[:height-int(shift_y),:width-int(shift_x),:]
? ? ? ? ? ? elif shift_x>=0 and shift_y<0:
? ? ? ? ? ? ? ? after_shift_imge[:height+int(shift_y),int(shift_x):,:] = bgr[-int(shift_y):,:width-int(shift_x),:]
? ? ? ? ? ? elif shift_x <0 and shift_y >=0:
? ? ? ? ? ? ? ? after_shift_imge[int(shift_y):,:width+int(shift_x),:] = bgr[:height-int(shift_y),-int(shift_x):,:]
? ? ? ? ? ? elif shift_x<0 and shift_y<0:
? ? ? ? ? ? ? ? after_shift_imge[:height+int(shift_y),:width+int(shift_x),:] = bgr[-int(shift_y):,-int(shift_x):,:]
? ? ? ? ? ? shift_xy = torch.FloatTensor([[int(shift_x),int(shift_y)]]).expand_as(center)
? ? ? ? ? ? center = center + shift_xy
? ? ? ? ? ? mask1 = (center[:,0]>0)& (center[:,0]>height)
? ? ? ? ? ? mask2 = (center[:,1]>0)& (center[:,1]>width)
? ? ? ? ? ? mask = (mask1 & mask2).view(-1,1)
? ? ? ? ? ? boxes_in = boxes[mask.expand_as(boxes)].view(-1,4)
? ? ? ? ? ? if len(boxes_in) == 0:
? ? ? ? ? ? ? ? return bgr,boxes,labels
? ? ? ? ? ? box_shift = torch.FloatTensor([[int(shift_x),int(shift_y),int(shift_x),int(shift_y)]]).expand_as(boxes_in)
? ? ? ? ? ? boxes_in = boxes_in+box_shift
? ? ? ? ? ? labels_in = labels[mask.view(-1)]
? ? ? ? ? ? return after_shift_imge,boxes_in,labels_in
? ? ? ? return bgr,boxes,labels
? ? def randomScale(self,bgr,boxes):
? ? ? ? if random.random() < 0.5:
? ? ? ? ? ? scale = random.uniform(0.8,1.2)
? ? ? ? ? ? h,w,c = bgr.shape
? ? ? ? ? ? bgr = CV2.resize(bgr,(int(w*scale),h))
? ? ? ? ? ? scale_tensor = torch.FloatTensor([[scale,1,scale,1]]).expand_as(boxes)
? ? ? ? ? ? boxes = boxes*scale_tensor
? ? ? ? ? ? return bgr,boxes
? ? ? ? return bgr,boxes
? ? def randomCrop(self,bgr,boxes,labels):
? ? ? ? if random.random() < 0.5:
? ? ? ? ? ? center = (boxes[:,:2]+boxes[:,2:])/2
? ? ? ? ? ? height,width,c = bgr.shape
? ? ? ? ? ? h = random.uniform(0.6*height,height)
? ? ? ? ? ? w = random.uniform(0.6*width,width)
? ? ? ? ? ? x = random.uniform(0,width-w)
? ? ? ? ? ? y = random.uniform(0,height-h)
? ? ? ? ? ? x,y,h,w = int(x),int(y),int(h),int(w)
? ? ? ? ? ? center = center - torch.FloatTensor([[x,y]]).expand_as(center)
? ? ? ? ? ? mask1 = (center[:,0]>0) & (center[:,0]<w)
? ? ? ? ? ? mask2 = (center[:,1]>0) & (center[:,0]<h)
? ? ? ? ? ? mask = (mask1 & mask2).view(-1,1)
? ? ? ? ? ? boxes_in = boxes[mask.expand_as(boxes)].view(-1,4)
? ? ? ? ? ? if(len(boxes_in)==0):
? ? ? ? ? ? ? ? return bgr,boxes,labels
? ? ? ? ? ? box_shift = torch.FloatTensor([[x,y,x,y]]).expand_as(boxes_in)
? ? ? ? ? ? boxes_in = boxes_in-box_shift
? ? ? ? ? ? boxes_in[:,0]=boxes_in[:,0].clamp(0,w)
? ? ? ? ? ? boxes_in[:,2]=boxes_in[:,2].clamp(0,w)
? ? ? ? ? ? boxes_in[:,1]=boxes_in[:,1].clamp(0,h)
? ? ? ? ? ? boxes_in[:,3]=boxes_in[:,3].clamp(0,h)
? ? ? ? ? ? labels_in = labels[mask.view(-1)]
? ? ? ? ? ? img_croped = bgr[y:y+h,x:x+w,:]
? ? ? ? ? ? return img_croped,boxes_in,labels_in
? ? ? ? return bgr,boxes,labels
? ? def subMean(self,bgr,mean):
? ? ? ? mean = np.array(mean,dtype=np.float32)
? ? ? ? bgr = bgr - mean
? ? ? ? return bgr
? ? def random_flip(self,im,boxes):
? ? ? ? if random.random() < 0.5:
? ? ? ? ? ? im_lr = np.fliplr(im).copy()
? ? ? ? ? ? h,w,_ = im.shape
? ? ? ? ? ? xmin = w - boxes[:,2]
? ? ? ? ? ? xmax = w - boxes[:,0]
? ? ? ? ? ? boxes[:,0] = xmin
? ? ? ? ? ? boxes[:,2] = xmax
? ? ? ? ? ? return im_lr,boxes
? ? ? ? return im,boxes
? ? def random_bright(self,im,delta=16):
? ? ? ? alpha = random.random()
? ? ? ? if alpha > 0.3:
? ? ? ? ? ? im = im * alpha + random.randrange(-delta,delta)
? ? ? ? ? ? im = im.clip(min=0,max=255).astype(np.uint8)
? ? ? ? return im
if __name__=='__main__':
? ? from torch.utils.data import DataLoader
? ? import torchvision.transforms as transforms
? ? file_root ='/Users/ls/PycharmProjects/YOLOV1_LS' ?## xx.jpg
? ? list_f = r'/Users/ls/PycharmProjects/YOLOV1_LS/2007_train.txt'
? ? train_dataset = yoloDataset(root=file_root,list_file=list_f,train=True,transform = [transforms.ToTensor()])
? ? train_loader = DataLoader(train_dataset,batch_size=1,shuffle=False,num_workers=0)
? ? train_iter = iter(train_loader)
? ? # for i in range(5):
? ? # ? ? img,target = next(train_iter)
? ? # ? ? print(target[target[...,0]>0])
? ? for i,(images,target) in enumerate(train_loader):
? ? ? ? print(1111111111111111111111)
? ? ? ? print(target)
? ? ? ? print(images)
```
(2)以下是數(shù)據(jù)增強(qiáng)中主要函數(shù)的講解。數(shù)據(jù)增強(qiáng)中,在對(duì)圖片操作的同時(shí),關(guān)鍵是也要對(duì)真實(shí)框作相應(yīng)的操作。例如對(duì)圖像進(jìn)行翻轉(zhuǎn),目標(biāo)物體位置發(fā)生變化,真實(shí)框也要進(jìn)行同樣的翻轉(zhuǎn),利于保障訓(xùn)練預(yù)測(cè)結(jié)果的準(zhǔn)確性。
1. 編碼要點(diǎn)
I.編碼中輸入的框是框在圖像中的相對(duì)位置,即真實(shí)框的左上角和右下角坐標(biāo)除以圖像寬高后的值。因?yàn)橐谡业教卣鲌D上真實(shí)框的位置,雖然原圖和特征圖的尺寸不一樣,但是真實(shí)框在原圖和特征圖的相對(duì)位置一樣,通過相對(duì)位置把真實(shí)框映射在特征圖上。
```
img.shape:[400,500,3] ? ? ? ? # 原圖 h:400;w:500
box:[100,120,200,250] ? ? ? ? # 真實(shí)框的坐標(biāo) [x1,y1,x2,y2]
box_img: [100/500,120/400,200/500,250/400] ? ? ? ? ? ?# 真實(shí)框在原圖上的相對(duì)位置 [x1/w,y1/h,x2/w,y2/h]
feature.shape:[7,7,30] ? ? ? ?# 特征圖
box_feature: [100/500*7,120/400*7,200/500*7,250/400*7] # 真實(shí)框在特征圖上的位置
```
II. dxy
```
cxcy_sample = cxcy[i] ? ? ? ? ?# 真實(shí)框相對(duì)原圖的中心點(diǎn)
ij = (cxcy_sample*grid_num).ceil()-1 ?# 真實(shí)框在特征圖上的中心點(diǎn),為了防止中心點(diǎn)越界,因此中心點(diǎn)坐標(biāo)向下取整并減一。
dxy = cxcy_sample*grid_num-ij ?# 偏移
```
III. target.shape[7,7,30],前10個(gè)數(shù)據(jù)是兩個(gè)框和置信度數(shù)據(jù),后面20個(gè)是類別的one_hot形式。
```
dim = 2
label = 5
0 ? 1 ?2 ?3 ? ?4 ? ?5 ? 6 ? 7 ?8 ? ?9 ? ? 10 ?11 ?12 ?13 ?14 ...29
dx ?dy w ?h ?conf ? dx ?dy ?w ?h ?conf ? ? 0 ? 0 ? 0 ? 0 ? 1 ... 0
```
編碼流程:
I. 真實(shí)框左上角、右下角[x1,y1,x2,y2]坐標(biāo)轉(zhuǎn)換為中心點(diǎn)和寬高[cx,cy,w,h]。
II. 遍歷圖像中真實(shí)框,計(jì)算真實(shí)框在特征圖上的中心點(diǎn)坐標(biāo)、中心點(diǎn)坐標(biāo)偏移。把框、置信度、類別信息填寫在標(biāo)簽target的對(duì)應(yīng)位置上。
```python
def encoder(self,boxes,labels):
? ? ? ? grid_num = 7 ? ? # 特征圖邊長(zhǎng),feature.shape:[7,7,30]
? ? ? ? target = torch.zeros((grid_num,grid_num,30)) ? # 標(biāo)簽
? ? ? ? wh = boxes[:,2:] - boxes[:,:2] ? ? ? ? # 真實(shí)框相對(duì)原圖的寬高
? ? ? ? cxcy = (boxes[:,2:] + boxes[:,:2])/2 ? # 真實(shí)框相對(duì)原圖的中心點(diǎn)
? ? ? ? for i in range(cxcy.size()[0]): ?# 遍歷每個(gè)框
? ? ? ? ? ? cxcy_sample = cxcy[i] ? ? ? ??
? ? ? ? ? ? ij = (cxcy_sample*grid_num).ceil()-1 ?# 真實(shí)框在特征圖上的中心點(diǎn)
? ? ? ? ? ? dxy = cxcy_sample*grid_num-ij ? ? # 中心偏移
? ? ? ? ? ? target[int(ij[1]),int(ij[0]),:2] = target[int(ij[1]),int(ij[0]),5:7] = dxy ??
? ? ? ? ? ? target[int(ij[1]),int(ij[0]),2:4] = target[int(ij[1]),int(ij[0]),7:9] = wh[i]
? ? ? ? ? ? target[int(ij[1]),int(ij[0]),4] = target[int(ij[1]),int(ij[0]),9] = 1 ?# 置信度
? ? ? ? ? ? target[int(ij[1]),int(ij[0]),int(labels[i])+9] = 1 ? # 類別
? ? ? ? return target
```
2. 隨機(jī)調(diào)節(jié)亮度
```python
def RandomBrightness(self,bgr):
? ? if random.random()<0.5: ? ? ? ?# 隨機(jī)值
? ? ? ? hsv = self.BGR2HSV(bgr) ? ?# bgr-->hsv
? ? ? ? h,s,v = CV2.split(hsv) ? ? # 通道分割
? ? ? ? v = v.astype(float) ? ? ? ?# 取出亮度通道,轉(zhuǎn)換數(shù)據(jù)類型
? ? ? ? v *= random.choice([0.5,1.5]) ?# 隨機(jī)變換亮度值
? ? ? ? v = np.clip(v,0,255).astype(hsv.dtype) ?# 限制數(shù)據(jù)范圍,轉(zhuǎn)換數(shù)據(jù)類型
? ? ? ? hsv = CV2.merge((h,s,v)) ? # 通道合并
? ? ? ? bgr = self.HVS2BGR(hsv) ? ?# hsv-->bgr
? ? return ?bgr
```
3. 隨機(jī)平移
I. 找到平移距離shift_x,shift_y,平移圖片。
II. 判斷框平移后是否合理。
III.對(duì)框平移。
```
bgr.shape[100,100,3]
1: shift_x= 20,shift_y= 30,after_shift_imge[30:,20:,:]=bgr[:70,:80,:]向右下角移動(dòng)
2: shift_x= 20,shift_y=-30,after_shift_imge[:70,20:,:]=bgr[30:,:80,:]向上平移
3: shift_x=-20,shift_y= 30,after_shift_imge[30:,:80,:]=bgr[:70,20:,:]向左平移
4: shift_x=-20,shift_y=-30,after_shift_imge[:70,:80,:]=bgr[30:,20:,:]向左上角移動(dòng)
```
```python
def randomShift(self,bgr,boxes,labels):
? ? ? ? center = (boxes[:,2:]+boxes[:,:2])/2
? ? ? ? if random.random()<0.5:
? ? ? ? ? ? height,width,c = bgr.shape
? ? ? ? ? ? after_shift_imge = np.zeros((height,width,c),dtype=bgr.dtype)
? ? ? ? ? ? after_shift_imge[:,:,:] = (104,117,123)
? ? ? ? ? ? shift_x = random.uniform(-width*0.2,width*0.2)
? ? ? ? ? ? shift_y = random.uniform(-height*0.2,height*0.2)
? ? ? ? ? ? if shift_x>=0 and shift_y>=0:
? ? ? ? ? ? ? ? after_shift_imge[int(shift_y):,int(shift_x):,:] = bgr[:height-int(shift_y),:width-int(shift_x),:]
? ? ? ? ? ? elif shift_x>=0 and shift_y<0:
? ? ? ? ? ? ? ? after_shift_imge[:height+int(shift_y),int(shift_x):,:] = bgr[-int(shift_y):,:width-int(shift_x),:]
? ? ? ? ? ? elif shift_x <0 and shift_y >=0:
? ? ? ? ? ? ? ? after_shift_imge[int(shift_y):,:width+int(shift_x),:] = bgr[:height-int(shift_y),-int(shift_x):,:]
? ? ? ? ? ? elif shift_x<0 and shift_y<0:
? ? ? ? ? ? ? ? after_shift_imge[:height+int(shift_y),:width+int(shift_x),:] = bgr[-int(shift_y):,-int(shift_x):,:]
? ? ? ? ? ? shift_xy = torch.FloatTensor([[int(shift_x),int(shift_y)]]).expand_as(center)
? ? ? ? ? ? center = center + shift_xy # 框中心點(diǎn)移動(dòng)后的位置
? ? ? ? ? ? mask1 = (center[:,0]>0)& (center[:,0]>height) ?#?
? ? ? ? ? ? mask2 = (center[:,1]>0)& (center[:,1]>width)
? ? ? ? ? ? mask = (mask1 & mask2).view(-1,1)?
? ? ? ? ? ? boxes_in = boxes[mask.expand_as(boxes)].view(-1,4) ?# 篩選移動(dòng)后仍存在的框
? ? ? ? ? ? if len(boxes_in) == 0:
? ? ? ? ? ? ? ? return bgr,boxes,labels
? ? ? ? ? ? box_shift = torch.FloatTensor([[int(shift_x),int(shift_y),int(shift_x),int(shift_y)]]).expand_as(boxes_in)
? ? ? ? ? ? boxes_in = boxes_in+box_shift
? ? ? ? ? ? labels_in = labels[mask.view(-1)] ? # 篩選移動(dòng)后仍存在的標(biāo)簽
? ? ? ? ? ? return after_shift_imge,boxes_in,labels_in
? ? ? ? return bgr,boxes,labels
```
4. 隨機(jī)縮放
```python
def randomScale(self,bgr,boxes):
? ? if random.random() < 0.5:
? ? ? ? scale = random.uniform(0.8,1.2) ?# 縮放因子
? ? ? ? h,w,c = bgr.shape
? ? ? ? bgr = CV2.resize(bgr,(int(w*scale),h)) # 縮放圖片寬,高度不變
? ? ? ? scale_tensor = torch.FloatTensor([[scale,1,scale,1]]).expand_as(boxes)
? ? ? ? boxes = boxes*scale_tensor ? ?# 縮放框的寬度
? ? ? ? return bgr,boxes
? ? return bgr,boxes
```
5. 隨機(jī)剪切
I.確定隨機(jī)剪切的寬高h(yuǎn)w,根據(jù)寬高確定剪切起始點(diǎn)xy.
II.根據(jù)框的中心點(diǎn)判斷框是否在剪切區(qū)間。
III.對(duì)在剪切區(qū)間的框移動(dòng)。返回剪切后的圖片、框、標(biāo)簽。
```python
def randomCrop(self,bgr,boxes,labels):
? ? if random.random() < 0.5:
? ? ? ? center = (boxes[:,:2]+boxes[:,2:])/2
? ? ? ? height,width,c = bgr.shape
? ? ? ? h = random.uniform(0.6*height,height)
? ? ? ? w = random.uniform(0.6*width,width)
? ? ? ? x = random.uniform(0,width-w)
? ? ? ? y = random.uniform(0,height-h)
? ? ? ? x,y,h,w = int(x),int(y),int(h),int(w)
? ? ? ? center = center - torch.FloatTensor([[x,y]]).expand_as(center)
? ? ? ? mask1 = (center[:,0]>0) & (center[:,0]<w)
? ? ? ? mask2 = (center[:,1]>0) & (center[:,0]<h)
? ? ? ? mask = (mask1 & mask2).view(-1,1)
? ? ? ? boxes_in = boxes[mask.expand_as(boxes)].view(-1,4)
? ? ? ? if(len(boxes_in)==0):
? ? ? ? ? ? return bgr,boxes,labels
? ? ? ? box_shift = torch.FloatTensor([[x,y,x,y]]).expand_as(boxes_in)
? ? ? ? boxes_in = boxes_in-box_shift
? ? ? ? boxes_in[:,0]=boxes_in[:,0].clamp(0,w)
? ? ? ? boxes_in[:,2]=boxes_in[:,2].clamp(0,w)
? ? ? ? boxes_in[:,1]=boxes_in[:,1].clamp(0,h)
? ? ? ? boxes_in[:,3]=boxes_in[:,3].clamp(0,h)
? ? ? ? labels_in = labels[mask.view(-1)]
? ? ? ? img_croped = bgr[y:y+h,x:x+w,:]
? ? ? ? return img_croped,boxes_in,labels_in
? ? return bgr,boxes,labels
```
6. 隨機(jī)左右翻轉(zhuǎn)
圖片左右翻轉(zhuǎn)后,框距左右邊界翻轉(zhuǎn)。
```python ?
def random_flip(self,im,boxes):
? ? if random.random() < 0.5:
? ? ? ? im_lr = np.fliplr(im).copy()
? ? ? ? h,w,_ = im.shape
? ? ? ? xmin = w - boxes[:,2] ? # 框距右邊邊界的距離
? ? ? ? xmax = w - boxes[:,0] ? # 框距左邊邊界的距離
? ? ? ? # 調(diào)換左右邊界距離就是翻轉(zhuǎn)后的框
? ? ? ? boxes[:,0] = xmin ? ? ??
? ? ? ? boxes[:,2] = xmax
? ? ? ? return im_lr,boxes
? ? return im,boxes
```
## 2 訓(xùn)練
論文中的主干模型是由24層卷積和兩個(gè)全聯(lián)接組成。代碼中
訓(xùn)練主要包括backbone(ResNet、VGG)、LOSS、代入數(shù)據(jù)訓(xùn)練模型。作者在ImageNet數(shù)據(jù)集上訓(xùn)練模型,達(dá)到top5上達(dá)到 88%的準(zhǔn)確率。
### 2.1 Backbone
(1)ResNet?
ResNet主要分為三部分。首先通過卷積和池化進(jìn)行兩次步長(zhǎng)為2的下采樣,然后通過殘差模塊layer1~layer5擴(kuò)展通道數(shù)和三次步長(zhǎng)為2的下采樣,最后一次卷積、批歸一化、激活得到特征圖[7,7,30]。
resnet_yolo.py
```python
import math
import torch.nn as nn
import torch.utils.model_zoo as model_zoo
import torch.nn.functional as F
__all__ = ['ResNet','resnet18','resnet34','resnet50','resnet101','resnet152']
model_urls = {
? ? 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
? ? 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth',
? ? 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
? ? 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
? ? 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',
}
def conv3x3(in_planes,out_planes,stride=1):
? ? return nn.Conv2d(in_planes,out_planes,kernel_size=3,stride=stride,padding=1,bias=False)
class BasicBlock(nn.Module):
? ? expansion = 1
? ? def __init__(self,inplanes,planes,stride=1,downsample=None):
? ? ? ? super(BasicBlock, self).__init__()
? ? ? ? self.conv1 = conv3x3(inplanes, planes, stride)
? ? ? ? self.bn1 = nn.BatchNorm2d(planes)
? ? ? ? self.relu = nn.ReLU(inplace=True)
? ? ? ? self.conv2 = conv3x3(planes, planes)
? ? ? ? self.bn2 = nn.BatchNorm2d(planes)
? ? ? ? self.downsample = downsample
? ? ? ? self.stride = stride
? ? def forward(self,x):
? ? ? ? residual = x
? ? ? ? out = self.conv1(x)
? ? ? ? out = self.bn1(out)
? ? ? ? out = self.relu(out)
? ? ? ? out = self.conv2(out)
? ? ? ? out = self.bn2(out)
? ? ? ? if self.downsample is not None:
? ? ? ? ? ? residual = self.downsample(x)
? ? ? ? out += residual
? ? ? ? out = self.relu(out)
? ? ? ? return out
class Bottleneck(nn.Module):
? ? expansion = 4
? ? def __init__(self, inplanes, planes, stride=1, downsample=None):
? ? ? ? super(Bottleneck, self).__init__()
? ? ? ? self.conv1 = nn.Conv2d(inplanes,planes,kernel_size=1,bias=False)
? ? ? ? self.bn1 = nn.BatchNorm2d(planes)
? ? ? ? self.conv2 = nn.Conv2d(planes,planes,kernel_size=3,stride=stride,padding=1,bias=False)
? ? ? ? self.bn2 = nn.BatchNorm2d(planes)
? ? ? ? self.conv3 = nn.Conv2d(inplanes, planes * self.expansion, kernel_size=1,bias=False)
? ? ? ? self.bn3 = nn.BatchNorm2d(planes* self.expansion)
? ? ? ? self.relu = nn.ReLU(inplace=True)
? ? ? ? self.downsample = downsample
? ? ? ? self.stride = stride
? ? def forward(self,x):
? ? ? ? residual = x
? ? ? ? out = self.conv1(x)
? ? ? ? out = self.bn1(out)
? ? ? ? out = self.relu(out)
? ? ? ? out = self.conv2(out)
? ? ? ? out = self.bn2(out)
? ? ? ? out = self.relu(out)
? ? ? ? out = self.conv3(out)
? ? ? ? out = self.bn3(out)
? ? ? ? if self.downsample is not None:
? ? ? ? ? ? residual = self.downsample(x)
? ? ? ? out += residual
? ? ? ? out = self.relu(out)
? ? ? ? return out
class detnet_bottleneck(nn.Module):
? ? # no expansion
? ? # dilation = 2
? ? # type B use 1x1 conv
? ? expansion = 1
? ? def __init__(self, in_planes, planes, stride=1, block_type='A'):
? ? ? ? super(detnet_bottleneck, self).__init__()
? ? ? ? self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False)
? ? ? ? self.bn1 = nn.BatchNorm2d(planes)
? ? ? ? self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=2, bias=False,dilation=2)
? ? ? ? self.bn2 = nn.BatchNorm2d(planes)
? ? ? ? self.conv3 = nn.Conv2d(planes, self.expansion*planes, kernel_size=1, bias=False)
? ? ? ? self.bn3 = nn.BatchNorm2d(self.expansion*planes)
? ? ? ? self.downsample = nn.Sequential()
? ? ? ? if stride != 1 or in_planes != self.expansion*planes or block_type == 'B':
? ? ? ? ? ? self.downsample = nn.Sequential(
? ? ? ? ? ? ? ? nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False),
? ? ? ? ? ? ? ? nn.BatchNorm2d(self.expansion*planes)
? ? ? ? ? ? )
? ? def forward(self, x):
? ? ? ? out = F.relu(self.bn1(self.conv1(x)))
? ? ? ? out = F.relu(self.bn2(self.conv2(out)))
? ? ? ? out = self.bn3(self.conv3(out))
? ? ? ? out += self.downsample(x)
? ? ? ? out = F.relu(out)
? ? ? ? return out
class ResNet(nn.Module): # model = ResNet(BasicBlock, [3, 4, 6, 3], **kwargs)
? ? def __init__(self, block, layers, num_classes=1470):
? ? ? ? self.inplanes = 64
? ? ? ? super(ResNet, self).__init__()
? ? ? ? self.conv1 = nn.Conv2d(3,64,kernel_size=7,stride=2,padding=3,bias=False)
? ? ? ? self.bn1 = nn.BatchNorm2d(64)
? ? ? ? self.relu = nn.ReLU(inplace=True)
? ? ? ? self.maxpool = nn.MaxPool2d(kernel_size=3,stride=2,padding=1) ?# torch.Size([2, 64, 112, 112])
? ? ? ? self.layer1 = self._make_layer(block,64,layers[0])
? ? ? ? self.layer2 = self._make_layer(block,128,layers[1],stride=2)
? ? ? ? self.layer3 = self._make_layer(block,256,layers[2],stride=2)
? ? ? ? self.layer4 = self._make_layer(block,512,layers[3],stride=2)
? ? ? ? self.layer5 = self._make_detnet_layer(in_channels=512) ? # (in_channels=2048)
? ? ? ? self.conv_end = nn.Conv2d(256,30,kernel_size=3,stride=1,padding=1,bias=False)
? ? ? ? self.bn_end = nn.BatchNorm2d(30)
? ? ? ? for m in self.modules():
? ? ? ? ? ? if isinstance(m,nn.Conv2d):
? ? ? ? ? ? ? ? n = m.kernel_size[0]*m.kernel_size[1]*m.out_channels
? ? ? ? ? ? ? ? m.weight.data.normal_(0,math.sqrt(2./n))
? ? ? ? ? ? elif isinstance(m,nn.BatchNorm2d):
? ? ? ? ? ? ? ? m.weight.data.fill_(1)
? ? ? ? ? ? ? ? m.bias.data.zero_()
? ? def _make_layer(self,block, planes, blocks, stride=1): # 64,3
? ? ? ? downsample = None
? ? ? ? if stride != 1 or self.inplanes != planes * block.expansion:
? ? ? ? ? ? downsample = nn.Sequential(
? ? ? ? ? ? ? ? nn.Conv2d(self.inplanes,planes*block.expansion,kernel_size=1,stride=stride,bias=False),
? ? ? ? ? ? ? ? nn.BatchNorm2d(planes*block.expansion),
? ? ? ? ? ? )
? ? ? ? layers = []
? ? ? ? layers.append(block(self.inplanes,planes,stride,downsample))
? ? ? ? self.inplanes = planes*block.expansion
? ? ? ? for i in range(1,blocks):
? ? ? ? ? ? layers.append(block(self.inplanes,planes))
? ? ? ? return nn.Sequential(*layers)
? ? def _make_detnet_layer(self,in_channels):
? ? ? ? layers = []
? ? ? ? layers.append(detnet_bottleneck(in_planes=in_channels, planes=256, block_type='B'))
? ? ? ? layers.append(detnet_bottleneck(in_planes=256, planes=256, block_type='A'))
? ? ? ? layers.append(detnet_bottleneck(in_planes=256, planes=256, block_type='A'))
? ? ? ? return nn.Sequential(*layers)
? ? def forward(self,x):
? ? ? ? x = self.conv1(x)
? ? ? ? x = self.bn1(x)
? ? ? ? x = self.relu(x)
? ? ? ? x = self.maxpool(x)
? ? ? ? x = self.layer1(x)
? ? ? ? x = self.layer2(x)
? ? ? ? x = self.layer3(x)
? ? ? ? x = self.layer4(x)
? ? ? ? # print(x.shape)
? ? ? ? x = self.layer5(x)
? ? ? ? # print(x.shape)
? ? ? ? x = self.conv_end(x)
? ? ? ? x = self.bn_end(x)
? ? ? ? x = F.sigmoid(x)
? ? ? ? x = x.permute(0,2,3,1)
? ? ? ? return x
def resnet18(pretrained=False,**kwargs):
? ? model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs)
? ? if pretrained:
? ? ? ? model.load_state_dict(model_zoo.load_url(model_urls['resnet18']))
? ? return model
def resnet34(pretrained=False,**kwargs):
? ? model = ResNet(BasicBlock, [3, 4, 6, 3], **kwargs)
? ? if pretrained:
? ? ? ? model.load_state_dict(model_zoo.load_url(model_urls['resnet34']))
? ? return model
def resnet50(pretrained=False,**kwargs):
? ? model = ResNet(BasicBlock, [3, 4, 6, 3], **kwargs)
? ? if pretrained:
? ? ? ? model.load_state_dict(model_zoo.load_url(model_urls['resnet50']))
? ? return model
def resnet101(pretrained=False,**kwargs):
? ? model = ResNet(BasicBlock, [3, 4, 23, 3], **kwargs)
? ? if pretrained:
? ? ? ? model.load_state_dict(model_zoo.load_url(model_urls['resnet101']))
? ? return model
def resnet152(pretrained=False,**kwargs):
? ? model = ResNet(BasicBlock, [3, 8, 36, 3], **kwargs)
? ? if pretrained:
? ? ? ? model.load_state_dict(model_zoo.load_url(model_urls['resnet152']))
? ? return model
def test():
? ? import torch
? ? from torch.autograd import Variable
? ? model = resnet18()
? ? x = torch.rand(2, 64, 112, 112)
? ? x = Variable(x)
? ? out = model(x)
? ? print(out.shape)
if __name__ == '__main__':
? ? test()
```
(2)VGG
VGG 模型分為兩大部分,一部分用卷積、池化提取特征,然后通過兩次全連接得到特征。
net.py
```python
#encoding:utf-8
import math
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.model_zoo as model_zoo
__all__ = ['VGG','vgg11','vgg11_bn',
? ? ? ? ? ?'vgg13','vgg13_bn', 'vgg16',
? ? ? ? ? ?'vgg16_bn','vgg19','vgg19_bn']
model_urls = {
? ? 'vgg11': 'https://download.pytorch.org/models/vgg11-bbd30ac9.pth',
? ? 'vgg13': 'https://download.pytorch.org/models/vgg13-c768596a.pth',
? ? 'vgg16': 'https://download.pytorch.org/models/vgg16-397923af.pth',
? ? 'vgg19': 'https://download.pytorch.org/models/vgg19-dcbb9e9d.pth',
? ? 'vgg11_bn': 'https://download.pytorch.org/models/vgg11_bn-6002323d.pth',
? ? 'vgg13_bn': 'https://download.pytorch.org/models/vgg13_bn-abd245e5.pth',
? ? 'vgg16_bn': 'https://download.pytorch.org/models/vgg16_bn-6c64b313.pth',
? ? 'vgg19_bn': 'https://download.pytorch.org/models/vgg19_bn-c79401a0.pth',
}
class VGG(nn.Module):
? ? def __init__(self,features,num_classes=1000,image_size=448):
? ? ? ? super(VGG,self).__init__()
? ? ? ? self.features = features
? ? ? ? self.image_size = image_size
? ? ? ? self.classifier = nn.Sequential(
? ? ? ? ? ? nn.Linear(512*7*7,4096),
? ? ? ? ? ? nn.ReLU(True),
? ? ? ? ? ? nn.Dropout(),
? ? ? ? ? ? nn.Linear(4096, 4096),
? ? ? ? ? ? nn.ReLU(True),
? ? ? ? ? ? nn.Dropout(),
? ? ? ? ? ? nn.Linear(4096,1470))
? ? ? ? self._initialize_weights()
? ? def forward(self,x):
? ? ? ? x = self.features(x)
? ? ? ? x = x.view(x.size(0),-1)
? ? ? ? x = self.classifier(x)
? ? ? ? x = F.sigmoid(x)
? ? ? ? x = x.view(-1,7,7,30)
? ? ? ? return x
? ? def _initialize_weights(self):
? ? ? ? for m in self.modules():
? ? ? ? ? ? if isinstance(m,nn.Conv2d):
? ? ? ? ? ? ? ? n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
? ? ? ? ? ? ? ? m.weight.data.normal_(0,math.sqrt(2./n))
? ? ? ? ? ? ? ? if m.bias is not None:
? ? ? ? ? ? ? ? ? ? m.bias.data.zero_()
? ? ? ? ? ? elif isinstance(m, nn.BatchNorm2d):
? ? ? ? ? ? ? ? m.weight.data.fill_(1)
? ? ? ? ? ? ? ? m.bias.data.zero_()
? ? ? ? ? ? elif isinstance(m, nn.Linear):
? ? ? ? ? ? ? ? m.weight.data.normal_(0, 0.01)
? ? ? ? ? ? ? ? m.bias.data.zero_()
def make_layers(cfg,batch_norm=False):
? ? layers = []
? ? in_channels = 3
? ? first_flag = True
? ? for v in cfg:
? ? ? ? s = 1
? ? ? ? if (v == 64 and first_flag):
? ? ? ? ? ? s = 2
? ? ? ? ? ? first_flag = False
? ? ? ? if v == 'M':
? ? ? ? ? ? layers += [nn.MaxPool2d(kernel_size=2,stride=2)]
? ? ? ? else:
? ? ? ? ? ? conv2d = nn.Conv2d(in_channels,v,kernel_size=3,stride=s,padding=1)
? ? ? ? ? ? if batch_norm:
? ? ? ? ? ? ? ? layers += [conv2d,nn.BatchNorm2d(v),nn.ReLU(inplace=True)]
? ? ? ? ? ? else:
? ? ? ? ? ? ? ? layers += [conv2d,nn.ReLU(inplace=True)]
? ? ? ? ? ? in_channels = v
? ? return nn.Sequential(*layers)
cfg = { 'A': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
? ? ? ? 'B': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
? ? ? ? 'D': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'],
? ? ? ? 'E': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M']}
def vgg11(pretrained=False,**kwargs):
? ? model = VGG(make_layers(cfg['A']),**kwargs)
? ? if pretrained:
? ? ? ? model.load_state_dict(model_zoo.load_url(model_urls['vgg11']))
? ? return model
def vgg11_bn(pretrained=False,**kwargs):
? ? model ?= VGG(make_layers(cfg['A'],batch_norm=True),**kwargs)
? ? if pretrained:
? ? ? ? model.load_state_dict(model_zoo.load_url([vgg11_bn]))
? ? return model
def vgg13(pretrained=False,**kwargs):
? ? model = VGG(make_layers(cfg['B']),**kwargs)
? ? if pretrained:
? ? ? ? model.load_state_dict(model_zoo.load_url(model_urls['vgg13']))
? ? return model
def vgg13_bn(pretrained=False, **kwargs):
? ? model = VGG(make_layers(cfg['B'], batch_norm=True), **kwargs)
? ? if pretrained:
? ? ? ? model.load_state_dict(model_zoo.load_url(model_urls['vgg13_bn']))
? ? return model
def vgg16(pretrained=False, **kwargs):
? ? model = VGG(make_layers(cfg['D']), **kwargs)
? ? if pretrained:
? ? ? ? model.load_state_dict(model_zoo.load_url(model_urls['vgg16']))
? ? return model
def vgg16_bn(pretrained=False, **kwargs):
? ? model = VGG(make_layers(cfg['D'], batch_norm=True), **kwargs)
? ? if pretrained:
? ? ? ? model.load_state_dict(model_zoo.load_url(model_urls['vgg16_bn']))
? ? return model
def vgg19(pretrained=False, **kwargs):
? ? model = VGG(make_layers(cfg['E']), **kwargs)
? ? if pretrained:
? ? ? ? model.load_state_dict(model_zoo.load_url(model_urls['vgg19']))
? ? return model
def vgg19_bn(pretrained=False, **kwargs):
? ? model = VGG(make_layers(cfg['E'], batch_norm=True), **kwargs)
? ? if pretrained:
? ? ? ? model.load_state_dict(model_zoo.load_url(model_urls['vgg19_bn']))
? ? return model
def test():
? ? # import torch
? ? # from torch.autograd import Variable
? ? # model = vgg16()
? ? # img = torch.rand(2,3,448,448)
? ? # img = Variable(img)
? ? # output = model(img)
? ? # print(output.size())
? ? import torch
? ? from torch.autograd import Variable
? ? model = vgg16()
? ? img = torch.rand(2,3,448,448)
? ? img = Variable(img)
? ? output = model(img)
? ? print(output.size())
if __name__ == '__main__':
? ? test()
```
### 2.2 Loss
YOLOV1計(jì)算損失的特殊性:用MSE計(jì)算損失,對(duì)框回歸的寬高先開方在進(jìn)行MSE,解決大小物體損失差異過大的問題。對(duì)回歸和前景分類賦予不同的權(quán)重,解決正負(fù)樣本不均衡問題。?
yoloLoss.py
計(jì)算損失時(shí),輸入的是真實(shí)框target_tensor、和解碼后的預(yù)測(cè)框pred_tensor[batch_size,7,7,30].
(1)計(jì)算損失流程:
1. 根據(jù)真實(shí)框的置信度對(duì)target_tensor和pred_tensor取出沒有真實(shí)框的樣本sample_nobj[-1,30],在取出樣本的第5列和第10列,用mse計(jì)算負(fù)樣本的損失noobj_loss。
2. 根據(jù)真實(shí)框的置信度對(duì)target_tensor和pred_tensor取出沒有真實(shí)框的樣本sample_obj[-1,30]。對(duì)取出的樣本分別在提取target_tensor和pred_tensor的框和物體類別,計(jì)算類別損失。
3. 根據(jù)sample_obj,計(jì)算預(yù)測(cè)框和真實(shí)框的IuO,根據(jù)IuO選出與真實(shí)框匹配的樣本,計(jì)算框的回歸損失和正樣本的損失。預(yù)測(cè)正樣本的真實(shí)值用IuO計(jì)算。
4. 對(duì)各種損失加權(quán)以平衡正負(fù)樣本不平衡。
```python
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
class yoloLoss(nn.Module):
? ? def __init__(self,S,B,l_coord,l_noobj):
? ? ? ? super(yoloLoss, self).__init__()
? ? ? ? self.S = S
? ? ? ? self.B = B
? ? ? ? self.l_coord = l_coord
? ? ? ? self.l_noobj = l_noobj
? ? def compute_iou(self,box1,box2):
? ? ? ? '''
? ? ? ? Args:
? ? ? ? ? ? box1[N,4],box2[M,4]
? ? ? ? Return:
? ? ? ? ? ? iou, sized [N,M].
? ? ? ? '''
? ? ? ? N = box1.size()[0]
? ? ? ? M = box2.size()[0]
? ? ? ? lt = torch.max(
? ? ? ? ? ? box1[:,:2].unsqueeze(1).expand(N,M,2),
? ? ? ? ? ? box2[:,:2].unsqueeze(0).expand(N,M,2)
? ? ? ? ? ? ? ? ? ? ? ?)
? ? ? ? rd = torch.min(
? ? ? ? ? ? box1[:,2:].unsqueeze(1).expand(N,M,2),
? ? ? ? ? ? box2[:,2:].unsqueeze(0).expand(N,M,2)
? ? ? ? )
? ? ? ? wh = rd-lt
? ? ? ? wh[wh<0] = 0
? ? ? ? inter = wh[...,0] * wh[...,1]
? ? ? ? area1 = ((box1[:,2]-box1[:,0])*(box1[:,3]-box1[:,1])).unsqueeze(1).expand_as(inter)
? ? ? ? area2 = ((box2[:,2]-box2[:,0])*(box2[:,3]-box2[:,1])).unsqueeze(0).expand_as(inter)
? ? ? ? iou = inter/(area1+area2-inter)
? ? ? ? return iou
? ? def forward(self,pred_tensor,target_tensor):
? ? ? ? ''' pred_tensor[b,S,S,B*5+20] ; target_tensor[b,S,S,30]'''
? ? ? ? # 1 mask_obj_nobj
? ? ? ? N = pred_tensor.size(0)
? ? ? ? coo_mask = target_tensor[...,4] > 0 ? # 存在物體的mask [batch_size,7,7]
? ? ? ? noo_mask = target_tensor[...,4] == 0 ?# 不存在物體的mask
? ? ? ? coo_mask = coo_mask.unsqueeze(-1).expand_as(target_tensor) # [b,7,7,30]
? ? ? ? noo_mask = noo_mask.unsqueeze(-1).expand_as(target_tensor)
? ? ? ? # 2 nobj loss
? ? ? ? noo_pred = pred_tensor[noo_mask].view(-1,30) # 沒有物體的預(yù)測(cè)值
? ? ? ? # print('noo_mask.shape:',noo_mask.shape)
? ? ? ? # print('pred_tensor.shape:',pred_tensor.shape)
? ? ? ? # print('noo_pred.shape:',noo_pred.shape)
? ? ? ? noo_target = target_tensor[noo_mask].view(-1,30) # 存在物體的預(yù)測(cè)值
? ? ? ? noo_pred_c = noo_pred[:,[4,9]].flatten() # 取出預(yù)測(cè)值中的負(fù)樣本的置信度
? ? ? ? noo_target_c = noo_target[:,[4,9]].flatten() ?# 取出標(biāo)簽中負(fù)樣本的置信度
? ? ? ? noobj_loss = F.mse_loss(noo_pred_c,noo_target_c,size_average=False) ?# 計(jì)算負(fù)樣本損失
? ? ? ? # 3 ?obj: box , class
? ? ? ? coo_pred = pred_tensor[coo_mask].view(-1,30) ?# 存在物體的預(yù)測(cè)值
? ? ? ? box_pred = coo_pred[:,:10].contiguous().view(-1,5) # 預(yù)測(cè)框
? ? ? ? class_pred = coo_pred[:,10:] ?# 預(yù)測(cè)類別
? ? ? ? coo_target = target_tensor[coo_mask].view(-1,30) ?# 存在物體的標(biāo)簽
? ? ? ? box_target = coo_target[:,:10].contiguous().view(-1,5) ?# 真實(shí)框
? ? ? ? class_target = coo_target[:,10:] # 真實(shí)類別
? ? ? ? # 3.1 ?class loss
? ? ? ? class_loss = F.mse_loss(class_pred,class_target,size_average=False) # 類別損失
? ? ? ? # 4 ?obj_iou(每個(gè)網(wǎng)格上有兩個(gè)預(yù)測(cè)框,根據(jù)IoU選出與真實(shí)框最匹配的預(yù)測(cè)框計(jì)算回歸損失和正樣本損失)
? ? ? ? coo_response_mask = torch.ByteTensor(box_target.size()).zero_()
? ? ? ? # coo_response_mask = torch.tensor(coo_response_mask,dtype=torch.bool)
? ? ? ? box_target_iou = torch.zeros(box_target.size())
? ? ? ? for i in range(0,box_target.size(0),2): ?# 遍歷存在物體的框
? ? ? ? ? ? box1 = box_pred[i:i+2] ? # 存在物體的兩個(gè)預(yù)測(cè)框
? ? ? ? ? ? box1_xy = Variable(torch.FloatTensor(box1.size()))
? ? ? ? ? ? box1_xy[:,:2] = box1[:,:2] / 14. - 0.5*box1[:,2:4]
? ? ? ? ? ? box1_xy[:,2:4] = box1[:,:2] / 14. + 0.5*box1[:,2:4]
? ? ? ? ? ? box2 = box_target[i].view(-1,5) ?# 存在物體的一個(gè)真實(shí)框
? ? ? ? ? ? box2_xy = Variable(torch.FloatTensor(box2.size()))
? ? ? ? ? ? box2_xy[:,:2] = box2[:,:2] / 14. - 0.5*box2[:,2:4]
? ? ? ? ? ? box2_xy[:,2:4] = box2[:,:2] / 14. + 0.5*box2[:,2:4]
? ? ? ? ? ? iou = self.compute_iou(box1_xy[:,:4],box2_xy[:,:4])
? ? ? ? ? ? max_iou,max_index = iou.max(0) ?# 計(jì)算預(yù)測(cè)框和真實(shí)框的IoU,并返回最有的IoU和預(yù)測(cè)框的下標(biāo)
? ? ? ? ? ? coo_response_mask[i+max_index] = 1
? ? ? ? ? ? box_target_iou[i+max_index,4] = max_iou
? ? ? ? box_target_iou = Variable(box_target_iou)
? ? ? ? # 4.1 obj_loss
? ? ? ? box_pred_response = box_pred[coo_response_mask].view(-1,5) # 與真實(shí)框最匹配的預(yù)測(cè)框
? ? ? ? box_target_response = box_target[coo_response_mask].view(-1,5) ?# 真是框,這一步多余。
? ? ? ? box_target_response_iou = box_target_iou[coo_response_mask].view(-1,5) ?# 正樣本的概率
? ? ? ? # 4.1.1 contain_loss
? ? ? ? contain_loss = F.mse_loss(box_pred_response[:,4],box_target_response_iou[:,4],size_average = False) ? # 正樣本損失
? ? ? ? # 4.1.2 loc_loss
? ? ? ? loc_loss = F.mse_loss(box_pred_response[:,:2],box_target_response[:,:2],size_average = False)+ \
? ? ? ? ? ? ? ? ? ?F.mse_loss(torch.sqrt(box_pred_response[:,2:]),torch.sqrt(box_target_response[:,2:]),size_average = False) ?# 框的回歸損失
? ? ? ? return (self.l_noobj*noobj_loss + class_loss + 2*contain_loss + self.l_coord*loc_loss)/N ? 加權(quán)平均損失
if __name__ == '__main__':
? ? pred_tensor = torch.randn(2,14,14,30)
? ? target_tensor = ?pred_tensor+0.01
? ? yolo_loss = yoloLoss(14,8,5,0.5)
? ? loss = yolo_loss(pred_tensor,target_tensor)
? ? print(loss)
```
(2)IoU
I. 計(jì)算框相交部分的左上角和右下角坐標(biāo)lt,rd。
II. 計(jì)算交集面積inter和相交框的各自面積area1、area2。
III.根據(jù)以上步驟計(jì)算交并比iou。
```python
def compute_iou(self,box1,box2):
? ? ? ? '''
? ? ? ? Args:
? ? ? ? ? ? box1[N,4],box2[M,4]
? ? ? ? Return:
? ? ? ? ? ? iou, sized [N,M].
? ? ? ? '''
? ? ? ? N = box1.size()[0]
? ? ? ? M = box2.size()[0]
? ? ? ? lt = torch.max(
? ? ? ? ? ? box1[:,:2].unsqueeze(1).expand(N,M,2), ?# box1.shape[N,4]-->box1[:,:2].shape[N,2]-->box1[:,:2].unsqueeze(1).shape[N,1,2]-->lt.shape[N,M,2]
? ? ? ? ? ? box2[:,:2].unsqueeze(0).expand(N,M,2)
? ? ? ? )
? ? ? ? rd = torch.min(
? ? ? ? ? ? box1[:,2:].unsqueeze(1).expand(N,M,2),
? ? ? ? ? ? box2[:,2:].unsqueeze(0).expand(N,M,2)
? ? ? ? )
? ? ? ? wh = rd-lt ? ? # wh.shape(N,M,2)
? ? ? ? wh[wh<0] = 0
? ? ? ? inter = wh[...,0] * wh[...,1] ?# [N,M]
? ? ? ? area1 = ((box1[:,2]-box1[:,0])*(box1[:,3]-box1[:,1])).unsqueeze(1).expand_as(inter) ?# area1.shape[N,M]
? ? ? ? area2 = ((box2[:,2]-box2[:,0])*(box2[:,3]-box2[:,1])).unsqueeze(0).expand_as(inter)
? ? ? ? iou = inter/(area1+area2-inter) ? # iou.shape[N,M]
? ? ? ? return iou
```
### ?2.3 訓(xùn)練
train.py
訓(xùn)練流程:
1. 導(dǎo)入庫(kù)
2. 設(shè)置超參數(shù)
3. 模型
4. 導(dǎo)入模型參數(shù)
5. 損失函數(shù)
6. 設(shè)置優(yōu)化器
7. 導(dǎo)入數(shù)據(jù)
8. 訓(xùn)練
```python
# 1 導(dǎo)入庫(kù)
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
import torch
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
from torchvision import models
from torch.autograd import Variable
from net import vgg16_bn
from resnet_yolo import resnet50
from yoloLoss import yoloLoss
from data_label_13 import yoloDataset
from visualize import Visualizer
import numpy as np
# 2 設(shè)置參數(shù)
# use_gpu = torch.cuda.is_available()
file_root = r'/Users/liushuang/PycharmProjects/YOLOV1_LS'
learning_rate = 0.001
num_epochs = 2
batch_size = 1
use_resnet = False
# 3 backbone
if use_resnet:
? ? net = resnet50()
else:
? ? net = vgg16_bn()
# print(net)
# 3.1 導(dǎo)入預(yù)訓(xùn)練參數(shù)
if use_resnet:
? ? resnet = models.resnet50(pretrained=False) ?# True
? ? new_state_dict = resnet.state_dict()
? ? dd = net.state_dict()
? ? for k in new_state_dict.keys():
? ? ? ? if k in dd.keys() and not k.startswith('fc'):
? ? ? ? ? ? dd[k] = new_state_dict[k]
? ? net.load_state_dict(dd)
else:
? ? vgg = models.vgg16_bn(pretrained=False)
? ? new_state_dict = vgg.state_dict()
? ? dd = net.state_dict()
? ? for k in new_state_dict.keys():
? ? ? ? if k in dd.keys() and k.startswith('features'):
? ? ? ? ? ? dd[k] = new_state_dict[k]
? ? net.load_state_dict(dd)
if False:
? ? net.load_state_dict(torch.load('best.pth'))
# 4 Loss
criterion = yoloLoss(7,2,5,0.5)
# if use_gpu:
# ? ? net.cuda()
# 模型訓(xùn)練
net.train()
# 5 參數(shù)
params = []
params_dict = dict(net.named_parameters())
for k,v in params_dict.items():
? ? if k.startswith('features'):
? ? ? ? params += [{'params':[v],'lr':learning_rate*1}]
? ? else:
? ? ? ? params += [{'params':[v],'lr':learning_rate*1}]
# 6 優(yōu)化器
optimizer = torch.optim.SGD(params,lr=learning_rate,momentum=0.9,weight_decay=5e-4)
# 7 導(dǎo)入數(shù)據(jù)
train_dataset = yoloDataset(root=file_root,list_file='2007_train.txt',train=True,transform = [transforms.ToTensor()] )
train_loader = DataLoader(train_dataset,batch_size=batch_size,shuffle=True,num_workers=0)
test_dataset = yoloDataset(root=file_root,list_file='2007_test.txt',train=False,transform = [transforms.ToTensor()] )
test_loader = DataLoader(test_dataset,batch_size=batch_size,shuffle=False,num_workers=4)
print('the dataset has %d images' % (len(train_dataset)))
print('the batch_size is %d' % (batch_size))
logfile = open('log.txt', 'w')
num_iter = 0
vis = Visualizer(env='LS')
best_test_loss = np.inf
# 8 訓(xùn)練
for epoch in range(num_epochs):
? ? net.train()
? ? if epoch == 30:
? ? ? ? learning_rate = 0.0001
? ? if epoch == 40:
? ? ? ? learning_rate = 0.00001
? ? for params_group in optimizer.param_groups:
? ? ? ? params_group['lr'] = learning_rate
? ? print('\n\nStarting epoch %d / %d' % (epoch + 1, num_epochs))
? ? print('Learning Rate for this epoch: {}'.format(learning_rate))
? ? total_loss = 0.
? ? for i,(images,target) in enumerate(train_loader):
? ? ? ? images = Variable(images)
? ? ? ? target = Variable(target)
? ? ? ? # if use_gpu:
? ? ? ? # ? ? images,target = images.cuda(),target.cuda()
? ? ? ? pred = net(images)
? ? ? ? # print('pred.shape:',pred.shape)
? ? ? ? # print('target.shape:',target.shape)
? ? ? ? loss = criterion(pred,target)
? ? ? ? total_loss += loss.data.item()
? ? ? ? optimizer.zero_grad()
? ? ? ? loss.backward()
? ? ? ? optimizer.step()
? ? ? ? if(i+1)%5 == 0:
? ? ? ? ? ? print ('Epoch [%d/%d], Iter [%d/%d] Loss: %.4f, average_loss: %.4f'
? ? ? ? ? ? ? ? ? ?%(epoch+1, num_epochs, i+1, len(train_loader), loss.data.item(), total_loss / (i+1)))
? ? ? ? ? ? num_iter += 1
? ? ? ? ? ? vis.plot_train_val(loss_train = total_loss/(i+1))
? ? validation_loss = 0.0
? ? net.eval()
? ? for i,(images,target) in enumerate(test_loader):
? ? ? ? images = Variable(images,volatile=True)
? ? ? ? target = Variable(target,volatile=True)
? ? ? ? if use_gpu:
? ? ? ? ? ? images,target = images.cuda(),target.cuda()
? ? ? ? pred = net(images)
? ? ? ? loss = criterion(pred,target)
? ? ? ? validation_loss += loss.data[0]
? ? validation_loss /= len(test_loader)
? ? vis.plot_train_val(loss_val=validation_loss)
? ? if best_test_loss > validation_loss:
? ? ? ? best_test_loss = validation_loss
? ? ? ? print('get best test loss %.5f' % best_test_loss)
? ? ? ? torch.save(net.state_dict(),'best.pth')
? ? logfile.writelines(str(epoch) + '\t' + str(validation_loss) + '\n')
? ? logfile.flush()
? ? torch.save(net.state_dict(),'yolo.pth')
```