Treinamento Distribuído

O que significa “Treinamento Distribuído (Distributed Training)”

Treinamento distribuído é a prática de treinar um modelo de aprendizado de máquina (machine learning) em vários dispositivos de computação (compute devices) (GPUs/TPUs) e, frequentemente, em várias máquinas (nós (nodes)). O objetivo é reduzir o tempo de treinamento em tempo de relógio (wall-clock), viabilizar o treinamento de modelos maiores do que cabem em um único dispositivo e melhorar a taxa de processamento (throughput).

Em sua essência, o treinamento distribuído é um problema de coordenação:

  • Cada trabalhador (worker) computa algo (geralmente gradientes (gradients)).
  • Os trabalhadores precisam trocar informações (parâmetros (parameters), gradientes, ativações (activations), estado do otimizador (optimizer state)).
  • O sistema precisa fazer isso de forma eficiente e correta, lidando com limitações de rede e variabilidade na velocidade de computação.

O treinamento distribuído fica ao lado de estratégias como Paralelismo de Dados/Modelo/Pipeline (Data/Model/Pipeline Parallelism) e depende fortemente de boas práticas de engenharia em Carregamento de Dados e Pipelines de Entrada (Data Loading & Input Pipelines), Precisão Mista (Mixed Precision) e Checkpointing e Tolerância a Falhas (Checkpointing & Fault Tolerance).

Por que a comunicação domina em escala

Para a maior parte do treinamento em aprendizado profundo (deep learning), cada passo realiza:

  1. Passagem direta (forward pass): computar ativações
  2. Passagem reversa (backward pass): computar gradientes
  3. Atualização do otimizador (optimizer update): aplicar gradientes aos parâmetros

Em um cenário de processo único, os passos (1–3) são todos limitados por computação e memória. Em cenários distribuídos, os passos (2–3) frequentemente exigem comunicação, o que pode se tornar o gargalo.

Um modelo mental útil é: o treinamento distribuído é limitado por quão rápido os trabalhadores conseguem concordar sobre as atualizações.

Métricas-chave

  • Taxa de processamento: amostras/segundo (ou tokens/segundo).
  • Escalonamento (aceleração) (scaling (speedup)):
    [ \text{speedup}(N) = \frac{T(1)}{T(N)} ]
  • Eficiência de escalonamento (scaling efficiency): [ \text{efficiency}(N) = \frac{\text{speedup}(N)}{N} ]
  • Escalonamento forte (strong scaling): tamanho total do lote fixo; mais dispositivos reduzem o tempo por passo (mais difícil de escalar bem).
  • Escalonamento fraco (weak scaling): tamanho do lote por dispositivo fixo; o tamanho total do lote cresce com os dispositivos (frequentemente mais fácil para taxa de processamento, mas afeta otimização/generalização).

O desafio prático: a computação por passo tipicamente cresce com o tamanho do modelo, mas a comunicação por passo frequentemente cresce com o tamanho dos parâmetros e o número de dispositivos. Em algum ponto, a rede (largura de banda/latência) limita o desempenho.

Padrões de comunicação no treinamento distribuído

Sistemas distribuídos expõem um pequeno conjunto de primitivas de comunicação coletiva (collective communication primitives). A maioria das estratégias de treinamento é construída a partir delas:

  • Difusão (broadcast): um trabalhador envia um tensor para todos os outros (ex.: parâmetros iniciais).
  • Redução (reduce): agregar valores de todos os trabalhadores em um (ex.: somar gradientes).
  • Redução total (all-reduce): redução + distribuição do resultado de volta para todos os trabalhadores (o carro-chefe do treinamento com paralelismo de dados).
  • Coleta total (all-gather): cada trabalhador contribui com um pedaço; todos recebem a concatenação (comum em configurações fragmentadas).
  • Redução-dispersão (reduce-scatter): reduzir e dispersar partições do resultado para os trabalhadores (frequentemente pareada com coleta total).

Comunicação síncrona vs assíncrona

Treinamento síncrono (synchronous training) (mais comum hoje):

  • Cada passo usa o mesmo estado global do modelo.
  • Os trabalhadores calculam gradientes, depois sincronizam (ex.: redução total), e então atualizam.
  • Quase determinístico e estável.

Treinamento assíncrono (asynchronous training) (menos comum para aprendizado profundo em larga escala hoje):

  • Os trabalhadores prosseguem de forma independente e fazem push/pull de atualizações (frequentemente via um servidor de parâmetros (parameter server)).
  • Pode melhorar a utilização do hardware, mas introduz gradientes defasados (stale gradients), o que pode prejudicar a convergência e exigir ajuste cuidadoso.

A maior parte do treinamento moderno de modelos grandes usa coletivas síncronas devido a melhor comportamento de convergência e semântica mais simples.

Redução Total: A primitiva central do paralelismo de dados

