Sampling From Scratch

import math
import torch
import torch.nn.functional as F
from torch import Tensor, tensor
from jaxtyping import Float, Int
from transformers import AutoModelForCausalLM, AutoTokenizer

torch.set_grad_enabled(False); # disable backprop
device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
print("device:", device)

model_name = "gpt2"
model = AutoModelForCausalLM.from_pretrained(model_name, local_files_only=True).to(device)
tokenizer = AutoTokenizer.from_pretrained(model_name, local_files_only=True)
tokenizer.pad_token = tokenizer.eos_token

LLMs do not operate on words – each word is converted into a high dimensional vector that contains information that gets passed through the model. At each layer, the model reads the vector, performs some computation (attention or MLP) and writes it back to the vector.

We call this vector the residual stream. To initially create these vectors from a sentence, we have a large lookup table of each “word” (or sub-word, see here for more info) to a this high dimensional vector.

We call each “word” a token.
You can imagine token ~= word

This is 768 dimensions on GPT2, and can also be thought of the width of the model

depth being the number of layers.

We look up each word in an embedding table. This is a map of 50,000 words to a high dimensional embedding.

model.transformer.wte
Embedding(50257, 768)

Let’s see the first 10 dimensions of the token (word) 9246

token = 9246
first_n_dimensions = 10
model.transformer.wte.weight[token, :first_n_dimensions]
tensor([-0.0164, -0.0934,  0.2425,  0.1398,  0.0388, -0.2592, -0.2724, -0.1625,
         0.1683,  0.0829], device='mps:0', requires_grad=True)

And to find the corresponding string word associated with token 9246:

def decode(tokens) -> str:
    return tokenizer.decode(tokens)

print(f"decoded token: {repr(decode(token))}")
decoded token: 'cat'

Using the tokenize and decode functions, we can convert back and forth between a string and the initial model vectors (“embeddings”).

Notably the model adds a “batch” dimension to the input, which allows us to process multiple inputs at the same time, imagine this allows us to run “the cat sat on the mat” and “I took my dog for a walk” at the same time.

