Python, Theory

TensorFlow 08: save and restore a subset of variables

TensorFlow provides save and restore functions for us to save and re-use the model parameters. If you have a trained VGG model, for example, it will be helpful for you to restore the first few layers then apply them in your own networks. This may raise a problem, how do we restore a subset of the parameters? You can always check the TF official document here. In this post, I will take some code from the document and add some practical points.

Save

Say we have a simple model which contains only two variables.

import tensorflow as tf
# Create some variables.

v1 = tf.get_variable("v1", shape=[3], initializer = tf.zeros_initializer)
v2 = tf.get_variable("v2", shape=[5], initializer = tf.zeros_initializer)

inc_v1 = v1.assign(v1+1)
dec_v2 = v2.assign(v2-1)

# Add an op to initialize the variables.
init_op = tf.global_variables_initializer()

In the older version of TF, I like to use tf.Variable() to state a variable. Here we use tf.get_variable(), a better way to state a variable, and you must define a shape here. This function can “get an existing variable with these parameters or create a new one.” We will see how it “gets an existing variable” in a moment.
Then we add 1 to v1, and minus 1 to v2. Following is a method to check the variable names which will be the keys when you save the model.

vars = [v for v in tf.global_variables()]
print (vars)

The first line we get all the variables in the graph as a list.
The output looks like:

[<tf.Variable 'v1:0' shape=(3,) dtype=float32_ref>, <tf.Variable 'v2:0' shape=(5,) dtype=float32_ref>]

It shows now we have two variables. The following gives you the name of the variable

print vars[0].name
>>>v1:0

Now let’s save them! The piece of code is from the official document.

# Add ops to save and restore all the variables.
saver = tf.train.Saver()
# Later, launch the model, initialize the variables, do some work, and save the
# variables to disk.
with tf.Session() as sess:
sess.run(init_op)
# Do some work with the model.
inc_v1.op.run()
dec_v2.op.run()
# Save the variables to disk.
save_path = saver.save(sess, "/tmp/model.ckpt")
print("Model saved in path: %s" % save_path)

If you write like saver = tf.train.Saver(), then call saver.save(), you will save all the variables.

If you want to save only v1 for example, simply change to:
saver = tf.train.Saver([v1]). So you can pass a variable list into the function, and only these will be saved.

Restore

Before you restore any model, you always need to define your own network first. Say we have exactly v1 and v2 and want to restore them from the previously saved model, we will do the following:

tf.reset_default_graph()

# Create some variables.
v1 = tf.get_variable("v1", shape=[3])
v2 = tf.get_variable("v2", shape=[5])

# Add ops to save and restore all the variables.
saver = tf.train.Saver()

# Later, launch the model, use the saver to restore variables from disk, and
# do some work with the model.
with tf.Session() as sess:
# Restore variables from disk.
saver.restore(sess, "/tmp/model.ckpt")
print("Model restored.")
# Check the values of the variables
print("v1 : %s" % v1.eval())
print("v2 : %s" % v2.eval())

What if you only want to restore v1? Similarly, we can change to:
saver = tf.train.Saver([v1])

This will be very helpful if you have a different network structure. For example, when you train your network, you might have v1and v2 only. Then you save the model. On the other day, you have a different network structure where a new v3 comes in. In this case, you need to restore v1 and v2 only, or you will meet an error “key not found in the checkpoint”.

Besides passing in a variable list, we can also pass in a Python dictionary, whose keys are variable names (a string), values are the variables: saver = tf.train.Saver({"v1": v1}), that means, you want the v1 variable in your graph to be restored via the name "v1".

Now if you have three variables, and want to restore the v2 paramters from the old model to your new v3 variable, and v1 v2 you want them to be random new variables.

# restore
tf.reset_default_graph()

# Create some variables.
v1 = tf.get_variable("v1", [3])
v2 = tf.get_variable("v2", [5])
v3 = tf.get_variable("v3", [5])

# Add ops to save and restore all the variables.

# saver = tf.train.Saver([v1])
saver = tf.train.Saver({"v2": v3})

# Later, launch the model, use the saver to restore variables from disk, and
# do some work with the model.
with tf.Session() as sess:
v1.initializer.run()
v2.initializer.run()
# Restore variables from disk.
saver.restore(sess, "/tmp/model.ckpt")
print("Model restored.")
# Check the values of the variables
print("v1 : %s" % v1.eval())
print("v2 : %s" % v2.eval())
print("v3 : %s" % v3.eval())

So do remember you always need to initialize the variables that you did not restore. tf.get_variable() gives a smart way for creating new variables: it will create new ones if you do not want to restore it without any errors; for those variables that you want to restore and reuse, this function will help you to do so. In these two cases, we use the same way of initializing.

Name scopes

If there is a larger network structure, it is natural to use name scopes, which means there will be longer names for the variables. But v.name can always get the name as a sting.
If I have a collection of variables under “foobar” name scope, I can get them by using the following code:

vars_to_restore=[v for v in tf.global_variables() if "foobar" in v.name]
vars_to_restore_dict = {}

# make the dictionary, note that everything here will have “:0”, avoid it.
for v in vars_to_restore:
vars_to_restore_dict[v.name[:-2]] = v

To restore them, we need to remove the special “tail” of the variable names. For example, tf.Variable 'v1:0' shape=(3,) dtype=float32_ref gives the name as 'v1:0'. When you have name scopes, you might have many v1 but with different names before it: 'foobar/v1:0', 'other_scope/v1:0'. If you want to restore the first v1, then the key of the dictionary you pass into the Saver function will be: 'foobar/v1', so we apply something like v.name[:-2] here to avoid the last two chars.

Check saved models

How do you check if a certain variable has been saved in the model? Suppose you know the name of it, say “v1”.

The official code shows that it is easy to print all the variables:

# import the inspect_checkpoint library
from tensorflow.python.tools import inspect_checkpoint as chkp

# print all tensors in checkpoint file
chkp.print_tensors_in_checkpoint_file("/tmp/model.ckpt", tensor_name='', all_tensors=True)

Then I looked into the official code on github here, found that we can use the following piece of code to check if a query variable name is actually in the saved model.

from tensorflow.python import pywrap_tensorflow 
reader = pywrap_tensorflow.NewCheckpointReader("/tmp/model.ckpt")
var_to_shape_map = reader.get_variable_to_shape_map() # 'var_to_shape_map' is a dictionary contains every tensor in the model

if 'v1' in var_to_shape_map.keys():
    # do something...

So var_to_shape_map will be a dictionary, the keys are the names of the saved variable. We can use the if statement to check. I tested with TensorFlow version 1.2.

Leave a Reply

Fill in your details below or click an icon to log in:

WordPress.com Logo

You are commenting using your WordPress.com account. Log Out /  Change )

Google+ photo

You are commenting using your Google+ account. Log Out /  Change )

Twitter picture

You are commenting using your Twitter account. Log Out /  Change )

Facebook photo

You are commenting using your Facebook account. Log Out /  Change )

w

Connecting to %s