Frameworks

O que “frameworks” significam em ML moderno

Em aprendizado de máquina (machine learning), um framework é a camada central de software que permite:

  • Definir modelos (por exemplo, Redes Neurais, Arquitetura Transformer)
  • Calcular gradientes via diferenciação automática (automatic differentiation, autodiff) para Retropropagação
  • Executar treinamento com eficiência em aceleradores (GPUs/TPUs) e em muitos dispositivos
  • Exportar ou servir modelos em ambientes de produção

Hoje, os três “frameworks de aprendizado profundo (deep learning)” mais comuns são PyTorch, TensorFlow e JAX. Eles se sobrepõem bastante em capacidade, mas diferem no modelo de programação, ecossistema, estratégia de implantação e em como alcançam desempenho.

Um modelo mental útil:

  • PyTorch: Python-first, modo eager por padrão; amplamente usado em pesquisa e cada vez mais forte em produção.
  • TensorFlow: ecossistema maduro de produção/implantação; execução em grafo é historicamente central, com APIs Keras de alto nível.
  • JAX: “NumPy + autodiff + compilação”; se destaca em transformações composicionais (jit/vmap/pmap) e frequentemente lidera em padrões de desempenho em pesquisa.

Fundamentos teóricos: autodiff, grafos de computação e compilação

Diferenciação automática (autodiff)

O treinamento em aprendizado profundo é, em grande parte, a aplicação repetida de:

  1. Passagem para frente (forward pass): calcular predições e perda
  2. Passagem para trás (backward pass): calcular gradientes da perda em relação aos parâmetros
  3. Atualização: aplicar um passo do otimizador (uma forma de Descida do Gradiente)

Frameworks fornecem autodiff em modo reverso (reverse-mode autodiff), que é eficiente quando você tem muitos parâmetros e uma única perda escalar (o caso comum).

Execução eager vs execução em grafo

Dois estilos amplos de execução:

  • Eager (imperativa): as operações executam imediatamente conforme o Python roda.
    • Prós: depuração fácil, fluxo de controle intuitivo.
    • Contras: o overhead do Python pode limitar o desempenho; mais difícil otimizar globalmente.
  • Grafo (compilada): capturar o cálculo em um grafo e então otimizá-lo/compilá-lo.
    • Prós: otimizações globais, runtime mais rápido, implantação mais fácil.
    • Contras: rastreamento/compilação pode ser confuso; depuração às vezes mais difícil.

Frameworks modernos borram essa linha:

  • PyTorch: eager por padrão, com torch.compile para otimizar/compilar.
  • TensorFlow: eager por padrão no TF2, mas tf.function constrói grafos.
  • JAX: estilo funcional; jax.jit compila e faz cache de grafos rastreados via XLA.

XLA e compilação

XLA (Accelerated Linear Algebra) é um compilador que otimiza programas de tensores para CPU/GPU/TPU. Ele é:

  • Central no JAX
  • Comum no TensorFlow
  • Cada vez mais importante no PyTorch (via torch.compile e pilhas de compilação relacionadas)

A compilação é mais importante quando:

  • O treinamento é em grande escala ou sensível à latência
  • Você usa TPUs
  • Você quer kernels fundidos e menos overhead em nível de Python

PyTorch

Modelo de programação

PyTorch é imperativo/eager-first: operações com tensores executam imediatamente, e os gradientes são registrados dinamicamente.

Conceitos-chave:

  • torch.Tensor com requires_grad=True
  • torch.autograd constrói um grafo de computação dinâmico durante a execução
  • Você normalmente escreve loops de treinamento explícitos (embora existam wrappers de alto nível)

Pontos fortes

  • Excelente ergonomia para pesquisa: depuração direta, fluxo de controle flexível.
  • Grande ecossistema: forte suporte em muitos repositórios e bibliotecas de modelos, especialmente em PLN e visão.
  • Produção melhorou: TorchScript (histórico), exportação ONNX, torch.compile e opções fortes de runtime (frequentemente em conjunto com ferramentas da NVIDIA).

Exemplo prático: um passo mínimo de treinamento

Abaixo está um pequeno exemplo de “um único passo” (não é uma configuração completa de dataloader) para ilustrar o estilo:

