Inferência Variacional

O que é inferência variacional?

Inferência variacional (variational inference, VI) é uma abordagem baseada em otimização para aproximar uma distribuição a posteriori bayesiana intratável. Na Inferência Bayesiana (Bayesian Inference), queremos o posterior

[ p(z \mid x) = \frac{p(x,z)}{p(x)} ]

onde:

  • (x) são variáveis observadas (dados)
  • (z) são variáveis latentes/parâmetros
  • (p(x,z)=p(x\mid z)p(z)) é a distribuição conjunta (joint distribution)
  • (p(x)=\int p(x,z),dz) é a verossimilhança marginal (marginal likelihood) (também conhecida como evidência (evidence))

A dificuldade central é que a evidência (p(x)) geralmente exige integrar (ou somar) sobre (z) de alta dimensionalidade, o que muitas vezes é computacionalmente intratável para modelos realistas.

A inferência variacional substitui a inferência exata por:

  1. Escolher uma família tratável de distribuições (q_\lambda(z)) (parametrizada por (\lambda))
  2. Otimizar (\lambda) para que (q_\lambda(z)) esteja “próxima” do verdadeiro posterior (p(z\mid x))

A noção mais comum de proximidade é a divergência de Kullback–Leibler (Kullback–Leibler divergence):

[ \mathrm{KL}\big(q_\lambda(z)\ |\ p(z\mid x)\big) ]

Minimizar essa divergência KL é equivalente a maximizar o limite inferior da evidência (evidence lower bound, ELBO), o que transforma a inferência em um problema familiar de otimização, solucionável com ferramentas como Descida do Gradiente (Gradient Descent).

A inferência variacional é amplamente usada porque, com frequência, é mais rápida e mais escalável do que métodos baseados em amostragem (sampling-based methods), como Monte Carlo via Cadeia de Markov (Markov Chain Monte Carlo, MCMC), especialmente em grandes conjuntos de dados e em cenários de aprendizado profundo (deep learning) (por exemplo, Autoencoders Variacionais (Variational Autoencoders, VAEs)).

Da regra de Bayes (Bayes’ rule) ao ELBO

O objetivo: aproximar o posterior

Gostaríamos de resolver:

[ q^*(z) = \arg\min_{q \in \mathcal{Q}} \mathrm{KL}\left(q(z)\ |\ p(z\mid x)\right) ]

onde (\mathcal{Q}) é a família variacional escolhida (por exemplo, Gaussianas fatoradas).

No entanto, (\mathrm{KL}(q | p(z\mid x))) envolve o posterior (p(z\mid x)), que contém (p(x)), que é intratável. O truque-chave é reescrever o objetivo em termos da conjunta (p(x,z)).

Derivação do ELBO

Comece com a evidência em log:

[ \log p(x) = \log \int p(x,z),dz ]

Introduza qualquer distribuição (q(z)) e use a desigualdade de Jensen:

[ \log p(x)= \log \int q(z)\frac{p(x,z)}{q(z)}dz \ge \int q(z)\log\frac{p(x,z)}{q(z)}dz ]

Defina o ELBO:

[ \mathcal{L}(q) = \mathbb{E}{q(z)}[\log p(x,z)] - \mathbb{E}{q(z)}[\log q(z)] ]

Isso pode ser reescrito como:

[ \mathcal{L}(q) = \mathbb{E}_{q(z)}[\log p(x\mid z)] - \mathrm{KL}\big(q(z)\ |\ p(z)\big) ]

Assim, o ELBO equilibra:

  • Ajuste aos dados: (\mathbb{E}_{q}[\log p(x\mid z)])
  • Regularização em direção ao prior: (-\mathrm{KL}(q(z)|p(z)))

Mais importante:

[ \log p(x) = \mathcal{L}(q) + \mathrm{KL}\big(q(z)\ |\ p(z\mid x)\big) ]

