Soft (Latent) Decoding
Early musings on a new decoding technique.
1. Abstract
Modern large language models are typically trained and used in a discrete autoregressive manner. At each step, the model projects its hidden state to vocabulary logits, and a single token is either sampled or chosen greedily. This discrete choice collapses the probability distribution at every timestep, potentially discarding latent or “parallel” reasoning paths.
In contrast, a soft (latent) chain-of-thought approach proposes to retain a mixture of possible tokens by computing a continuous next-token embedding. Rather than collapsing to one token, the model’s next input embedding is the distribution over all possible next tokens, weighted by probability using the softmax. This in effect is removing the sampling step from the model.
2. Motivation
2.1 Standard Discrete Decoding
= [ BOS ]
tokens = [ W_E[tokens] ]
embeddings for t in range(max_length):
= model(embeddings)
h = h[-1] @ W_U
logits = softmax(logits)
probabilities = sample(probabilities)
next_token = W_E[next_token] # discrete sampling
next_emb
tokens.append(next_token)
embeddings.append(next_emb)if next_token == EOS:
break
This method is efficient and well-aligned with the training distribution (discrete). However, it collapses the model’s uncertainty at every timestep, keeping only one token thread and discarding all other plausible tokens.
2.2 Soft (Latent) Chain-of-Thought
= [ BOS ]
tokens = [ W_E[tokens] ]
embeddings for t in range(max_length):
= model(embeddings)
h = h[-1] @ W_U
logits = softmax(logits)
probabilities = sample(probabilities)
next_token = proabilities @ W_E.T # use sample distribution
next_emb
tokens.append(next_token)
embeddings.append(next_emb)if next_token == EOS:
break
Instead of using a discrete token, use the softmax distribution to use tokens weighted by the probability.
2.3 Motivations for Soft Decoding
- The model may be allowed to maintain it’s distribution over each token for each step, so the model’s internal state may reflect several distribution simultaneously.
- (Kind of like BEAM search without the combinatorial blowup).
- The model remains “end-to-end differentiable” – gradients from from future tokens can propagate through the sampling.
3 Approaches to Creating a Soft Next-Token Embedding
3.1 Plain Softmax Mixture
Take the model’s logits, apply a standard softmax, then multiply the resulting distribution by the embedding matrix \(W_E\).
- Possible to have “blurry” embedding if distribution is diffuse over many tokens – doesn’t approximate a single token.
3.2. Gumbel-Softmax – LOOK AT
- Paper Reference: Eric Jang et al., “Categorical Reparameterization with Gumbel-Softmax,” ICLR 2017.
- Key Mechanism: Add Gumbel noise \(\mathbf{g}\) to the logits and scale by temperature \(\tau\), then apply softmax:
\[ \mathbf{y}_t = \text{Softmax}\Bigl(\frac{\mathbf{z}_t + \mathbf{g}_t}{\tau}\Bigr). \]
For low \(\tau\), \(\mathbf{y}_t\) becomes close to a one-hot vector—mimicking discrete sampling—yet remains differentiable.
- Helpful for end-to-end gradient flow through “token choices”.
4. Challenges and Caveats
4.1 Out-Of-Distribution Inputs
Pretrained LLMs only saw discrete token embeddings during training. A “soft mix” of embeddings can be drastically different distribution from any single token embedding, causing unpredictable outputs. Finetuning will be needed to adapt the model to these continuous embeddings.
4.2 Sequential Fine-Tuning
Each output is dependent on all previous outputs, which can be useful for gradient propagation, but is compute inefficient in training. Kind of like RNNs.
4.3 Blurry embeddings / Entropy explosion
If the distribution is too broad, embedding may become uniform over many tokens, which may push the model input off manifold.
Possible solutions:
- Sampling adjustments: temperature, top-k, min-p
- Gumbel-Softmax?
- Entropy regularization: \(\mathbf{L} = \mathbf{L}_{\text{xent}} + \beta \cdot \mathbf{H}(P)\)
- \(\beta\) is a hyperparameter
- \(\mathbf{H}\) is the entropy of the distribution \(P\)
4.4 Scheduled Sampling
Not obvious exactly how to incorporate the ground truth into the sampling. Discrete decoding handles this simply with “teacher-forcing”, replacing the predicted embedding with the correct one.
We may need to mix the ground truth with the soft sample to stabilize training.
4.5 Catastrophic Forgetting
- Partial freezing (maybe \(W_E\) / first layer only?)
- LoRA
- low LR
4.6 Integration with RLHF (PPO)
Standard PPO assume discrete tokens as actions.
- If you feed a soft distribution, you must either discretize before giving it to a reward model.
5. Implementation Sketch
5.1 Stage 1: SFT for Soft Decoding
import torch
import torch.nn.functional as F
from fastcore.meta import delegates
def mk_proba_dist(
# (batch_size, d_vocab)
logits, =1.0,
temperature=None,
top_k=None,
min_p
):= logits.shape
batch_size, d_vocab = logits.device
device if top_k:
= logits.topk(top_k, dim=-1)
logits, idxs else:
= (
idxs =device)
torch.arange(d_vocab, device
.repeat(batch_size)
.reshape(batch_size, d_vocab)
)
# TODO: temperature before or after min_p?
= F.softmax(logits / temperature, dim=-1)
probs
if min_p is not None:
= probs.max(dim=-1, keepdim=True).values
max_probs = max_probs * min_p
threshold = probs >= threshold
mask = probs * mask
probs = probs / probs.sum(dim=-1, keepdim=True) # renormalize
probs = idxs * mask
idxs return idxs, probs
@delegates(mk_proba_dist)
def soft_sampling_train_step(
model,# tokens of shape (batch_size, seq_len)
batch, # model's embedding matrix
W_E, # guidance weighting -- 1 equivalent to discrete sampling
guidance_alpha, **kwargs, # passed to mk_proba_dist
):"Single train step using soft sampling"
assert 0 <= guidance_alpha <= 1
= batch.shape
batch_size, seq_len = batch.device
device
# cache
= None
past_key_values = torch.zeros(batch_size, 1, dtype=torch.long, device=device)
position_ids
= torch.tensor(0., device=device)
loss = W_E[batch[:, :1]] # BOS shape: (batch_size, 1, d_model)
embeds = [ batch[:, :1].detach().cpu() ]
tokens for t in range(1, seq_len):
= model(
outputs =embeds,
inputs_embeds=past_key_values,
past_key_values=position_ids,
position_ids=True
use_cache
)
= outputs.logits[:, -1]
logits_t = outputs.past_key_values
past_key_values
= mk_proba_dist(logits_t, **kwargs)
i_t, p_t
# loss
= F.cross_entropy(p_t, batch[:, t])
loss_t += loss_t
loss
# discrete sample -- for logging
= torch.multinomial(p_t, 1) # (batch_size, 1)
indices = torch.arange(batch_size)[:, None] # (batch_size, 1)
batch_indices = i_t[batch_indices, indices].detach().cpu()
next_token
tokens.append(next_token)
# soft sample
= p_t @ W_E # soft sampling
next_emb_soft = W_E[batch[:, t]] # guidance sampling
next_emb_gt
= (
next_embed * next_emb_gt +
guidance_alpha 1 - guidance_alpha) * next_emb_soft
(
)= torch.cat([embeds, next_embed[:, None, :]], dim=1)
embeds += 1
position_ids
if return_tokens:
= torch.cat(tokens, dim=1)
tokens # normalize gradient: sum batch, mean sequence length
/= seq_len
loss return loss, tokens
I think I want a initial guidance_alpha
of 1 to mimic discrete training to warm start the model to get to a stable baseline before we shift the input distribution. Perhaps warmup lr
here too. Then warmup guidance_alpha
to some maximum value.
TODO: how do we want to apply guidance_alpha
? I think we should clamp the correct token at at least guidance_alpha
proportion in the probability distribution (this is helpful when the model already gives the correct token a high probability).
5.2 Stage 2: RLHF / PPO
TODO
6. Additional Considerations
6.1 Comparison with BEAM
Beam Search / Self-Consistency: Another way to keep multiple possibilities is to track multiple discrete beams. However, that can explode combinatorially.
If the soft decoding works, we effectively combine multiple tokens into a single “latent” path at each step.
6.3 Interpretability
We can record how the model’s beliefs evolve by sampling tokens at each step, though it’s not a “true” discrete chain.
6.4 Selectively use soft decoding
Perhaps only for thinking tokens? And switch back to discrete for the final answer.
6.5 Temperature / TopK / MinP as a learnable parameter
As the model is fully differentiable, we could use temperature (and perhaps topk, minp as learnable parameters).
- This might collapse initially where the model greedily samples to keep in old distribution
6.6 Soft Beam Search
Instead of combining every step, maintain a beam of plausible tokens, and for each entry produce a soft mixture.
Or we could do a beam search, and collapse the embeddings back down into a soft embedding.
6.7 Combine with prefix tuning
Specialized prefix of a set of learnable vectors, that signifies the model does soft decoding.
Alternatively put inside a new <think>
token or something.
- Keep multiple reasoning paths alive in a single forward pass.
- Provide a differentiable approach to token selection, which is valuable in advanced training or RL settings.
- Fully differentiable through sampling steps
References
- Gumbel-Softmax
- Jang, E., Gu, S., & Poole, B. (2017). Categorical Reparameterization with Gumbel-Softmax. ICLR.
- arXiv:1611.01144
- Self-Consistency
- Wang, X., et al. (2022). Self-Consistency Improves Chain of Thought Reasoning in Language Models. arXiv:2203.11171
- RLHF
- Christiano, P., et al. (2017). Deep Reinforcement Learning from Human Preferences. NIPS.
- Bai, Y., et al. (2022). Training a Helpful and Harmless Assistant with RL from Human Feedback. (Anthropic work)
- Prefix / Prompt Tuning
- Lester, B., Al-Rfou, R., & Constant, N. (2021). The Power of Scale for Parameter-Efficient Prompt Tuning. EMNLP.
- Li, X. L., & Liang, P. (2021). Prefix-Tuning: Optimizing Continuous Prompts for Generation. ACL.