Edifice.Generative.GAN (Edifice v0.2.0)

Copy Markdown View Source

Generative Adversarial Network framework.

Provides building blocks for GAN architectures including standard GAN, WGAN (Wasserstein), and conditional GAN variants.

Architecture

Noise z ~ N(0, I)          Real data x
     |                          |
     v                          v
+----------+              +----------+
| Generator|              |Discrimin.|
| G(z) -> x'|             | D(x) -> p|
+----------+              +----------+
     |                          |
     v                          v
Fake samples              Real/Fake score

Training

GANs use adversarial training:

  • Discriminator: maximize log(D(x)) + log(1 - D(G(z)))
  • Generator: maximize log(D(G(z))) (or minimize -log(D(G(z))))

Usage

{generator, discriminator} = GAN.build(
  latent_size: 128,
  output_size: 784,
  generator_sizes: [256, 512],
  discriminator_sizes: [512, 256]
)

Summary

Types

Options for build/1.

Functions

Build generator and discriminator networks.

Build a conditional generator.

Build the discriminator network.

Build the generator network.

Standard GAN discriminator loss.

Standard GAN generator loss (non-saturating).

Wasserstein GAN discriminator (critic) loss.

Wasserstein GAN generator loss.

Types

build_opt()

@type build_opt() ::
  {:output_size, pos_integer()}
  | {:latent_size, pos_integer()}
  | {:generator_sizes, [pos_integer()]}
  | {:discriminator_sizes, [pos_integer()]}
  | {:activation, atom()}
  | {:output_activation, atom()}

Options for build/1.

Functions

build(opts \\ [])

@spec build([build_opt()]) :: {Axon.t(), Axon.t()}

Build generator and discriminator networks.

Options

  • :latent_size - Size of noise vector z (default: 128)
  • :output_size - Size of generated output (required)
  • :generator_sizes - Hidden layer sizes for G (default: [256, 512])
  • :discriminator_sizes - Hidden layer sizes for D (default: [512, 256])
  • :activation - Activation function (default: :relu)
  • :output_activation - Generator output activation (default: :sigmoid)

Returns

Tuple of {generator, discriminator} Axon models.

build_conditional_generator(opts \\ [])

@spec build_conditional_generator(keyword()) :: Axon.t()

Build a conditional generator.

Takes both noise z and conditioning label y.

build_discriminator(opts \\ [])

@spec build_discriminator(keyword()) :: Axon.t()

Build the discriminator network.

Maps data to a real/fake probability.

build_generator(opts \\ [])

@spec build_generator(keyword()) :: Axon.t()

Build the generator network.

Maps latent noise z to data space.

discriminator_loss(real_scores, fake_scores)

Standard GAN discriminator loss.

L_D = -mean(log(D(real))) - mean(log(1 - D(fake)))

generator_loss(fake_scores)

Standard GAN generator loss (non-saturating).

L_G = -mean(log(D(G(z))))

gradient_penalty(real_data, fake_data, critic_fn, params, key)

Gradient penalty for WGAN-GP.

Penalizes gradients that deviate from unit norm along interpolations between real and fake data.

Requires a PRNG key for sampling random interpolation coefficients.

wasserstein_critic_loss(real_scores, fake_scores)

Wasserstein GAN discriminator (critic) loss.

L_D = mean(D(fake)) - mean(D(real))

wasserstein_generator_loss(fake_scores)

Wasserstein GAN generator loss.

L_G = -mean(D(G(z)))