Implementing Swish Activation Function in Keras
Review of Keras
Keras is a favorite tool among many in Machine Learning. TensorFlow is even replacing their high level API with Keras come TensorFlow version 2. For those new to Keras. Keras is called a “front-end” api for machine learning. Using Keras you can swap out the “backend” between many frameworks in eluding TensorFlow, Theano, or CNTK officially. Although one of my favorite libraries PlaidML have built their own support for Keras.
This kind of backend agnostic framework is great for developers. If using Keras directly you can use PlaidML backend on MacOS with GPU support while developing and creating your ML model. Then when you are ready for production you can swap out the backend for TensorFlow and have it serving predictions on a Linux server. All without changing any code just a configuration file.
At some point in your journey you will get to a point where Keras starts limiting what you are able to do. It is at this point TensorFlow’s website will point you to their “expert” articles and start teaching you how to use TensorFlow’s low level api’s to build neural networks without the limitations of Keras.
Before jumping into this lower level you might consider extending Keras before moving past it. This can be a great option to save reusable code written in Keras and to prototype changes to your network in a high level framework that allows you to move quick.
What is an Activation Function
If you are new to machine learning you might have heard of activation functions but not quite sure how they work outside of just setting the typical
ReLU on your layers. Let us do a quick recap just to make sure we know why we might want a custom one.
Activation functions are quite important to your layers. They sit at the end of your layers as little gate keepers. As gate keepers they affect what data gets though to the next layer if any data at all is allowed to pass them. What kind of complex mathematics is going on that determine this gatekeeping function? Let us take a look at the
Rectified Linear Unitreferred to as
ReLU. This is executed by the programming function
max(0, x). Yup that is it! Simple making sure the value returned doesn’t go below 0.
This simple gatekeeping function has become arguably the most popular of activation functions. This is mostly due to how fast it is to run the
max function. However
ReLU has limitations.
Why the Swish Activation Function
There is one glaring issue to the
Relu function. In machine learning we learn from our errors at the end of our forward path, then during the backward pass update the weights and bias of our network on each layer to make better predictions. What happens during this backward pass between two neurons one of which returned a negative number really close to 0 and another one that had a large negative number? During this backward pass they would be treated as the same. There would be no way to know one was closer to 0 than the other one because we removed this information during the forward pass. Once they hit 0 it is rare for the weight to recover and will remain 0 going forward. This is called the ‘Dying ReLU Problem’
There are functions that try to address this problem like the
Leaky ReLU or the
Leaky ReLU and
ELU functions both try to account for the fact that just returning 0 isn’t great for training the network.
ELU typically out preforms
ReLU and its leaky cousin. However there is one glaring issue with this function. The
ELUcalculation used is dependent on the value of
x. This branching conditional check is expensive when compared to its linear relatives. As software developers we don’t think much about branching statements. However, in the world of ML branching can be too costly sometimes.
Let us go ahead and define the math behind each of these methods.
????(exp(x) - 1) if x < 0 else x
Looking at Swish we can see it is defined as the following:
x * sigmoid(???? * x) in the original paper they showed great results using ???? = 1 and that is what we used in the graph below.
For added fun I included a gif of the swish function so you can see what happens as we change the ???? value.
The big win we get with swish is it outperforms ReLU by about 0.6%-0.9% while costing close to the same computationally. You can find a graphing playground with a few activation functions defined and some values being passed through them.Activation Functions. The research paper on Swish can be found here: 1710.05941v1 Swish: a Self-Gated Activation Function
Defining Swish in Keras
Okay so we are sold on Swish and want to put it in all of out networks right? Maybe not quite yet, but given how easy it is to swap out we at least want to implement it and see if it can help our network improve.
In a simple network you might have something that looks like the below code. Let us see how we can use our own activation function.
model.add(Dense(256, activation = "relu"))
model.add(Dense(100, activation = "relu"))
First off we are going to create our activation function as a simple python function with two parameters. We could just leave the beta out of our function. Given the paper’s specifications keep it as a variable, we can follow that same construct.
from keras.backend import sigmoid
def swish(x, beta = 1):
return (x * sigmoid(beta * x))
Next we register this custom object with Keras. For this we get our custom objects, tell it to update, then pass in a dictionary with a key of what we want to call it and the activation function for it. Note here we pass the swish function into the Activation class to actually build the activation function.
from keras.utils.generic_utils import get_custom_objects
from keras.layers import Activation
Finally we can change our activation to say swish instead of relu.
model.add(Dense(256, activation = “swish”))
model.add(Dense(100, activation = “swish”))
Just like that we have now extended Keras with a new “state-of-the-art” activation function. By doing this we can help keep our models at the forefront of research while not jumping down just yet to TensorFlow’s low level APIs. You can find my notebook experimenting with the swish function referenced in this post here: Digit-Recognizer/Digit Recognizer – Swish.ipynb at master · nicollis/Digit-Recognizer · GitHub