Read this story on medium.

Saving nested layers in TensorFlow

How to serialize nested layers in TF 2.5.0.

Saving nested layers in TensorFlow

How to serialize nested layers in TF 2.5.0.

A custom TF layer is one that subclasses from tf.keras.layers.Layer . This is powerful on its own, but a particularly desirable feature is to have nested layers. Serializing nested layers is a little bit of a headache, however, but necessary in order to save models with nested layers using .

Nested layers. Image credit: author.

Saving a custom layer

Let’s first make a custom layer:

This includes the get_config and from_config methods which are used to serialize the custom layer. Custom attributes like self.x are included by first calling the super class’s get_config() , and then using config.update({...}) . Note that if you have a tf.Variable like: = self.add_weight(

you can add it in the config using the numpy conversion:


Let’s save a model with the custom layer:

This should work without error. Note that you have to build the model, e.g. by passing some data through it, before you save it.

Saving a model with a nested layer

Next, let’s create a nested layer that has an InnerLayer :

It’s almost the same as before, but we have to add inner_layer under get_config . Note that we do not write:

"inner_layer": self.inner_layer.get_config()

as this leads to the error:

AttributeError: 'dict' object has no attribute '_serialized_attributes'

To save the nested model:

Here we have called:"test_save_nested", save_traces=False)

If we instead use save_traces=True , we get the warning:

WARNING:absl:Found untraced functions such as inner_layer_1_layer_call_and_return_conditional_losses, inner_layer_1_layer_call_fn, inner_layer_1_layer_call_fn, inner_layer_1_layer_call_and_return_conditional_losses, inner_layer_1_layer_call_and_return_conditional_losses while saving (showing 5 of 5). These functions will not be directly callable after loading.

As discussed in the documentation, using save_traces=True is not needed since we have defined custom get_config/from_config methods:

New in TensoFlow 2.4 The argument save_traces has been added to, which allows you to toggle SavedModel function tracing. Functions are saved to allow the Keras to re-load custom objects without the original class definitons, so when save_traces=False, all custom objects must have defined get_config/from_config methods. When loading, the custom objects must be passed to the custom_objects argument. save_traces=False reduces the disk space used by the SavedModel and saving time.

Saving models with many nested layers in a dictionary

Finally, let’s look at a nested layer that contains multiple inner layers, stored in a dictionary:

Here we have two inner layers in the dictionary:

"lyr1": InnerLayer(x_inner),
"lyr2": InnerLayer(x_inner)

If we just use a naive from_config like this:

def from_config(cls, config):
return cls(**config)

You will get the error:

AttributeError: 'dict' object has no attribute '_serialized_attributes'

This is because if we print out the config dictionary, we see:

Config input:  {'name': 'nested_dict_layer', 'trainable': True, 'dtype': 'float32', 'x_outer': 0.2, 'inner_layers': {'lyr1': {'class_name': 'InnerLayer', 'config': {'name': 'inner_layer_1', 'trainable': True, 'dtype': 'float32', 'x': 0.5}}, 'lyr2': {'class_name': 'InnerLayer', 'config': {'name': 'inner_layer_2', 'trainable': True, 'dtype': 'float32', 'x': 0.5}}}}

Instead, we must recreate the inner layers in the from_config :

def from_config(cls, config):
print("Config input: ", config)

inner_layers = {}
for key,val in config["inner_layers"].items():
inner_layers[key] = InnerLayer(**val['config'])
config["inner_layers"] = inner_layers

print("Config recreated: ", config)

return cls(**config)

Then the recreated config is correct:

Config recreated:  {'name': 'nested_dict_layer', 'trainable': True, 'dtype': 'float32', 'x_outer': 0.2, 'inner_layers': {'lyr1': <__main__.InnerLayer object at 0x7fd208fe0d50>, 'lyr2': <__main__.InnerLayer object at 0x7fd1d851af90>}}

And saving/loading the model works as expected:


Done! Small tricks to help serialize your nested layers for saving and loading.