-
Notifications
You must be signed in to change notification settings - Fork 3
Expand file tree
/
Copy pathtrain.py
More file actions
executable file
·120 lines (98 loc) · 3.44 KB
/
train.py
File metadata and controls
executable file
·120 lines (98 loc) · 3.44 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
#!/usr/bin/python3
import tensorflow as tf
import numpy as np
from generator import BatchGenerator, PrecomputeBatchGenerator
from model import Model
from config import network
def get_number_parameters(variables):
total_parameters = 0
for variable in variables:
shape = variable.get_shape()
variable_parameters = 1
for dim in shape:
variable_parameters *= dim.value
total_parameters += variable_parameters
return total_parameters
def train(epochs, steps, batch_size, image_size, alphabet, max_sequence_length, max_lines):
import time
img_w, img_h = image_size
images_input = tf.placeholder(shape=(batch_size, img_h, img_w, 1), dtype=tf.float32)
sequences_input = tf.placeholder(shape=(batch_size, max_sequence_length), dtype=tf.int32)
is_training = tf.placeholder(shape=(), dtype=tf.bool)
add_eos = tf.placeholder(shape=(), dtype=tf.bool)
model = Model(
images_input,
sequences_input,
is_training,
add_eos,
max_sequence_length,
alphabet)
endpoints = model.endpoints()
trainable = get_number_parameters(tf.trainable_variables())
print('Model has {} trainable parameters'.format(trainable))
train_op = tf.contrib.layers.optimize_loss(
endpoints['loss'],
tf.train.get_global_step(),
optimizer='Adam',
learning_rate=0.0001,
summaries=['loss', 'learning_rate'])
tf.summary.image('input_images', images_input)
tf.summary.image('alignments', endpoints['alignments'])
merged = tf.summary.merge_all()
train_generator = PrecomputeBatchGenerator(
size=image_size,
alphabet=alphabet,
max_sequence_length=max_sequence_length,
max_lines=max_lines,
batch_size=batch_size)
val_generator = PrecomputeBatchGenerator(
size=image_size,
alphabet=alphabet,
max_sequence_length=max_sequence_length,
max_lines=max_lines,
batch_size=batch_size,
precompute_size=1000)
saver = tf.train.Saver()
with tf.Session() as sess:
init = tf.global_variables_initializer()
sess.run(init)
def random_name():
return ''.join([np.random.choice(list('0123456789')) for _ in range(8)])
train_writer = tf.summary.FileWriter('logs/{}'.format(random_name()), sess.graph)
ckpt = tf.train.get_checkpoint_state('train/')
if ckpt and ckpt.model_checkpoint_path:
saver.restore(sess, ckpt.model_checkpoint_path)
add_eos_epoch = 60
for e in range(epochs):
t = time.time()
for step, (imgs, seqs) in enumerate(train_generator.generate_batch()):
if step < steps:
sess.run([train_op], feed_dict={
images_input: imgs,
sequences_input: seqs,
is_training: True,
add_eos: e >= add_eos_epoch})
else:
break
for imgs, seqs in val_generator.generate_batch():
summary, predictions = sess.run([merged, endpoints['predictions']], feed_dict={
images_input: imgs,
sequences_input: seqs,
is_training: False,
add_eos: e >= add_eos_epoch})
sequences = seqs
break
train_writer.add_summary(summary, e)
print("Epoch {} ends with time {:.4f}".format(e, time.time() - t))
print("Expectation: {}".format(sequences[0]))
print("Reality: {}".format(predictions[0]))
saver.save(sess, 'train/model', global_step=epochs)
if __name__ == '__main__':
image_size = network['image_size']
alphabet = network['alphabet']
max_sequence_length = network['max_sequence_length']
max_lines = network['max_lines']
epochs = network['epochs']
steps = network['steps']
batch_size = network['batch_size']
train(epochs, steps, batch_size, image_size, alphabet, max_sequence_length, max_lines)