在Tensorflow中训练一个模型后:
如何拯救训练过的模型? 您以后如何恢复这个保存的模型?
在Tensorflow中训练一个模型后:
如何拯救训练过的模型? 您以后如何恢复这个保存的模型?
当前回答
你可以使用Tensorflow中的saver对象来保存你训练过的模型。该对象提供保存和恢复模型的方法。
在TensorFlow中保存一个训练好的模型:
tf.train.Saver.save(sess, save_path, global_step=None, latest_filename=None,
meta_graph_suffix='meta', write_meta_graph=True,
write_state=True, strip_default_attrs=False,
save_debug_info=False)
在TensorFlow中恢复已保存的模型:
tf.train.Saver.restore(sess, save_path, latest_filename=None,
meta_graph_suffix='meta', clear_devices=False,
import_scope=None)
其他回答
这里所有的答案都很棒,但我想补充两点。
首先,详细说明@user7505159的答案,“。添加到要恢复的文件名的开头可能很重要。
例如,您可以保存没有“的图形。/"在文件名中如下所示:
# Some graph defined up here with specific names
saver = tf.train.Saver()
save_file = 'model.ckpt'
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
saver.save(sess, save_file)
但是为了恢复图形,您可能需要在前面加上一个"。/"到file_name:
# Same graph defined up here
saver = tf.train.Saver()
save_file = './' + 'model.ckpt' # String addition used for emphasis
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
saver.restore(sess, save_file)
你并不总是需要“。/”,但是它会根据你的环境和TensorFlow版本而导致问题。
它还想提到sess.run(tf.global_variables_initializer())在恢复会话之前可能很重要。
如果在尝试恢复保存的会话时收到关于未初始化变量的错误,请确保在保存程序之前包含sess.run(tf.global_variables_initializer())。恢复(sess, save_file)行。这样你就不用头疼了。
最简单的方法是使用keras api,在线保存模型和一行加载模型
from keras.models import load_model
my_model.save('my_model.h5') # creates a HDF5 file 'my_model.h5'
del my_model # deletes the existing model
my_model = load_model('my_model.h5') # returns a compiled model identical to the previous one
我正在改进我的回答,以添加更多关于保存和恢复模型的细节。
在Tensorflow 0.11版本中(及之后):
保存模型:
import tensorflow as tf
#Prepare to feed input, i.e. feed_dict and placeholders
w1 = tf.placeholder("float", name="w1")
w2 = tf.placeholder("float", name="w2")
b1= tf.Variable(2.0,name="bias")
feed_dict ={w1:4,w2:8}
#Define a test operation that we will restore
w3 = tf.add(w1,w2)
w4 = tf.multiply(w3,b1,name="op_to_restore")
sess = tf.Session()
sess.run(tf.global_variables_initializer())
#Create a saver object which will save all the variables
saver = tf.train.Saver()
#Run the operation by feeding input
print sess.run(w4,feed_dict)
#Prints 24 which is sum of (w1+w2)*b1
#Now, save the graph
saver.save(sess, 'my_test_model',global_step=1000)
恢复模型:
import tensorflow as tf
sess=tf.Session()
#First let's load meta graph and restore weights
saver = tf.train.import_meta_graph('my_test_model-1000.meta')
saver.restore(sess,tf.train.latest_checkpoint('./'))
# Access saved Variables directly
print(sess.run('bias:0'))
# This will print 2, which is the value of bias that we saved
# Now, let's access and create placeholders variables and
# create feed-dict to feed new data
graph = tf.get_default_graph()
w1 = graph.get_tensor_by_name("w1:0")
w2 = graph.get_tensor_by_name("w2:0")
feed_dict ={w1:13.0,w2:17.0}
#Now, access the op that you want to run.
op_to_restore = graph.get_tensor_by_name("op_to_restore:0")
print sess.run(op_to_restore,feed_dict)
#This will print 60 which is calculated
这里已经很好地解释了这一点和一些更高级的用例。
一个快速完整的教程,保存和恢复Tensorflow模型
在@Vishnuvardhan Janapati的回答之后,这里是另一种在TensorFlow 2.0.0下保存和重载自定义层/度量/损失模型的方法
import tensorflow as tf
from tensorflow.keras.layers import Layer
from tensorflow.keras.utils.generic_utils import get_custom_objects
# custom loss (for example)
def custom_loss(y_true,y_pred):
return tf.reduce_mean(y_true - y_pred)
get_custom_objects().update({'custom_loss': custom_loss})
# custom loss (for example)
class CustomLayer(Layer):
def __init__(self, ...):
...
# define custom layer and all necessary custom operations inside custom layer
get_custom_objects().update({'CustomLayer': CustomLayer})
通过这种方式,一旦您执行了这些代码,并使用tf.keras.models保存了您的模型。Save_model或model。save或ModelCheckpoint回调,您可以重新加载您的模型,而不需要精确的自定义对象,就像这样简单
new_model = tf.keras.models.load_model("./model.h5"})
在新版本的tensorflow 2.0中,保存/加载模型的过程要容易得多。因为Keras API的实现,一个TensorFlow的高级API。
保存一个模型: 请查阅相关文档以作参考: https://www.tensorflow.org/versions/r2.0/api_docs/python/tf/keras/models/save_model
tf.keras.models.save_model(model_name, filepath, save_format)
加载一个模型:
https://www.tensorflow.org/versions/r2.0/api_docs/python/tf/keras/models/load_model
model = tf.keras.models.load_model(filepath)