Frameworks
O que “frameworks” significam em ML moderno
Em aprendizado de máquina (machine learning), um framework é a camada central de software que permite:
- Definir modelos (por exemplo, Redes Neurais, Arquitetura Transformer)
- Calcular gradientes via diferenciação automática (automatic differentiation, autodiff) para Retropropagação
- Executar treinamento com eficiência em aceleradores (GPUs/TPUs) e em muitos dispositivos
- Exportar ou servir modelos em ambientes de produção
Hoje, os três “frameworks de aprendizado profundo (deep learning)” mais comuns são PyTorch, TensorFlow e JAX. Eles se sobrepõem bastante em capacidade, mas diferem no modelo de programação, ecossistema, estratégia de implantação e em como alcançam desempenho.
Um modelo mental útil:
- PyTorch: Python-first, modo eager por padrão; amplamente usado em pesquisa e cada vez mais forte em produção.
- TensorFlow: ecossistema maduro de produção/implantação; execução em grafo é historicamente central, com APIs Keras de alto nível.
- JAX: “NumPy + autodiff + compilação”; se destaca em transformações composicionais (jit/vmap/pmap) e frequentemente lidera em padrões de desempenho em pesquisa.
Fundamentos teóricos: autodiff, grafos de computação e compilação
Diferenciação automática (autodiff)
O treinamento em aprendizado profundo é, em grande parte, a aplicação repetida de:
- Passagem para frente (forward pass): calcular predições e perda
- Passagem para trás (backward pass): calcular gradientes da perda em relação aos parâmetros
- Atualização: aplicar um passo do otimizador (uma forma de Descida do Gradiente)
Frameworks fornecem autodiff em modo reverso (reverse-mode autodiff), que é eficiente quando você tem muitos parâmetros e uma única perda escalar (o caso comum).
Execução eager vs execução em grafo
Dois estilos amplos de execução:
- Eager (imperativa): as operações executam imediatamente conforme o Python roda.
- Prós: depuração fácil, fluxo de controle intuitivo.
- Contras: o overhead do Python pode limitar o desempenho; mais difícil otimizar globalmente.
- Grafo (compilada): capturar o cálculo em um grafo e então otimizá-lo/compilá-lo.
- Prós: otimizações globais, runtime mais rápido, implantação mais fácil.
- Contras: rastreamento/compilação pode ser confuso; depuração às vezes mais difícil.
Frameworks modernos borram essa linha:
- PyTorch: eager por padrão, com
torch.compilepara otimizar/compilar. - TensorFlow: eager por padrão no TF2, mas
tf.functionconstrói grafos. - JAX: estilo funcional;
jax.jitcompila e faz cache de grafos rastreados via XLA.
XLA e compilação
XLA (Accelerated Linear Algebra) é um compilador que otimiza programas de tensores para CPU/GPU/TPU. Ele é:
- Central no JAX
- Comum no TensorFlow
- Cada vez mais importante no PyTorch (via
torch.compilee pilhas de compilação relacionadas)
A compilação é mais importante quando:
- O treinamento é em grande escala ou sensível à latência
- Você usa TPUs
- Você quer kernels fundidos e menos overhead em nível de Python
PyTorch
Modelo de programação
PyTorch é imperativo/eager-first: operações com tensores executam imediatamente, e os gradientes são registrados dinamicamente.
Conceitos-chave:
torch.Tensorcomrequires_grad=Truetorch.autogradconstrói um grafo de computação dinâmico durante a execução- Você normalmente escreve loops de treinamento explícitos (embora existam wrappers de alto nível)
Pontos fortes
- Excelente ergonomia para pesquisa: depuração direta, fluxo de controle flexível.
- Grande ecossistema: forte suporte em muitos repositórios e bibliotecas de modelos, especialmente em PLN e visão.
- Produção melhorou: TorchScript (histórico), exportação ONNX,
torch.compilee opções fortes de runtime (frequentemente em conjunto com ferramentas da NVIDIA).
Exemplo prático: um passo mínimo de treinamento
Abaixo está um pequeno exemplo de “um único passo” (não é uma configuração completa de dataloader) para ilustrar o estilo:
import torch
import torch.nn as nn
import torch.optim as optim
# Simple linear model
model = nn.Linear(10, 1)
optimizer = optim.Adam(model.parameters(), lr=1e-3)
loss_fn = nn.MSELoss()
x = torch.randn(32, 10) # batch of 32
y = torch.randn(32, 1)
model.train()
pred = model(x)
loss = loss_fn(pred, y)
optimizer.zero_grad()
loss.backward()
optimizer.step()
print(loss.item())
Desempenho e compilação: `torch.compile`
Versões recentes do PyTorch oferecem suporte a compilar partes dos modelos para melhorar a velocidade:
compiled_model = torch.compile(model) # PyTorch 2.x
pred = compiled_model(x)
Isso pode trazer ganhos reais de velocidade (especialmente para cargas de trabalho do tipo transformer), mas a compilação pode introduzir casos de borda (ops não suportadas, formas dinâmicas). É melhor adotar de forma iterativa: compile o modelo, rode testes, faça profiling e então amplie a cobertura.
Notas de ecossistema
Add-ons comuns incluem:
- Orquestração de treinamento: PyTorch Lightning, Accelerate
- Treinamento distribuído: DDP, FSDP
- Hubs de modelos: muitos modelos em Hubs & Registros de Modelos são PyTorch-first
TensorFlow (e Keras)
Modelo de programação
O TensorFlow 2 tornou a execução eager o padrão, mas sua história “pronta para produção” frequentemente usa compilação em grafo via @tf.function. O Keras fornece uma API de alto nível para definir modelos.
Conceitos-chave:
tf.Tensoretf.Variabletf.GradientTapepara autodiff em modo eager@tf.functionpara rastrear/compilar Python em um grafo
Pontos fortes
- Ecossistema de implantação: TensorFlow Serving, TFLite (mobile/edge), TF.js (navegador).
- Keras: loops de treinamento padronizados e de alto nível (
model.compile,model.fit) podem reduzir boilerplate. - Suporte a TPU: integração forte, especialmente em ambientes do Google Cloud.
Exemplo prático: passo de treinamento customizado com `GradientTape`
import tensorflow as tf
model = tf.keras.Sequential([
tf.keras.layers.Dense(1)
])
optimizer = tf.keras.optimizers.Adam(1e-3)
loss_fn = tf.keras.losses.MeanSquaredError()
x = tf.random.normal([32, 10])
y = tf.random.normal([32, 1])
with tf.GradientTape() as tape:
pred = model(x, training=True)
loss = loss_fn(y, pred)
grads = tape.gradient(loss, model.trainable_variables)
optimizer.apply_gradients(zip(grads, model.trainable_variables))
print(float(loss))
Compilação em grafo com `tf.function`
@tf.function
def train_step(x, y):
with tf.GradientTape() as tape:
pred = model(x, training=True)
loss = loss_fn(y, pred)
grads = tape.gradient(loss, model.trainable_variables)
optimizer.apply_gradients(zip(grads, model.trainable_variables))
return loss
Isso pode melhorar significativamente o throughput ao compilar o passo em um grafo otimizado.
Notas de ecossistema
O TensorFlow é frequentemente escolhido quando:
- Você precisa de implantação em edge/mobile (TFLite)
- Você precisa de inferência no navegador (TF.js)
- Sua organização já tem infraestrutura em TF
Para rastreamento de experimentos e pipelines, ainda é comum combinar TF com ferramentas de Ferramentas de Experimentos e ferramentas de dados de Dados.
JAX
Modelo de programação
JAX é melhor entendido como:
- API ao estilo NumPy (
jax.numpy), mas as operações são puras e diferenciáveis - Autodiff via
jax.grad - Compilação via
jax.jit(XLA) - Vetorização fácil via
jax.vmap - Paralelismo entre dispositivos via
jax.pmap(e sistemas mais novos comopjit/sharding em configurações modernas de JAX)
JAX incentiva fortemente um estilo funcional (functional style):
- Evitar mutação in-place de arrays
- Tratar parâmetros como valores explícitos passados para funções
- Usar transformações (
grad,jit,vmap) para gerar programas eficientes
Pontos fortes
- Transformações composicionais:
grad(jit(f)),vmap(grad(f))etc. são padrões de primeira classe. - Excelente potencial de desempenho: a compilação XLA é central; frequentemente muito rápida quando escrita de maneira “amigável ao JAX”.
- Ótimo para pesquisa: especialmente quando você precisa de matemática customizada, meta-aprendizado (meta-learning), paralelismo em grande escala ou vetorização elegante.
Exemplo prático: um pequeno passo de regressão linear
Isto é intencionalmente de baixo nível para mostrar a abordagem “funções + transformações” do JAX:
import jax
import jax.numpy as jnp
key = jax.random.key(0)
W = jax.random.normal(key, (10, 1))
b = jnp.zeros((1,))
def model(params, x):
W, b = params
return x @ W + b
def loss_fn(params, x, y):
pred = model(params, x)
return jnp.mean((pred - y) ** 2)
grad_fn = jax.grad(loss_fn)
x = jax.random.normal(jax.random.key(1), (32, 10))
y = jax.random.normal(jax.random.key(2), (32, 1))
params = (W, b)
grads = grad_fn(params, x, y)
lr = 1e-3
params = jax.tree.map(lambda p, g: p - lr * g, params, grads)
print(loss_fn(params, x, y))
Em projetos reais, você normalmente usa uma biblioteca do ecossistema:
- Flax (biblioteca de redes neurais)
- Optax (otimizadores)
- Orbax (checkpointing)
- Haiku ou Equinox (bibliotecas alternativas de redes neurais)
Compilação JIT
jit_loss = jax.jit(loss_fn)
jit_grad = jax.jit(grad_fn)
A compilação tem um custo inicial (a primeira chamada rastreia e compila), mas as chamadas subsequentes são rápidas.
Principais armadilhas
- Polimorfismo de forma (shape polymorphism) e fluxo de controle podem ser difíceis: muitos ramos do lado do Python são resolvidos no momento do rastreamento.
- Você precisa pensar com cuidado sobre o que é compilado e armazenado em cache (por exemplo, mudar formas pode disparar recompilação).
- Depurar dentro de funções compiladas com
jité diferente (embora as ferramentas tenham melhorado).
Considerações comuns entre frameworks
Alocação em dispositivos e precisão mista
Os três suportam GPUs; o suporte a TPU é mais forte em JAX e TensorFlow (embora exista suporte a TPU no PyTorch via integrações).
Precisão mista (FP16/BF16) é uma alavanca importante de desempenho, especialmente para grandes modelos transformer. Cada framework tem suas próprias APIs e boas práticas; na prática, você também depende do hardware e da pilha do fornecedor (CUDA, cuDNN, NCCL).
Treinamento distribuído
Em escala, você vai se preocupar com:
- Paralelismo de dados (replicar o modelo; dividir lotes)
- Paralelismo de modelo (dividir o modelo entre dispositivos)
- Otimizadores e parâmetros fragmentados (sharded)
- Eficiência de comunicação (all-reduce, reduce-scatter)
Wrappers e bibliotecas de alto nível frequentemente importam mais do que a escolha do framework base. Sua pilha de avaliação (ver Harnesses de Avaliação) e logging/gestão de artefatos (ver Ferramentas de Experimentos) deve se integrar bem.
Exportação e interoperabilidade (ONNX e além)
- ONNX pode ajudar a exportar modelos entre ecossistemas, mas a cobertura é imperfeita — especialmente para ops mais novas ou camadas customizadas.
- Para aplicações com LLMs, você também pode exportar para runtimes especializados ou servidores de inferência; a escolha prática frequentemente depende da sua pilha de Ferramentas para LLM.
Como escolher entre PyTorch, TensorFlow e JAX
Escolher é menos sobre “qual é o melhor” e mais sobre casar restrições: habilidades do time, alvos de implantação, requisitos de desempenho e dependências do ecossistema.
Escolha PyTorch se…
- Você quer o padrão mais comum em pesquisa e ecossistemas de modelos open-source.
- Você valoriza facilidade de depuração e flexibilidade (fluxo de controle dinâmico, iteração mais fácil).
- Você planeja fazer fine-tuning ou estender repositórios de modelos populares (muitos são PyTorch-first).
- Você quer um caminho forte para alto desempenho com
torch.compile, FSDP e ferramentas modernas.
Casos de uso típicos:
- Fine-tuning e experimentação com LLMs
- Protótipos de pesquisa de modelos customizados que podem virar sistemas de produção
- Ambientes de treinamento centrados em GPU
Escolha TensorFlow se…
- Seu requisito principal é implantação madura para:
- Mobile/edge (TFLite)
- Navegador (TF.js)
- Infra de serving estabelecida (TF Serving)
- Você prefere APIs de alto nível no estilo Keras e loops de treinamento padronizados.
- Sua organização já tem pipelines em TF e expertise operacional.
Casos de uso típicos:
- Inferência no dispositivo (mobile, embarcado)
- Ambientes corporativos com infraestrutura TF existente
- Produtos que se beneficiam do toolchain de implantação do TF
Escolha JAX se…
- Você quer transformações de programa composicionais (
jit,vmap,grad, sharding) e se sente confortável com um estilo mais funcional. - Você roda em TPUs ou quer padrões de desempenho XLA-first.
- Você faz pesquisa que se beneficia de vetorização e compilação (por exemplo, meta-learning, simulação em lote grande, paralelismo avançado).
Casos de uso típicos:
- Pesquisa e treinamento fortemente baseados em TPU
- Loops de treinamento altamente otimizados onde compilação e vetorização são centrais
- Fluxos de trabalho que se beneficiam de código matemático limpo, semelhante ao NumPy
Checklist prático de decisão
Faça estas perguntas em ordem:
Onde o modelo precisa rodar?
- Mobile/navegador: TensorFlow costuma ser o caminho mais fácil.
- GPU em servidor: PyTorch é um padrão muito comum; todos podem funcionar.
- TPU: JAX ou TensorFlow geralmente fornecem a experiência mais fluida.
Você precisa integrar com um ecossistema específico?
- Se seu modelo/código-alvo é PyTorch-first, trocar de framework é caro.
- Se sua pilha de implantação já é TF Serving/TFLite, TensorFlow reduz atrito.
Quão importante é “depuração pythônica” vs desempenho centrado em compilação?
- Muita iteração/depuração: PyTorch tende a ser o mais simples.
- Mentalidade de desempenho “compile tudo”: JAX brilha; TF também é forte com
tf.function.
O que seu time já conhece?
- Familiaridade com o framework costuma ser o maior multiplicador de produtividade.
Quais são suas restrições fora do modelo?
- Pipelines de dados (Dados), acesso a datasets (Datasets & Hospedagem), avaliação (Harnesses de Avaliação), rastreamento de experimentos (Ferramentas de Experimentos) e licenciamento (Modelos Abertos & Licenças) podem dominar o sucesso prático.
Uma recomendação pragmática de workflow
Muitos times convergem para uma abordagem híbrida:
- Pesquisa / desenvolvimento de modelo: PyTorch (ou JAX em certas organizações de pesquisa)
- Treinamento em produção: PyTorch ou TensorFlow dependendo da infraestrutura
- Implantação:
- Inferência em servidor: frequentemente PyTorch + um servidor/runtime de inferência, ou TF Serving
- Edge/navegador: frequentemente TensorFlow (TFLite / TF.js)
O “melhor” framework é aquele que minimiza o atrito ao longo de todo o ciclo de vida: experimentação → treinamento → avaliação → implantação → manutenção.
Resumo
- PyTorch: melhor padrão geral para pesquisa e uma grande fração de ML em produção; eager-first com compilação e escalabilidade cada vez mais fortes.
- TensorFlow: ecossistema de implantação mais forte de ponta a ponta (especialmente edge e navegador) e ferramentas maduras para grafos; Keras oferece ergonomia de alto nível.
- JAX: framework funcional, baseado em transformações, com XLA no núcleo; se destaca quando você quer compilação/vetorização composicionais e fluxos de trabalho fortes com TPUs.
Se você estiver em dúvida:
- Comece com PyTorch para a maioria dos trabalhos gerais de ML e LLMs.
- Prefira TensorFlow se você sabe que precisa implantar em mobile/navegador ou já tem infraestrutura TF.
- Prefira JAX se você está otimizando para treinamento TPU-first ou fluxos de pesquisa que se beneficiam da composicionalidade de
jit/vmap.