Como a divergência KL é não negativa, (\mathcal{L}(q)\le \log p(x)). Portanto, maximizar o ELBO é equivalente a minimizar (\mathrm{KL}(q|p)).

Escolhendo uma família variacional \(\mathcal{Q}\)

A acurácia e o custo computacional da inferência variacional dependem fortemente da família (\mathcal{Q}).

Escolhas comuns incluem:

  • Campo médio (mean-field) (totalmente fatorada): [ q(z)=\prod_{i=1}^m q_i(z_i) ] Esta é a “VI padrão” mais comum porque simplifica a otimização, mas frequentemente subestima correlações no posterior.

  • Famílias variacionais estruturadas (structured variational families): retêm algumas dependências, por exemplo, [ q(z)=q(z_a)q(z_b\mid z_a) ] Úteis quando correlações são cruciais.

  • Aproximações de família exponencial (exponential-family approximations) (frequentemente conjugadas (conjugate) ao modelo), permitindo atualizações eficientes.

  • Famílias flexíveis como fluxos normalizantes (normalizing flows), que transformam uma densidade base simples em uma densidade complexa. (Elas são populares na modelagem generativa profunda moderna, muitas vezes junto com Redes Neurais (Neural Networks).)

Um ponto conceitual-chave: a inferência variacional é aproximada por construção. Você escolhe uma família computacionalmente gerenciável e obtém a melhor aproximação dentro dessa família.

Inferência variacional de campo médio (mean-field variational inference, MFVI)

A ideia

A inferência variacional de campo médio assume uma fatoração:

[ q(z)=\prod_{i=1}^m q_i(z_i) ]

Isso transforma um problema difícil de inferência em alta dimensionalidade em um conjunto de problemas de otimização acoplados para cada fator (q_i).

Inferência variacional por ascensão por coordenadas (coordinate ascent variational inference, CAVI)

Para muitos modelos (especialmente modelos de família exponencial condicionalmente conjugados), podemos otimizar o ELBO por ascensão por coordenadas (coordinate ascent): atualizando um fator (q_i) por vez enquanto mantemos os demais fixos.

Um resultado clássico:

[ \log q_i^*(z_i) = \mathbb{E}{q{-i}}[\log p(x,z)] + \text{constant} ]

onde (q_{-i} = \prod_{j\ne i} q_j(z_j)).

Interpretação:

  • O fator ótimo para (z_i) é proporcional ao exponencial do log da conjunta esperado sob os outros fatores.
  • Em modelos conjugados, isso mantém (q_i) na mesma família de distribuição (por exemplo, Gaussiana, Gamma, Dirichlet), fornecendo atualizações em forma fechada.

Exemplo prático: modelo de mistura bayesiano (alto nível)

Considere um modelo de mistura Gaussiana bayesiano com atribuições latentes de cluster (c_n) e parâmetros de cluster (\theta_k). A inferência exata do posterior acopla todas as atribuições e parâmetros. A inferência variacional de campo médio pode escolher:

[ q(c,\theta)=\left(\prod_{n} q(c_n)\right)\left(\prod_k q(\theta_k)\right) ]

A CAVI alterna entre:

  • Atualizar cada (q(c_n)) usando log-verossimilhanças de cluster esperadas sob (q(\theta_k))
  • Atualizar cada (q(\theta_k)) usando “contagens suaves” de (q(c_n))

Isso produz um algoritmo que lembra o EM, mas bayesiano: ele retorna distribuições sobre parâmetros, não apenas estimativas pontuais como em Estimativa de Máxima Verossimilhança (Maximum Likelihood Estimation, MLE).

Inferência variacional estocástica (stochastic variational inference, SVI)

A CAVI pode ser muito eficiente, mas em conjuntos de dados massivos ela ainda pode ser cara demais porque as atualizações frequentemente exigem expectativas sobre todos os pontos de dados.