No paralelismo de dados (data parallelism) clássico, cada trabalhador tem uma cópia completa do modelo e processa um minilote (mini-batch) diferente. Cada um computa gradientes (g_i). Para se comportar como um treinamento em um único dispositivo no lote combinado, os trabalhadores precisam do gradiente médio:

[ g = \frac{1}{N}\sum_{i=1}^N g_i ]

Isso é exatamente o que a redução total fornece (tipicamente soma), seguido de dividir por (N).

Por que a redução total é preferida

  • Sem gargalo central (ao contrário de um servidor de parâmetros ingênuo).
  • Escala bem em interconexões (interconnects) de alta largura de banda.
  • Amplamente otimizada em bibliotecas como NCCL (NVIDIA) e implementações de MPI.

Algoritmos comuns de redução total (nível de intuição)

Duas grandes famílias são amplamente usadas:

Redução total em anel (ring all-reduce)

  • Cada GPU envia/recebe blocos de/para vizinhos em um anel.
  • A comunicação é ótima em termos de largura de banda para tensores grandes.
  • O tempo escala aproximadamente com: [ \approx 2\cdot\frac{N-1}{N}\cdot\frac{\text{tensor_bytes}}{\text{bandwidth}} ]
  • A latência cresce linearmente com (N) (mais etapas ao redor do anel).

Redução total baseada em árvore (tree-based all-reduce) (ex.: árvore binária)

  • Reduz subindo uma árvore e depois difunde descendo.
  • Menos etapas de comunicação (melhor escalonamento de latência).
  • Pode ser menos ótima em largura de banda para mensagens muito grandes, dependendo de implementação/topologia.

Na prática, as bibliotecas escolhem algoritmos dinamicamente com base no tamanho do tensor, no número de ranks (ranks) e na topologia.

Exemplo prático: PyTorch DistributedDataParallel (DDP)

O DistributedDataParallel do PyTorch realiza redução total de gradientes (em buckets) durante a retropropagação (backprop).

import os
import torch
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP

def main():
    dist.init_process_group(backend="nccl")
    local_rank = int(os.environ["LOCAL_RANK"])
    torch.cuda.set_device(local_rank)

    model = MyModel().cuda()
    model = DDP(model, device_ids=[local_rank])

    opt = torch.optim.AdamW(model.parameters(), lr=3e-4)
    for batch in loader:  # use DistributedSampler in practice
        x, y = batch[0].cuda(), batch[1].cuda()
        loss = model(x, y)
        opt.zero_grad(set_to_none=True)
        loss.backward()     # triggers gradient all-reduce
        opt.step()

if __name__ == "__main__":
    main()

Um comando típico de execução:

torchrun --nproc_per_node=8 train.py

Detalhes do DDP que importam para desempenho:

  • Bucketização de gradientes (gradient bucketing): muitos tensores pequenos são agrupados para reduzir overhead por chamada.
  • Sobreposição (overlap): a comunicação de gradientes pode se sobrepor com partes posteriores da retropropagação.
  • Uma redução total por bucket: reduz o custo de latência (menos coletivas).

Topologia de comunicação: intra-nó vs inter-nó

O desempenho do treinamento distribuído depende fortemente da interconexão:

  • Dentro de um nó: GPUs podem se conectar via PCIe, NVLink/NVSwitch (largura de banda muito alta).
  • Entre nós: comumente InfiniBand (RDMA) ou Ethernet de alta velocidade (às vezes com RoCE).

O elo mais lento no caminho de comunicação tende a dominar. Um padrão frequente:

  • 8 GPUs dentro de um nó sincronizam rapidamente
  • A sincronização entre nós é muito mais lenta e se torna o gargalo em escala

As bibliotecas frequentemente usam redução total hierárquica (hierarchical all-reduce):

  1. reduzir dentro do nó (rápido)
  2. reduzir entre nós (lento)
  3. difundir de volta dentro do nó

Isso reduz o tráfego entre nós e pode melhorar a escalabilidade.

Gargalos de escalonamento: onde o desempenho vai para morrer

1) Limites de largura de banda (tensores grandes)

A redução total frequentemente é limitada por largura de banda (bandwidth-bound) para tensores grandes de gradiente. Se seu modelo tem (P) parâmetros e você faz redução total de gradientes em precisão total a cada passo:

  • Gradientes FP32: ~4 bytes/parâmetro
  • Gradientes BF16/FP16: ~2 bytes/parâmetro (se comunicados em precisão reduzida)

Exemplo: um modelo com 1B de parâmetros tem ~4 GB de gradientes em FP32. Mesmo com uma excelente interconexão, mover gigabytes por passo pode dominar o tempo.

