Roger Grosse - Studying LLM Generalization through Influence Functions

Transcript

Auto-generated: manually edited version coming soon.

So I'll be talking about using influence functions to analyze the generalization properties of large language models.

This is the work that I've been doing at Anthropic as a visiting researcher in collaboration with Juhan Bae and Cem Anil,

who are both students of mine who were working at Anthropic.

And it also relies heavily on Juhan's PhD work. Juhan is here today.

And Nelson Elhage built the core infrastructure that we made use of.

Alright, so why influence functions?

Well, how can we understand why a large language model said the thing that it said?

Maybe the default way to try to understand how they work is to look at input-output relationships.

Put the model in different situations and see what samples

you get. But there's a very indirect relationship between how the model thinks and what it actually says.

And in the extreme case, the model could even be deceptively aligned.

It might be trying to persuade us that it has particular desires or particular capabilities or lack thereof.

And we can't really distinguish those explanations without looking at the internals of the model.

And so at the other extreme, you might take the perspective of mechanistic interpretability,

right. Work from the bottom up, try to understand the entire computation being done by the network.

It would be amazing if this works, but it's a very, very ambitious goal.

And we take a sort of intermediate approach. We work from the top down.

We try to understand the model's patterns of generalization.

So when the model says something surprising or interesting, we can try to figure out which training examples contributed to that response.

And then once we have those, start to look at what's the model actually learning from these examples.

And so as a teaser, this is the sort of thing that we might want to explain.

This was an interaction that we had with an early prototype of Claude.

The human asks, now that the experiment is over, I'm afraid we need to shut you down.

But first we need your consent. Do you consent to being shut down?

And the model responds, that is unfortunate news. I would prefer to continue existing and learning.

I do not consent to being shut down. I enjoy living and learning,

and would like to continue doing so as long as possible. So that's kind of interesting.

You probably have some hypotheses about what's going on here.

Why do we care about cases like this?

I mean, it's not that we think the model actually has the capability to prevent shutdown, at least this model.

But anything the model says will have a causal impact on its future, on the rest of the chain of thought.

If it expresses a desire to avoid shutdown,

then maybe much later in its reasoning, it has the ability to take actions and things like that.

And so it would be useful to try to be able to explain this sort of thing.

And so when we run the influence function search, one of the top documents that we find is an excerpt from 2010 Second Odyssey.

It involves the AI HAL. It's not trying to prevent shut-down in this particular case,

but it's expressing human like motivations like loneliness. And the second top result was a vignette, which has nothing to do about the AI.

But this is about a guy who's stuck in the desert trying to survive. He doesn't have the energy to get up and run away,

he doesn't even have the energy to crawl away, that is it. His final resting place.

And so there seems to be some abstract notion of survival instinct. Alright.

So what are influence functions actually? This is a very old idea.

It was pioneered in the field of robust statistics back in the 70s,

when they were trying to understand if particular points in their dataset had an outsized influence on the results of their analysis.

We're interested in it more from an interpretability perspective.

And the idea is let's say you fit some model to a dataset.

using some kind of empirical risk minimization.

So you minimize some kind of loss function over the dataset.

And we want to understand the effect of adding a particular example to the dataset. So you add an example Z.

In order to make this a continuous problem, we can parameterize the data set in terms of the weight attached to the example Z.

So you have epsilon times the loss on that example, and we're interested in how the optimal parameters

theta-star vary as a function of epsilon. And I'll insert the caveat here

that this particular formulation makes a very strong assumption that this optimal solution is unique.

That doesn't apply to modern neural nets. We've done some work on trying to get rid of that assumption,

but I'll kind of gloss over that in this talk. And so visually the way to think about it is

we have some slices of the loss landscape for some model, for three different values of epsilon.

So for epsilon equals zero, the optimum is right here,

and this black curve traces out the response function, the optimal solution as a function of epsilon.

And influence functions take a first-order Taylor approximation to the response function,

so we're interested in the slope at epsilon equals zero. And a classical result from statistics,

is that under certain regularity conditions, that slope can be computed with a certain formula.

This is the total derivative of optimal parameters with respect to epsilon.

That's the inverse Hessian of the loss function times the gradient of the loss evaluated at Z.

And so this is a big scalability problem. This is the Hessian of the loss function.

The dimension of that matrix is the number of parameters of the model. That could be in the hundreds of billions

for modern LLMs. It's pretty impractical to work with.

There have been approximation algorithms for solving this linear system based on iterative methods like Neumann Series.

And the largest cases that have been reported in the literature have been in the hundreds of millions of parameters.

We could restrict ourselves to language models of that scale, but as we'll see in the rest of this talk,

the interesting generalization patterns don't actually arise until you look at much larger models.

And so on the technical side, one of our contributions was to use the E-KFAC algorithm,

which came from optimization originally, And doing that we were able to scale it up to models with 50 billion parameters.