A inferência variacional estocástica escala a inferência variacional para grandes conjuntos de dados usando:

  • Minilotes (minibatches) de dados
  • Gradientes estocásticos (stochastic gradients) do ELBO
  • Muitas vezes, gradientes naturais (natural gradients) para convergência mais rápida em cenários de família exponencial

Esboço em alto nível:

  1. Amostre um minilote (B \subset {1,\dots,N})
  2. Calcule uma estimativa não enviesada do gradiente do ELBO (ou gradiente natural)
  3. Atualize (\lambda \leftarrow \lambda + \rho_t \widehat{\nabla_\lambda \mathcal{L}}) com um cronograma de taxa de aprendizado (\rho_t)

A SVI tornou a inferência variacional prática para modelos bayesianos em larga escala, como modelos de tópicos (por exemplo, LDA) e modelos bayesianos hierárquicos em pipelines de escala industrial.

Inferência variacional caixa-preta (black-box variational inference, BBVI)

Campo médio + conjugação é conveniente, mas modelos probabilísticos modernos frequentemente são não conjugados (non-conjugate) (por exemplo, verossimilhanças de regressão logística, verossimilhanças de redes neurais). Nesse caso, não conseguimos derivar atualizações de CAVI em forma fechada.

A inferência variacional caixa-preta usa estimadores genéricos de gradiente para o ELBO:

[ \nabla_\lambda \mathcal{L}(\lambda) = \nabla_\lambda \mathbb{E}{q\lambda(z)}[\log p(x,z) - \log q_\lambda(z)] ]

Dois estimadores de gradiente amplamente usados são:

1) Estimador de função de pontuação (score-function, REINFORCE)

[ \nabla_\lambda \mathcal{L}(\lambda) = \mathbb{E}{q\lambda(z)}\Big[(\log p(x,z)-\log q_\lambda(z))\nabla_\lambda \log q_\lambda(z)\Big] ]

  • Funciona para muitas distribuições (incluindo latentes discretos)
  • Frequentemente tem alta variância, então precisa de redução de variância (baselines/variáveis de controle)

2) Truque de reparametrização (reparameterization trick)

Se pudermos escrever amostras como uma transformação determinística de ruído:

[ z = g(\epsilon;\lambda), \quad \epsilon \sim p(\epsilon) ]

então:

[ \nabla_\lambda \mathcal{L}(\lambda) = \nabla_\lambda \mathbb{E}{\epsilon}[\log p(x,g(\epsilon;\lambda)) - \log q\lambda(g(\epsilon;\lambda))] ]

Isso normalmente gera gradientes de menor variância e é o principal mecanismo por trás de Autoencoders Variacionais.

Exemplo mínimo no estilo PyTorch (PyTorch-like): BBVI com reparametrização

Abaixo está um esboço simplificado para um posterior variacional Gaussiano (q_\lambda(z)=\mathcal{N}(\mu,\sigma^2)) (diagonal), otimizando um ELBO dadas as funções log_joint(x, z) e log_q(z, mu, log_std):

import torch

mu = torch.zeros(d, requires_grad=True)
log_std = torch.zeros(d, requires_grad=True)
opt = torch.optim.Adam([mu, log_std], lr=1e-3)

for step in range(num_steps):
    eps = torch.randn(d)
    z = mu + torch.exp(log_std) * eps  # reparameterization

    elbo = log_joint(x, z) - log_q(z, mu, log_std)  # Monte Carlo estimate
    loss = -elbo  # maximize ELBO <=> minimize negative ELBO

    opt.zero_grad()
    loss.backward()
    opt.step()

Na prática, você usaria minilotes, múltiplas amostras de (\epsilon) e estabilização numérica cuidadosa, mas o padrão é o mesmo: defina uma estimativa do ELBO e otimize-a com diferenciação automática (automatic differentiation, autodiff) (veja também Retropropagação (Backpropagation)).

