Saving nested layers in TensorFlow
Saving nested layers in TensorFlow
How to serialize nested layers in TF 2.5.0
.
A custom TF layer is one that subclasses from tf.keras.layers.Layer
. This is powerful on its own, but a particularly desirable feature is to have nested layers. Serializing nested layers is a little bit of a headache, however, but necessary in order to save models with nested layers using model.save(...)
.

Saving a custom layer
Let’s first make a custom layer:
This includes the get_config
and from_config
methods which are used to serialize the custom layer. Custom attributes like self.x
are included by first calling the super class’s get_config()
, and then using config.update({...})
. Note that if you have a tf.Variable
like:
self.abc = self.add_weight(
name="abc",
shape=3,
trainable=False,
initializer=tf.constant_initializer(np.random.rand(3)),
dtype='float32'
)
you can add it in the config using the numpy
conversion:
config.update({"abc": self.abc.numpy()})
Let’s save a model with the custom layer:
This should work without error. Note that you have to build the model, e.g. by passing some data through it, before you save it.
Saving a model with a nested layer
Next, let’s create a nested layer that has an InnerLayer
:
It’s almost the same as before, but we have to add inner_layer
under get_config
. Note that we do not write:
"inner_layer": self.inner_layer.get_config()
as this leads to the error:
AttributeError: 'dict' object has no attribute '_serialized_attributes'
To save the nested model:
Here we have called:
model.save("test_save_nested", save_traces=False)
If we instead use save_traces=True
, we get the warning:
WARNING:absl:Found untraced functions such as inner_layer_1_layer_call_and_return_conditional_losses, inner_layer_1_layer_call_fn, inner_layer_1_layer_call_fn, inner_layer_1_layer_call_and_return_conditional_losses, inner_layer_1_layer_call_and_return_conditional_losses while saving (showing 5 of 5). These functions will not be directly callable after loading.
As discussed in the documentation, using save_traces=True
is not needed since we have defined custom get_config/from_config
methods:
New in TensoFlow 2.4 The argument save_traces has been added to model.save, which allows you to toggle SavedModel function tracing. Functions are saved to allow the Keras to re-load custom objects without the original class definitons, so when save_traces=False, all custom objects must have defined get_config/from_config methods. When loading, the custom objects must be passed to the custom_objects argument. save_traces=False reduces the disk space used by the SavedModel and saving time.
Saving models with many nested layers in a dictionary
Finally, let’s look at a nested layer that contains multiple inner layers, stored in a dictionary:
Here we have two inner layers in the dictionary:
{
"lyr1": InnerLayer(x_inner),
"lyr2": InnerLayer(x_inner)
}
If we just use a naive from_config
like this:
@classmethod
def from_config(cls, config):
return cls(**config)
You will get the error:
AttributeError: 'dict' object has no attribute '_serialized_attributes'
This is because if we print out the config dictionary, we see:
Config input: {'name': 'nested_dict_layer', 'trainable': True, 'dtype': 'float32', 'x_outer': 0.2, 'inner_layers': {'lyr1': {'class_name': 'InnerLayer', 'config': {'name': 'inner_layer_1', 'trainable': True, 'dtype': 'float32', 'x': 0.5}}, 'lyr2': {'class_name': 'InnerLayer', 'config': {'name': 'inner_layer_2', 'trainable': True, 'dtype': 'float32', 'x': 0.5}}}}
Instead, we must recreate the inner layers in the from_config
:
@classmethod
def from_config(cls, config):
print("Config input: ", config)
inner_layers = {}
for key,val in config["inner_layers"].items():
inner_layers[key] = InnerLayer(**val['config'])
config["inner_layers"] = inner_layers
print("Config recreated: ", config)
return cls(**config)
Then the recreated config is correct:
Config recreated: {'name': 'nested_dict_layer', 'trainable': True, 'dtype': 'float32', 'x_outer': 0.2, 'inner_layers': {'lyr1': <__main__.InnerLayer object at 0x7fd208fe0d50>, 'lyr2': <__main__.InnerLayer object at 0x7fd1d851af90>}}
And saving/loading the model works as expected:
Conclusion
Done! Small tricks to help serialize your nested layers for saving and loading.