import torch
import torch.nn as nn
import torch.optim as optim

# Simple linear model
model = nn.Linear(10, 1)
optimizer = optim.Adam(model.parameters(), lr=1e-3)
loss_fn = nn.MSELoss()

x = torch.randn(32, 10)       # batch of 32
y = torch.randn(32, 1)

model.train()
pred = model(x)
loss = loss_fn(pred, y)

optimizer.zero_grad()
loss.backward()
optimizer.step()

print(loss.item())

Desempenho e compilação: `torch.compile`

Versões recentes do PyTorch oferecem suporte a compilar partes dos modelos para melhorar a velocidade:

compiled_model = torch.compile(model)  # PyTorch 2.x
pred = compiled_model(x)

Isso pode trazer ganhos reais de velocidade (especialmente para cargas de trabalho do tipo transformer), mas a compilação pode introduzir casos de borda (ops não suportadas, formas dinâmicas). É melhor adotar de forma iterativa: compile o modelo, rode testes, faça profiling e então amplie a cobertura.

Notas de ecossistema

Add-ons comuns incluem:

  • Orquestração de treinamento: PyTorch Lightning, Accelerate
  • Treinamento distribuído: DDP, FSDP
  • Hubs de modelos: muitos modelos em Hubs & Registros de Modelos são PyTorch-first

TensorFlow (e Keras)

Modelo de programação

O TensorFlow 2 tornou a execução eager o padrão, mas sua história “pronta para produção” frequentemente usa compilação em grafo via @tf.function. O Keras fornece uma API de alto nível para definir modelos.

Conceitos-chave:

  • tf.Tensor e tf.Variable
  • tf.GradientTape para autodiff em modo eager
  • @tf.function para rastrear/compilar Python em um grafo

Pontos fortes

  • Ecossistema de implantação: TensorFlow Serving, TFLite (mobile/edge), TF.js (navegador).
  • Keras: loops de treinamento padronizados e de alto nível (model.compile, model.fit) podem reduzir boilerplate.
  • Suporte a TPU: integração forte, especialmente em ambientes do Google Cloud.

Exemplo prático: passo de treinamento customizado com `GradientTape`

import tensorflow as tf

model = tf.keras.Sequential([
    tf.keras.layers.Dense(1)
])

optimizer = tf.keras.optimizers.Adam(1e-3)
loss_fn = tf.keras.losses.MeanSquaredError()

x = tf.random.normal([32, 10])
y = tf.random.normal([32, 1])

with tf.GradientTape() as tape:
    pred = model(x, training=True)
    loss = loss_fn(y, pred)

grads = tape.gradient(loss, model.trainable_variables)
optimizer.apply_gradients(zip(grads, model.trainable_variables))

print(float(loss))

Compilação em grafo com `tf.function`

@tf.function
def train_step(x, y):
    with tf.GradientTape() as tape:
        pred = model(x, training=True)
        loss = loss_fn(y, pred)
    grads = tape.gradient(loss, model.trainable_variables)
    optimizer.apply_gradients(zip(grads, model.trainable_variables))
    return loss

Isso pode melhorar significativamente o throughput ao compilar o passo em um grafo otimizado.

Notas de ecossistema

O TensorFlow é frequentemente escolhido quando:

  • Você precisa de implantação em edge/mobile (TFLite)
  • Você precisa de inferência no navegador (TF.js)
  • Sua organização já tem infraestrutura em TF

Para rastreamento de experimentos e pipelines, ainda é comum combinar TF com ferramentas de Ferramentas de Experimentos e ferramentas de dados de Dados.

JAX

Modelo de programação

JAX é melhor entendido como:

  • API ao estilo NumPy (jax.numpy), mas as operações são puras e diferenciáveis
  • Autodiff via jax.grad
  • Compilação via jax.jit (XLA)
  • Vetorização fácil via jax.vmap
  • Paralelismo entre dispositivos via jax.pmap (e sistemas mais novos como pjit/sharding em configurações modernas de JAX)

JAX incentiva fortemente um estilo funcional (functional style):

  • Evitar mutação in-place de arrays
  • Tratar parâmetros como valores explícitos passados para funções
  • Usar transformações (grad, jit, vmap) para gerar programas eficientes

