TensorFlow provides save and restore functions for us to save and re-use the model parameters. If you have a trained VGG model, for example, it will be helpful for you to restore the first few layers then apply them in your own networks. This may raise a problem, how do we restore a subset of the parameters? You can always check the TF official document here. In this post, I will take some code from the document and add some practical points.
Save
Say we have a simple model which contains only two variables.
import tensorflow as tf # Create some variables. v1 = tf.get_variable("v1", shape=[3], initializer = tf.zeros_initializer) v2 = tf.get_variable("v2", shape=[5], initializer = tf.zeros_initializer) inc_v1 = v1.assign(v1+1) dec_v2 = v2.assign(v2-1) # Add an op to initialize the variables. init_op = tf.global_variables_initializer()
In the older version of TF, I like to use tf.Variable()
to state a variable. Here we use tf.get_variable()
, a better way to state a variable, and you must define a shape here. This function can “get an existing variable with these parameters or create a new one.” We will see how it “gets an existing variable” in a moment.
Then we add 1 to v1, and minus 1 to v2. Following is a method to check the variable names which will be the keys when you save the model.
vars = [v for v in tf.global_variables()] print (vars)
The first line we get all the variables in the graph as a list.
The output looks like:
[<tf.Variable 'v1:0' shape=(3,) dtype=float32_ref>, <tf.Variable 'v2:0' shape=(5,) dtype=float32_ref>]
It shows now we have two variables. The following gives you the name of the variable
print vars[0].name >>>v1:0
Now let’s save them! The piece of code is from the official document.
# Add ops to save and restore all the variables. saver = tf.train.Saver() # Later, launch the model, initialize the variables, do some work, and save the # variables to disk. with tf.Session() as sess: sess.run(init_op) # Do some work with the model. inc_v1.op.run() dec_v2.op.run() # Save the variables to disk. save_path = saver.save(sess, "/tmp/model.ckpt") print("Model saved in path: %s" % save_path)
If you write like saver = tf.train.Saver()
, then call saver.save()
, you will save all the variables.
If you want to save only v1
for example, simply change to:
saver = tf.train.Saver([v1])
. So you can pass a variable list into the function, and only these will be saved.
Restore
Before you restore any model, you always need to define your own network first. Say we have exactly v1
and v2
and want to restore them from the previously saved model, we will do the following:
tf.reset_default_graph() # Create some variables. v1 = tf.get_variable("v1", shape=[3]) v2 = tf.get_variable("v2", shape=[5]) # Add ops to save and restore all the variables. saver = tf.train.Saver() # Later, launch the model, use the saver to restore variables from disk, and # do some work with the model. with tf.Session() as sess: # Restore variables from disk. saver.restore(sess, "/tmp/model.ckpt") print("Model restored.") # Check the values of the variables print("v1 : %s" % v1.eval()) print("v2 : %s" % v2.eval())
What if you only want to restore v1
? Similarly, we can change to:
saver = tf.train.Saver([v1])
This will be very helpful if you have a different network structure. For example, when you train your network, you might have v1
and v2
only. Then you save the model. On the other day, you have a different network structure where a new v3
comes in. In this case, you need to restore v1
and v2
only, or you will meet an error “key not found in the checkpoint”.
Besides passing in a variable list, we can also pass in a Python dictionary, whose keys are variable names (a string), values are the variables: saver = tf.train.Saver({"v1": v1})
, that means, you want the v1
variable in your graph to be restored via the name "v1"
.
Now if you have three variables, and want to restore the v2
paramters from the old model to your new v3
variable, and v1
v2
you want them to be random new variables.
# restore tf.reset_default_graph() # Create some variables. v1 = tf.get_variable("v1", [3]) v2 = tf.get_variable("v2", [5]) v3 = tf.get_variable("v3", [5]) # Add ops to save and restore all the variables. # saver = tf.train.Saver([v1]) saver = tf.train.Saver({"v2": v3}) # Later, launch the model, use the saver to restore variables from disk, and # do some work with the model. with tf.Session() as sess: v1.initializer.run() v2.initializer.run() # Restore variables from disk. saver.restore(sess, "/tmp/model.ckpt") print("Model restored.") # Check the values of the variables print("v1 : %s" % v1.eval()) print("v2 : %s" % v2.eval()) print("v3 : %s" % v3.eval())
So do remember you always need to initialize the variables that you did not restore. tf.get_variable()
gives a smart way for creating new variables: it will create new ones if you do not want to restore it without any errors; for those variables that you want to restore and reuse, this function will help you to do so. In these two cases, we use the same way of initializing.
Name scopes
If there is a larger network structure, it is natural to use name scopes, which means there will be longer names for the variables. But v.name
can always get the name as a sting.
If I have a collection of variables under “foobar” name scope, I can get them by using the following code:
vars_to_restore=[v for v in tf.global_variables() if "foobar" in v.name] vars_to_restore_dict = {} # make the dictionary, note that everything here will have “:0”, avoid it. for v in vars_to_restore: vars_to_restore_dict[v.name[:-2]] = v
To restore them, we need to remove the special “tail” of the variable names. For example, tf.Variable 'v1:0' shape=(3,) dtype=float32_ref
gives the name as 'v1:0'
(here, the :0
usually means the variable was processed using GPU number 0). When you have name scopes, you might have many v1 but with different names before it: 'foobar/v1:0'
, 'other_scope/v1:0'
. If you want to restore the first v1
, then the key of the dictionary you pass into the Saver function will be: 'foobar/v1'
, so we apply something like v.name[:-2]
here to avoid the last two chars.
Check saved models
How do you check if a certain variable has been saved in the model? Suppose you know the name of it, say “v1”.
The official code shows that it is easy to print all the variables:
# import the inspect_checkpoint library from tensorflow.python.tools import inspect_checkpoint as chkp # print all tensors in checkpoint file chkp.print_tensors_in_checkpoint_file("/tmp/model.ckpt", tensor_name='', all_tensors=True)
Then I looked into the official code on github here, found that we can use the following piece of code to check if a query variable name is actually in the saved model.
from tensorflow.python import pywrap_tensorflow reader = pywrap_tensorflow.NewCheckpointReader("/tmp/model.ckpt") var_to_shape_map = reader.get_variable_to_shape_map() # 'var_to_shape_map' is a dictionary contains every tensor in the model if 'v1' in var_to_shape_map.keys(): # do something...
So var_to_shape_map
will be a dictionary, the keys are the names of the saved variable. We can use the if statement to check. I tested with TensorFlow version 1.2.