No Description
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

train.py 1.7KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859
  1. #!/usr/bin/env python
  2. import os
  3. import tensorflow as tf
  4. import model
  5. import params
  6. import time
  7. import cv2
  8. import numpy as np
  9. from get_val_data import val_data
  10. from get_train_data import train_data
  11. sess = tf.InteractiveSession()
  12. LR = 1e-4
  13. # loss = tf.losses.absolute_difference(model.y_, model.y)
  14. loss = tf.losses.mean_squared_error(model.y_, model.y)
  15. train = tf.train.AdamOptimizer(LR).minimize(loss)
  16. saver = tf.train.Saver()
  17. sess.run(tf.global_variables_initializer())
  18. # get the val data
  19. val_X, val_Y = val_data()
  20. # finish get val data
  21. file_list = os.listdir('/raw_data/balanced_data')
  22. for i in range(params.epoch):
  23. # prepare data for training
  24. for file in file_list:
  25. if file.endswith('.npy'):
  26. print("Start process on file : " , file)
  27. train_X, train_Y = train_data(file)
  28. #finishing getting the training data
  29. # start the train on the data
  30. # seperate the data into batch size
  31. batch_iteration = int(train_X.shape[0] / params.batch)
  32. for iteration in range(batch_iteration):
  33. batch_X = train_X[iteration*params.batch:(iteration+1)*params.batch]
  34. batch_Y = train_Y[iteration*params.batch:(iteration+1)*params.batch]
  35. train.run(feed_dict={model.x: batch_X, model.y_: batch_Y, model.keep_prob: 0.8})
  36. t_loss = loss.eval(feed_dict={model.x: batch_X, model.y_: batch_Y, model.keep_prob: 1.0})
  37. v_loss = loss.eval(feed_dict={model.x: val_X, model.y_: val_Y, model.keep_prob: 1.0})
  38. #v_acc = acc.eval(feed_dict={model.x: val_X, model.y_: val_Y, model.keep_prob: 1.0})
  39. print ("epoch {} of {}, batch {} of {}, train loss {}, val loss {}".format(i, params.epoch,iteration,batch_iteration,t_loss, v_loss))
  40. model_name = "./weight/SSC_epoch_{}_LR_{}.model".format(i,LR)
  41. save_path = saver.save(sess, model_name)