logo

Saving nested layers in TensorFlow

How to serialize nested layers in TF 2.5.0.


Saving nested layers in TensorFlow

How to serialize nested layers in TF 2.5.0.

A custom TF layer is one that subclasses from tf.keras.layers.Layer . This is powerful on its own, but a particularly desirable feature is to have nested layers. Serializing nested layers is a little bit of a headache, however, but necessary in order to save models with nested layers using model.save(...) .

Nested layers. Image credit: author.

Saving a custom layer

Let’s first make a custom layer:

This includes the get_config and from_config methods which are used to serialize the custom layer. Custom attributes like self.x are included by first calling the super class’s get_config() , and then using config.update({...}) . Note that if you have a tf.Variable like:

self.abc = self.add\_weight(  
   ="abc",  
   =3,  
   =False,  
   =tf.constant\_initializer(np.random.rand(3)),  
   ='float32'  
   )

you can add it in the config using the numpy conversion:

config.update({"abc": self.abc.numpy()})

Let’s save a model with the custom layer:

This should work without error. Note that you have to build the model, e.g. by passing some data through it, before you save it.

Saving a model with a nested layer

Next, let’s create a nested layer that has an InnerLayer :

It’s almost the same as before, but we have to add inner_layer under get_config . Note that we do not write:

"inner\_layer": self.inner\_layer.get\_config()

as this leads to the error:

AttributeError: 'dict' object has no attribute '\_serialized\_attributes'

To save the nested model:

Here we have called:

model.save("test\_save\_nested", =False)

If we instead use save_traces=True , we get the warning:

WARNING:absl:Found untraced functions such as inner\_layer\_1\_layer\_call\_and\_return\_conditional\_losses, inner\_layer\_1\_layer\_call\_fn, inner\_layer\_1\_layer\_call\_fn, inner\_layer\_1\_layer\_call\_and\_return\_conditional\_losses, inner\_layer\_1\_layer\_call\_and\_return\_conditional\_losses while saving (showing 5 of 5). These functions will not be directly callable after loading.

As discussed in the documentation, using save_traces=True is not needed since we have defined custom get_config/from_config methods:

Saving models with many nested layers in a dictionary

Finally, let’s look at a nested layer that contains multiple inner layers, stored in a dictionary:

Here we have two inner layers in the dictionary:

{  
"lyr1": InnerLayer(x\_inner),  
"lyr2": InnerLayer(x\_inner)  
}

If we just use a naive from_config like this:

@classmethod  
def from\_config(, ):  
   return cls(**config)

You will get the error:

AttributeError: 'dict' object has no attribute '\_serialized\_attributes'

This is because if we print out the config dictionary, we see:

Config input:  {'name': 'nested\_dict\_layer', 'trainable': True, 'dtype': 'float32', 'x\_outer': 0.2, 'inner\_layers': {'lyr1': {'class\_name': 'InnerLayer', 'config': {'name': 'inner\_layer\_1', 'trainable': True, 'dtype': 'float32', 'x': 0.5}}, 'lyr2': {'class\_name': 'InnerLayer', 'config': {'name': 'inner\_layer\_2', 'trainable': True, 'dtype': 'float32', 'x': 0.5}}}}

Instead, we must recreate the inner layers in the from_config :

@classmethod  
def from\_config(, ):  
   print("Config input: ", config)  
     
   inner\_layers = {}  
   for key,val in config["inner\_layers"].items():  
      inner\_layers[key] = InnerLayer(**val['config'])  
   config["inner\_layers"] = inner\_layers  
     
   print("Config recreated: ", config)  
     
   return cls(**config)

Then the recreated config is correct:

Config recreated:  {'name': 'nested\_dict\_layer', 'trainable': True, 'dtype': 'float32', 'x\_outer': 0.2, 'inner\_layers': {'lyr1': <\_\_main\_\_.InnerLayer object at 0x7fd208fe0d50>, 'lyr2': <\_\_main\_\_.InnerLayer object at 0x7fd1d851af90>}}

And saving/loading the model works as expected:

Conclusion

Done! Small tricks to help serialize your nested layers for saving and loading.

By Oliver K. Ernst, Ph.D. on June 16, 2021.

Canonical link

Exported from Medium on July 24, 2022.