首先,感谢师兄的帮助。师兄的代码封装成类,流畅精美,容易调试。我的代码是堆积成的,被师兄嘲笑说写脚本。好吧!我的代码只有我懂,哈哈! 希望以后代码能写得工整点。现在还是让我先懂。这里,我做了一个简单的任务:0,1,2三个数字的分类。准确率:0.9806666666666667
(部分)代码分为:
1 train_net.py
1 #import some module 2 import time 3 import os 4 import numpy as np 5 import sys 6 import cv2 7 sys.path.append("/home/wang/Downloads/caffe-master/python") 8 import caffe 9 #from prepare_data import DataConfig 10 #from data_config import DataConfig 11 12 #configure GPU mode 13 ''' uncommend below line to use gpu ''' 14 caffe.set_mode_gpu() 15 16 # about dataset 17 ##dataset = Dataset('/home/wang/Downloads/object/extract/') 18 ##dataset = dataset.Split('train') 19 ##data_config = DataConfig(dataset) 20 ##data_config.SetBatchSize(256) 21 data_config='/home/wang/Downloads/caffe-master/examples/myFig_recognition/data/train/' 22 23 24 25 #configure solve.prototxt 26 solver = caffe.SGDSolver('models/solver.prototxt') 27 28 # load pretrain model 29 print('load pretrain model') 30 solver.net.copy_from('models/bvlc_reference_caffenet.caffemodel') 31 32 solver.net.layers[0].SetDataConfig(data_config) 33 34 for i in range(1, 10000): 35 # Make one SGD update 36 solver.step(5) 37 if i % 100 == 0: 38 solver.net.save('tmp.caffemodel') 39 ''' TODO: test code '''
2 test_net.py
1 #import setup 2 import time 3 import os 4 import random 5 import sys 6 sys.path.append("/home/wang/Downloads/caffe-master/python") 7 import caffe 8 import cv2 9 import numpy as np 10 import random 11 12 13 from utils import PrepareImage 14 #from dataset import Dataset 15 from test_data import test_data_pre 16 17 test_num_once=10 18 19 20 ''' uncommend below line to use gpu ''' 21 # caffe.set_mode_gpu() 22 23 # dataset 24 #dataset = Dataset('/home/wang/Downloads/object/extract/') 25 #dataset = dataset.Split('test') 26 27 # load net 28 net = caffe.Net('models/deploy.prototxt', caffe.TEST) 29 30 31 # load train model 32 print('load pretrain model') 33 net.copy_from('tmp.caffemodel') 34 35 #test all samples one by one 36 data_pre='/home/wang/Downloads/caffe-master/examples/myFig_recognition/data/test/' 37 #(imgPaths, gt_label) = dataset[int(random.random()*num_obj)] 38 (imgPaths, gt_label)=test_data_pre(data_pre) 39 num_img = len(imgPaths) 40 correct_num=0 41 for idx in range(num_img): 42 img = cv2.imread(imgPaths[idx]) 43 img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR) 44 tmp_img = img.copy() # for display 45 img = PrepareImage(img, (227, 227)) 46 net.blobs['data'].reshape(test_num_once, 3, 227, 227) 47 net.blobs['data'].data[...] = img 48 #net.blobs['data'].data[i,:,:,:] = img 49 net.forward() 50 score = net.blobs['cls_prob'].data 51 if score.argmax()==gt_label[idx]: 52 correct_num=correct_num+1 53 if idx%100==0: 54 print("Please wait some minutes...") 55 correct_rate=correct_num*1.0/num_img 56 print('The correct rate is :',correct_rate) 57 58 59
3 test_data.py
1 import os 2 import numpy as np 3 from random import randint 4 import cv2 5 from utils import PrepareImage,CatImage 6 #class data: 7 #path should be /home/ 8 def test_data_pre(path): 9 img_list=[] 10 image_num=len(os.listdir(path+'/0'))+len(os.listdir(path+'/1'))+len(os.listdir(path+'/2')) 11 label = np.zeros(image_num, dtype=np.float32) 12 13 i=0 14 for idf in range(3): 15 idf_str=str(idf) 16 path1=path+idf_str 17 tmp_path=os.listdir(path1) 18 for idi in range(len(tmp_path)): 19 img_path=path1+'/'+tmp_path[idi] 20 img_list.append(img_path) 21 label[i]=idf 22 i=i+1 23 return ( img_list,label)
4 pre_data.py
1 import os 2 import numpy as np 3 from random import randint 4 import cv2 5 from utils import PrepareImage,CatImage 6 #class data: 7 #path should be /home/ 8 def prepare_data(path,batchsize): 9 #tmp_path=os.listdir(path) 10 img_list=[] 11 label = np.zeros(batchsize, dtype=np.float32) 12 for i in range(batchsize): 13 #randomly select one file 14 idf=randint(0,2) 15 idf_str=str(idf) 16 path1=path+idf_str 17 tmp_path=os.listdir(path1) 18 19 #randomly select one image 20 idi=randint(0,len(tmp_path)-1) 21 #img = cv2.imread(imgPaths[idx]) 22 img_path=path1+'/'+tmp_path[idi] 23 img=cv2.imread(img_path) 24 25 img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR) 26 flip = randint(0, 1)>0 27 if flip > 0: 28 img = img[:, ::-1, :] # flip left to right 29 30 img=PrepareImage(img, (227,227)) 31 img_list.append(img) 32 label[i]=idf 33 imgData = CatImage(img_list) 34 return (imgData,label)
5 utils.py
1 import os 2 import cv2 3 import numpy as np 4 5 def PrepareImage(im, size): 6 im = cv2.resize(im, (size[0], size[1])) 7 im = im.transpose(2, 0, 1) 8 im = im.astype(np.float32, copy=False) 9 return im 10 11 def CatImage(im_list): 12 max_shape = np.array([im.shape for im in im_list]).max(axis=0) 13 blob = np.zeros((len(im_list), 3, max_shape[1], max_shape[2]), dtype=np.float32) 14 # set to mean value 15 blob[:, 0, :, :] = 102.9801 16 blob[:, 1, :, :] = 115.9465 17 blob[:, 2, :, :] = 122.7717 18 for i, im in enumerate(im_list): 19 blob[i, :, 0:im.shape[1], 0:im.shape[2]] = im 20 return blob
6 layer/data_layer.py
1 import caffe 2 import numpy as np 3 4 #import data_config 5 #import prepare_data 6 from pre_data import prepare_data 7 8 class DataLayer(caffe.Layer): 9 10 def SetDataConfig(self, data_config): 11 self._data_config = data_config 12 13 def GetDataConfig(self): 14 return self._data_config 15 16 def setup(self, bottom, top): 17 # data blob 18 top[0].reshape(1, 3, 227, 227) 19 #top[0].reshape(1, 3, 34, 44) 20 # label type 21 top[1].reshape(1, 1) 22 23 def reshape(self, bootom, top): 24 pass 25 26 def forward(self, bottom, top): 27 #(imgs, label) = self._data_config.next() 28 path=self.GetDataConfig() 29 (imgs,label)=prepare_data(path,128) 30 (N, C, W, H) = imgs.shape 31 # image data 32 top[0].reshape(N, C, W, H) 33 top[0].data[...] = imgs 34 # object type label 35 top[1].reshape(N) 36 top[1].data[...] = label 37 38 def backward(self, top, propagate_down, bottom): 39 pass
7 layer/__init__.py
import data_layer
还有一些caffe中经典的东西没放进来。
代码和数据:
本站文章如无特殊说明,均为本站原创,如若转载,请注明出处:python caffe 在师兄的代码上修改成自己风格的代码 - Python技术站