To calculate derivatives of nested functions, we can use a rule from calculus: The Chain Rule.

### A picture is worth a thousand equations

As you probably noticed, the maths is starting to get quite dense. When we start working with neural networks, we can easily get 100s or 1000s of functions deep so to get a handle on things, we’ll need a different strategy. Helpfully, there is one: turning it into a graph.

We can start with some rules:

Variables are represented with circles and operations are represented with boxes

Inputs to an operation are represented with arrows that point to the operation box. Outputs point away.

For example, here is the diagram for

And that’s it! All of the equations we’ll be working with can be represented graphically using these simple rules. To try it out, let’s draw the diagram for a more complex formula:

This is an example of a structure called a graph (also called a network). A lot of problem in computer science get much easier if you can represent them with a graph and this is no exception.

The real power of these diagrams is that they can also help us with our derivatives. Take

From before, we can find its derivatives by differentiating each operation wrt its inputs and multiplying the results together. In this case, we get:

We can also graph it like this:

If you imagine walking from to each of the inputs, you might notice a similarity between the edges you pass through and the equations above. If you walk from to , you’ll pass through `a->c->d`

. Similarly, if you walk from to , you’ll pass through `a->d->e`

. Notice that both paths go through `c`

, the edge coming out of `add`

that corresponds to the input . Also, both equations include the term .

If I rename the edges as follows:

We can see that going from to , we pass through , and . If we multiply these together, we get exactly !

It turns out that this rule works in general:

If we have some operation , we should label the edge corresponding to input with

Then, if we want to find the derivative of the output node wrt any of the inputs,

The derivative of an output variable wrt one of the input variables can be found by traversing the graph from the output to the input and multiplying together the derivatives for every edge on the path

To cover every edge case, there are some extra details

If a graph contains multiple paths from the output to an input, then the derivative is the sum of the products for each path

This comes from the case we saw earlier where when we have different functions that have the same input we have to add their derivative chains together.

If an edge is not the input to any function, its derivative is 1

This covers the edge that leads from the final operation to the output. You can think of the edge having the derivative

And that’s it! Let’s try it out with :

Here, instead of writing the formulae for each derivative, I have gone ahead and calculated their actual values. Instead of just figuring out the formulae for a derivative, we want to calculate its value when we plug in our input parameters.

All that remains is to multiply the local derivatives together along each path. We’ll call the product of derivatives along a single path a chain (after the chain rule)

We can get from to via the green path and the red path. Following these paths, we get: Along the green path we get:

Adding these together, we get

If we work out the derivative algebraically:

We can see that it seems to work! Calculating is left as an exercise for the reader (I’ve always wanted to say that).

To summarise, we have invented the following algorithm for calculating of a variable wrt its inputs:

- Turn the equation into a graph
- Label each edge with the appropriate derivative
- Find every path from the output to the input variable you care about
- Follow each path and multiply the derivatives you pass through
- Add together the results for each path

Now that we have an algorithm in pictures and words, let’s turn it into code.

### The Algorithm™

Surprisingly, we have actually already converted our functions into graphs. If you recall, when we generate a tensor from an operation, we record the inputs to the operation in the output tensor (in `.args`

). We also stored the functions to calculate derivatives for each of the inputs in `.local_derivatives`

which means that we know both the destination and derivative for every edge that points to a given node. This means that we’ve already completed steps 1 and 2.

The next challenge is to find all paths from the tensor we want to differentiate to the input tensors that created it. Because none of our operations are self referential (outputs are never fed back in as inputs), and all of our edges have a direction, our graph of operations is a directed acyclic graph or DAG. The property of the graph having no cycles means that we can find all paths to every parameter pretty easily with a Breadth First Search (or Depth First Search but BFS makes some optimisations easier as we’ll see in part 2).

To try it out, let’s recreate that giant graph we made earlier. We can do this by first calculating from the inputs

```
y = Tensor(1)
m = Tensor(2)
x = Tensor(3)
c = Tensor(4)
# L = (y - (mx + c))^2
left = _sub(y, _add(_mul(m, x), c))
right = _sub(y, _add(_mul(m, x), c))
L = _mul(left, right)
# Attaching names to tensors will make our
# diagram look nicer
y.name = "y"
m.name = "m"
x.name = "x"
c.name = "c"
L.name = "L"
```

