1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36
|
import threading lock = threading.Lock()
def save_models(opt, net, epoch, train_loss, best_loss, test_loss): train_loss = float(train_loss) best_loss = float(best_loss) test_loss = float(test_loss) if opt.SAVE_TEMP_MODEL: net.save(epoch, train_loss / opt.NUM_TRAIN, "temp_model.dat")
if test_loss / opt.NUM_TEST < best_loss: best_loss = test_loss / opt.NUM_TEST net.save(epoch, train_loss / opt.NUM_TRAIN, "best_model.dat") return best_loss
class MyThread(threading.Thread): def __init__(self, opt, net, epoch, train_loss, best_loss, test_loss): threading.Thread.__init__(self) self.opt = opt self.net = net self.epoch = epoch self.train_loss = train_loss self.best_loss = best_loss self.test_loss = test_loss
def run(self): lock.acquire() try: self.best_loss = save_models(self.opt, self.net, self.epoch, self.train_loss, self.best_loss, self.test_loss) finally: lock.release()
|