📚 Notes: How does sequence generation works in Fairseq/Espresso?

Published on Jan. 14, 2021, last edited on Jan. 14, 2021

Fairseq is a “sequence sequence modeling toolkit written in PyTorch that allows researchers and developers to train custom models for translation, summarization, language modeling and other text generation tasks”. I will try to explain how fairesq generates sequences with a Language Model while using a beam search algorithm.

Important: I describe here a simplified version of my understanding of the sequence generation algorithm. For a fully detailed understanding, you must dive into the code.

Introduction

I’m going to explain the SequenceGenerator’s function generate(…) from Espresso, a toolkit based on Fairseq. Reference commit used: 6d29f790c24fbad336fee69db98ee0b34e5cd9b6 on 23 February 2020). The SequenceGenerator from Espresso is almost identical to the original one on Fairseq.

The SequenceGenerator is used to generate outputs from a batch of inputs using models. We can see models as a list composed of one or more models :

- A transcription model with an encoder part and a decoder system
- A language model : only the decoder part

The SequenceGenerator’s generate function is “only”:

def generate(self, models, sample, **kwargs):
    """Generate a batch of translations.

    Args:
        models (List[~fairseq.models.FairseqModel]): ensemble of models
        sample (dict): batch
        prefix_tokens (torch.LongTensor, optional): force decoder to begin
            with these tokens
        bos_token (int, optional): beginning of sentence token
            (default: self.eos)
    """
    lm_weight = kwargs.get('lm_weight', 0.0)
    if lm_weight == 0.0:
      	# EnsembleModel: fusion all the models outputs together. Note: we can have only one model here
        # => probs are averaged
        model = EnsembleModel(models)
    else:
      	# LMFusionModel: fusion models[0] (ASR model) with models[1] (LM model)
        # => probs are averaged, with a lm_weight ponderation for the LM
        model = LMFusionModel(models, lm_weight)
    return self._generate(model, sample, **kwargs)

So, how does the _generate(model, sample, **kwargs) function works ? Well first, we’ll start by looking at the arguments:

def _generate(
    self,
    model,  # the model (EnsembleModel or LMFusionModel)
            # will be used to call encoder_forward and decoder_forward
    sample, # a dict containing at least "net_input". for example:
            # {'net_input': {
            #   'src_tokens': tensor([[35, 68]]), # tensor with ids of tokens "already" emitted
            #   'src_lengths': tensor([2])        # the current length of this tensor
            # }}
    prefix_tokens=None, # can replace the src_tokens tensor in sample 
    bos_token=None,
    **kwargs
):

If your model have an encoder, it will first compute it’s output with this line. If you don’t have a encoder, encoder_outs will be None.

encoder_outs = model.forward_encoder(encoder_input)

Decoding algorithm

Then, the decoding step will use a beam search algorithm to find best sequences. For this example, we set beam_size = 5. That means that, at each decoding step, only the 5 most likely sequence hypotheses will be kept.

Here is what happens for each decoding step :

1/ Forward

Log-Probabilities for each output-dictionary tokens are obtained with:

lprobs, avg_attn_scores = model.forward_decoder(
	tokens[:, :step + 1], encoder_outs, temperature=self.temperature,
)

lprobs[lprobs != lprobs] = -math.inf    # replace NaNs by -∞
lprobs[:, self.pad] = -math.inf         # never select pad
lprobs[:, self.unk] -= self.unk_penalty # apply unk penalty

Here, tokens[:, :step + 1] is a 2D-tensor of size (beam_size, step + 1) containing tokens ids for each beam that have been predicted so far. The returned lprobs is a new 2D-tensor of size (beam_size, output_vocabulary_size). For each output token in each beam hypothesis we have a log probability.

We then need to run our beam search algorithm to select only the top hypotheses, but before that, we will handle the prefix tokens.

2/ Prefix tokens

When we ask our model(s) to generate some output, we might pass some prefix tokens to it: we constraint our model to first output those tokens. If our model is only a language model, here’s what that would mean:

