Nx.Vulkan.ChainShaderSpecs (nx_vulkan v0.1.0)

Copy Markdown View Source

Catalog of family specs for templated chain-shader synthesis.

Each spec is a %ShaderTemplate.FamilySpec{} describing how to fill the holes in Nx.Vulkan.ShaderTemplate's GLSL skeleton.

Phase 1 ships:

  • :beta — Beta(α, β) on logit-unconstrained space (q ∈ ℝ → q* = sigmoid(q))
  • :gamma — Gamma(α, β) on log-unconstrained space (q ∈ ℝ → q* = exp(q))
  • :lognormal — Lognormal(μ, σ) on log-unconstrained space.
               **Mathematically reduces to Normal(μ, σ) on q_uc**
               because the Jacobian `log|dq/dq_uc| = q_uc` exactly
               cancels the `-log(q)` = `-q_uc` term in log p(q).
               Shipped to prove the template handles same-skeleton-
               different-name cases. Use the hand-written Normal
               shader in production unless you specifically need
               :lognormal as a label.

Existing 6 hand-written shaders (Normal, Exponential, StudentT, Cauchy, HalfNormal, Weibull) are NOT replicated here; they live as literal SPV files in nx_vulkan/priv/shaders/. Goal of Phase 1 is to demonstrate a NEW shader synthesized from scratch.

Summary

Functions

Beta(α, β) on logit-unconstrained space.

Push-constant byte layout for Beta. Packed as the C struct

Gamma(α, β) on log-unconstrained space.

Push for Gamma: same 24-byte layout as Beta.

Lognormal(μ, σ) on log-unconstrained space.

Push for Lognormal: 24 bytes, layout {n, K, eps, mu, sigma, logp_const}. logp_const = -0.5·log(2π σ²) (per-element constant absorbed at host).

Functions

beta()

Beta(α, β) on logit-unconstrained space.

Unconstrained: q_uc = logit(q), q = sigmoid(q_uc) ∈ (0, 1).

log p(q | α, β) = (α-1) log q + (β-1) log(1-q) - log B(α, β) log p(q_uc) = log p(q) + log|dq/dq_uc| = log p(q) + log(q (1-q))

        = α log q + β log(1-q) - log B(α, β)

d/dq_uc log p(q_uc) = α (1-q) - β q = α - (α+β) q

beta_push(n, k, eps, alpha, beta, logp_const)

Push-constant byte layout for Beta. Packed as the C struct:

uint  n; uint K; float eps; float alpha; float beta_param; float logp_const;

Total: 24 bytes.

Caller computes logp_const = lgamma(α+β) - lgamma(α) - lgamma(β) and passes it explicitly — keeps lgamma out of nx_vulkan's dependency surface.

Returns binary suitable for the shim's push-constants pointer.

gamma()

Gamma(α, β) on log-unconstrained space.

Unconstrained: q_uc = log(q), q = exp(q_uc) ∈ (0, ∞).

log p(q | α, β) = α·log β + (α-1) log q - β q - lgamma(α) log p(q_uc) = log p(q) + log|dq/dq_uc| = log p(q) + q_uc

        = α·log β + α·q_uc - β·exp(q_uc) - lgamma(α)

d/dq_uc log p(q_uc) = α - β·exp(q_uc)

Push fields share Beta's layout (alpha, beta_param, logp_const). The logp_const captures the q-independent normalizer α·log β - lgamma(α).

gamma_push(n, k, eps, alpha, beta, logp_const)

Push for Gamma: same 24-byte layout as Beta.

Caller computes logp_const = α·log(β) - lgamma(α) and passes it explicitly — keeps lgamma out of nx_vulkan's dependency surface.

lognormal()

Lognormal(μ, σ) on log-unconstrained space.

Unconstrained: q_uc = log(q), q = exp(q_uc) ∈ (0, ∞).

log p(q | μ, σ) = -log q - 0.5 log(2π σ²) - (log q - μ)² / (2σ²) log p(q_uc) = log p(q) + q_uc

        = -0.5 log(2π σ²) - (q_uc - μ)² / (2σ²)

This is literally Normal(μ, σ) on q_uc. The -log q = -q_uc term and the Jacobian +q_uc cancel exactly. The grad and logp expressions below are identical to what a Normal-on-q_uc spec would produce.

d/dq_uc log p(q_uc) = -(q_uc - μ) / σ²

Shipped as a working template-driven shader to demonstrate the template handles this family. In production use the hand-written Normal shader (binary-identical math, vendored .spv).

lognormal_push(n, k, eps, mu, sigma)

Push for Lognormal: 24 bytes, layout {n, K, eps, mu, sigma, logp_const}. logp_const = -0.5·log(2π σ²) (per-element constant absorbed at host).