Pontos fortes

  • Transformações composicionais: grad(jit(f)), vmap(grad(f)) etc. são padrões de primeira classe.
  • Excelente potencial de desempenho: a compilação XLA é central; frequentemente muito rápida quando escrita de maneira “amigável ao JAX”.
  • Ótimo para pesquisa: especialmente quando você precisa de matemática customizada, meta-aprendizado (meta-learning), paralelismo em grande escala ou vetorização elegante.

Exemplo prático: um pequeno passo de regressão linear

Isto é intencionalmente de baixo nível para mostrar a abordagem “funções + transformações” do JAX:

import jax
import jax.numpy as jnp

key = jax.random.key(0)
W = jax.random.normal(key, (10, 1))
b = jnp.zeros((1,))

def model(params, x):
    W, b = params
    return x @ W + b

def loss_fn(params, x, y):
    pred = model(params, x)
    return jnp.mean((pred - y) ** 2)

grad_fn = jax.grad(loss_fn)

x = jax.random.normal(jax.random.key(1), (32, 10))
y = jax.random.normal(jax.random.key(2), (32, 1))

params = (W, b)
grads = grad_fn(params, x, y)

lr = 1e-3
params = jax.tree.map(lambda p, g: p - lr * g, params, grads)

print(loss_fn(params, x, y))

Em projetos reais, você normalmente usa uma biblioteca do ecossistema:

  • Flax (biblioteca de redes neurais)
  • Optax (otimizadores)
  • Orbax (checkpointing)
  • Haiku ou Equinox (bibliotecas alternativas de redes neurais)

Compilação JIT

jit_loss = jax.jit(loss_fn)
jit_grad = jax.jit(grad_fn)

A compilação tem um custo inicial (a primeira chamada rastreia e compila), mas as chamadas subsequentes são rápidas.

Principais armadilhas

  • Polimorfismo de forma (shape polymorphism) e fluxo de controle podem ser difíceis: muitos ramos do lado do Python são resolvidos no momento do rastreamento.
  • Você precisa pensar com cuidado sobre o que é compilado e armazenado em cache (por exemplo, mudar formas pode disparar recompilação).
  • Depurar dentro de funções compiladas com jit é diferente (embora as ferramentas tenham melhorado).

Considerações comuns entre frameworks

Alocação em dispositivos e precisão mista

Os três suportam GPUs; o suporte a TPU é mais forte em JAX e TensorFlow (embora exista suporte a TPU no PyTorch via integrações).

Precisão mista (FP16/BF16) é uma alavanca importante de desempenho, especialmente para grandes modelos transformer. Cada framework tem suas próprias APIs e boas práticas; na prática, você também depende do hardware e da pilha do fornecedor (CUDA, cuDNN, NCCL).

Treinamento distribuído

Em escala, você vai se preocupar com:

  • Paralelismo de dados (replicar o modelo; dividir lotes)
  • Paralelismo de modelo (dividir o modelo entre dispositivos)
  • Otimizadores e parâmetros fragmentados (sharded)
  • Eficiência de comunicação (all-reduce, reduce-scatter)

Wrappers e bibliotecas de alto nível frequentemente importam mais do que a escolha do framework base. Sua pilha de avaliação (ver Harnesses de Avaliação) e logging/gestão de artefatos (ver Ferramentas de Experimentos) deve se integrar bem.

Exportação e interoperabilidade (ONNX e além)

  • ONNX pode ajudar a exportar modelos entre ecossistemas, mas a cobertura é imperfeita — especialmente para ops mais novas ou camadas customizadas.
  • Para aplicações com LLMs, você também pode exportar para runtimes especializados ou servidores de inferência; a escolha prática frequentemente depende da sua pilha de Ferramentas para LLM.

Como escolher entre PyTorch, TensorFlow e JAX

Escolher é menos sobre “qual é o melhor” e mais sobre casar restrições: habilidades do time, alvos de implantação, requisitos de desempenho e dependências do ecossistema.

Escolha PyTorch se…

  • Você quer o padrão mais comum em pesquisa e ecossistemas de modelos open-source.
  • Você valoriza facilidade de depuração e flexibilidade (fluxo de controle dinâmico, iteração mais fácil).
  • Você planeja fazer fine-tuning ou estender repositórios de modelos populares (muitos são PyTorch-first).
  • Você quer um caminho forte para alto desempenho com torch.compile, FSDP e ferramentas modernas.

