在Tensorflow中训练一个模型后:

如何拯救训练过的模型? 您以后如何恢复这个保存的模型?


当前回答

下面是我对这两种基本情况的简单解决方案,这两种情况的不同之处在于您是想从文件加载图形还是在运行时构建它。

这个答案适用于Tensorflow 0.12+(包括1.0)。

在代码中重建图形

储蓄

graph = ... # build the graph
saver = tf.train.Saver()  # create the saver after the graph
with ... as sess:  # your session object
    saver.save(sess, 'my-model')

加载

graph = ... # build the graph
saver = tf.train.Saver()  # create the saver after the graph
with ... as sess:  # your session object
    saver.restore(sess, tf.train.latest_checkpoint('./'))
    # now you can use the graph, continue training or whatever

还从文件中加载图形

当使用这种技术时,确保所有的层/变量都显式地设置了唯一的名称。否则Tensorflow将使名称本身是唯一的,因此它们将不同于存储在文件中的名称。在前一种技术中,这不是问题,因为名称在加载和保存时都以相同的方式“损坏”。

储蓄

graph = ... # build the graph

for op in [ ... ]:  # operators you want to use after restoring the model
    tf.add_to_collection('ops_to_restore', op)

saver = tf.train.Saver()  # create the saver after the graph
with ... as sess:  # your session object
    saver.save(sess, 'my-model')

加载

with ... as sess:  # your session object
    saver = tf.train.import_meta_graph('my-model.meta')
    saver.restore(sess, tf.train.latest_checkpoint('./'))
    ops = tf.get_collection('ops_to_restore')  # here are your operators in the same order in which you saved them to the collection

其他回答

最简单的方法是使用keras api,在线保存模型和一行加载模型

from keras.models import load_model

my_model.save('my_model.h5')  # creates a HDF5 file 'my_model.h5'

del my_model  # deletes the existing model


my_model = load_model('my_model.h5') # returns a compiled model identical to the previous one

对于TensorFlow版本< 0.11.0RC1:

保存的检查点包含模型中的变量值,而不是模型/图本身,这意味着当您恢复检查点时,图应该是相同的。

这里有一个线性回归的例子,其中有一个训练循环,保存变量检查点,还有一个评估部分,将恢复之前运行中保存的变量并计算预测。当然,如果你愿意,你也可以恢复变量并继续训练。

x = tf.placeholder(tf.float32)
y = tf.placeholder(tf.float32)

w = tf.Variable(tf.zeros([1, 1], dtype=tf.float32))
b = tf.Variable(tf.ones([1, 1], dtype=tf.float32))
y_hat = tf.add(b, tf.matmul(x, w))

...more setup for optimization and what not...

saver = tf.train.Saver()  # defaults to saving all variables - in this case w and b

with tf.Session() as sess:
    sess.run(tf.initialize_all_variables())
    if FLAGS.train:
        for i in xrange(FLAGS.training_steps):
            ...training loop...
            if (i + 1) % FLAGS.checkpoint_steps == 0:
                saver.save(sess, FLAGS.checkpoint_dir + 'model.ckpt',
                           global_step=i+1)
    else:
        # Here's where you're restoring the variables w and b.
        # Note that the graph is exactly as it was when the variables were
        # saved in a prior training run.
        ckpt = tf.train.get_checkpoint_state(FLAGS.checkpoint_dir)
        if ckpt and ckpt.model_checkpoint_path:
            saver.restore(sess, ckpt.model_checkpoint_path)
        else:
            ...no checkpoint found...

        # Now you can run the model to get predictions
        batch_x = ...load some data...
        predictions = sess.run(y_hat, feed_dict={x: batch_x})

下面是变量文档,涵盖了保存和恢复。这是保存程序的文档。

您可以保存网络中的变量使用

saver = tf.train.Saver() 
saver.save(sess, 'path of save/fileName.ckpt')

要恢复网络以供以后或在另一个脚本中重用,请使用:

saver = tf.train.Saver()
saver.restore(sess, tf.train.latest_checkpoint('path of save/')
sess.run(....) 

重要的几点:

第一次运行和以后运行之间的Sess必须相同(一致的结构)。 储蓄者。还原需要保存文件的文件夹路径,而不是单个文件路径。

你可以使用Tensorflow中的saver对象来保存你训练过的模型。该对象提供保存和恢复模型的方法。

在TensorFlow中保存一个训练好的模型:

tf.train.Saver.save(sess, save_path, global_step=None, latest_filename=None,
                    meta_graph_suffix='meta', write_meta_graph=True,
                    write_state=True, strip_default_attrs=False,
                    save_debug_info=False)

在TensorFlow中恢复已保存的模型:

tf.train.Saver.restore(sess, save_path, latest_filename=None,
                       meta_graph_suffix='meta', clear_devices=False,
                       import_scope=None)

你也可以在TensorFlow/skflow中查看例子,它提供了保存和恢复方法,可以帮助你轻松地管理模型。它具有一些参数,您还可以控制备份模型的频率。