Mitigações:

  • Comunicar em BF16/FP16 quando for seguro (frequentemente combinado com Precisão Mista).
  • Aumentar a sobreposição computação/comunicação (dimensionamento de buckets, agendamento).
  • Usar estratégias de fragmentação (sharding) (veja abaixo) para não fazer redução total de tudo em todo lugar.

2) Limites de latência (muitos tensores pequenos)

Coletivas pequenas sofrem com latência de disparo (launch latency) e overhead por mensagem. Mesmo com alta largura de banda, muitas reduções totais minúsculas podem ser lentas.

Mitigações:

  • Agrupar gradientes em tensores maiores (o DDP faz isso).
  • Fundir operações (dependente de framework e biblioteca).
  • Evitar pontos de sincronização muito granulares.

3) Trabalhadores atrasados e travamentos por sincronização

O treinamento síncrono avança no ritmo do trabalhador mais lento. Lentidões podem vir de:

  • carregamento de dados desigual
  • gargalos de CPU
  • limitação térmica (thermal throttling)
  • jitter do SO (OS jitter) / processos em segundo plano
  • congestionamento de rede
  • diferentes escolhas de kernel (kernel) devido a formas dinâmicas

Mitigações:

  • Garantir pipelines de entrada robustos (Carregamento de Dados e Pipelines de Entrada).
  • Fixar (pin) threads de CPU, ajustar workers do dataloader, fazer prefetch.
  • Usar formas de lote uniformes (ou bucketização cuidadosa para sequências de comprimento variável).
  • Monitorar tempo de passo por rank e contadores de rede.

4) Baixa sobreposição computação/comunicação

O ideal é esconder a comunicação sob a computação da retropropagação. Na prática, a sobreposição pode ser limitada por:

  • buckets pequenos demais
  • dependências entre camadas
  • restrições de agendamento do framework
  • retropropagação muito rápida (modelos pequenos) em que a comunicação não pode ser escondida

Mitigações:

  • Ajustar tamanhos de bucket (o PyTorch expõe alguns controles).
  • Usar tamanhos de lote maiores por GPU (mais computação por passo).
  • Preferir kernels fundidos (fused kernels) / atenção otimizada para aumentar a intensidade de computação.

5) Estado do otimizador e pressão de memória (levando a comunicação extra)

Otimizadores como Adam/AdamW mantêm estado adicional (ex.: momento/variância), frequentemente 2× o tamanho dos parâmetros. Em muitas configurações distribuídas, isso aumenta:

  • pegada de memória (limita tamanho do lote)
  • volume de comunicação (se os estados do otimizador forem fragmentados ou sincronizados)

Mitigações:

  • Otimizadores fragmentados (sharded optimizers) / particionamento no estilo ZeRO (ZeRO-style) (discutido abaixo).
  • Usar otimizadores com estado menor quando apropriado (dependente do problema).
  • Checkpointing de ativações (activation checkpointing) (não é uma correção de comunicação, mas ajuda a comportar lotes/modelos maiores).

6) Escalonamento do tamanho do lote e problemas de convergência (um “gargalo algorítmico”)

Mesmo que você consiga escalar o hardware perfeitamente, aumentar o tamanho total do lote pode mudar o comportamento de otimização. Treinamento com lotes grandes pode:

  • exigir mudanças em cronogramas de taxa de aprendizado (learning rate) e warmup
  • reduzir ruído do gradiente (às vezes melhorando, às vezes piorando a generalização)
  • atingir um “tamanho de lote crítico” além do qual mais lote traz retornos decrescentes

Isso é um gargalo de escalonamento porque você pode ser forçado a usar acumulação de gradientes (gradient accumulation) (mais passos por atualização) para manter a qualidade — reduzindo a aceleração que você esperava obter.

Contexto relacionado: Descida do Gradiente, Retropropagação.

Além da redução total: outros padrões de comunicação no treinamento moderno

À medida que os modelos cresceram, o paralelismo de dados puro (redução total de gradientes) tornou-se insuficiente. Duas evoluções comuns:

Paralelismo de dados fragmentado (reduce-scatter + all-gather)

Em vez de cada trabalhador manter todos os gradientes ou o estado do otimizador, os trabalhadores mantêm fragmentos (shards). Um padrão comum:

  • redução-dispersão de gradientes: cada trabalhador termina com apenas uma fatia dos gradientes reduzidos
  • cada trabalhador atualiza sua fatia dos parâmetros/estado do otimizador
  • coleta total dos parâmetros atualizados (ou fragmentos necessários) para a próxima passagem direta

Isso reduz a memória por trabalhador e pode reduzir o overhead de comunicação em alguns regimes, ao custo de um agendamento mais complexo.

Essas ideias são amplamente conhecidas por abordagens no estilo “ZeRO” (Zero Redundancy Optimizer) (popularizadas pelo DeepSpeed) e também aparecem em outros frameworks.

Comunicação em paralelismo de modelo e de pipeline

