Browse Source

3rd data collection

James Jeon 2 years ago
parent
commit
2db6403829
5 changed files with 25 additions and 21 deletions
  1. 21
    17
      balance_data.py
  2. 1
    1
      eval.py
  3. 1
    1
      get_train_data.py
  4. 1
    1
      get_val_data.py
  5. 1
    1
      train.py

+ 21
- 17
balance_data.py View File

@@ -6,6 +6,8 @@ from random import shuffle
6 6
 data_size = 512
7 7
 minus = []
8 8
 plus = []
9
+zero = []
10
+
9 11
 for i in range(0,16):
10 12
 	minus.append([])
11 13
 	plus.append([])
@@ -29,7 +31,7 @@ for file_name in file_list:
29 31
 
30 32
 			# if the steering wheel angle in in right to the center
31 33
 			if(int_data > 550):
32
-				int_data = int_data - 4095
34
+				int_data = int_data - 4096
33 35
 				int_decimal = 1 - int_decimal 
34 36
 				steering_angle = int_data - int_decimal
35 37
 			else:
@@ -42,33 +44,32 @@ for file_name in file_list:
42 44
 				continue
43 45
 
44 46
 			if int_data == 0:
45
-				if int_data > steering_angle:
46
-					minus[int_data].append([data[0], data[1]])
47
-				else:
48
-					plus[int_data].append([data[0], data[1]])
47
+				zero.append([data[0], data[1]])
49 48
 			elif int_data > 0:
50 49
 				plus[int_data].append([data[0], data[1]])
51 50
 			else:
52 51
 				int_data = int_data * -1
53 52
 				minus[int_data].append([data[0], data[1]])
54 53
 
55
-for i in range(0,16):
54
+print("zero :",len(zero))
55
+for i in range(1,16):
56 56
 	print("minus[{}] :".format(i),len(minus[i]),"plus[{}] :".format(i), len(plus[i]))
57 57
 
58 58
 
59
-for i in range(0,16):
59
+shuffle(zero)
60
+zero = zero[:data_size]
61
+for i in range(1,16):
60 62
 	shuffle(minus[i])
61 63
 	shuffle(plus[i]) 
62 64
 
63 65
 	minus[i] = minus[i][:data_size]
64 66
 	plus[i] = plus[i][:data_size]
65 67
 
66
-
67
-for i in range(0,16):
68
+print("zero :",len(zero))
69
+for i in range(1,16):
68 70
 	print("minus[{}] :".format(i),len(minus[i]),"plus[{}] :".format(i), len(plus[i]))
69 71
 
70
-
71
-for i in range(0,16):
72
+for i in range(1,16):
72 73
 	if len(minus[i]) != data_size:
73 74
 		while True:
74 75
 			minus[i] = minus[i] + minus[i] 
@@ -80,27 +81,30 @@ for i in range(0,16):
80 81
 			plus[i] = plus[i] + plus[i] 
81 82
 			if len(plus[i]) >= data_size:
82 83
 				break
83
-for i in range(0,16):
84
+
85
+print("zero :",len(zero))
86
+for i in range(1,16):
84 87
 	print("minus[{}] :".format(i),len(minus[i]),"plus[{}] :".format(i), len(plus[i]))
85 88
 
86
-for i in range(0,16):
89
+for i in range(1,16):
87 90
 	shuffle(minus[i])
88 91
 	shuffle(plus[i]) 
89 92
 
90 93
 	minus[i] = minus[i][:data_size]
91 94
 	plus[i] = plus[i][:data_size]
92 95
 
93
-for i in range(0,16):
96
+print("zero :",len(zero))
97
+for i in range(1,16):
94 98
 	print("minus[{}] :".format(i),len(minus[i]),"plus[{}] :".format(i), len(plus[i]))
95
-	
99
+
96 100
 
97 101
 batch_size = 64
98 102
 batch = int(data_size / batch_size)
99 103
 #print(data_size/batch_size , batch)
100 104
 
101 105
 for j in range(batch):
102
-	temp_list = []
103
-	for i in range(0,16):
106
+	temp_list = zero[j*batch_size:(j+1)*batch_size]
107
+	for i in range(1,16):
104 108
 		temp_list = temp_list + minus[i][j*batch_size:(j+1)*batch_size]
105 109
 		#print("temp_minus : " , len(temp_minus))
106 110
 		temp_list = temp_list + plus[i][j*batch_size:(j+1)*batch_size]

+ 1
- 1
eval.py View File

@@ -9,7 +9,7 @@ import numpy as np
9 9
 
10 10
 sess = tf.InteractiveSession()
11 11
 saver = tf.train.Saver()
12
-saver.restore(sess, "./model_1")
12
+saver.restore(sess, "./weight/MSE_without-0_4.model")
13 13
 
14 14
 file_list = os.listdir('/raw_data/eval_data')
15 15
 for file in file_list:

+ 1
- 1
get_train_data.py View File

@@ -26,7 +26,7 @@ def train_data(file_name):
26 26
 
27 27
 		# if the steering wheel angle in in right to the center
28 28
 		if(int_data > 550):
29
-			int_data = int_data - 4095
29
+			int_data = int_data - 4096
30 30
 			int_decimal = 1 - int_decimal 
31 31
 			final_data = int_data - int_decimal
32 32
 		else:

+ 1
- 1
get_val_data.py View File

@@ -32,7 +32,7 @@ def val_data():
32 32
 
33 33
 				# if the steering wheel angle in in right to the center
34 34
 				if(int_data > 550):
35
-					int_data = int_data - 4095
35
+					int_data = int_data - 4096
36 36
 					int_decimal = 1 - int_decimal 
37 37
 					final_data = int_data - int_decimal
38 38
 				else:

+ 1
- 1
train.py View File

@@ -51,7 +51,7 @@ for i in range(params.epoch):
51 51
 
52 52
 				print ("epoch {} of {}, batch {} of {}, train loss {}, val loss {}".format(i, params.epoch,iteration,batch_iteration,t_loss, v_loss))
53 53
 
54
-	model_name = "MSE_{}.model".format(i)
54
+	model_name = "./weight/MSE_without-0_{}.model".format(i)
55 55
 	save_path = saver.save(sess, model_name)
56 56
 
57 57
 

Loading…
Cancel
Save