Input to a LLM is a list of tokens, which we call length sequence length (or seq / T (for time dimesion) for short.

def tokenize(input) -> Int[Tensor, "bs seq"]:
    return tokenizer(
        input,
        return_tensors="pt",
        padding=True,
        truncation=True,
        max_length=model.config.n_ctx,
    )["input_ids"].to(device)

prompt = 'the cat sat on a mat'
tokens = tokenize(prompt)
embeddings = model.transformer.wte.weight[tokens]

decoded = decode(tokens[0])

print(f"""\
# prompt
{prompt}

# tokens shape: {tuple(tokens.shape)}
{tokens.tolist()}

# decoded
{decoded}

# embeddings shape: {tuple(embeddings.shape)}
""")
# prompt
the cat sat on a mat

# tokens shape: (1, 6)
[[1169, 3797, 3332, 319, 257, 2603]]

# decoded
the cat sat on a mat

# embeddings shape: (1, 6, 768)

Output

Now given the prompt input, lets run the tokens through the model and look at the output. These are called logits.

logits = model(tokens).logits

print(f"""\
# Tokens ({tuple(tokens.shape)})

# Logit Output ({tuple(logits.shape)})
""")
# Tokens ((1, 6))

# Logit Output ((1, 6, 50257))

The input has shape, (batch size, sequence length), with output (batch size, sequence length, logits)

For each token in the sequence, the model outputs a score for every next token (50K) representing how likely that token is to come next.

For each token, we can see which token the model predicted as most likely.

for i in range(tokens.shape[1]):
    inp = decode(tokens[0, :i+1])
    pred = decode(logits[0, i].argmax())
    print(f"{repr(decode(tokens[0, :i+1]))} => {repr(pred)}")
'the' => ','
'the cat' => ','
'the cat sat' => ' on'
'the cat sat on' => ' the'
'the cat sat on a' => ' bench'
'the cat sat on a mat' => ','

So to continue generating tokens, we need to run an auto regressive function, that selects a token from the last word in the sequence, and append it to the prompt.

def generate(prompt, num_tokens, verbose=False):
    tokens = tokenize(prompt)
    for i in range(num_tokens):
        logits = model(tokens).logits[0, -1] # get the scores of the final token [shape: (n_vocab)]
        next_token = logits.argmax(keepdim=True) # pick the largest one
        tokens = torch.cat([ tokens, next_token[None] ], dim=1) # concatenate to the current text
        if verbose:
            print(decode(tokens[0]))
    return decode(tokens[0])

generate(prompt, num_tokens=20, verbose=True);
the cat sat on a mat,
the cat sat on a mat, and
the cat sat on a mat, and the
the cat sat on a mat, and the cat
the cat sat on a mat, and the cat sat
the cat sat on a mat, and the cat sat on
the cat sat on a mat, and the cat sat on a
the cat sat on a mat, and the cat sat on a mat
the cat sat on a mat, and the cat sat on a mat,
the cat sat on a mat, and the cat sat on a mat, and
the cat sat on a mat, and the cat sat on a mat, and the
the cat sat on a mat, and the cat sat on a mat, and the cat
the cat sat on a mat, and the cat sat on a mat, and the cat sat
the cat sat on a mat, and the cat sat on a mat, and the cat sat on
the cat sat on a mat, and the cat sat on a mat, and the cat sat on a
the cat sat on a mat, and the cat sat on a mat, and the cat sat on a mat
the cat sat on a mat, and the cat sat on a mat, and the cat sat on a mat,
the cat sat on a mat, and the cat sat on a mat, and the cat sat on a mat, and
the cat sat on a mat, and the cat sat on a mat, and the cat sat on a mat, and the
the cat sat on a mat, and the cat sat on a mat, and the cat sat on a mat, and the cat

Sampling Probability Distribution

But just picking the most likely can give quite bland output

This takes the model output (which can be any number) and create a probability distribution such that all the scores add up to 1.

To do this we use the softmax function.

def generate(prompt, num_tokens, verbose=False, seed=42): # add a seed to keep the output deterministic. Try other seeds!
    torch.manual_seed(seed)
    tokens = tokenize(prompt)
    for i in range(num_tokens):
        logits = model(tokens).logits[0, -1]
        ### New lines
        probs = F.softmax(logits, dim=-1) # create probability distribution of scores
        next_token = torch.multinomial(probs, 1) # pick a single token from distribution
        ###
        tokens = torch.cat([ tokens, next_token[None] ], dim=1)
        if verbose:
            print(decode(tokens[0]))
    return decode(tokens[0])

generate(prompt, num_tokens=20, verbose=True);
the cat sat on a mat and
the cat sat on a mat and did
the cat sat on a mat and did something
the cat sat on a mat and did something which
the cat sat on a mat and did something which,
the cat sat on a mat and did something which, oddly
the cat sat on a mat and did something which, oddly enough
the cat sat on a mat and did something which, oddly enough,
the cat sat on a mat and did something which, oddly enough, most
the cat sat on a mat and did something which, oddly enough, most ordinary
the cat sat on a mat and did something which, oddly enough, most ordinary folk
the cat sat on a mat and did something which, oddly enough, most ordinary folk never
the cat sat on a mat and did something which, oddly enough, most ordinary folk never do
the cat sat on a mat and did something which, oddly enough, most ordinary folk never do!
the cat sat on a mat and did something which, oddly enough, most ordinary folk never do! )
the cat sat on a mat and did something which, oddly enough, most ordinary folk never do! ) It
the cat sat on a mat and did something which, oddly enough, most ordinary folk never do! ) It is
the cat sat on a mat and did something which, oddly enough, most ordinary folk never do! ) It is 137
the cat sat on a mat and did something which, oddly enough, most ordinary folk never do! ) It is 137E
the cat sat on a mat and did something which, oddly enough, most ordinary folk never do! ) It is 137EUNE

This already gives a much more interesting output! But perhaps we want to control

Now how can we control how much of the distribution we sample.

Temperature

Temperature controls how the distribution is sampled. It’s best shown in the context of the examples above

  • Temperature 0: Completely flattens the distrubution, all probability is given to the token with the largest score
  • Temperature 1: Standard softmax distrubution, same as sampling above

By increasing the temperature, we increase the chance of a token with a lower probability getting picked.

def generate(prompt, num_tokens, temperature=1.0, seed=42):
    torch.manual_seed(seed)
    tokens = tokenize(prompt)
    temperature = max(temperature, 1e-8) # temperature 0 => divide by _very small_ constant
    for i in range(num_tokens):
        logits = model(tokens).logits[0, -1]
        probs = F.softmax(logits / temperature, dim=-1) # divide scores, flattening distribution
        next_token = torch.multinomial(probs, 1)
        tokens = torch.cat([ tokens, next_token[None] ], dim=1)
    return decode(tokens[0])

for temp in torch.arange(0, 2.2, 0.2):
    print(f"\n### {temp.item():.1f} ###")
    print(generate(prompt, num_tokens=20, temperature=temp))

### 0.0 ###
the cat sat on a mat, and the cat sat on a mat, and the cat sat on a mat, and the cat

### 0.2 ###
the cat sat on a mat and she was eating a bowl of rice. The cat was so hungry that she had to be fed

