在Tensorflow中训练一个模型后:

如何拯救训练过的模型? 您以后如何恢复这个保存的模型?


当前回答

在TensorFlow 0.11.0RC1版本中,你可以通过调用tf.train直接保存和恢复你的模型。Export_meta_graph和tf.train。根据https://www.tensorflow.org/programmers_guide/meta_graph的Import_meta_graph。

保存模型

w1 = tf.Variable(tf.truncated_normal(shape=[10]), name='w1')
w2 = tf.Variable(tf.truncated_normal(shape=[20]), name='w2')
tf.add_to_collection('vars', w1)
tf.add_to_collection('vars', w2)
saver = tf.train.Saver()
sess = tf.Session()
sess.run(tf.global_variables_initializer())
saver.save(sess, 'my-model')
# `save` method will call `export_meta_graph` implicitly.
# you will get saved graph files:my-model.meta

恢复模型

sess = tf.Session()
new_saver = tf.train.import_meta_graph('my-model.meta')
new_saver.restore(sess, tf.train.latest_checkpoint('./'))
all_vars = tf.get_collection('vars')
for v in all_vars:
    v_ = sess.run(v)
    print(v_)

其他回答

无论你想把模型保存在哪里,

self.saver = tf.train.Saver()
with tf.Session() as sess:
            sess.run(tf.global_variables_initializer())
            ...
            self.saver.save(sess, filename)

确保你所有的任务。变量有名称,因为您可能希望稍后使用它们的名称来恢复它们。 在你想预测的地方,

saver = tf.train.import_meta_graph(filename)
name = 'name given when you saved the file' 
with tf.Session() as sess:
      saver.restore(sess, name)
      print(sess.run('W1:0')) #example to retrieve by variable name

确保该保护程序在相应的会话中运行。 请记住,如果使用tf.train.latest_checkpoint('./'),那么将只使用最新的检查点。

下面是我对这两种基本情况的简单解决方案,这两种情况的不同之处在于您是想从文件加载图形还是在运行时构建它。

这个答案适用于Tensorflow 0.12+(包括1.0)。

在代码中重建图形

储蓄

graph = ... # build the graph
saver = tf.train.Saver()  # create the saver after the graph
with ... as sess:  # your session object
    saver.save(sess, 'my-model')

加载

graph = ... # build the graph
saver = tf.train.Saver()  # create the saver after the graph
with ... as sess:  # your session object
    saver.restore(sess, tf.train.latest_checkpoint('./'))
    # now you can use the graph, continue training or whatever

还从文件中加载图形

当使用这种技术时,确保所有的层/变量都显式地设置了唯一的名称。否则Tensorflow将使名称本身是唯一的,因此它们将不同于存储在文件中的名称。在前一种技术中,这不是问题,因为名称在加载和保存时都以相同的方式“损坏”。

储蓄

graph = ... # build the graph

for op in [ ... ]:  # operators you want to use after restoring the model
    tf.add_to_collection('ops_to_restore', op)

saver = tf.train.Saver()  # create the saver after the graph
with ... as sess:  # your session object
    saver.save(sess, 'my-model')

加载

with ... as sess:  # your session object
    saver = tf.train.import_meta_graph('my-model.meta')
    saver.restore(sess, tf.train.latest_checkpoint('./'))
    ops = tf.get_collection('ops_to_restore')  # here are your operators in the same order in which you saved them to the collection

如第6255期所述:

use '**./**model_name.ckpt'
saver.restore(sess,'./my_model_final.ckpt')

而不是

saver.restore('my_model_final.ckpt')

下面是一个使用Tensorflow 2.0 SavedModel格式(根据文档,这是推荐的格式)的简单MNIST数据集分类器的简单示例,使用Keras函数式API,没有太多的花哨操作:

# Imports
import tensorflow as tf
from tensorflow.keras.layers import Input, Dense, Flatten
from tensorflow.keras.models import Model
import matplotlib.pyplot as plt

# Load data
mnist = tf.keras.datasets.mnist # 28 x 28
(x_train,y_train), (x_test, y_test) = mnist.load_data()

# Normalize pixels [0,255] -> [0,1]
x_train = tf.keras.utils.normalize(x_train,axis=1)
x_test = tf.keras.utils.normalize(x_test,axis=1)

# Create model
input = Input(shape=(28,28), dtype='float64', name='graph_input')
x = Flatten()(input)
x = Dense(128, activation='relu')(x)
x = Dense(128, activation='relu')(x)
output = Dense(10, activation='softmax', name='graph_output', dtype='float64')(x)
model = Model(inputs=input, outputs=output)

model.compile(optimizer='adam',
             loss='sparse_categorical_crossentropy',
             metrics=['accuracy'])

# Train
model.fit(x_train, y_train, epochs=3)

# Save model in SavedModel format (Tensorflow 2.0)
export_path = 'model'
tf.saved_model.save(model, export_path)

# ... possibly another python program 

# Reload model
loaded_model = tf.keras.models.load_model(export_path) 

# Get image sample for testing
index = 0
img = x_test[index] # I normalized the image on a previous step

# Predict using the signature definition (Tensorflow 2.0)
predict = loaded_model.signatures["serving_default"]
prediction = predict(tf.constant(img))

# Show results
print(np.argmax(prediction['graph_output']))  # prints the class number
plt.imshow(x_test[index], cmap=plt.cm.binary)  # prints the image

serving_default是什么?

它是所选标记的签名定义的名称(在本例中,选择了默认的服务标记)。此外,本文还解释了如何使用saved_model_cli查找模型的标记和签名。

免责声明

这只是一个基本的例子,如果你只是想让它运行起来,但这绝不是一个完整的答案-也许我可以在未来更新它。我只是想给出一个在TF 2.0中使用SavedModel的简单示例,因为我在任何地方都没有见过这样简单的SavedModel。

@Tom的回答是一个SavedModel的例子,但它在Tensorflow 2.0上不起作用,因为不幸的是有一些突破性的变化。

@Vishnuvardhan Janapati的回答是TF 2.0,但它不适合SavedModel格式。

在TensorFlow 0.11.0RC1版本中,你可以通过调用tf.train直接保存和恢复你的模型。Export_meta_graph和tf.train。根据https://www.tensorflow.org/programmers_guide/meta_graph的Import_meta_graph。

保存模型

w1 = tf.Variable(tf.truncated_normal(shape=[10]), name='w1')
w2 = tf.Variable(tf.truncated_normal(shape=[20]), name='w2')
tf.add_to_collection('vars', w1)
tf.add_to_collection('vars', w2)
saver = tf.train.Saver()
sess = tf.Session()
sess.run(tf.global_variables_initializer())
saver.save(sess, 'my-model')
# `save` method will call `export_meta_graph` implicitly.
# you will get saved graph files:my-model.meta

恢复模型

sess = tf.Session()
new_saver = tf.train.import_meta_graph('my-model.meta')
new_saver.restore(sess, tf.train.latest_checkpoint('./'))
all_vars = tf.get_collection('vars')
for v in all_vars:
    v_ = sess.run(v)
    print(v_)