TensorFlow provides save and restore functions for us to save and re-use the model parameters. If you have a trained VGG model, for example, it will be helpful for you to restore the first few layers then apply them in your own networks. This may raise a problem, how do we restore a subset of the parameters? You can always check the TF official document here. In this post, I will take some code from the document and add some practical points.

### Save

Say we have a simple model which contains only two variables.

import tensorflow as tf # Create some variables. v1 = tf.get_variable("v1", shape=[3], initializer = tf.zeros_initializer) v2 = tf.get_variable("v2", shape=[5], initializer = tf.zeros_initializer) inc_v1 = v1.assign(v1+1) dec_v2 = v2.assign(v2-1) # Add an op to initialize the variables. init_op = tf.global_variables_initializer()

In the older version of TF, I like to use `tf.Variable()`

to state a variable. Here we use `tf.get_variable()`

, a better way to state a variable, and you must define a shape here. This function can “get an existing variable with these parameters or create a new one.” We will see how it “gets an existing variable” in a moment.

Then we add 1 to v1, and minus 1 to v2. Following is a method to check the variable names which will be the keys when you save the model.

vars = [v for v in tf.global_variables()] print (vars)

The first line we get all the variables in the graph as a list.

The output looks like:

[<tf.Variable 'v1:0' shape=(3,) dtype=float32_ref>, <tf.Variable 'v2:0' shape=(5,) dtype=float32_ref>]

It shows now we have two variables. The following gives you the name of the variable

print vars[0].name >>>v1:0

Now let’s save them! The piece of code is from the official document.

# Add ops to save and restore all the variables. saver = tf.train.Saver() # Later, launch the model, initialize the variables, do some work, and save the # variables to disk. with tf.Session() as sess: sess.run(init_op) # Do some work with the model. inc_v1.op.run() dec_v2.op.run() # Save the variables to disk. save_path = saver.save(sess, "/tmp/model.ckpt") print("Model saved in path: %s" % save_path)

If you write like `saver = tf.train.Saver()`

, then call `saver.save()`

, you will save all the variables.

If you want to save only `v1`

for example, simply change to:

`saver = tf.train.Saver([v1])`

. So you can pass a variable list into the function, and only these will be saved.

### Restore

Before you restore any model, you always need to define your own network first. Say we have exactly `v1`

and `v2`

and want to restore them from the previously saved model, we will do the following:

tf.reset_default_graph() # Create some variables. v1 = tf.get_variable("v1", shape=[3]) v2 = tf.get_variable("v2", shape=[5]) # Add ops to save and restore all the variables. saver = tf.train.Saver() # Later, launch the model, use the saver to restore variables from disk, and # do some work with the model. with tf.Session() as sess: # Restore variables from disk. saver.restore(sess, "/tmp/model.ckpt") print("Model restored.") # Check the values of the variables print("v1 : %s" % v1.eval()) print("v2 : %s" % v2.eval())

What if you only want to restore `v1`

? Similarly, we can change to:

`saver = tf.train.Saver([v1])`

This will be very helpful if you have a different network structure. For example, when you train your network, you might have `v1`

and `v2`

only. Then you save the model. On the other day, you have a different network structure where a new `v3`

comes in. In this case, you need to restore `v1`

and `v2`

only, or you will meet an error “key not found in the checkpoint”.

Besides passing in a variable list, we can also pass in a Python dictionary, whose keys are variable names (a string), values are the variables: `saver = tf.train.Saver({"v1": v1})`

, that means, you want the `v1`

variable in your graph to be restored via the name `"v1"`

.

Now if you have three variables, and want to restore the `v2`

paramters from the old model to your new `v3`

variable, and `v1`

`v2`

you want them to be random new variables.

# restore tf.reset_default_graph() # Create some variables. v1 = tf.get_variable("v1", [3]) v2 = tf.get_variable("v2", [5]) v3 = tf.get_variable("v3", [5]) # Add ops to save and restore all the variables. # saver = tf.train.Saver([v1]) saver = tf.train.Saver({"v2": v3}) # Later, launch the model, use the saver to restore variables from disk, and # do some work with the model. with tf.Session() as sess: v1.initializer.run() v2.initializer.run() # Restore variables from disk. saver.restore(sess, "/tmp/model.ckpt") print("Model restored.") # Check the values of the variables print("v1 : %s" % v1.eval()) print("v2 : %s" % v2.eval()) print("v3 : %s" % v3.eval())

So do remember you always need to initialize the variables that you did not restore. `tf.get_variable()`

gives a smart way for creating new variables: it will create new ones if you do not want to restore it without any errors; for those variables that you want to restore and reuse, this function will help you to do so. In these two cases, we use the same way of initializing.

### Name scopes

If there is a larger network structure, it is natural to use name scopes, which means there will be longer names for the variables. But `v.name`

can always get the name as a sting.

If I have a collection of variables under “foobar” name scope, I can get them by using the following code:

vars_to_restore=[v for v in tf.global_variables() if "foobar" in v.name] vars_to_restore_dict = {} # make the dictionary, note that everything here will have “:0”, avoid it. for v in vars_to_restore: vars_to_restore_dict[v.name[:-2]] = v

To restore them, we need to remove the special “tail” of the variable names. For example, `tf.Variable 'v1:0' shape=(3,) dtype=float32_ref`

gives the name as `'v1:0'`

. When you have name scopes, you might have many v1 but with different names before it: `'foobar/v1:0'`

, `'other_scope/v1:0'`

. If you want to restore the first `v1`

, then the key of the dictionary you pass into the Saver function will be: `'foobar/v1'`

, so we apply something like `v.name[:-2]`

here to avoid the last two chars.

### Check saved models

How do you check if a certain variable has been saved in the model? Suppose you know the name of it, say “v1”.

The official code shows that it is easy to print all the variables:

# import the inspect_checkpoint library from tensorflow.python.tools import inspect_checkpoint as chkp # print all tensors in checkpoint file chkp.print_tensors_in_checkpoint_file("/tmp/model.ckpt", tensor_name='', all_tensors=True)

Then I looked into the official code on github here, found that we can use the following piece of code to check if a query variable name is actually in the saved model.

from tensorflow.python import pywrap_tensorflow reader = pywrap_tensorflow.NewCheckpointReader("/tmp/model.ckpt") var_to_shape_map = reader.get_variable_to_shape_map() # 'var_to_shape_map' is a dictionary contains every tensor in the model if 'v1' in var_to_shape_map.keys(): # do something...

So `var_to_shape_map`

will be a dictionary, the keys are the names of the saved variable. We can use the if statement to check. I tested with TensorFlow version 1.2.