Generative Adversarial Network
Table of Content
1. Generative Adversarial Networks (GAN) 1.1. GAN Loss Function 1.2. Example 1.3. Evaluation
1. Generative Adversarial Networks (GAN)Let's say we have about 200,000 satellite images and we want to improve their quality. We will capture: coastlines, ports, cities, farms, mountains, oceans, suburbs.The labels will be equivalent aerial images corresponding to each satellite image.The aerial images have 4x the resolution of satellite images (e.g. a 2x2 section would expand to 4x4 section).Note: The aerial images have to be taken from the same exact location (and preferably at the same day time).We'll use satellite images for creating features.What model could we use for training this problem?We can use a type of NN models called Generative Adversarial Networks (GAN).GANs are just two NNs in themselves. One NN is called the generator.The generator will take in a low resolution image (i.e. satellite images), and create its best guess for what the high resolution image (i.e. aerial image) would be of the equivalent low resolution image.The second part of the GAN is called the discriminator. The discriminator will take in real aerial images and the aerial image estimates (produced by the generator), and it will decide if the input image is real or fake.Essentially, the generator is working to confuse the discriminator and the discriminator is working its best to differentiate between the generated images and the real images. How does GANs get trained?Feed the low-res image to the generator to produce an estimated hi-res image.Feed the hi-res image to the discriminator.Propagate the discriminator loss all the way back the generator.* Here, we're fixing the weights of the discriminator, and we're only updating the generator in accordance to the loss generated from its attempt to fool the discriminator. 1.1. GAN Loss FunctionDiscriminator's Loss: max(1mm[logD(x)+log(1-D(G(z)))])Generator's Loss: min(1mmlog(1-D(G(z))))Combined GAN's Loss Function (Adversarial Min-Max): minGmaxD(1mm[logD(x)+log(1-D(G(z)))]) How do we know when this loss function has converged?When the accuracy of the discriminator drops to around 50%, i.e. it can't do better than the random chance (that means the generator could successfully fool the discriminator). 1.2. ExampleGoing back to our satellite example: The generator is going to be a CNN. Input → Low-res satellite imagesOutput → estimated aerial images (hi-res)CNN has to do: * Upsample → 4x resolution* Pixel Shuffle* Residual connections (to avoid vanishing gradients)The discriminator is also a CNN with a sigmoid binary classifier as the end.* We're using Leaky ReLU.Mini-batch size → 16* Note: The batch size is a bit smaller than usual. The discriminator goes through a few iterations of training before we allow the generator to start learning from the discriminator's decisions. Keeping the mini batches smaller means that the discriminator will only get 16 images to train on before the generator gets to jump in and begin training as well. This is so the discriminator doesn't out-learn the generator, such that the discriminator doesn't supply any any valuable feedback to the generator because it's already learned so much that the gradients will be smallThe generator and the discriminator have to learn together.Adjust the learning rate:* We can also adjust (lower) the learning rate of the discriminator to make sure that the generator can still keep up and learn together with it.Mode collapse* Mode collapse is when the generator figures out an image which can fool the discriminator and it just continues to output that same exact image since it's found out how to fool it.* This can be avoided by using unrolled GAN → it allows the generator to see what the discriminator will look like in a few more steps ahead of it such that the generator is encouraged not to learn some local exploitation of the discriminator.* It now has to account for what the discriminator will also look for in the future. 1.3. EvaluationIn our test set, if we have the true labels (e.g. hi-res image in the example above), we can just take the MSE per pixel.Another thing to look at is Peak Signal to Noise Ratio (PSNR)20.log(pixelmax)-10.log(MSE) →This is a gauge of how noisy the produced image is → The smaller the MSE the better the PSNR.We can also use human raters to determine if they can tell the difference between our generated images vs. the real image. Back to Top