Casos de uso típicos:

  • Fine-tuning e experimentação com LLMs
  • Protótipos de pesquisa de modelos customizados que podem virar sistemas de produção
  • Ambientes de treinamento centrados em GPU

Escolha TensorFlow se…

  • Seu requisito principal é implantação madura para:
    • Mobile/edge (TFLite)
    • Navegador (TF.js)
    • Infra de serving estabelecida (TF Serving)
  • Você prefere APIs de alto nível no estilo Keras e loops de treinamento padronizados.
  • Sua organização já tem pipelines em TF e expertise operacional.

Casos de uso típicos:

  • Inferência no dispositivo (mobile, embarcado)
  • Ambientes corporativos com infraestrutura TF existente
  • Produtos que se beneficiam do toolchain de implantação do TF

Escolha JAX se…

  • Você quer transformações de programa composicionais (jit, vmap, grad, sharding) e se sente confortável com um estilo mais funcional.
  • Você roda em TPUs ou quer padrões de desempenho XLA-first.
  • Você faz pesquisa que se beneficia de vetorização e compilação (por exemplo, meta-learning, simulação em lote grande, paralelismo avançado).

Casos de uso típicos:

  • Pesquisa e treinamento fortemente baseados em TPU
  • Loops de treinamento altamente otimizados onde compilação e vetorização são centrais
  • Fluxos de trabalho que se beneficiam de código matemático limpo, semelhante ao NumPy

Checklist prático de decisão

Faça estas perguntas em ordem:

  1. Onde o modelo precisa rodar?

    • Mobile/navegador: TensorFlow costuma ser o caminho mais fácil.
    • GPU em servidor: PyTorch é um padrão muito comum; todos podem funcionar.
    • TPU: JAX ou TensorFlow geralmente fornecem a experiência mais fluida.
  2. Você precisa integrar com um ecossistema específico?

    • Se seu modelo/código-alvo é PyTorch-first, trocar de framework é caro.
    • Se sua pilha de implantação já é TF Serving/TFLite, TensorFlow reduz atrito.
  3. Quão importante é “depuração pythônica” vs desempenho centrado em compilação?

    • Muita iteração/depuração: PyTorch tende a ser o mais simples.
    • Mentalidade de desempenho “compile tudo”: JAX brilha; TF também é forte com tf.function.
  4. O que seu time já conhece?

    • Familiaridade com o framework costuma ser o maior multiplicador de produtividade.
  5. Quais são suas restrições fora do modelo?

Uma recomendação pragmática de workflow

Muitos times convergem para uma abordagem híbrida:

  • Pesquisa / desenvolvimento de modelo: PyTorch (ou JAX em certas organizações de pesquisa)
  • Treinamento em produção: PyTorch ou TensorFlow dependendo da infraestrutura
  • Implantação:
    • Inferência em servidor: frequentemente PyTorch + um servidor/runtime de inferência, ou TF Serving
    • Edge/navegador: frequentemente TensorFlow (TFLite / TF.js)

O “melhor” framework é aquele que minimiza o atrito ao longo de todo o ciclo de vida: experimentação → treinamento → avaliação → implantação → manutenção.

Resumo

  • PyTorch: melhor padrão geral para pesquisa e uma grande fração de ML em produção; eager-first com compilação e escalabilidade cada vez mais fortes.
  • TensorFlow: ecossistema de implantação mais forte de ponta a ponta (especialmente edge e navegador) e ferramentas maduras para grafos; Keras oferece ergonomia de alto nível.
  • JAX: framework funcional, baseado em transformações, com XLA no núcleo; se destaca quando você quer compilação/vetorização composicionais e fluxos de trabalho fortes com TPUs.

Se você estiver em dúvida:

  • Comece com PyTorch para a maioria dos trabalhos gerais de ML e LLMs.
  • Prefira TensorFlow se você sabe que precisa implantar em mobile/navegador ou já tem infraestrutura TF.
  • Prefira JAX se você está otimizando para treinamento TPU-first ou fluxos de pesquisa que se beneficiam da composicionalidade de jit/vmap.