Let’s say our prefix string is “I would like”. Then, we convert this sentence to a tensor using indices from the output dictionary : we now have torch.tensor([5038, 8203, 2830]). Our system must now predict outputs that start always starts with 5038, 8203, and 2830.

Now, we are for exemple at time step = 0. How do we make sure our model forces the outputs of all beams to be 5038 ?

# handle prefix tokens (possibly with different lengths)
if prefix_tokens is not None and step < prefix_tokens.size(1) and step < max_len:
    prefix_toks = prefix_tokens[:, step].unsqueeze(-1).repeat(1, beam_size).view(-1)
    prefix_lprobs = lprobs.gather(-1, prefix_toks.unsqueeze(-1))
    prefix_mask = prefix_toks.ne(self.pad)
    lprobs[prefix_mask] = -math.inf
    lprobs[prefix_mask] = lprobs[prefix_mask].scatter_(
        -1, prefix_toks[prefix_mask].unsqueeze(-1), prefix_lprobs[prefix_mask]
    )

    if prefix_toks.eq(self.eos).any():
        # the prefix includes the EOS tag, not explained here for the sake of simplicity 

Basically, what this code does is simple: it puts $-\infty$ everywhere in the lprobs tensor, except for the current token, where it keeps the already computed probability.

Now, it’s time to apply the Beam Search algorithm. Here is how it is defined:

class BeamSearch(Search):
    def __init__(self, tgt_dict):
        super().__init__(tgt_dict)

    @torch.jit.export
    def step(self, step: int, lprobs, scores):
        self._init_buffers(lprobs)
        bsz, beam_size, vocab_size = lprobs.size()

        if step == 0:
            # at the first step all hypotheses are equally likely, so use
            # only the first beam
            lprobs = lprobs[:, ::beam_size, :].contiguous()
        else:
            # make probs contain cumulative scores for each hypothesis
            lprobs.add_(scores[:, :, step - 1].unsqueeze(-1))

        top_prediction = torch.topk(
            lprobs.view(bsz, -1), # convert to j * vocab_size + i representation (see below)
            k=min(
                # Take the best 2 x beam_size predictions. We'll choose the first
                # beam_size of these which don't predict eos to continue with.
                beam_size * 2,
                lprobs.view(bsz, -1).size(1) - 1,  # -1 so we never select pad
            ),
        )
        self.scores_buf = top_prediction[0]
        self.indices_buf = top_prediction[1]
        
        # Project back into relative indices and beams
        # before : token i in hypothesis j would have been represented as j * vocab_size + i
        # now : simply i, with an additional j value in beams_buf
        self.beams_buf = self.indices_buf // vocab_size
        self.indices_buf = indices_buf.fmod(vocab_size)
        
        # At this point, beams_buf and indices_buf are single-dim and contain relative indices
        return self.scores_buf, self.indices_buf, self.beams_buf

We use it by giving log-probabilities for each output tokens for each of the $k$ hypotheses. We also give scores for each of these hypotheses:

cand_scores, cand_indices, cand_beams = self.search.step(
  step,
  lprobs.view(bsz, -1, self.vocab_size), # (bsz, beam_size, vocab_size): we have a log-prob
                                         # for each token in each hypothesis
  scores.view(bsz, beam_size, -1)[:, :, :step], # scores for each hyps at each step
)

Assuming self.search is a BeamSearch object, cand_indices will contain the 2 * top-$k$ tokens ids associated with cand_beams hypotheses. For example :

>>> cand_indices
tensor([[45591, 54801, 15738, 60750, 62425, 60750, 16349, 22326, 52779, 53156]])
>>> cand_beams
tensor([[0, 1, 0, 3, 2, 0, 0, 0, 4, 1]])

4/ Update our tokens

Now that we have the most probable tokens inside cand_indices, associated with beams hypotheses, we can update our final token tensor so that:

# active_hypos is a tensor containing ids of hypotheses that are still active
torch.gather(
  cand_indices, dim=1, index=active_hypos,
  out=tokens_buf.view(bsz, beam_size, -1)[:, :, step + 1],
)