Saving nested layers in TensorFlow

How to serialize nested layers in TF 2.5.0.

A custom TF layer is one that subclasses from tf.keras.layers.Layer . This is powerful on its own, but a particularly desirable feature is to have nested layers. Serializing nested layers is a little bit of a headache, however, but necessary in order to save models with nested layers using model.save(...) .

Nested layers. Image credit: author.

Saving a custom layer

Let’s first make a custom layer:

This includes the get_config and from_config methods which are used to serialize the custom layer. Custom attributes like self.x are included by first calling the super class’s get_config() , and then using config.update({...}) . Note that if you have a tf.Variable like:

self.abc = self.add\_weight(  

you can add it in the config using the numpy conversion:

config.update({"abc": self.abc.numpy()})

Let’s save a model with the custom layer:

This should work without error. Note that you have to build the model, e.g. by passing some data through it, before you save it.

Saving a model with a nested layer

Next, let’s create a nested layer that has an InnerLayer :

It’s almost the same as before, but we have to add inner_layer under get_config . Note that we do not write:

"inner\_layer": self.inner\_layer.get\_config()

as this leads to the error:

AttributeError: 'dict' object has no attribute '\_serialized\_attributes'

To save the nested model:

Here we have called:

model.save("test\_save\_nested", =False)

If we instead use save_traces=True , we get the warning:

WARNING:absl:Found untraced functions such as inner\_layer\_1\_layer\_call\_and\_return\_conditional\_losses, inner\_layer\_1\_layer\_call\_fn, inner\_layer\_1\_layer\_call\_fn, inner\_layer\_1\_layer\_call\_and\_return\_conditional\_losses, inner\_layer\_1\_layer\_call\_and\_return\_conditional\_losses while saving (showing 5 of 5). These functions will not be directly callable after loading.

As discussed in the documentation, using save_traces=True is not needed since we have defined custom get_config/from_config methods:

Saving models with many nested layers in a dictionary

Finally, let’s look at a nested layer that contains multiple inner layers, stored in a dictionary:

Here we have two inner layers in the dictionary:

"lyr1": InnerLayer(x\_inner),  
"lyr2": InnerLayer(x\_inner)  

If we just use a naive from_config like this:

def from\_config(, ):  
   return cls(**config)

You will get the error:

AttributeError: 'dict' object has no attribute '\_serialized\_attributes'

This is because if we print out the config dictionary, we see:

Config input:  {'name': 'nested\_dict\_layer', 'trainable': True, 'dtype': 'float32', 'x\_outer': 0.2, 'inner\_layers': {'lyr1': {'class\_name': 'InnerLayer', 'config': {'name': 'inner\_layer\_1', 'trainable': True, 'dtype': 'float32', 'x': 0.5}}, 'lyr2': {'class\_name': 'InnerLayer', 'config': {'name': 'inner\_layer\_2', 'trainable': True, 'dtype': 'float32', 'x': 0.5}}}}

Instead, we must recreate the inner layers in the from_config :

def from\_config(, ):  
   print("Config input: ", config)  
   inner\_layers = {}  
   for key,val in config["inner\_layers"].items():  
      inner\_layers[key] = InnerLayer(**val['config'])  
   config["inner\_layers"] = inner\_layers  
   print("Config recreated: ", config)  
   return cls(**config)

Then the recreated config is correct:

Config recreated:  {'name': 'nested\_dict\_layer', 'trainable': True, 'dtype': 'float32', 'x\_outer': 0.2, 'inner\_layers': {'lyr1': <\_\_main\_\_.InnerLayer object at 0x7fd208fe0d50>, 'lyr2': <\_\_main\_\_.InnerLayer object at 0x7fd1d851af90>}}

And saving/loading the model works as expected:


Done! Small tricks to help serialize your nested layers for saving and loading.


Oliver K. Ernst
June 16, 2021

Read this on Medium