


  • 可以采用Faster-RCNN或SSD来实现交通灯的识别



  1. # import some libs
  2. import cv2
  3. import os
  4. import glob
  5. import random
  6. import numpy as np
  7. import matplotlib.pyplot as plt
  8. import matplotlib.image as mpimg
  9. %matplotlib inline
  1. # Image data directories
  2. IMAGEDIR_TRAINING = "traffic_light_images/training/"
  3. IMAGE_DIR_TEST = "traffic_light_images/test/"
  4. #load data
  5. def load_dataset(image_dir):
  6. '''
  7. This function loads in images and their labels and places them in a list
  8. image_dir:directions where images stored
  9. '''
  10. im_list =[]
  11. image_types= ['red','yellow','green']
  12. #Iterate through each color folder
  13. for im_type in image_types:
  14. file_lists = glob.glob(os.path.join(image_dir,im_type,'*'))
  15. print(len(file_lists))
  16. for file in file_lists:
  17. im = mpimg.imread(file)
  18. if not im is None:
  19. im_list.append((im,im_type))
  20. return im_list
  21. IMAGE_LIST = load_dataset(IMAGE_DIR_TRAINING)
  1. 723
  2. 35
  3. 429
Visualize the data


  • 显示图像
  • 打印出图片的大小
  • 打印出图片对应的标签

  1. ,ax = plt.subplots(1,3,figsize=(5,2))
  2. #red
  3. imgred = IMAGE_LIST[0][0]
  4. ax[0].imshow(img_red)
  5. ax[0].annotate(IMAGE_LIST[0][1],xy=(2,5),color='blue',fontsize='10')
  6. ax[0].axis('off')
  7. ax[0].set_title(img_red.shape,fontsize=10)
  8. #yellow
  9. img_yellow = IMAGE_LIST[730][0]
  10. ax[1].imshow(img_yellow)
  11. ax[1].annotate(IMAGE_LIST[730][1],xy=(2,5),color='blue',fontsize='10')
  12. ax[1].axis('off')
  13. ax[1].set_title(img_yellow.shape,fontsize=10)
  14. #green
  15. img_green = IMAGE_LIST[800][0]
  16. ax[2].imshow(img_green)
  17. ax[2].annotate(IMAGE_LIST[800][1],xy=(2,5),color='blue',fontsize='10')
  18. ax[2].axis('off')
  19. ax[2].set_title(img_green.shape,fontsize=10)
  20. plt.show()
PreProcess Data






  1. # 标准化输入图像,这里我们resize图片大小为32x32x3,这里我们也可以对图像进行裁剪、平移、旋转
  2. def standardize(image_list):
  3. '''
  4. This function takes a rgb image as input and return a standardized version
  5. image_list: image and label
  6. '''
  7. standard_list = []
  8. #Iterate through all the image-label pairs
  9. for item in image_list:
  10. image = item[0]
  11. label = item[1]
  12. # Standardize the input
  13. standardized_im = standardize_input(image)
  14. # Standardize the output(one hot)
  15. one_hot_label = one_hot_encode(label)
  16. # Append the image , and it's one hot encoded label to the full ,processed list of image data
  17. standard_list.append((standardized_im,one_hot_label))
  18. return standard_list
  19. def standardize_input(image):
  20. #Resize all images to be 32x32x3
  21. standard_im = cv2.resize(image,(32,32))
  22. return standard_im
  23. def one_hot_encode(label):
  24. #return the correct encoded label.
  25. '''
  26. # one_hot_encode("red") should return: [1, 0, 0]
  27. # one_hot_encode("yellow") should return: [0, 1, 0]
  28. # one_hot_encode("green") should return: [0, 0, 1]
  29. '''
  30. if label=='red':
  31. return [1,0,0]
  32. elif label=='yellow':
  33. return [0,1,0]
  34. else:
  35. return [0,0,1]
Test your code


  1. import unittest
  2. from IPython.display import Markdown,display
  3. # Helper function for printing markdown text(text in color/bold/etc)
  4. def printmd(string):
  5. display(Markdown(string))
  6. # Print a test falied message,given an error
  7. def print_fail():
  8. printmd('<span style=="color: red;">Test Failed</span>')
  9. def print_pass():
  10. printmd('<span style="color:green;">Test Passed</span>')
  11. # A class holding all tests
  12. class Tests(unittest.TestCase):
  13. #Tests the 'one_hot_encode' function,which is passed in as an argument
  14. def test_one_hot(self,one_hot_function):
  15. #test that the generate onr-hot lables match the expected one-hot label
  16. #for all three cases(red,yellow,green)
  17. try:
  18. self.assertEqual([1,0,0],one_hot_function('red'))
  19. self.assertEqual([0,1,0],one_hot_function('yellow'))
  20. self.assertEqual([0,0,1],one_hot_function('green'))
  21. #enter exception
  22. except self.failureException as e:
  23. #print out an error message
  24. print_fail()
  25. print('Your function did not return the excepted one-hot label')
  26. print('\n'+str(e))
  27. return
  28. print_pass()
  29. #Test if ay misclassified images are red but mistakenly classifed as green
  30. def test_red_aa_green(self,misclassified_images):
  31. #Loop through each misclassified image and the labels
  32. for im,predicted_label,true_label in misclassified_images:
  33. #check if the iamge is one of a red light
  34. if(true_label==[1,0,0]):
  35. try:
  36. self.assertNotEqual(true_label,[0,1,0])
  37. except self.failureException as e:
  38. print_fail()
  39. print('Warning:A red light is classified as green.')
  40. print('\n'+str(e))
  41. return
  42. print_pass()
  43. tests = Tests()
  44. tests.test_one_hot(one_hot_encode)
