Read this story on medium.

A cheat sheet for custom TensorFlow layers and models

How to redefine everything from numpy, and some actually useful tricks for part assignment, saving custom layers, etc.

A cheat sheet for custom TensorFlow layers and models

How to redefine everything from numpy, and some actually useful tricks for part assignment, saving custom layers, etc.

This article is about compiling all those TensorFlow manuals into one page. Image from Unsplash by Patrick Tomasso.

We’re talking about creating custom TensorFlow layers and models in:

TensorFlow v2.5.0.

No wasting time with introductions, let’s go:

Numpy mappings to TensorFlow

  • Keep in mind for all operations: in TensorFlow, the first axis is the batch_size.
  • Trick to get past the batch size when it’s inconvenient: wrap linear algebra in tf.map_fn , e.g. converting vectors of size n in a batch (batch_size x n) to diagonal matrices of size (batch_size x n x n) :
mats = tf.map_fn(lambda vec: tf.linalg.tensor_diag(vec), vecs)
The inputs must, following any transpositions, be tensors of rank >= 2 where the inner 2 dimensions specify valid matrix multiplication dimensions, and any further outer dimensions specify matching batch size.
  • Sum an array: np.sum(x)
  • Dot product of two vectors:,y)tf.tensordot(x,y,1)
  • Matrix product with vector A.b :,b)tf.linalg.matvec(A,b)
  • Trigonometry: np.sintf.math.sin
  • Transpose (batch_size x n x m)(batch_size x m x n) : np.transpose(x, axes=(0,2,1))tf.transpose(x, perm=[0,2,1])
  • Vector to diagonal matrix: np.diag(x)tf.linalg.tensor_diag(x)
  • Concatenate matrices: np.concatenate((A,B),axes=0)tf.concat([A,B],1)
  • Matrix flatten: [[A,B],[C,D]] into a single matrix:
tmp1 = tf.concat([A,B],2)
tmp2 = tf.concat([C,D],2)
mat = tf.concat([tmp1,tmp2],1)
  • Kronecker product of a vector vec of size n (makes a matrix n x n): tf.tensordot(vec,vec,axes=0) except the vectors usually come in a batch vecs of size (batch_size x n) , so we need to map:
mats = tf.map_fn(lambda vec: tf.tensordot(vec,vec,axes=0),vecs)
  • Zeros: np.zeros(n)tf.zeros_like(n) . Note that to avoid the error “Cannot convert a partially known TensorShape to a Tensor”, you should use tf.zeros_like instead of tf.zeros since the former does not require the size to be known until runtime, see also here.

Part assignment by indexes

In numpy we can just write:

a = np.random.rand(3)
a[0] = 5.0

This doesn’t work in TensorFlow:

a = tf.Variable(initial_value=np.random.rand(3),dtype='float32')
a[0] = 5.0
# TypeError: 'ResourceVariable' object does not support item assignment

Instead, you can add/subtract unit vectors/matrices/tensors with the appropriate values. The tf.one_hot function can be used to construct these:

e = tf.one_hot(indices=0,depth=3,dtype='float32')
# <tf.Tensor: shape=(3,), dtype=float32, numpy=array([1., 0., 0.], dtype=float32)>

and then change the value like this:

a = a + (new_val - old_val) * e

where a is the TF variable, new_val and old_val are floats, and e is the unit vector.

For example, in the previous example:

Part assignment by indexes for matrices, etc.

Same as the vector example, but to construct the unit matrices you can use this helper function:

which you can adapt to higher order tensors as well.

Getting the shape of things

Use tf.shape(x) instead of x.shape . You can read this discussion here: or:

Do not use .shape; rather use tf.shape. However, h = tf.shape(x)[1]; w = tf.shape(x)[2] will let h, w be symbolic or graph-mode tensors (integer) that will contain a dimension of x. The value will be determined runtime. In such a case, tf.reshape(x, [-1, w, h]) will produce a (symbolic) tensor of shape [?, ?, ?] (still unknown) whose tensor shape will be known on runtime.

What to subclass — tf.keras.layers.Layer or tf.keras.Model?

Model if you want to use fit or evaluate or something similar, otherwise Layer .

From the docs:

Typically you inherit from keras.Model when you need the model methods like:,Model.evaluate, and (see Custom Keras layers and models for details). One other feature provided by keras.Model (instead of keras.layers.Layer) is that in addition to tracking variables, a keras.Model also tracks its internal layers, making them easier to inspect.

What do I need to implement to subclass a Layer?

  • Obviously call super in the constructor.
  • Obviously implement the call method.
  • You should implement get_config and from_config — they are needed often for saving the layer, e.g. if save_traces=False in the model.
  • @tf.keras.utils.register_keras_serializable(package="my_package") should be added at the top of the class. From the docs:
This decorator injects the decorated class or function into the Keras custom object dictionary, so that it can be serialized and deserialized without needing an entry in the user-provided custom object dict.

This really means that if you do not put it, you must later load your model containing the layer like this Keras documentation describes:

loaded_1 = keras.models.load_model(
"my_model", custom_objects={"MyLayer": MyLayer}

which is kind of a pain.

What do I need to implement to subclass a Model?

Saving and loading models and layers

If you have correctly defined get_config and from_config for each of your layers and models as above, you can just save using:"model", save_traces=False)
# Or just save the weights

Note that you do not need save_traces if you have defined those correctly. This also makes the saved models much smaller!

To load, just use:

model = tf.keras.models.load_model("model")
# Or just load the weights for an already existing model

Note that we don’t need to specify and custom_objects if we correctly used the


decorator everywhere.

How to reduce the size of TensorBoard callbacks if you have a large custom model

Similar to how save_traces reduces the size of saved models, you can use write_graph=False in the TensorBoard callback to reduce the size of these files:

logdir = os.path.join("logs","%Y%m%d-%H%M%S"))
tensorboard_callback = tf.keras.callbacks.TensorBoard(logdir,

Or, for a ModelCheckpoint , you can use save_weights_only=True to just save the weights and not the model at all:

val_checkpoint = tf.keras.callbacks.ModelCheckpoint(


Hope this starter-set of tricks helped!