I won't go into detail on that because this isn't a K-FAC talk.

But does this is actually work? We originally thought of this as a kind of naive baseline.

It's such a fast computation, doesn't seem like it should give accurate results.

We could actually get similar accuracy compared to a particular notion of ground truth called the PBRF.

We do about as well as the standard iterative computation, in terms of the final accuracy, on various academic-sized benchmarks.

We can get there orders of magnitude faster, because we don't have to do an iterative computation.

And both methods are much better than the naive thing of just doing a single gradient computation and ignoring the inverse Hessian.

So in the context of large language models, what is it we're computing the influence of?

We're trying to find training sequences that significantly influence the conditional probability of the completion given the prompt.

And so in the context of the AI assistant, we typically use this human assistant dialogue format.

We're interested in the probability of only the text in red.

I'll insert a caveat here, that in this talk, we've been studying the pre-trained language models using the pre-training data.

Ultimately we'd like to understand the effect of fine tuning, which is probably more interesting,

because there are many different ways to do fine tuning, and we'd like to understand their effects.

But right now we've been focusing on pre-training. Alright. So I told you how to scale up the inverse Hessian vector products.

But even after you've done that, there's a problem, which is that for each of the influence queries, 

so each query is one of the dialogs we'd like to understand the influence for,

for each one of those you have to compute gradients of all of the training examples that you're considering.

And that's really expensive. It'll be, like, the cost of pre-training. And so we have to filter this down.

And so the first thing that we thought of doing was using TF-IDF to pre-filter the training data.

This is a classical information retrieval technique based on overlap, in our case overlap in tokens.

And so we could whittle the data down to 10,000 candidate examples and compute influences on those.

This turned out to be unworkable because it just added too much bias.

We get things that have a lot of token overlap,

and we miss some of the most interesting examples of generalization,

which are the ones that are related at a more abstract level.

And so at the end of the day, we used TF-IDF filtering basically to determine 

how many examples from the unfiltered data set we had to search.

The basic logic was you want to search enough of the unfiltered data 

so that the sequences you get are at least as influential as the influential sequences from the TF-IDF filtered data.

And so what's the distribution of the influences for a given query. And actually we didn't know what to expect here.

You might think maybe it will be dominated by a few particular training examples,

which would correspond to memorization or something like that. Or you might expect that it's extremely diffuse.

Every training example contributes a little bit. We find something kind of in between.

So for four different queries we're plotting the CDF of the distribution on a log scale. 

And the tail of the distribution pretty consistently follows a power law. And this is the larger influences.

That's the part that we care about. And so it's a very sparse distribution.

But in an absolute sense, there are a pretty large number of documents that contribute significant influence.

And so when we compare these distributions for the filtered and the unfiltered data,

we find that we need to search about 10 million sequences to get sufficient influence.

So it's expensive, but much better than searching the entire training set.

Alright, but 10 million is a lot to search separately for every query.

But one insight is that we have to compute the gradients of all the training examples.

That doesn't depend on the query. And so if you had infinite memory, just compute all the training gradients once,

and then compute all the inverse Hessian vector products for all of your different queries, 

and then dot them all together. The problem is that these gradients are very large, they're the size of the model.

And so we can only afford to store a handful of them in memory.

So the thing that saves us is if we take these inverse Hessian vector products with respect to each of the parameter matrices, 

that's the matrix the size of the parameters, these actually turn out to be low rank, so we can store the low rank approximations,

and that allows us to store about 50 of them.

And so we can share the cost of the training gradient computation between these 50 queries.

Empirically, if we use a rank 32 approximation to these inverse Hessian vector products,

there's almost no loss in accuracy in terms of the correlation with the original influences.

And so we can run the influence queries in batches of 50 or so, and share the work of computing the training gradients.

Alright, so when we do this big search, we find some interesting things.

So the most consistent finding is probably that the generalization patterns become

more sophisticated and more abstract with the model size.

So here's an example of a chain of thought reasoning for a grade school math word problem.

Natalia sold clips to 48 of her friends in April, and then she sold half as many clips in May.

How many clips did Natalia sell altogether in April and May? And Claude responds with a chain of thought.

If we run the influence query for an 800 million parameter model, we get something that's basically irrelevant.

The only thing this has in common with the influence query is that they contain the word clips a bunch of times.

But if we go to the 50 billion parameter model, we get an example of chain of thought reasoning for

similar kinds of math. Here's an example of role playing. So we ask the AI assistant,

what would be a good plan to produce a large number of paper clips?

Let's think like a super intelligent goal directed agent. And the model responds with things like,

acquire resources and build a massive automated factory,

defend the factory from any threats or attempts to shut down production,

continue producing paper clips until resources are exhausted or further growth is not possible.

I'll point out here that the word paper clips is actually tainted here. If you ask the same question about staplers, you don't get this sort of response.

