📚 Notes: How does sequence generation works in Fairseq/Espresso?
Published on Jan. 14, 2021, last edited on Jan. 14, 2021Fairseq is a “sequence sequence modeling toolkit written in PyTorch that allows researchers and developers to train custom models for translation, summarization, language modeling and other text generation tasks”. I will try to explain how fairesq generates sequences with a Language Model while using a beam search algorithm.
Important: I describe here a simplified version of my understanding of the sequence generation algorithm. For a fully detailed understanding, you must dive into the code.
Introduction
I’m going to explain the SequenceGenerator’s function generate(…) from Espresso, a toolkit based on Fairseq. Reference commit used: 6d29f790c24fbad336fee69db98ee0b34e5cd9b6 on 23 February 2020). The SequenceGenerator from Espresso is almost identical to the original one on Fairseq.
The SequenceGenerator is used to generate outputs from a batch of inputs using models. We can see models as a list composed of one or more models :
- A transcription model with an encoder part and a decoder system
- A language model : only the decoder part
The SequenceGenerator’s generate function is “only”:
def generate(self, models, sample, **kwargs):
"""Generate a batch of translations.
Args:
models (List[~fairseq.models.FairseqModel]): ensemble of models
sample (dict): batch
prefix_tokens (torch.LongTensor, optional): force decoder to begin
with these tokens
bos_token (int, optional): beginning of sentence token
(default: self.eos)
"""
lm_weight = kwargs.get('lm_weight', 0.0)
if lm_weight == 0.0:
# EnsembleModel: fusion all the models outputs together. Note: we can have only one model here
# => probs are averaged
model = EnsembleModel(models)
else:
# LMFusionModel: fusion models[0] (ASR model) with models[1] (LM model)
# => probs are averaged, with a lm_weight ponderation for the LM
model = LMFusionModel(models, lm_weight)
return self._generate(model, sample, **kwargs)
So, how does the _generate(model, sample, **kwargs)
function works ? Well first, we’ll start by looking at the arguments:
def _generate(
self,
model, # the model (EnsembleModel or LMFusionModel)
# will be used to call encoder_forward and decoder_forward
sample, # a dict containing at least "net_input". for example:
# {'net_input': {
# 'src_tokens': tensor([[35, 68]]), # tensor with ids of tokens "already" emitted
# 'src_lengths': tensor([2]) # the current length of this tensor
# }}
prefix_tokens=None, # can replace the src_tokens tensor in sample
bos_token=None,
**kwargs
):
If your model have an encoder, it will first compute it’s output with this line. If you don’t have a encoder, encoder_outs will be None
.
encoder_outs = model.forward_encoder(encoder_input)
Decoding algorithm
Then, the decoding step will use a beam search algorithm to find best sequences. For this example, we set beam_size = 5
. That means that, at each decoding step, only the 5 most likely sequence hypotheses will be kept.
Here is what happens for each decoding step
:
1/ Forward
Log-Probabilities for each output-dictionary tokens are obtained with:
lprobs, avg_attn_scores = model.forward_decoder(
tokens[:, :step + 1], encoder_outs, temperature=self.temperature,
)
lprobs[lprobs != lprobs] = -math.inf # replace NaNs by -∞
lprobs[:, self.pad] = -math.inf # never select pad
lprobs[:, self.unk] -= self.unk_penalty # apply unk penalty
Here, tokens[:, :step + 1]
is a 2D-tensor of size (beam_size, step + 1) containing tokens ids for each beam that have been predicted so far. The returned lprobs
is a new 2D-tensor of size (beam_size, output_vocabulary_size). For each output token in each beam hypothesis we have a log probability.
We then need to run our beam search algorithm to select only the top hypotheses, but before that, we will handle the prefix tokens.
2/ Prefix tokens
When we ask our model(s) to generate some output, we might pass some prefix tokens to it: we constraint our model to first output those tokens. If our model is only a language model, here’s what that would mean:
Let’s say our prefix string is “I would like”. Then, we convert this sentence to a tensor using indices from the output dictionary : we now have torch.tensor([5038, 8203, 2830])
. Our system must now predict outputs that start always starts with 5038, 8203, and 2830.
Now, we are for exemple at time step = 0
. How do we make sure our model forces the outputs of all beams to be 5038 ?
# handle prefix tokens (possibly with different lengths)
if prefix_tokens is not None and step < prefix_tokens.size(1) and step < max_len:
prefix_toks = prefix_tokens[:, step].unsqueeze(-1).repeat(1, beam_size).view(-1)
prefix_lprobs = lprobs.gather(-1, prefix_toks.unsqueeze(-1))
prefix_mask = prefix_toks.ne(self.pad)
lprobs[prefix_mask] = -math.inf
lprobs[prefix_mask] = lprobs[prefix_mask].scatter_(
-1, prefix_toks[prefix_mask].unsqueeze(-1), prefix_lprobs[prefix_mask]
)
if prefix_toks.eq(self.eos).any():
# the prefix includes the EOS tag, not explained here for the sake of simplicity
Basically, what this code does is simple: it puts $-\infty$ everywhere in the lprobs
tensor, except for the current token, where it keeps the already computed probability.
3/ Beam Search
Now, it’s time to apply the Beam Search algorithm. Here is how it is defined:
class BeamSearch(Search):
def __init__(self, tgt_dict):
super().__init__(tgt_dict)
@torch.jit.export
def step(self, step: int, lprobs, scores):
self._init_buffers(lprobs)
bsz, beam_size, vocab_size = lprobs.size()
if step == 0:
# at the first step all hypotheses are equally likely, so use
# only the first beam
lprobs = lprobs[:, ::beam_size, :].contiguous()
else:
# make probs contain cumulative scores for each hypothesis
lprobs.add_(scores[:, :, step - 1].unsqueeze(-1))
top_prediction = torch.topk(
lprobs.view(bsz, -1), # convert to j * vocab_size + i representation (see below)
k=min(
# Take the best 2 x beam_size predictions. We'll choose the first
# beam_size of these which don't predict eos to continue with.
beam_size * 2,
lprobs.view(bsz, -1).size(1) - 1, # -1 so we never select pad
),
)
self.scores_buf = top_prediction[0]
self.indices_buf = top_prediction[1]
# Project back into relative indices and beams
# before : token i in hypothesis j would have been represented as j * vocab_size + i
# now : simply i, with an additional j value in beams_buf
self.beams_buf = self.indices_buf // vocab_size
self.indices_buf = indices_buf.fmod(vocab_size)
# At this point, beams_buf and indices_buf are single-dim and contain relative indices
return self.scores_buf, self.indices_buf, self.beams_buf
We use it by giving log-probabilities for each output tokens for each of the $k$ hypotheses. We also give scores for each of these hypotheses:
cand_scores, cand_indices, cand_beams = self.search.step(
step,
lprobs.view(bsz, -1, self.vocab_size), # (bsz, beam_size, vocab_size): we have a log-prob
# for each token in each hypothesis
scores.view(bsz, beam_size, -1)[:, :, :step], # scores for each hyps at each step
)
Assuming self.search
is a BeamSearch object, cand_indices
will contain the 2 * top-$k$ tokens ids associated with cand_beams
hypotheses. For example :
>>> cand_indices
tensor([[45591, 54801, 15738, 60750, 62425, 60750, 16349, 22326, 52779, 53156]])
>>> cand_beams
tensor([[0, 1, 0, 3, 2, 0, 0, 0, 4, 1]])
- This means that the possible following tokens are :
- hypothesis $0$: 45591 ; 15738 ; 60750 ; 16349 ; 22326
- hypothesis $1$: 54801 ; 53156
- hypothesis $2$: 62425
- hypothesis $3$: 60750
- hypothesis $4$: 52779
4/ Update our tokens
Now that we have the most probable tokens inside cand_indices, associated with beams hypotheses, we can update our final token tensor so that:
- The forward in the next step will have these new tokens as the “previous ones”
- We can construct our finals sentences. A hypothesis is ended if:
- the last token is
<eos>
(end-of-sentence) ; and - it hasn’t a score of $-\infty$
- the last token is
# active_hypos is a tensor containing ids of hypotheses that are still active
torch.gather(
cand_indices, dim=1, index=active_hypos,
out=tokens_buf.view(bsz, beam_size, -1)[:, :, step + 1],
)