Generative Adversarial Network framework.
Provides building blocks for GAN architectures including standard GAN, WGAN (Wasserstein), and conditional GAN variants.
Architecture
Noise z ~ N(0, I) Real data x
| |
v v
+----------+ +----------+
| Generator| |Discrimin.|
| G(z) -> x'| | D(x) -> p|
+----------+ +----------+
| |
v v
Fake samples Real/Fake scoreTraining
GANs use adversarial training:
- Discriminator: maximize log(D(x)) + log(1 - D(G(z)))
- Generator: maximize log(D(G(z))) (or minimize -log(D(G(z))))
Usage
{generator, discriminator} = GAN.build(
latent_size: 128,
output_size: 784,
generator_sizes: [256, 512],
discriminator_sizes: [512, 256]
)
Summary
Functions
Build generator and discriminator networks.
Build a conditional generator.
Build the discriminator network.
Build the generator network.
Standard GAN discriminator loss.
Standard GAN generator loss (non-saturating).
Gradient penalty for WGAN-GP.
Wasserstein GAN discriminator (critic) loss.
Wasserstein GAN generator loss.
Types
@type build_opt() :: {:output_size, pos_integer()} | {:latent_size, pos_integer()} | {:generator_sizes, [pos_integer()]} | {:discriminator_sizes, [pos_integer()]} | {:activation, atom()} | {:output_activation, atom()}
Options for build/1.
Functions
Build generator and discriminator networks.
Options
:latent_size- Size of noise vector z (default: 128):output_size- Size of generated output (required):generator_sizes- Hidden layer sizes for G (default: [256, 512]):discriminator_sizes- Hidden layer sizes for D (default: [512, 256]):activation- Activation function (default: :relu):output_activation- Generator output activation (default: :sigmoid)
Returns
Tuple of {generator, discriminator} Axon models.
Build a conditional generator.
Takes both noise z and conditioning label y.
Build the discriminator network.
Maps data to a real/fake probability.
Build the generator network.
Maps latent noise z to data space.
Standard GAN discriminator loss.
L_D = -mean(log(D(real))) - mean(log(1 - D(fake)))
Standard GAN generator loss (non-saturating).
L_G = -mean(log(D(G(z))))
Gradient penalty for WGAN-GP.
Penalizes gradients that deviate from unit norm along interpolations between real and fake data.
Requires a PRNG key for sampling random interpolation coefficients.
Wasserstein GAN discriminator (critic) loss.
L_D = mean(D(fake)) - mean(D(real))
Wasserstein GAN generator loss.
L_G = -mean(D(G(z)))