Quando os parâmetros são divididos entre dispositivos (Paralelismo de Dados/Modelo/Pipeline):

  • paralelismo de modelo (model parallelism) exige padrões de coleta total / tudo-para-tudo (all-to-all) para ativações ou resultados parciais
  • paralelismo de pipeline (pipeline parallelism) introduz envio/recebimento (send/recv) de ativações e gradientes entre estágios do pipeline

Esses padrões de comunicação podem ser menos regulares do que a redução total e mais sensíveis à latência e ao overhead de bolha (bubble overhead) (tempo ocioso do pipeline).

Dicas práticas de engenharia de desempenho

Meça antes de otimizar

Perguntas úteis:

  • Você está limitado por largura de banda ou por latência (latency-bound)?
  • A comunicação está se sobrepondo com a computação?
  • Qual coletiva domina (redução total vs coleta total vs tudo-para-tudo)?
  • O tráfego inter-nó é o gargalo?

Ferramentas comumente usadas:

  • Profiler do PyTorch + variáveis de ambiente de debug do NCCL (para insight de comunicação)
  • NVIDIA Nsight Systems para linha do tempo de kernel/comunicação
  • Monitoramento do cluster para utilização de rede e congestionamento

Técnicas que geralmente ajudam

  • Aumentar o trabalho por GPU: micro-lotes (micro-batches) maiores ou acumulação de gradientes (trade-off: memória, convergência).
  • Usar comunicação BF16/FP16 quando estável (frequentemente padrão em stacks modernos).
  • Ajustar tamanhos de bucket para reduzir overhead de coletivas pequenas, mantendo a capacidade de sobreposição.
  • Coletivas hierárquicas entre nós (frequentemente automático, mas uma configuração ciente da topologia importa).
  • Ajustar o pipeline de entrada para que as GPUs não fiquem ociosas esperando dados.

Técnicas com trade-offs

  • Compressão/rarefação de gradientes (gradient compression / sparsification): pode reduzir largura de banda, mas adiciona overhead de computação e pode afetar a convergência.
  • Atualizações assíncronas (asynchronous updates): melhoram utilização, mas introduzem defasagem e frequentemente exigem mais ajuste.
  • Checkpointing de ativações: economiza memória, mas aumenta computação (ainda pode valer a pena no geral se permitir melhor escalonamento).

Considerações de correção

Treinamento distribuído não é apenas “treinamento mais rápido”; ele muda propriedades numéricas e de agendamento:

  • Reduções em ponto flutuante (floating-point reductions) não são perfeitamente associativas, então a ordem de soma pode mudar os resultados levemente.
  • Não determinismo (non-determinism) pode vir de escolhas de kernel, reduções paralelas e timing de rede.
  • Dropout (dropout) e embaralhamento de dados precisam ser tratados com cuidado entre ranks.

Para práticas sobre determinismo e repetibilidade, veja Reprodutibilidade.

Quando o escalonamento deixa de compensar (e o que fazer)

Um resultado comum no mundo real é que adicionar GPUs produz retornos decrescentes. Diagnosticar por quê ajuda a escolher a correção certa:

  • Se o tempo de redução total cresce com a contagem de GPUs: você está limitado por comunicação
    → reduza o volume de comunicação (fragmentação, menor precisão), melhore a interconexão, aumente a computação por passo.
  • Se a variância do tempo de passo cresce: você está limitado por trabalhadores atrasados
    → corrija o pipeline de dados, reduza variabilidade (formas estáticas), inspecione ranks lentos.
  • Se a qualidade de validação piora com lote grande: limite algorítmico
    → ajuste taxa de aprendizado/warmup, considere acumulação de gradientes, ajuste regularização ou aceite escalonamento mais fraco.

Em stacks de treinamento maduras, os melhores resultados frequentemente vêm da combinação de:

  • paralelismo de dados para taxa de processamento,
  • fragmentação para reduzir memória/comunicação,
  • e paralelismo de pipeline/modelo cuidadosamente aplicado quando o modelo não cabe de outra forma.

Resumo

O treinamento distribuído é, fundamentalmente, sobre coordenar computação com comunicação eficiente. A primitiva de comunicação dominante no treinamento com paralelismo de dados é a redução total, mas o treinamento moderno em grande escala depende cada vez mais de redução-dispersão, coleta total e padrões cientes da topologia para evitar gargalos de escalonamento.

Os principais gargalos de escalonamento são:

  • limites de largura de banda (mover tensores enormes),
  • limites de latência (muitas coletivas pequenas),
  • trabalhadores atrasados (travamentos síncronos),
  • sobreposição incompleta computação/comunicação,
  • e limites algorítmicos do treinamento com lotes grandes.

Entender esses gargalos — e os padrões de comunicação por trás deles — é a diferença entre “mais GPUs” e treinamento mais rápido.