Checkpointing e Tolerância a Falhas
Visão geral
A criação de pontos de verificação (checkpointing) e a tolerância a falhas (fault tolerance) são práticas de engenharia que permitem que tarefas de treinamento (training jobs) de longa duração salvem progresso, retomem após interrupção e continuem apesar de falhas de hardware/software. Elas são essenciais para o aprendizado profundo (deep learning) moderno porque:
- O treinamento pode levar de horas a semanas.
- Falhas são comuns em escala: resets de unidades de processamento gráfico (GPU), instabilidades na rede, problemas no sistema de arquivos (filesystem), falhas por estouro de memória (out-of-memory), e preempções (preemptions) (por exemplo, instâncias spot (spot instances) ou manutenção agendada do cluster).
- O treinamento distribuído (distributed training) aumenta a “superfície” de falhas: com mais nós e processos, a probabilidade de que algo dê errado cresce.
A criação de pontos de verificação é o mecanismo para persistir o estado do treinamento. A tolerância a falhas é o projeto de sistema mais amplo que garante que a tarefa consiga se recuperar corretamente — às vezes automaticamente — sem perder trabalho demais e sem alterar silenciosamente o resultado do treinamento.
Este artigo foca em treinamento de aprendizado profundo em escala e se conecta a tópicos relacionados como Treinamento Distribuído (Distributed Training), Paralelismo de Dados/Modelo/Pipeline (Data/Model/Pipeline Parallelism), Precisão Mista (Mixed Precision), Carregamento de Dados e Pipelines de Entrada (Data Loading & Input Pipelines) e Reprodutibilidade (Reproducibility).
O que um “ponto de verificação” deve conter
Um ponto de verificação só é útil se capturar estado suficiente para retomar o treinamento como pretendido. Em aprendizado profundo, “pesos do modelo (model weights)” por si só geralmente não bastam.
Estado central de treinamento (normalmente obrigatório)
Parâmetros do modelo
- Os pesos/vieses das suas Redes Neurais (Neural Networks).
Estado do otimizador (optimizer)
- Para Adam/AdamW isso inclui estimativas do primeiro/segundo momento; para SGD com momentum, buffers de momentum; etc.
- Sem isso, retomar pode mudar de forma perceptível a trajetória de treinamento e a convergência.
Estado do agendador de taxa de aprendizado (learning-rate scheduler)
- Muitos agendamentos dependem da contagem de passos (aquecimento (warmup), decaimento cosseno (cosine decay), ciclo único (one-cycle)).
- Se o agendador reiniciar incorretamente, você pode ter um salto súbito na taxa de aprendizado.
Contador de passo global (global step) / iteração
- Uma única fonte de verdade para “onde estamos”.
Frequentemente necessário para correção ou reprodutibilidade
Estado do gerador de números aleatórios (random number generator, RNG)
- Python
random, RNG do NumPy, RNG do framework (geradores CPU/CUDA do PyTorch) e geradores por processo em configurações distribuídas. - Importante para dropout, aumento de dados (data augmentation), amostragem e algumas camadas.
- Python
Estado do carregador de dados (data loader) / amostrador (sampler)
- Se houver embaralhamento, você pode querer retomar no meio da época (epoch) sem repetir nem pular exemplos.
- Em treinamento distribuído, cada worker frequentemente tem seu próprio fragmento (shard) e fluxo de embaralhamento.
Estado de precisão mista
- Se estiver usando precisão mista automática (automatic mixed precision, AMP), salve o estado do escalador de gradientes (gradient scaler) (por exemplo,
GradScalerno PyTorch) para evitar instabilidade ao retomar. - Veja Precisão Mista.
- Se estiver usando precisão mista automática (automatic mixed precision, AMP), salve o estado do escalador de gradientes (gradient scaler) (por exemplo,
Estado de acumulação de gradientes (gradient accumulation)
- Se estiver acumulando gradientes ao longo de vários micro-lotes (microbatches), você pode precisar de:
- índice do passo de acumulação
- gradientes parcialmente acumulados (raramente salvos, mas podem importar se você quiser continuação exata)
- Se estiver acumulando gradientes ao longo de vários micro-lotes (microbatches), você pode precisar de:
Metadados “nice-to-have” (para operação e depuração)
Configuração do experimento
- Hiperparâmetros (hyperparameters), configuração de arquitetura, versão do dataset, hash do commit do código, tag da imagem de contêiner (container image).
- Isso dá suporte à auditabilidade e torna respondível “o que produziu este ponto de verificação?”.
Métricas (metrics) e estado de parada antecipada (early stopping)
- Melhor pontuação de validação até agora, contadores de paciência, etc.
Estratégias de pontos de verificação: trade-offs e padrões
A criação de pontos de verificação é, fundamentalmente, um trade-off entre:
- Overhead: tempo gasto gravando pontos de verificação (e potencialmente pausando o treinamento)
- Custo de recuperação: quanto trabalho é perdido quando ocorre uma falha
- Custo de armazenamento: quanto de armazenamento em disco/objetos você consome
Pontos de verificação completos vs parciais
Ponto de verificação completo: modelo + otimizador + agendador + gerador de números aleatórios + quaisquer extras
- Melhor para retomada real.
- Maior e mais lento para gravar.
Ponto de verificação apenas de pesos (às vezes chamado de “ponto de verificação de inferência (inference checkpoint)”)
- Útil para implantação, avaliação ou pontos de partida para fine-tuning.
- Não é suficiente para retomar o treinamento com fidelidade.
Um padrão comum é gravar:
- pontos de verificação de treinamento frequentes (estado completo) para recuperação de falhas
- snapshots ocasionais apenas do modelo para avaliação e acompanhamento de “melhor modelo”
Com que frequência criar pontos de verificação (teoria e prática)
Se falhas ocorrem aleatoriamente, um intervalo “ótimo” de pontos de verificação equilibra computação desperdiçada devido a falhas vs tempo gasto criando pontos de verificação. Resultados clássicos (fórmula de Young, refinada por Daly) sugerem um intervalo aproximadamente proporcional a:
- maior taxa de falhas ⇒ criar pontos de verificação mais frequentemente
- maior tempo de gravação do ponto de verificação ⇒ criar pontos de verificação com menos frequência
Na prática, equipes normalmente escolhem uma política como:
- a cada N passos (por exemplo, a cada 1000 iterações)
- a cada T minutos (por exemplo, a cada 15 minutos)
- além de em marcos importantes (fim de época, após avaliação)
Orientação prática:
- Para treinamento em nó único, criar pontos de verificação a cada 5–30 minutos é comum.
- Para grandes tarefas distribuídas em hardware preemptível/spot, criar pontos de verificação a cada 2–10 minutos pode se justificar se o I/O de pontos de verificação for eficiente.
Pontos de verificação síncronos vs assíncronos
Síncrono: o treinamento pausa enquanto grava.
- Simples e robusto.
- Pode deixar as GPUs ociosas, reduzindo a taxa de processamento (throughput).
Assíncrono: o treinamento continua enquanto uma linha de execução (thread) ou processo em segundo plano grava.
- Melhor utilização.
- Mais complexo: é preciso garantir um snapshot consistente (pesos/otimizador devem corresponder entre si) e gerenciar overhead de memória.
Muitos sistemas em grande escala usam pontos de verificação assíncronos mais pontos de verificação sincronizados “duros” periódicos.
Pontos de verificação incrementais e diferenciais
Em vez de gravar todo o estado a cada vez, você pode gravar apenas mudanças desde o último ponto de verificação. Isso pode reduzir I/O, mas adiciona complexidade e pode tornar restaurações mais lentas ou mais frágeis. É mais comum em sistemas que gerenciam pontos de verificação na camada de armazenamento ou no nível do framework do que em scripts ad-hoc.
Pontos de verificação fragmentados (sharded checkpoints) (crítico em escala)
Quando modelos e estados do otimizador ficam enormes (vários GB a TB), salvar um único arquivo monolítico torna-se impraticável.
Pontos de verificação fragmentados:
- dividem o estado em múltiplos arquivos (frequentemente um por rank (rank) ou por fragmento de parâmetros)
- reduzem gargalos de “um único escritor”
- permitem I/O em paralelo durante gravação e restauração
Isso é especialmente importante com:
- paralelismo de modelo (model parallelism) e paralelismo de pipeline (pipeline parallelism) (Paralelismo de Dados/Modelo/Pipeline)
- particionamento de estado do otimizador (optimizer state partitioning) (por exemplo, estilo ZeRO (ZeRO-style))
- abordagens totalmente fragmentadas (fully sharded approaches) como FSDP em treinamento distribuído
Tolerância a falhas em treinamento distribuído
No Treinamento Distribuído, falhas ficam mais complexas:
- Um rank cai → a tarefa inteira frequentemente aborta (a menos que um runtime elástico/tolerante a falhas seja usado).
- Pontos de verificação devem representar um estado global consistente entre ranks.
Snapshots consistentes e coordenação
O requisito-chave é que o ponto de verificação corresponda a um único passo lógico de treinamento para todos os processos. Abordagens típicas:
Barreira (barrier) e depois salvar
- Todos os ranks chegam a um ponto de sincronização.
- A gravação é coordenada de modo que cada rank escreva seu fragmento para o mesmo passo.
Ponto de verificação dirigido por coordenador
- O rank 0 grava metadados / “manifesto (manifest)”
- Outros ranks gravam seus fragmentos
- Um “commit” final indica completude do ponto de verificação
Treinamento elástico / reiniciável
Alguns runtimes suportam reiniciar a tarefa com um número diferente de workers (elasticidade (elasticity)). Isso é útil quando nós são preemptados ou a capacidade do cluster muda.
No entanto, a elasticidade complica:
- fragmentação de dados (data sharding) (quem processa quais amostras agora?)
- particionamento do otimizador (como redistribuir o estado do otimizador?)
- fluxos do gerador de números aleatórios e determinismo (determinism)
Se reprodutibilidade estrita importar, veja Reprodutibilidade e considere restringir a elasticidade ou gerenciar com cuidado a lógica de refragmentação (resharding).
Tornando pontos de verificação confiáveis: atomicidade, corrupção e escolhas de armazenamento
Falhas podem ocorrer durante a gravação do ponto de verificação. O resultado mais perigoso é carregar silenciosamente um ponto de verificação parcial ou corrompido.
Escritas atômicas com “escrever e depois renomear”
Um padrão padrão é:
- Gravar o ponto de verificação em um caminho temporário
- Fazer flush e sync (na medida do possível)
- Renomear/mover atomicamente para o caminho final
- Opcionalmente gravar/atualizar um ponteiro
latest
Em sistemas de arquivos POSIX (POSIX filesystems), renomear é tipicamente atômico.
import os
import torch
from pathlib import Path
def atomic_torch_save(obj, path: str):
path = Path(path)
tmp = path.with_suffix(path.suffix + ".tmp")
torch.save(obj, tmp)
os.replace(tmp, path) # atomic on most local/POSIX filesystems
Manifesto + marcador de commit
Para pontos de verificação fragmentados, uma estratégia comum é:
- Cada rank grava
ckpt_step_1000_rank_07.pt - O rank 0 grava um manifesto
ckpt_step_1000.jsonlistando:- número do passo
- fragmentos esperados
- hashes/tamanhos (opcional)
- Apenas quando todos os fragmentos existem o rank 0 grava um pequeno marcador
ckpt_step_1000.COMMITTED
Na restauração, você só carrega pontos de verificação que têm o marcador de commit (commit marker).
Onde armazenar pontos de verificação
- NVMe / SSD local: mais rápido, mas não é durável se o nó morrer.
- Sistema de arquivos de rede (NFS/Lustre/GPFS): durável e compartilhado, mas pode ser lento ou ficar sobrecarregado em escala.
- Armazenamento de objetos (object storage) (S3/GCS/Azure Blob): durável e escalável, mas tem semânticas diferentes (uploads, consistência eventual (eventual consistency), sem renomeação atômica).
Um padrão comum em produção:
- Gravar em disco local rápido
- Fazer upload de forma assíncrona para armazenamento de objetos durável
- Manter apenas um pequeno cache local
Exemplo prático: ponto de verificação de treinamento em PyTorch (processo único)
Este exemplo salva estado suficiente para retomar um loop de treinamento típico.
import torch
import random
import numpy as np
def save_checkpoint(path, model, optimizer, scheduler, scaler, step):
ckpt = {
"step": step,
"model": model.state_dict(),
"optimizer": optimizer.state_dict(),
"scheduler": scheduler.state_dict() if scheduler is not None else None,
"scaler": scaler.state_dict() if scaler is not None else None,
"rng": {
"python": random.getstate(),
"numpy": np.random.get_state(),
"torch": torch.random.get_rng_state(),
"cuda": torch.cuda.get_rng_state_all() if torch.cuda.is_available() else None,
},
}
atomic_torch_save(ckpt, path)
def load_checkpoint(path, model, optimizer, scheduler, scaler, map_location="cpu"):
ckpt = torch.load(path, map_location=map_location)
model.load_state_dict(ckpt["model"])
optimizer.load_state_dict(ckpt["optimizer"])
if scheduler is not None and ckpt["scheduler"] is not None:
scheduler.load_state_dict(ckpt["scheduler"])
if scaler is not None and ckpt["scaler"] is not None:
scaler.load_state_dict(ckpt["scaler"])
# Restore RNG states (optional but useful for reproducibility)
import random
import numpy as np
random.setstate(ckpt["rng"]["python"])
np.random.set_state(ckpt["rng"]["numpy"])
torch.random.set_rng_state(ckpt["rng"]["torch"])
if torch.cuda.is_available() and ckpt["rng"]["cuda"] is not None:
torch.cuda.set_rng_state_all(ckpt["rng"]["cuda"])
return ckpt["step"]
Notas:
- Salvar estados do gerador de números aleatórios da CUDA pode ser importante se você quiser retomar com comportamento estocástico semelhante.
- A reprodutibilidade exata ainda pode ser limitada por kernels não determinísticos (nondeterministic GPU kernels); veja Reprodutibilidade.
Exemplo prático: considerações de pontos de verificação distribuídos (PyTorch)
Com Paralelismo de Dados Distribuído (Distributed Data Parallel, DDP), um padrão comum é:
- o rank 0 grava um ponto de verificação “global” (porque todos os ranks têm pesos de modelo idênticos)
- ou cada rank grava seu próprio fragmento se estiver usando treinamento fragmentado (FSDP/ZeRO)
DDP: salvar apenas no rank 0
import torch.distributed as dist
def is_rank0():
return (not dist.is_initialized()) or dist.get_rank() == 0
if is_rank0():
save_checkpoint("ckpt_latest.pt", model, optimizer, scheduler, scaler, step)
dist.barrier() # ensure others don't race ahead if needed
Isso funciona para o DDP clássico porque cada rank mantém uma cópia completa dos parâmetros e do estado do otimizador (embora o estado do otimizador ainda possa ser grande).
Treinamento fragmentado (FSDP/ZeRO): salvar fragmentos
Em abordagens fragmentadas, cada rank possui apenas parte dos parâmetros/estado do otimizador. O ponto de verificação também precisa ser fragmentado, ou o framework precisa reunir o estado (frequentemente caro demais).
A maioria dos frameworks modernos fornece helpers embutidos para pontos de verificação distribuídos. Prefira essas APIs em vez de serialização (serialization) feita à mão, porque elas lidam com:
- particionamento de parâmetros/otimizador
- refragmentação na restauração
- padrões eficientes de comunicação coletiva (collective communication)
Retomar “corretamente”: fontes sutis de divergência
Mesmo que seu ponto de verificação carregue, a execução retomada pode divergir da original. Às vezes isso é aceitável; às vezes quebra experimentos.
Culpados comuns:
Mudanças na ordem dos dados no meio da época
Se você retomar no passo 12.345 mas seu carregador de dados reiniciar no começo de uma época, você repetirá dados e mudará gradientes.
Mitigações:
- criar pontos de verificação apenas nos limites de época (mais simples, mas desperdiça mais trabalho)
- persistir o estado do amostrador/carregador de dados
- usar embaralhamento determinístico derivado de
(epoch, global_step)para reconstruir a posição
Isso se relaciona de perto com Carregamento de Dados e Pipelines de Entrada.
Desalinhamento do agendamento de taxa de aprendizado
Se o estado do agendador não for salvo (ou se “step” estiver inconsistente), a taxa de aprendizado pode saltar.
Mitigações:
- sempre armazenar
global_step - salvar o estado do agendador ou calcular a taxa de aprendizado de forma puramente determinística a partir de
global_step
Reset do escalador de precisão mista
Se o escalador de precisão mista resetar, você pode ver uma explosão de NaNs ou lentidões.
Mitigação:
- salvar/carregar o estado do escalador (PyTorch
GradScaler, etc.) - veja Precisão Mista
Limite de acumulação de gradientes
Se você falhar no meio da acumulação e retomar no próximo micro-lote, sua atualização efetiva difere.
Mitigações:
- criar pontos de verificação apenas nos limites de atualização (após o passo do otimizador)
- ou armazenar o contador de acumulação (e aceitar que gradientes parciais tipicamente não são restaurados)
Kernels não determinísticos e não determinismo distribuído
Mesmo com o gerador de números aleatórios restaurado, a execução pode não ser idêntica bit a bit devido a não determinismo de kernels de GPU ou à ordem de operações coletivas.
Se reprodutibilidade for importante, veja Reprodutibilidade e considere:
- configurações de algoritmos determinísticos (com trade-offs de desempenho)
- seeds fixas por rank
- controle cuidadoso da fragmentação de dados
Lidando com preempção e interrupções planejadas
Em nuvem ou clusters compartilhados, tarefas podem ser interrompidas com pouco aviso. Uma configuração de treinamento tolerante a falhas deve:
- criar pontos de verificação com frequência suficiente para manter o trabalho perdido sob controle
- responder a sinais de terminação e criar ponto de verificação antes de sair
Pontos de verificação acionados por sinal (padrão comum)
Se seu ambiente envia SIGTERM antes de matar o processo, você pode capturá-lo e criar um ponto de verificação.
Conceitualmente:
- registrar um handler
- definir um flag
- criar o ponto de verificação em um limite seguro (fim do passo) em vez de dentro do handler
Isso evita corromper estado ou causar deadlock (deadlocking) em workers distribuídos.
Validação de ponto de verificação/tolerância a falhas (não pule isso)
Código de ponto de verificação frequentemente “parece ok” até você precisar. Trate como funcionalidade crítica e teste.
Testes recomendados:
Teste de retomada
- Treine por N passos
- Salve um ponto de verificação
- Retome e treine por mais M passos
- Compare com uma execução sem interrupção (permitindo pequeno não determinismo, se aplicável)
Teste de crash
- Mate aleatoriamente a tarefa (ou um rank) durante o treinamento
- Garanta que a lógica de reinício encontre o ponto de verificação confirmado mais recente e continue
Teste de corrupção
- Delete ou trunque um fragmento
- Garanta que o carregador rejeite o ponto de verificação e faça fallback para um mais antigo
Teste de escala (distribuído)
- Verifique que a criação de pontos de verificação não sobrecarrega o sistema de arquivos no número-alvo de nós
Boas práticas operacionais
Manter múltiplas gerações
Mantenha:
latest(mais recente)best(melhor pontuação de validação)- vários pontos de verificação históricos (por exemplo, últimos 3–10) em caso de corrupção ou regressão
Acompanhar “versões de esquema” do ponto de verificação
Quando você muda o código de modelo/otimizador, pontos de verificação antigos podem não carregar de forma limpa. Armazene um campo de versão:
checkpoint_version: 3model_config: {...}
Depois, escreva lógica explícita de migração quando necessário.
Separar artefatos de “treinamento” vs “serving”
Pontos de verificação de treinamento frequentemente incluem estado do otimizador e do gerador de números aleatórios; em serving tipicamente basta:
- pesos do modelo
- tokenizer / configuração de pré-processamento
- configurações de tempo de inferência
Atenção à segurança e privacidade
Pontos de verificação podem vazar:
- dados de treinamento memorizados
- comportamento proprietário do modelo
- configuração sensível
Aplique controle de acesso, criptografia em repouso (encryption-at-rest) e práticas cuidadosas de compartilhamento.
Como pontos de verificação se encaixam em treinamento em escala
A criação de pontos de verificação e a tolerância a falhas estão fortemente acopladas a outras preocupações de escala:
- Pipelines de entrada de alta taxa de transferência podem tornar falhas mais custosas se forem difíceis de “rebobinar”; veja Carregamento de Dados e Pipelines de Entrada.
- Estratégias de paralelismo determinam se pontos de verificação precisam ser fragmentados e quão cara é a restauração; veja Paralelismo de Dados/Modelo/Pipeline.
- Comunicação e orquestração de processos determinam modos de falha e comportamento de reinício; veja Treinamento Distribuído.
- Precisão mista adiciona estado extra (escaladores) e sensibilidade durante retomada; veja Precisão Mista.
- Reprodutibilidade exige captura cuidadosa de estado (gerador de números aleatórios, amostrador, contadores de passos) e consciência de não determinismo; veja Reprodutibilidade.
Resumo
A criação de pontos de verificação é o mecanismo prático que torna sobrevivíveis execuções longas de treinamento: você salva periodicamente um snapshot consistente do estado de treinamento e pode restaurá-lo após falhas ou interrupções. A tolerância a falhas vai além de salvar arquivos: inclui escritas atômicas/confirmadas, coordenação distribuída, armazenamento confiável, tratamento de sinais e teste do caminho de recuperação.
Quando bem feita, a criação de pontos de verificação transforma falhas de “dias perdidos de forma catastrófica” em “reiniciar de 10 minutos atrás”, o que é uma capacidade fundamental para treinar modelos modernos em escala.