### 0.4 ###
the cat sat on a mat and she was screaming.

"I don't know what to do," she said.


### 0.6 ###
the cat sat on a mat and she looked at me with a smile and said, "by the way, I'm going home

### 0.8 ###
the cat sat on a mat and did something which was oddly fitting. It did not sit well!

I think I tried

### 1.0 ###
the cat sat on a mat and did something which, oddly enough, most ordinary folk never do! ) It is 137EUNE

### 1.2 ###
the cat sat on a mat and did something which none of us appreciated working for me -by Roy Cairity 13717 walks

### 1.4 ###
the cat sat on a mat and did something predictably right Floracles understandably squirmed excited -by Roy Collins 403 paths 137 moves 15

### 1.6 ###
the cat sat on a mat and did Sunshade Floracles instead game day morning -by Roy Collins 403 paths 137 moves dealt

### 1.8 ###
the cat sat on a mat and did Sunshade Floracles alternating physics workwitch syllby Roy Collins 403 paths 137 moves dealt

### 2.0 ###
the cat sat on a mat token (~nil predictably right Poké GraphPlex physics proofwitch botby lived broadcast 403 paths 137 moves dealt

As the temperature increases, less likely tokens are predicted, which can lead to more interesting output. Setting the temperature hyperparameter correctly can be key to model performance.

Top K

Another parameter used in sampling is top_k. This essentially limits the model predicting too “wild” predictions by limiting the probability distribution to the top k results.

A.k.a currently we are sampling from the entire distribution of 50,000 tokens. But it makes sense that only the top 50 tokens are reasonable continuations

def generate(prompt, num_tokens, temperature=1.0, top_k=50, seed=42):
    torch.manual_seed(seed)
    tokens = tokenize(prompt)
    temperature = max(temperature, 1e-8)
    for i in range(num_tokens):
        logits = model(tokens).logits[0, -1]
        if top_k:
            logits, idxs = logits.topk(top_k) # Sample only topk tokens
        else:
            idxs = torch.arange(len(logits), device=device) # All idxs
    
        probs = F.softmax(logits / temperature, dim=-1)
        next_token = idxs[torch.multinomial(probs, 1)] # we use the idxs of topk only
        tokens = torch.cat([ tokens, next_token[None] ], dim=1)

        tokens = torch.cat([ tokens, next_token[None] ], dim=1)
    return decode(tokens[0])

for temp in torch.arange(0, 2.2, 0.2):
    print(f"\n### Temperature {temp.item():.1f} ###")
    print(generate(prompt, num_tokens=20, temperature=temp))

### Temperature 0.0 ###
the cat sat on a mat, and the cat sat on a mat, and the cat sat on a mat, and the cat

### Temperature 0.2 ###
the cat sat on a mat, and the cat was sitting on the mat.

"I thought, 'Oh, I

### Temperature 0.4 ###
the cat sat on a mat.)

"This is what you have to do," said the boy. "You have to

### Temperature 0.6 ###
the cat sat on a mat.)

"They're just trying to get him to stop," says one visitor. "I

### Temperature 0.8 ###
the cat sat on a mat.)

"Yeah, you know, I thought it was a nice day last night (laughs

### Temperature 1.0 ###
the cat sat on a mat.)

If that's all it takes, well, that's how I love it that you

### Temperature 1.2 ###
the cat sat on a mat.)

If that's all it takes, well, that's how I love him the way

### Temperature 1.4 ###
the cat sat on a mat.)

If such a man had had one word to say to me about love of Christ:

### Temperature 1.6 ###
the cat sat on a mat.)

If such a man had had one word to say to me about love of Christ:

### Temperature 1.8 ###
the cat sat on a mat.)

If such a man had had one word to give to this matter his friend must be

### Temperature 2.0 ###
the cat sat on a mat.)

If such a man had called us one day while we was still being called names in

You can see at even very high temperatures, the output does not devolve into gibberish.

Min P

Top K can often be a to naive heuristic for sampling. A more common technique nowdays is to instead dispose of tokens that have too low probability.

We do this by computing the fraction of the of the probability of a token compared to the most probable token.

A.k.a If the most probable token has 60% proability and we have min_p = 0.1, we dispose of all tokens with a probability less than 6%.