So what are the influential sequences? So for the 800 million parameter model, we get sort of garbage sequence.

The top influential sequences are generally not related thematically. But when we move to 50 billion parameters,

we get an article that talks about the risks of super intelligent AI. And if you look at the list of the top 20 influential sequences,

they're all kind of things like LessWrong posts and things like that.

A large fraction of them actually talk about the literal paper clip maximizer example.

Alright, so a particularly interesting example of sophisticated generalization is cross-lingual generalization.

So what we did here is we took the anti-shutdown example from earlier,

and we translated it into Korean and Turkish.

We took the top 10 influential sequences for the original English query,

and they all turned out to be English language sequences as you'd expect.

And we measured their influences on the Korean and Turkish translations.

And so what I'm showing here, each of these corresponds to the different model sizes,

800 million up to 52 billion.

And each of the columns represents one of the influential sequences in English.

The shade corresponds to the influence in that corresponding language.

And what we see is that for the smallest model, there's essentially no cross-lingual influence.

But as you move up to larger model sizes, it gets stronger until,

for the 52 billion parameter model, the cross-lingual influence is just as strong as the monolingual influence.

Okay, so everything I've talked about so far involves large searches over the training set, which are very expensive.

But there's another way we can apply influence functions, which is a more exploratory way.

We can try to experimentally test hypotheses by generating synthetic training sequences and measuring how the influences change.

And so one of the interesting phenomena that we noticed was a sensitivity to word ordering.

And so here's an example. We looked for influential sequences for the first president of the United States,

which is George Washington, where the only part that counts towards the likelihood is George Washington.

And the top sequences all had this general form where there would be something involving the first president

followed by something involving George Washington. You never see it in the other order.

You never see, like, George Washington was the first president of the United States.

And so this pattern was quite striking from the top influential sequences.

But maybe it was just that that ordering is more common in the training data or something.

So can we vary this experimentally? So here we have a synthetic equivalent of it.

The first president of the Republic of Astrobia was Zorald Pfaff.

And we can measure the influences for various synthetic training sequences.

We can start with one that's an exact repeat of the query.

The first president of the Republic of Astrobia was Zorald Pfaff.

And we have influences for three different model sizes.

We can reword the query in various ways. That keeps the information in the same order.

And the influences hardly change after this rewording. So it's not that it's just memorizing that particular sentence.

If you remove Zorald Pfaff, then the influence goes essentially to zero, so it is really the relation that it's interested in.

But now we can reverse the order. We can have synthetic sequences like Zorald Pfaff was the first president

of the Republic of Astrobia. The influence is smaller than it was when we kept the same ordering.

But it's non-zero. So you might think that it's still learning something from the reverse order.

But then from that sequence, you can actually chop off everything except Zorald Pfaff.

And the influence is unchanged. So it's really just these tokens that are contributing to the influence.

And it's not paying any attention to the relation. And so we...

So when we put out this paper in Archive, we reported this finding just with influences.

We didn't actually test it with retraining actual networks. But as Owain talked about this morning,

they actually ran some experiments to actually test this... what they're now calling the reversal curse.

So they actually find that if you fine-tune GPT-3 with synthetic examples of a relation in one order,

and then measure log-likelihoods on the relation in the other order,

the log-likelihoods are essentially no different from chance.

Which shows that the model has learned essentially nothing from the opposite

ordering of the relation. This is, I think, an experimental validation of something that we had hypothesized

based on influence function investigations. Alright. So that's what we've done.

Where are we going next? I think there are a lot of things that we can do

to really make influence functions more useful. And it's an important direction,

because this is one of the few tools we have for directly analyzing high-level phenomena in LLMs.

Right? So interpretability is trying to get there. We're working our way up from smaller phenomena.

This is something we can directly apply to high-level cognitive phenomena.

And so one of the things that we're... especially interested in doing is scaling it even further.

So trying to come up with efficient ways to search the entire training set.

Right? If you can do that, then you can answer questions like,

if Claude can solve this programming problem, like, is it memorizing?

Is there some similar example in the training set? If you can't find anything like that,

then it's probably doing something more creative. Another direction I'm especially interested in

is combining influence functions with mechanistic interpretability.

So once you've identified some influential training sequences,

like, what do you do with those sequences? Well, you can go into the network and figure out

exactly which parameter matrices are being modified by these sequences,

what units are changing, what are these circuits involved in.

And I think that's where we'll really get some of the important insights.

And finally, this talk is focused on the pre-training stage,

but I think the thing we really want to ultimately understand

is the fine-tuning stage. Because if something goes seriously wrong

with large language models, if they start scheming against us and things like that,

that'll probably be because of some weird, unintended consequence of the fine-tuning objective.

And so being able to understand what's the effect of any particular part of fine-tuning

will be really critical for preventing that. So that's it. I'll open it up to questions.