Como a inferência variacional se relaciona com inferência bayesiana, MCMC e EM

Inferência variacional vs inferência bayesiana exata

  • A inferência bayesiana exata produz o verdadeiro posterior (p(z\mid x)).
  • A inferência variacional produz uma aproximação (q_\lambda(z)) dentro de uma família escolhida.

A inferência variacional é melhor vista como uma aproximação fundamentada: ela otimiza um objetivo claro (ELBO/KL) e fornece uma distribuição (útil para incerteza), mas não é exata.

Inferência variacional vs MCMC

Métodos de Monte Carlo via Cadeia de Markov (por exemplo, Monte Carlo Hamiltoniano) aproximam o posterior gerando amostras cuja distribuição empírica se aproxima de (p(z\mid x)).

Trade-offs típicos:

  • Acurácia:

    • MCMC pode ser assintoticamente exata (dado tempo suficiente e diagnósticos corretos).
    • A inferência variacional é enviesada devido à família variacional restrita e à direção da KL.
  • Velocidade e escalabilidade:

    • A inferência variacional frequentemente é mais rápida, e a SVI escala bem para conjuntos de dados enormes.
    • MCMC pode ser caro em alta dimensionalidade ou com muitos dados, a menos que seja especializado.
  • Modos de falha:

    • A inferência variacional frequentemente subestima a incerteza (especialmente com campo médio).
    • MCMC pode misturar mal (por exemplo, posteriors multimodais) e requer verificações cuidadosas de convergência.

Um fluxo de trabalho comum na prática é: prototipar com inferência variacional por velocidade e, em seguida, validar com MCMC em subconjuntos menores quando viável.

Inferência variacional vs EM

O algoritmo EM (Expectation-Maximization, EM) pode ser visto como a otimização de um limite inferior de (\log p(x)) usando uma distribuição auxiliar sobre variáveis latentes.

  • Em EM, a “distribuição variacional” frequentemente é definida como a condicional exata (p(z\mid x,\theta)) a cada iteração, e os parâmetros (\theta) são otimizados.
  • Em inferência variacional, você normalmente aproxima o posterior sobre variáveis latentes e/ou parâmetros, resultando em um análogo bayesiano do EM.

A inferência variacional de campo médio para alguns modelos pode se parecer com EM, mas a inferência variacional retorna distribuições (posteriors) em vez de estimativas pontuais.

Inferência variacional em VAEs

Um Autoencoder Variacional (Variational Autoencoder, VAE) é um modelo generativo profundo em que:

  • (p_\theta(x\mid z)) é um decodificador (decoder) (rede neural)
  • (p(z)) é um prior simples (frequentemente Normal padrão)
  • (q_\phi(z\mid x)) é um codificador (encoder) (rede de inferência (inference network)) que amortiza a inferência

O objetivo do VAE para cada ponto de dados (x) é o ELBO:

[ \mathbb{E}{q\phi(z\mid x)}[\log p_\theta(x\mid z)] - \mathrm{KL}(q_\phi(z\mid x)|p(z)) ]

Isso é inferência variacional com uma reviravolta importante:

  • Em vez de otimizar parâmetros variacionais (\lambda) separados por ponto de dados, aprendemos um modelo de inferência amortizada (amortized inference) (q_\phi(z\mid x)) que prevê parâmetros variacionais a partir de (x).

Isso permite inferência eficiente em tempo de teste com uma única passagem direta (forward pass) pelo codificador.

Considerações práticas, diagnósticos e armadilhas comuns

A direção da KL importa: \(\mathrm{KL}(q\|p)\) é “busca de modo (mode-seeking)”

A inferência variacional padrão minimiza (\mathrm{KL}(q|p)), o que penaliza colocar massa onde (p) tem pouca massa, mas é mais tolerante a deixar de cobrir regiões onde (p) tem massa. Na prática, isso frequentemente produz:

  • Bom ajuste a um modo dominante
  • Variância a posteriori subestimada
  • Dificuldade em cobrir múltiplos modos (posteriors multimodais)