Test Passed

Standardized_Train_List = standardize(IMAGE_LIST)
  • 1

Feature Extraction



  1. #Visualize
  2. image_num = 0
  3. test_im = Standardized_Train_List[image_num][0]
  4. test_label = Standardized_Train_List[image_num][1]
  5. #convert to hsv
  6. hsv = cv2.cvtColor(test_im, cv2.COLOR_RGB2HSV)
  7. # Print image label
  8. print('Label [red, yellow, green]: ' + str(test_label))
  9. h = hsv[:,:,0]
  10. s = hsv[:,:,1]
  11. v = hsv[:,:,2]
  12. # Plot the original image and the three channels
  1. , ax = plt.subplots(1, 4, figsize=(20,10))
  2. ax[0].settitle('Standardized image')
  3. ax[0].imshow(test_im)
  4. ax[1].set_title('H channel')
  5. ax[1].imshow(h, cmap='gray')
  6. ax[2].set_title('S channel')
  7. ax[2].imshow(s, cmap='gray')
  8. ax[3].set_title('V channel')
  9. ax[3].imshow(v, cmap='gray')
  1. # create feature
  2. '''
  3. HSV即色相、饱和度、明度(英语:Hue, Saturation, Value),又称HSB,其中B即英语:Brightness。
  4. 色相(H)是色彩的基本属性,就是平常所说的颜色名称,如红色、黄色等。
  5. 饱和度(S)是指色彩的纯度,越高色彩越纯,低则逐渐变灰,取0-100%的数值。
  6. 明度(V),亮度(L),取0-100%。
  7. '''
  8. def create_feature(rgb_image):
  9. '''
  10. Basic brightness feature
  11. rgb_image : a rgb_image
  12. '''
  13. hsv = cv2.cvtColor(rgb_image,cv2.COLOR_RGB2HSV)
  14. sum_brightness = np.sum(hsv[:,:,2])
  15. area = 3232
  16. avg_brightness = sum_brightness / area#Find the average
  17. return avg_brightness
  18. def high_saturation_pixels(rgb_image,threshold=80):
  19. '''
  20. Returns average red and green content from high saturation pixels
  21. Usually, the traffic light contained the highest saturation pixels in the image.
  22. The threshold was experimentally determined to be 80
  23. '''
  24. high_sat_pixels = []
  25. hsv = cv2.cvtColor(rgb,cv2.COLOR_RGB2HSV)
  26. for i in range(32):
  27. for j in range(32):
  28. if hsv[i][j][1] > threshold:
  29. high_sat_pixels.append(rgb_image[i][j])
  30. if not high_sat_pixels:
  31. return highest_sat_pixel(rgb_image)
  32. sum_red = 0
  33. sum_green = 0
  34. for pixel in high_sat_pixels:
  35. sum_red+=pixel[0]
  36. sum_green+=pixel[1]
  37. # use sum() instead of manually adding them up
  38. avg_red = sum_red / len(high_sat_pixels)
  39. avg_green = sum_green / len(high_sat_pixels)0.8
  40. return avg_red,avg_green
  41. def highest_sat_pixel(rgb_image):
  42. '''
  43. Finds the highest saturation pixels, and checks if it has a higher green
  44. or a higher red content
  45. '''
  46. hsv = cv2.cvtColor(rgb_image,cv2.COLOR_RGB2HSV)
  47. s = hsv[:,:,1]
  48. x,y = (np.unravel_index(np.argmax(s),s.shape))
  49. if rgb_image[x,y,0] > rgb_image[x,y,1]*0.9:
  50. return 1,0 #red has a higher content
  51. return 0,1
Test dataset

