Johannes von Oswald - Mechanistic Interpretability of in-context learning


I'm super happy to talk about mechanistic interpretability of in-context learning.

I'm at Google. and this has been work with colleagues at Google and ETH Zurich.

the motivation is that we wanted to understand the transformer architecture a bit better, and

especially it's fascinating few-shot learning capabilities that seem to be magically appearing

if you just train these models autoregressively. So you can provide a text description of new tasks,

and the model seems to learn, based on examples that you give here, and,

what we want to study is: What happens here? Why does this happen? 

Just to jump directly to the hypothesis...

We think that next token prediction is actually a gradient-based mesa-optimizer,

that is implemented in the transformer architecture. And because there is this nice optimizer in the architecture,

you can actually repurpose this for super cool things. For example, in-context learning or

potentially chain-of-thought, or other things.

Just a visualisation of the hypothesis. This huge thing is a transformer.

And in the beginning, you get the sequence. The first thing that the transformer is doing

is kind of looking at the sequence and deciding what it can consider an input and a target for dataset.

That's what we call a mesa-dataset. The most simple instance of this is just to consider a sliding window.

So you just slide over your sequence and you consider every element as an input and a target.

And then the hypothesis is that the transformer inside the architecture is optimising,

let's say, fast weights... so implicitly parameters that it's learning based on the sequence

and using as a next token prediction. First step copy stuff together, and then learn implicit fast weights.

We don't really study this in language models, but only on a very, very toyish setting. So we consider,

linear dynamical systems, where W the teacher... W* here is an orthogonal matrix.

So every sequence comes from a newly, freshly-sampled orthogonal teacher,

You sample some x(0)s and perturb a little bit of noise.

And you provide this data to the transformer, you train this autoregressively...

On every token, try a next token prediction.

This is not a classification task, so also slightly different as an LLMs.

What is nice about the setting is that we actually know ground truth; a solution to this.

This is recursive least squares. So just a least squares solver here, regularised

for every time step, T, you want a new one because you get new new data per per time.

So the transformer - to solve this problem very well - needs to implement a recursive least-squares algorithm.

And this actually requires T matrix inverses inversions in the forward pass.

So what can we say about the transformer architecture? Actually, it can solve that problem pretty well.

One nice result that was actually in a previous paper, is that a linear self-attention layer -- let's consider it different to Softmax,

but still kind of close to the classic transformer architecture -- can very easily implement a gradient step.

And now, in follow up works, what you can also show...

so you see, the similarities of the expressions here ... that a single linear self attention layer can actually implement a summand of this,

of a truncated Neumann series so actually can invert these matrices like, projected matrices where you,

multiply with your X(T). Also: super well.

So this implies that a transformer architecture can actually invert these T matrices very efficiently in log(log(1/epsilon) when,

epsilon is the error that you want to have on the final matrix inversion.

OK, coming to the last slide. So what?

For a single linear, self attention layer, you can study this very thoroughly.

When you go deeper, we fall back to linear probes and yeah,

you get - though this is a cherry picked result - you get here very nice probes.

In the B figure, you see that if you probe for the target...

This is showing the... what the transformer should predict at the end...

You see over the layers it's getting better and better and better over time.

And you actually can probe for the red as well.

So you can probe for the approximated projected matrix inverse here.

And you also see this nicely decreasing so that it actually implements this 

.. or seems to be implementing something similar to this... and kind of closing the circle.

You can now use this as in-context learner as well.