Alternativas (por exemplo, KL reverso (reverse KL), (\alpha)-divergências ((\alpha)-divergences)) existem, mas são menos “padrão” e podem ser mais difíceis de otimizar.

Ótimos locais e inicialização

O ELBO frequentemente é não convexo (non-convex) em (\lambda). Uma boa inicialização pode fazer diferença:

  • inicializar próximo do prior
  • usar inicialização aquecida (warm-start) a partir de modelos mais simples
  • usar múltiplas reinicializações aleatórias para modelos sensíveis

Monitoramento e avaliação

O ELBO é útil, mas não garante incerteza correta. Ferramentas comuns de avaliação incluem:

  • Verificações preditivas a posteriori (posterior predictive checks): simular a partir de (p(x_{\text{new}}\mid z)) com (z\sim q(z)) e comparar com dados reais
  • Log-verossimilhança em conjunto de validação (held-out log likelihood) (quando estimável)
  • Métricas de calibração (calibration metrics) e regras de pontuação próprias (proper scoring rules) (veja Regras de Pontuação Próprias (Proper Scoring Rules))
  • Correções por amostragem por importância (importance sampling) (por exemplo, PSIS-LOO em fluxos de trabalho bayesianos) para avaliar a qualidade da aproximação

Quando campo médio é grosseiro demais

Se correlações importam (por exemplo, modelos hierárquicos, forte acoplamento de parâmetros), considere:

  • famílias variacionais estruturadas
  • parametrizações mais ricas do posterior (Gaussianas com covariância completa (full-covariance Gaussians))
  • fluxos normalizantes
  • abordagens híbridas (inicialização por inferência variacional + refinamento curto com MCMC)

Onde a inferência variacional é usada

A inferência variacional é um mecanismo geral de inferência e aparece em ML e IA:

  • Modelos de tópicos (topic models) e modelos bayesianos de texto (uso clássico de campo médio + SVI)
  • Modelos lineares generalizados (generalized linear models) bayesianos (por exemplo, regressão logística (logistic regression)) com BBVI
  • Aprendizado profundo bayesiano (Bayesian deep learning) (posteriors aproximados sobre pesos ou variáveis latentes)
  • Sistemas de programação probabilística que oferecem backends automáticos de inferência variacional (veja Programação Probabilística (Probabilistic Programming))
  • Modelagem generativa com variáveis latentes (latent-variable generative modeling), especialmente VAEs e extensões
  • Comparação aproximada de modelos (approximate model comparison) via ELBO como proxy para a evidência (relacionado a Critérios de Informação (Information Criteria), embora não sejam a mesma coisa)

Resumo

  • A inferência variacional transforma a aproximação do posterior bayesiano em um problema de otimização.
  • Você escolhe uma família tratável (q_\lambda(z)) e minimiza (\mathrm{KL}(q|p)), ou, de forma equivalente, maximiza o ELBO.
  • A inferência variacional de campo médio simplifica o posterior por fatoração e pode ser otimizada por ascensão por coordenadas em modelos conjugados.
  • A inferência variacional estocástica escala a inferência variacional com minilotes e otimização estocástica.
  • A inferência variacional caixa-preta usa estimadores genéricos de gradiente (função de pontuação ou reparametrização) para lidar com modelos não conjugados.
  • A inferência variacional tipicamente é mais rápida e mais escalável do que MCMC, mas pode ser enviesada e subestimar a incerteza — especialmente com famílias restritivas.
  • VAEs são uma aplicação moderna proeminente da inferência variacional, usando inferência amortizada e o truque de reparametrização.

Se você já está confortável com a regra de Bayes e objetivos baseados em expectativa, a inferência variacional é uma ponte natural entre modelagem probabilística e o aprendizado de máquina moderno centrado em otimização.