And then using Breadth First Search to do 3 things:

- Find all nodes
- Find all edges
- Find all paths from to our parameters

We haven’t implemented a simple way to check whether two tensors are identical so we’ll need to compare hashes.

```
edges = []
stack = [(L, [L])]
nodes = []
edges = []
while stack:
node, current_path = stack.pop()
# Record nodes we haven't seen before
if hash(node) not in [hash(n) for n in nodes]:
nodes.append(node)
# If we have reached a parameter (it has no arguments
# because it wasn't created by an operation) then
# record the path taken to get here
if not node.args:
if node.paths is None:
node.paths = []
node.paths.append(current_path)
continue
for arg in node.args:
stack.append((arg, current_path + [arg]))
# Record every new edge
edges.append((hash(node), hash(arg)))
```

Now we’ve got all of the edges and nodes, we have complete knowledge of our computational graph. Let’s use networkx to plot it

```
# Assign a unique integer to each
# unnamed node so we know which
# node is which in the picture
labels = {}
for i, node in enumerate(nodes):
if node.name is None:
labels[hash(node)] = str(i)
else:
labels[hash(node)] = node.name
graph = nx.DiGraph()
graph.add_edges_from(edges)
pos = nx.nx_agraph.pygraphviz_layout(graph, prog="dot")
nx.draw(graph, pos=pos, labels=labels)
```

If you squint a bit, you can see that this looks like the graph we made earlier! Let’s take a look at the paths the algorithm found from to .

```
for path in x.paths:
steps = []
for step in path:
steps.append(labels[hash(step)])
print("->".join(steps))
```

```
L->1->2->4->x
L->8->9->10->x
```

The paths look correct! All we need to do now is to modify the algorithm a bit to keep track of the chain of derivatives along each path.

```
y = Tensor(1)
m = Tensor(2)
x = Tensor(3)
c = Tensor(4)
# L = (y - (mx + c))^2
left = _sub(y, _add(_mul(m, x), c))
right = _sub(y, _add(_mul(m, x), c))
L = _mul(left, right)
y.name = "y"
m.name = "m"
x.name = "x"
c.name = "c"
L.name = "L"
```

```
stack = [(L, [L], [])]
nodes = []
edges = []
while stack:
node, current_path, current_chain = stack.pop()
# Record nodes we haven't seen before
if hash(node) not in [hash(n) for n in nodes]:
nodes.append(node)
# If we have reached a parameter (it has no arguments
# because it wasn't created by an operation) then
# record the path taken to get here
if not node.args:
if node.paths is None:
node.paths = []
if node.chains is None:
node.chains = []
node.paths.append(current_path)
node.chains.append(current_chain)
continue
for arg, op in zip(node.args, node.local_derivatives):
next_node = arg
next_path = current_path + [arg]
next_chain = current_chain + [op]
stack.append((arg, next_path, next_chain))
# Record every new edge
edges.append((hash(node), hash(arg)))
```

Let’s check if the derivatives were recorded correctly.

```
print(f"Number of chains: {len(x.chains)}")
for chain in x.chains:
print(chain)
```

```
Number of chains: 2
[Tensor(-9), Tensor(-1), Tensor(1), Tensor(2)]
[Tensor(-9), Tensor(-1), Tensor(1), Tensor(2)]
```

Looks reasonable so far. We have 2 identical paths, each with 4 derivatives (one for each edge in the path) as expected.

Let’s multiply the derivatives together along each path and add the total for each path together and see if we get the right answer.

According my calculations (and Wolfram Alpha) the derivative of wrt is: Plugging the values for our tensors in, we get

```
total_derivative = Tensor(0)
for chain in x.chains:
chain_total = Tensor(1)
for step in chain:
chain_total = _mul(chain_total, step)
total_derivative = _add(total_derivative, chain_total)
total_derivative
```

The correct answer! It looks like our algorithm works. All that remains is to put all the pieces together.