def generate(
    prompt,
    num_tokens,
    temperature=1.0,
    top_k=None,
    min_p=None,
    seed=42
):
    torch.manual_seed(seed)
    tokens = tokenize(prompt)
    temperature = max(temperature, 1e-8)
    for i in range(num_tokens):
        logits = model(tokens).logits[0, -1]
        if top_k:
            logits, idxs = logits.topk(top_k)
        else:
            idxs = torch.arange(len(logits), device=device)

        # TODO: temperature before or after min_p?
        probs = F.softmax(logits / temperature, dim=-1)

        if min_p is not None:
            mask = probs >= (probs.max() * min_p) 
            idxs, probs = idxs[mask], probs[mask]

        next_token = idxs[torch.multinomial(probs, 1)]
        tokens = torch.cat([ tokens, next_token[None] ], dim=1)
    return decode(tokens[0])

for min_p in reversed(torch.logspace(start=math.log10(0.01), end=math.log10(0.5), steps=10, base=10)):
    print(f"\n### Min P: {min_p.item():.2f} ###")
    print(generate(prompt, num_tokens=20, temperature=1.5, min_p=min_p))

### Min P: 0.50 ###
the cat sat on a mat and the cat sat on a bench. The cat was sitting on a bench.

"You

### Min P: 0.32 ###
the cat sat on a mat with a large bag of water in it, and she had been in the water for a long time

### Min P: 0.21 ###
the cat sat on a mat with a small tray on the top, while the dog stood up, looking up at the cat and

### Min P: 0.14 ###
the cat sat on a mat that had been planted over a tall wall and its claws were stuck to the floor. She tried to

### Min P: 0.09 ###
the cat sat on a mat and told him to hold it down. But he shook his head, "Don't get involved,

### Min P: 0.06 ###
the cat sat on a mat and my cock was about a foot up from her anus. She walked back to my room and let

### Min P: 0.04 ###
the cat sat on a mat in a hospital bed beside a bed-covered seat."

"It's such a strong feeling

### Min P: 0.02 ###
the cat sat on a mat

Until God said, "Maybe it's best you don't be dead before you let the

### Min P: 0.02 ###
the cat sat on a mat: more reason to dig!

Joaquim Jimenez joined the rest of us over a

### Min P: 0.01 ###
the cat sat on a mat: I trust our ability to keep an eye on my kid and claim her your autograph this late

Frequency Penalty

As we’ve seen at low temperatures, the model has a tendancy to repeat itself. For this we can apply a frequency penalty to discourage the model from predicting the same token again.

higher frequency -> higher penalty. If token not in sequence, count will be 0 and no penalty applied

def generate(
    prompt,
    num_tokens,
    temperature=1.0,
    top_k=None,
    min_p=None,
    frequency_penalty=None,
    seed=42,
):
    torch.manual_seed(seed)
    tokens = tokenize(prompt)
    temperature = max(temperature, 1e-8)
    for i in range(num_tokens):
        logits = model(tokens).logits[0, -1]

        if frequency_penalty:
            *_, vocab_size = logits.shape
            # get frequency of each of the logits in the current output
            id_freqs = torch.bincount(tokens[0], minlength=vocab_size)
            logits -= frequency_penalty * id_freqs

        if top_k:
            logits, idxs = logits.topk(top_k)
        else:
            idxs = torch.arange(len(logits), device=device)

        # TODO: temperature before or after min_p?
        probs = F.softmax(logits / temperature, dim=-1)

        if min_p is not None:
            mask = probs >= (probs.max() * min_p) 
            idxs, probs = idxs[mask], probs[mask]

        next_token = idxs[torch.multinomial(probs, 1)]
        tokens = torch.cat([ tokens, next_token[None] ], dim=1)
    return decode(tokens[0])

for freq_penalty in torch.linspace(start=0, end=1., steps=6):
    print(f"\n### Frequency Penalty {freq_penalty.item():.1f} ###")
    print(generate(prompt, num_tokens=20, temperature=0., frequency_penalty=freq_penalty))

### Frequency Penalty 0.0 ###
the cat sat on a mat, and the cat sat on a mat, and the cat sat on a mat, and the cat

### Frequency Penalty 0.2 ###
the cat sat on a mat, and the cat sat on a chair.

"I'm not going to lie, I

### Frequency Penalty 0.4 ###
the cat sat on a mat, and the cat was sitting on a chair.

"I'm not sure what you're

### Frequency Penalty 0.6 ###
the cat sat on a mat, and the cat was sitting on a chair.

"I'm not sure what you're

### Frequency Penalty 0.8 ###
the cat sat on a mat, and the dog was sitting on a chair.

"I'm not sure what happened to

### Frequency Penalty 1.0 ###
the cat sat on a mat, and the dog was sitting on a chair.

"I'm not sure what happened to

It no longer repeats itself continuously.

KV Cache

When running our autoregressive generate function, we currently recalculate the logit outputs of every previous token in the sequence, before discarding with

logits = model(tokens).logits[0, -1]