网络结构定义的差异:
在Python中(network3.py),网络定义时,不但定义了结构参数 layers,还定义了对应的 mini_batch_size。也就是说在 network3.py中定义的网络,是与mini_batch_size有相关性的。如果计算过程中要进行 mini_batch_size的调整,直接更改 mini_batch_size然后重新计算是不可行的。
因此,需要对已有的网络结构进行重新生成。
#网络结构定义
class Network(object):
def init (self, layers, mini_batch_size):
self.layers = layers
self.mini_batch_size = mini_batch_size
self.params = [param for layer in self.layers for param in layer.params]
self.x = T.matrix(“x”)
self.y = T.ivector(“y”)
init_layer = self.layers[0]
init_layer.set_inpt(self.x, self.x, self.mini_batch_size)
for j in range(1, len(self.layers)):
prev_layer, layer = self.layers[j-1], self.layers[j]
layer.set_inpt(
prev_layer.output, prev_layer.output_dropout, self.mini_batch_size)
self.output = self.layers[-1].output
self.output_dropout = self.layers[-1].output_dropout
根据上述的 网络结构参数定义的代码,可以推断出,如果想更改 mini_batch_size的大小,应该需要把原来(已有)的网络参数 net.layers,net.params重新赋值到一个网络中去。
载入已有的数据
step 1:
查找网络的层级结构,
net.layers
[, ]
step 2:
查找各层的输入与输出参数 n_in, n_out的数值
net.layers[0].n_in
784
net.layers[1].n_in
100
net.layers[0].n_out
100
net.layers[1].n_out
10
step 3:
仔细的推敲了原来的结构,使用如下代码,可以更新mini_batch_size后重新计算。
net = Network(net.layers, NEW_mini_batch_size)
net2.SGD(training_data, epochs, NEW_mini_batch_size, eta,
validation_data, test_data)
看起来还不错。
the old mini_batch_size is: 50
the new mini_batch_size is: 100
Training mini-batch number 0
Epoch 0: validation accuracy 90.53%
This is the best validation accuracy to date.
The corresponding test accuracy is 90.43%
Finished training network.
Best validation accuracy of 90.53% obtained at iteration 499
Corresponding test accuracy of 90.43%
step 4: 实现思路与代码
每次计算完毕后,把网络学习后的参数保存到磁盘。需要使用的时候,再读取到内存,调整mini_batch_size(批处理大小), eta(学习率), epochs(迭代次数)后再进行计算。
实现代码
#读取网络参数
rst_path = “rst/conv_net3.json”
pfile = open(rst_path, ‘rb’) # read current contents
net = pickle.load(pfile)
pfile.close()
#设置 批处理、学习率、迭代数等参数
mini_batch_size = 60
epochs = 5
eta = 0.1
#更新mini_batch_size和网络参数!
#使用已有的 layers 结构和参数。
net = Network(net.layers, mini_batch_size);
#载入训练和测试数据
training_data, validation_data, test_data = network3.load_data_shared()
net.SGD(training_data, epochs, mini_batch_size, eta,
validation_data, test_data)
#保存数据
with open(rst_path, ‘wb’) as save_file:
pickle.dump(net, save_file)
save_file.close()