reference url


  1. def estimate_label(rgb_image,display=False):
  2. '''
  3. rgb_image:Standardized RGB image
  4. '''
  5. return red_green_yellow(rgb_image,display)
  6. def findNoneZero(rgb_image):
  7. rows,cols,
  1. = rgbimage.shape
  2. counter = 0
  3. for row in range(rows):
  4. for col in range(cols):
  5. pixels = rgb_image[row,col]
  6. if sum(pixels)!=0:
  7. counter = counter+1
  8. return counter
  9. def red_green_yellow(rgb_image,display):
  10. '''
  11. Determines the red , green and yellow content in each image using HSV and experimentally
  12. determined thresholds. Returns a Classification based on the values
  13. '''
  14. hsv = cv2.cvtColor(rgb_image,cv2.COLOR_RGB2HSV)
  15. sum_saturation = np.sum(hsv[:,:,1])# Sum the brightness values
  16. area = 3232
  17. avg_saturation = sum_saturation / area #find average
  18. sat_low = int(avg_saturation1.3)#均值的1.3倍,工程经验
  19. val_low = 140
  20. #Green
  21. lower_green = np.array([70,sat_low,val_low])
  22. upper_green = np.array([100,255,255])
  23. green_mask = cv2.inRange(hsv,lower_green,upper_green)
  24. green_result = cv2.bitwise_and(rgb_image,rgb_image,mask = green_mask)
  25. #Yellow
  26. lower_yellow = np.array([10,sat_low,val_low])
  27. upper_yellow = np.array([60,255,255])
  28. yellow_mask = cv2.inRange(hsv,lower_yellow,upper_yellow)
  29. yellow_result = cv2.bitwise_and(rgb_image,rgb_image,mask=yellow_mask)
  30. # Red
  31. lower_red = np.array([150,sat_low,val_low])
  32. upper_red = np.array([180,255,255])
  33. red_mask = cv2.inRange(hsv,lower_red,upper_red)
  34. red_result = cv2.bitwise_and(rgb_image,rgb_image,mask = red_mask)
  35. if display==True:
  36. ,ax = plt.subplots(1,5,figsize=(20,10))
  37. ax[0].set_title('rgb image')
  38. ax[0].imshow(rgb_image)
  39. ax[1].set_title('red result')
  40. ax[1].imshow(red_result)
  41. ax[2].set_title('yellow result')
  42. ax[2].imshow(yellow_result)
  43. ax[3].set_title('green result')
  44. ax[3].imshow(green_result)
  45. ax[4].set_title('hsv image')
  46. ax[4].imshow(hsv)
  47. plt.show()
  48. sum_green = findNoneZero(green_result)
  49. sum_red = findNoneZero(red_result)
  50. sum_yellow = findNoneZero(yellow_result)
  51. if sum_red >= sum_yellow and sum_red>=sum_green:
  52. return [1,0,0]#Red
  53. if sum_yellow>=sum_green:
  54. return [0,1,0]#yellow
  55. return [0,0,1]#green
  1. img_test = [(img_red,'red'),(img_yellow,'yellow'),(img_green,'green')]
  2. standardtest = standardize(img_test)
  3. for img in standardtest:
  4. predicted_label = estimate_label(img[0],display = True)
  5. print('Predict label :',predicted_label)
  6. print('True label:',img[1])
  1. Predict label : [1, 0, 0]
  2. True label: [1, 0, 0]
  1. Predict label : [0, 1, 0]
  2. True label: [0, 1, 0]
  1. Predict label : [0, 0, 1]
  2. True label: [0, 0, 1]
  1. # Using the load_dataset function in helpers.py
  2. # Load test data
  3. TEST_IMAGE_LIST = load_dataset(IMAGE_DIR_TEST)
  4. # Standardize the test data
  6. # Shuffle the standardized test data
  7. random.shuffle(STANDARDIZED_TEST_LIST)
  1. 181
  2. 9
  3. 107
Determine the Accuracy


  1. # COnstructs a list of misclassfied iamges given a list of test images and their labels
  2. # This will throw an assertionerror if labels are not standardized(one hot encode)
  3. def get_misclassified_images(test_images,display=False):
  4. misclassified_images_labels = []
  5. #Iterate through all the test images
  6. #Classify each image and compare to the true label
  7. for image in test_images:
  8. # Get true data
  9. im = image[0]
  10. true_label = image[1]
  11. assert (len(true_label)==3),'This true_label is not the excepted length (3).'
  12. #Get predicted label from your classifier
  13. predicted_label = estimate_label(im,display=False)
  14. assert(len(predicted_label)==3),'This predicted_label is not the excepted length (3).'
  15. #compare true and predicted labels
  16. if(predicted_label!=true_label):
  17. #if these labels are ot equal, the image has been misclassified
  18. misclassified_images_labels.append((im,predicted_label,true_label))
  19. # return the list of misclassified [image,predicted_label,true_label] values
  20. return misclassified_images_labels
  21. # Find all misclassified images in a given test set
  22. MISCLASSIFIED = get_misclassified_images(STANDARDIZED_TEST_LIST,display=False)
  23. #Accuracy calcuations
  24. total = len(STANDARDIZED_TEST_LIST)
  25. num_correct = total-len(MISCLASSIFIED)
  26. accuracy = num_correct / total
  27. print('Accuracy:'+str(accuracy))
  28. print('Number of misclassfied images = '+str(len(MISCLASSIFIED))+' out of '+str(total))
  1. Accuracy:0.9797979797979798
  2. Number of misclassfied images = 6 out of 297
