No Description
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

balance_data.py 2.9KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119
  1. import os
  2. import numpy as np
  3. import time
  4. from random import shuffle
  5. data_size = 512
  6. minus = []
  7. plus = []
  8. zero = []
  9. for i in range(0,16):
  10. minus.append([])
  11. plus.append([])
  12. file_list = os.listdir('/raw_data')
  13. #print(file_list)
  14. for file_name in file_list:
  15. if file_name.endswith('.npy'):
  16. print(file_name)
  17. file_location = "/raw_data/" + file_name
  18. loaded_data = np.load(file_location)
  19. loaded_data = loaded_data[30:]
  20. for data in loaded_data:
  21. # change the can data (HEX) to numerical data
  22. tmp = data[1]
  23. hex_data = tmp[-23:-21] + tmp[-20:-18]
  24. hex_decimal = tmp[-3:-1]
  25. int_data = int(hex_data, 16)
  26. int_decimal = int(hex_decimal, 16) / 256
  27. # if the steering wheel angle in in right to the center
  28. if(int_data > 550):
  29. int_data = int_data - 4096
  30. int_decimal = 1 - int_decimal
  31. steering_angle = int_data - int_decimal
  32. else:
  33. # put the int and the decimal together
  34. steering_angle = int_data + int_decimal
  35. int_data = int(steering_angle)
  36. if int_data >= 16 or int_data <= -16:
  37. continue
  38. if int_data == 0:
  39. zero.append([data[0], data[1]])
  40. elif int_data > 0:
  41. plus[int_data].append([data[0], data[1]])
  42. else:
  43. int_data = int_data * -1
  44. minus[int_data].append([data[0], data[1]])
  45. print("zero :",len(zero))
  46. for i in range(1,16):
  47. print("minus[{}] :".format(i),len(minus[i]),"plus[{}] :".format(i), len(plus[i]))
  48. shuffle(zero)
  49. zero = zero[:data_size]
  50. for i in range(1,16):
  51. shuffle(minus[i])
  52. shuffle(plus[i])
  53. minus[i] = minus[i][:data_size]
  54. plus[i] = plus[i][:data_size]
  55. print("zero :",len(zero))
  56. for i in range(1,16):
  57. print("minus[{}] :".format(i),len(minus[i]),"plus[{}] :".format(i), len(plus[i]))
  58. for i in range(1,16):
  59. if len(minus[i]) != data_size:
  60. while True:
  61. minus[i] = minus[i] + minus[i]
  62. if len(minus[i]) >= data_size:
  63. break
  64. if len(plus[i]) != data_size:
  65. while True:
  66. plus[i] = plus[i] + plus[i]
  67. if len(plus[i]) >= data_size:
  68. break
  69. print("zero :",len(zero))
  70. for i in range(1,16):
  71. print("minus[{}] :".format(i),len(minus[i]),"plus[{}] :".format(i), len(plus[i]))
  72. for i in range(1,16):
  73. shuffle(minus[i])
  74. shuffle(plus[i])
  75. minus[i] = minus[i][:data_size]
  76. plus[i] = plus[i][:data_size]
  77. print("zero :",len(zero))
  78. for i in range(1,16):
  79. print("minus[{}] :".format(i),len(minus[i]),"plus[{}] :".format(i), len(plus[i]))
  80. batch_size = 64
  81. batch = int(data_size / batch_size)
  82. #print(data_size/batch_size , batch)
  83. for j in range(batch):
  84. temp_list = zero[j*batch_size:(j+1)*batch_size]
  85. for i in range(1,16):
  86. temp_list = temp_list + minus[i][j*batch_size:(j+1)*batch_size]
  87. #print("temp_minus : " , len(temp_minus))
  88. temp_list = temp_list + plus[i][j*batch_size:(j+1)*batch_size]
  89. #print("temp_plus : ", len(temp_plus))
  90. name = str(time.time())
  91. file_name = "/raw_data/balanced_data/" + name + ".npy"
  92. print(len(temp_list))
  93. shuffle(temp_list)
  94. np.save(file_name,temp_list)
  95. print("Saved file name : ",name + ".npy")