GAN生成对抗网络Tutorial

本文转载于博客园,感觉写得比较清晰,保存一下供以后查看。

GAN生成式对抗网络(一)——原理

生成式对抗网络(GAN, Generative Adversarial Networks )是一种深度学习模型

GAN包括两个核心模块

  • 1.生成器模块 –generator
  • 2.判别器模块–desciminator

GAN通俗原理解释

为了通俗的解释GAN原理,可以类比为伪造货币的例子(这个比方纯粹为了解释)
现在有个伪造货币的任务。

你有一堆真实的货币,一个可以不断提高鉴别能力的鉴定货币真伪的设备,还有一个可以提高伪造能力的伪造货币的设备。

1.我们继续不断的强化鉴定设备的 鉴定能力,尽全力让他能将真币识别为真币,将价比识别为价币。(鉴定结果是一个0到1之间的概率。越接近0,说明鉴定结果越是假币)

2.我们让伪造设备不断的伪造假币,将假币真币混合在一起,交给鉴定设备鉴定。根据鉴定结果(概率),我们不断改善伪造设备,使伪造的假币被鉴定为真的概率持续提高。
现在形成了矛与盾的局面。一个伪造货币设备,和鉴定货币真伪设备的持续较量,两者都不断的从对抗中吸取经验、教训,提高自己。

两者不断的对抗,两者的能力都持续不断的提高,最终我们得到了一个货币鉴定专家,一个伪造货币天才,而且这个伪造货币天才,学习能力超级强。将它制造的假币和真币混在一起之后,我们这个鉴定专家,已经区分不出来,都认为是真的货币 了。

那么,现在伪造货币设备伪造的货币,在市面上就可以认为是真的了。因为,我们那个高级的鉴别设备,都已经无法区分他是否是真的,更不要说其他普通的鉴定设备了。

GAN原理总结

如上所述,GAN生成式对抗网络的原理即:在一个不断提高判断能力的判断器的持续反馈下,不断改善生成器的生成参数,直到生成器生成的结果能够通过判断器的判断。


GAN生成式对抗网络(二)——tensorflow代码示例

代码实现

当初学习时,主要学习的这个博客 https://xyang35.github.io/2017/08/22/GAN-1/ ,写的挺好的。

1. 本文目的,用GAN实现最简单的例子,帮助认识GAN算法

import numpy as np
from matplotlib import pyplot as plt
batch_size = 4

2. 真实数据集,我们要通过GAN学习这个数据集,然后生成和他分布规则一样的数据集

X = np.random.normal(size=(1000, 2))
A = np.array([[1, 2], [-0.1, 0.5]])
b = np.array([1, 2])
X = np.dot(X, A) + b

plt.scatter(X[:, 0], X[:, 1])
plt.show()

# 等会通过这个函数,不断从中取x值,取值数量为batch_size
def iterate_minibatch(x, batch_size, shuffle=True):
    indices = np.arange(x.shape[0])
    if shuffle:
        np.random.shuffle(indices)

    for i in range(0, x.shape[0], batch_size):
        yield x[indices[i:i + batch_size], :]

3.封装GAN对象

包含生成器,判别器

class GAN(object):
    def __init__(self):
        #初始函数,在这里对初始化模型
    def netG(self, z):
        #生成器模型
    def netD(self, x, reuse=False):
        #判别器模型

4.生成器netG

随意输入的z,通过z*w+b的矩阵运算(全连接运算),返回结果

    def netG(self, z):
        """1-layer fully connected network"""

        with tf.variable_scope("generator") as scope:
            W = tf.get_variable(name="g_W", shape=[2, 2],
                                initializer=tf.contrib.layers.xavier_initializer(),
                                trainable=True)
            b = tf.get_variable(name="g_b", shape=[2],
                                initializer=tf.zeros_initializer(),
                                trainable=True)
            return tf.matmul(z, W) + b

5.判别器nefD

判别器为三层全连接网络。隐层部分使用tanh激活函数。输出部分没有激活函数

def netD(self, x, reuse=False):
        """3-layer fully connected network"""

        with tf.variable_scope("discriminator") as scope:
            if reuse:
                scope.reuse_variables()

            W1 = tf.get_variable(name="d_W1", shape=[2, 5],
                                 initializer=tf.contrib.layers.xavier_initializer(),
                                 trainable=True)
            b1 = tf.get_variable(name="d_b1", shape=[5],
                                 initializer=tf.zeros_initializer(),
                                 trainable=True)
            W2 = tf.get_variable(name="d_W2", shape=[5, 3],
                                 initializer=tf.contrib.layers.xavier_initializer(),
                                 trainable=True)
            b2 = tf.get_variable(name="d_b2", shape=[3],
                                 initializer=tf.zeros_initializer(),
                                 trainable=True)
            W3 = tf.get_variable(name="d_W3", shape=[3, 1],
                                 initializer=tf.contrib.layers.xavier_initializer(),
                                 trainable=True)
            b3 = tf.get_variable(name="d_b3", shape=[1],
                                 initializer=tf.zeros_initializer(),
                                 trainable=True)

            layer1 = tf.nn.tanh(tf.matmul(x, W1) + b1)
            layer2 = tf.nn.tanh(tf.matmul(layer1, W2) + b2)
            return tf.matmul(layer2, W3) + b3

6.初始化init函数

def __init__(self):
        # input, output
         #占位变量,等会用来保存随机产生的数,
        self.z = tf.placeholder(tf.float32, shape=[None, 2], name='z')   
        #占位变量,真实数据的
        self.x = tf.placeholder(tf.float32, shape=[None, 2], name='real_x')  

        # define the network
        #生成器,对随机变量进行加工,产生伪造的数据
        self.fake_x = self.netG(self.z)  

         #判别器对真实数据进行判别,返回判别结果
         #reuse=false,表示不是共享变量,需要tensorflow开辟变量地址
        self.real_logits = self.netD(self.x, reuse=False)  

        #判别器对伪造数据进行判别,返回判别结果
         #reuse=true,表示是共享变量,复用netD中已有的变量
        self.fake_logits = self.netD(self.fake_x, reuse=True)


        # define losses
        #判定器的损失值,将真实数据的判定为真实数据,将伪造数据的判断为伪造数据的得分情况
        self.loss_D = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=self.real_logits,
                                                                             labels=tf.ones_like(self.real_logits))) + \
                      tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=self.fake_logits,
                                                                             labels=tf.zeros_like(self.real_logits)))
        #生成器的生成分数。伪造的数据,别判断器判定为真实数据的得分情况
        self.loss_G = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=self.fake_logits,
                                                                             labels=tf.ones_like(self.real_logits)))

        # collect variables
        t_vars = tf.trainable_variables()
        #存放判别器中用到的变量
        self.d_vars = [var for var in t_vars if 'd_' in var.name]
        #存放生成器中用到的变量
        self.g_vars = [var for var in t_vars if 'g_' in var.name]

7.开始训练

gan = GAN()

#使用随机梯度下降
d_optim = tf.train.AdamOptimizer(learning_rate=0.05).minimize(gan.loss_D, var_list=gan.d_vars)
g_optim = tf.train.AdamOptimizer(learning_rate=0.01).minimize(gan.loss_G, var_list=gan.g_vars)

init = tf.global_variables_initializer()

with tf.Session() as sess:
    sess.run(init)
    #将数据循环10次
    for epoch in range(10):
        avg_loss = 0.
        count = 0
        #从真实数据集当中,随机抓取batch_size数量个值
        for x_batch in iterate_minibatch(X, batch_size=batch_size):
            # generate noise z
            #随机变量,数量为batch_size
            z_batch = np.random.normal(size=(4, 2))

            # update D network
             #将拿到的真实数据值和随机生成的数值,喂养给sess,并bp优化一次
            loss_D, _ = sess.run([gan.loss_D, d_optim],
                                 feed_dict={
                                     gan.z: z_batch,
                                     gan.x: x_batch,
                                 })

            # update G network
            loss_G, _ = sess.run([gan.loss_G, g_optim],
                                 feed_dict={
                                     gan.z: z_batch,
                                     gan.x: np.zeros(z_batch.shape),  # dummy input
                                 })

            avg_loss += loss_D
            count += 1

        avg_loss /= count
        #每一个epoch都展示一次生成效果
        z = np.random.normal(size=(100, 2))
        # 随机生成100个数值,0到1000---用来从真实值里面取数据
        excerpt = np.random.randint(1000, size=1000)
        fake_x, real_logits, fake_logits = sess.run([gan.fake_x, gan.real_logits, gan.fake_logits],
                                                    feed_dict={gan.z: z, gan.x: X[excerpt, :]})
        accuracy = 0.5 * (np.sum(real_logits > 0.5) / 100. + np.sum(fake_logits < 0.5) / 100.)
        print('\ndiscriminator loss at epoch %d: %f' % (epoch, avg_loss))
        print('\ndiscriminator accuracy at epoch %d: %f' % (epoch, accuracy))
        plt.scatter(X[:, 0], X[:, 1])
        plt.scatter(fake_x[:, 0], fake_x[:, 1])
        plt.show()




GAN生成式对抗网络(三)——mnist数据生成

通过GAN生成式对抗网络,产生mnist数据

引入包,数据约定等

import numpy as np
import matplotlib.pyplot as plt
import input_data  #读取数据的一个工具文件,不影响理解
import tensorflow as tf

# 获取数据
mnist = input_data.read_data_sets('data/', one_hot=True)
trainimg = mnist.train.images

X = mnist.train.images[:, :]
batch_size = 64

#用来返回真实数据
def iterate_minibatch(x, batch_size, shuffle=True):
    indices = np.arange(x.shape[0])
    if shuffle:
        np.random.shuffle(indices)
    for i in range(0, x.shape[0]-1000, batch_size):
        temp = x[indices[i:i + batch_size], :]
        temp = np.array(temp) * 2 - 1
        yield np.reshape(temp, (-1, 28, 28, 1))

GAN对象结构

class GAN(object):
    def __init__(self):
        #初始函数,在这里对初始化模型
    def netG(self, z):
        #生成器模型
    def netD(self, x, reuse=False):
        #判别器模型

生成器函数

对随机值z(维度为1,100),进行包装,伪造,产生伪造数据。

包装过程概括为:全连接->reshape->反卷积

包装过程中使用了batch_normalization,Leaky ReLU,dropout,tanh等技巧

#对随机值z(维度为1,100),进行包装,伪造,产生伪造数据。
    #包装过程概括为:全连接->reshape->反卷积
    #包装过程中使用了batch_normalization,Leaky ReLU,dropout,tanh等技巧
    def netG(self,z,alpha=0.01):
        with tf.variable_scope('generator') as scope:
            layer1 = tf.layers.dense(z, 4 * 4 * 512)  # 这是一个全连接层,输出 (n,4*4*512)
            layer1 = tf.reshape(layer1, [-1, 4, 4, 512])
            # batch normalization
            layer1 = tf.layers.batch_normalization(layer1, training=True)  # 做BN标准化处理
            # Leaky ReLU
            layer1 = tf.maximum(alpha * layer1, layer1)
            # dropout
            layer1 = tf.nn.dropout(layer1, keep_prob=0.8)

            # 4 x 4 x 512 to 7 x 7 x 256
            layer2 = tf.layers.conv2d_transpose(layer1, 256, 4, strides=1, padding='valid')
            layer2 = tf.layers.batch_normalization(layer2, training=True)
            layer2 = tf.maximum(alpha * layer2, layer2)
            layer2 = tf.nn.dropout(layer2, keep_prob=0.8)

            # 7 x 7 256 to 14 x 14 x 128
            layer3 = tf.layers.conv2d_transpose(layer2, 128, 3, strides=2, padding='same')
            layer3 = tf.layers.batch_normalization(layer3, training=True)
            layer3 = tf.maximum(alpha * layer3, layer3)
            layer3 = tf.nn.dropout(layer3, keep_prob=0.8)

            # 14 x 14 x 128 to 28 x 28 x 1
            logits = tf.layers.conv2d_transpose(layer3, 1, 3, strides=2, padding='same')
            # MNIST原始数据集的像素范围在0-1,这里的生成图片范围为(-1,1)
            # 因此在训练时,记住要把MNIST像素范围进行resize
            outputs = tf.tanh(logits)

            return outputs

判别器函数

通过深度卷积+全连接的形式,判别器将输入分类为真数据,还是假数据。

    def netD(self, x, reuse=False,alpha=0.01):
        with tf.variable_scope('discriminator') as scope:
            if reuse:
                scope.reuse_variables()
            layer1 = tf.layers.conv2d(x, 128, 3, strides=2, padding='same')
            layer1 = tf.maximum(alpha * layer1, layer1)
            layer1 = tf.nn.dropout(layer1, keep_prob=0.8)

            # 14 x 14 x 128 to 7 x 7 x 256
            layer2 = tf.layers.conv2d(layer1, 256, 3, strides=2, padding='same')
            layer2 = tf.layers.batch_normalization(layer2, training=True)
            layer2 = tf.maximum(alpha * layer2, layer2)
            layer2 = tf.nn.dropout(layer2, keep_prob=0.8)

            # 7 x 7 x 256 to 4 x 4 x 512
            layer3 = tf.layers.conv2d(layer2, 512, 3, strides=2, padding='same')
            layer3 = tf.layers.batch_normalization(layer3, training=True)
            layer3 = tf.maximum(alpha * layer3, layer3)
            layer3 = tf.nn.dropout(layer3, keep_prob=0.8)

            # 4 x 4 x 512 to 4*4*512 x 1
            flatten = tf.reshape(layer3, (-1, 4 * 4 * 512))
            f = tf.layers.dense(flatten, 1)
            return f

初始化函数

有一个前置训练,将真实数据喂给判别器,训练判别器的鉴别能力

    # 有一个前置训练,将真实数据喂给判别器,训练判别器的鉴别能力
    def __init__(self):
        self.z = tf.placeholder(tf.float32, shape=[batch_size, 100], name='z')  # 随机输入值
        self.x = tf.placeholder(tf.float32, shape=[batch_size, 28, 28, 1], name='real_x')  # 图片值

        self.fake_x = self.netG(self.z)  # 将随机输入,包装为伪造图片值

        self.pre_logits = self.netD(self.x, reuse=False)  # 判别器预训练时,判别器对真实数据的判别情况-未sigmoid处理
        self.real_logits = self.netD(self.x, reuse=True)  # 判别器对真实数据的判别情况-未sigmoid处理
        self.fake_logits = self.netD(self.fake_x, reuse=True)  # 判别器对伪造数据的判别情况-未sigmoid处理

        # 预训练时判别器,判别器将真实数据判定为真的得分情况。
        self.loss_pre_D = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=self.pre_logits,
                                                                                 labels=tf.ones_like(self.pre_logits)))
        # 训练时,判别器将真实数据判定为真,将伪造数据判定为假的得分情况。
        self.loss_D = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=self.real_logits,
                                                                             labels=tf.ones_like(self.real_logits))) + \
                      tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=self.fake_logits,
                                                                             labels=tf.zeros_like(self.fake_logits)))
        # 训练时,生成器伪造的数据,被判定为真实数据的得分情况。
        self.loss_G = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=self.fake_logits,
                                                                             labels=tf.ones_like(self.fake_logits)))

        # 获取生成器和判定器对应的变量地址,用于更新变量
        t_vars = tf.trainable_variables()
        self.g_vars = [var for var in t_vars if var.name.startswith("generator")]
        self.d_vars = [var for var in t_vars if var.name.startswith("discriminator")]

开始训练

gan = DCGAN()
#预训练时的梯度优化函数
d_pre_optim = tf.train.AdamOptimizer(learning_rate=0.001, beta1=0.4).minimize(gan.loss_pre_D, var_list=gan.d_vars)
#判别器的梯度优化函数
d_optim = tf.train.AdamOptimizer(learning_rate=0.001, beta1=0.4).minimize(gan.loss_D, var_list=gan.d_vars)
#预训练时的梯度优化函数
g_optim = tf.train.AdamOptimizer(learning_rate=0.001, beta1=0.4).minimize(gan.loss_G, var_list=gan.g_vars)

init = tf.global_variables_initializer()

with tf.Session() as sess:
    sess.run(init)
    #对判别器的预训练,训练了两个epoch
    for i in range(2):
        print('判别器初始训练,第' + str(i) + '次包')
        for x_batch in iterate_minibatch(X, batch_size=batch_size):
            loss_pre_D, _ = sess.run([gan.pre_logits, d_pre_optim],
                                     feed_dict={
                                         gan.x: x_batch
                                     })
    #训练5个epoch
    for epoch in range(5):
        print('对抗' + str(epoch) + '次包')
        avg_loss = 0
        count = 0
        for x_batch in iterate_minibatch(X, batch_size=batch_size):
            z_batch = np.random.uniform(-1, 1, size=(batch_size, 100))  # 随机起点值

            loss_D, _ = sess.run([gan.loss_D, d_optim],
                                 feed_dict={
                                     gan.z: z_batch,
                                     gan.x: x_batch
                                 })

            loss_G, _ = sess.run([gan.loss_G, g_optim],
                                 feed_dict={
                                     gan.z: z_batch,
                                     # gan.x: np.zeros(z_batch.shape)
                                 })

            avg_loss += loss_D
            count += 1

        # 显示预测情况
        if True:
            avg_loss /= count
            z = np.random.normal(size=(batch_size, 100))
            excerpt = np.random.randint(100, size=batch_size)
            needTest = np.reshape(X[excerpt, :], (-1, 28, 28, 1))
            fake_x, real_logits, fake_logits = sess.run([gan.fake_x, gan.real_logits, gan.fake_logits],
                                                        feed_dict={gan.z: z, gan.x: needTest})
            # accuracy = (np.sum(real_logits > 0.5) + np.sum(fake_logits < 0.5)) / (2 * batch_size)
            print('real_logits')
            print(len(real_logits))
            print('fake_logits')
            print(len(fake_logits))
            print('\ndiscriminator loss at epoch %d: %f' % (epoch, avg_loss))
            # print('\ndiscriminator accuracy at epoch %d: %f' % (epoch, accuracy))
            print('----')
            print()

            # curr_img = np.reshape(trainimg[i, :], (28, 28))  # 28 by 28 matrix
            curr_img = np.reshape(fake_x[0], (28, 28))
            plt.matshow(curr_img, cmap=plt.get_cmap('gray'))
            plt.show()
            curr_img2 = np.reshape(fake_x[10], (28, 28))
            plt.matshow(curr_img2, cmap=plt.get_cmap('gray'))
            plt.show()
            curr_img3 = np.reshape(fake_x[20], (28, 28))
            plt.matshow(curr_img3, cmap=plt.get_cmap('gray'))
            plt.show()

            curr_img4 = np.reshape(fake_x[30], (28, 28))
            plt.matshow(curr_img4, cmap=plt.get_cmap('gray'))
            plt.show()

            curr_img5 = np.reshape(fake_x[40], (28, 28))
            plt.matshow(curr_img5, cmap=plt.get_cmap('gray'))
            plt.show()
            # plt.figure(figsize=(28, 28))

            # plt.title("" + str(i) + "th Training Data "
            #           + "Label is " + str(curr_label))
            # print("" + str(i) + "th Training Data "
            #       + "Label is " + str(curr_label))

            # plt.scatter(X[:, 0], X[:, 1])
            # plt.scatter(fake_x[:, 0], fake_x[:, 1])
            # plt.show()

结果

GAN生成式对抗网络(四)——SRGAN超高分辨率图片重构

论文pdf 地址:https://arxiv.org/pdf/1609.04802v1.pdf

实际效果

  • 清晰度距离我的期待有距离。
  • 颜色上面存在差距。
  • 解决想法
  • 增加一个颜色判别器。将颜色值反馈给生成器

srgan论文是建立在gan基础上的,利用gan生成式对抗网络,将图片重构为高清分辨率的图片。

github上有开源的srgan项目。由于开源者,开发时考虑的问题更丰富,技巧更为高明,导致其代码都比较难以阅读和理解。

在为了充分理解这个论文。这里结合论文,开源代码,和自己的理解重新写了个srgan高清分辨率模型。

GAN原理

在一个不断提高判断能力的判断器的持续反馈下,不断改善生成器的生成参数,直到生成器生成的结果能够通过判断器的判断。(见本博客其他文章)

SRGAN用到的模块,及其关系


损失值,根据的这个关系结构计算的。

注意:vgg19是使用已经训练好的模型,这里只是拿来提取特征使用,

对于生成器,根据三个运算结果数据,进行随机梯度的优化调整

  • ①判定器生成数据的鉴定结果
  • ②vgg19的特征比较情况
  • ③生成图形与理想图形的mse差距

论文中,生成器和判别器的模型图


生成器结构为:一层卷积,16层残差卷积,再将第一层卷积结果+16层残差结,卷积+2倍反卷积,卷积+2倍反卷积,tanh缩放,产生生成结果。

判别器结构为:8层卷积+reshape,全连接。(论文中,用了两层。我这里只用了一层全连接,参数量太大,我6G 的gpu内存不够用)

vgg19结构:在vgg19的第四层,返回获取到的特征结果,进行MSE对比

注意:BN处理,leaky relu等等处理技巧

代码解释

import numpy as np
import os
import tensorlayer as tl
import tensorflow as tf

#获取vgg9.npy中vgg19的参数, 
vgg19_npy_path = "./vgg19.npy"
if not os.path.isfile(vgg19_npy_path):
    print("Please download vgg19.npz from : https://github.com/machrisaa/tensorflow-vgg")
    exit()
npz = np.load(vgg19_npy_path, encoding='latin1').item()
w_params = []
b_params = []
for val in sorted(npz.items()):
    W = np.asarray(val[1][0])
    b = np.asarray(val[1][1])
    # print("  Loading %s: %s, %s" % (val[0], W.shape, b.shape))
    w_params.append(W, )
    b_params.extend(b)


#tensorlayer加载图片时,用于处理图片。随机获取图片中 192*192的矩阵, 内存不足时,可以优化这里
def crop_sub_imgs_fn(x, is_random=True):
    x = tl.prepro.crop(x, wrg=192, hrg=192, is_random=is_random)
    x = x / (255. / 2.)
    x = x - 1.
    return x
#resize矩阵 内存不足时,可以优化这里
def downsample_fn(x):
    x = tl.prepro.imresize(x, size=[48, 48], interp='bicubic', mode=None)
    x = x / (255. / 2.)
    x = x - 1.
    return x

# 参数
config = {
    "epoch": 5,
}

# 内存不够时,可以减小这个
batch_size = 10 


class SRGAN(object):
    def __init__(self):
        # with tf.device('/gpu:0'):
        #占位变量,存储需要重构的图片
        self.x = tf.placeholder(tf.float32, shape=[batch_size, 48, 48, 3], name='train_bechanged')
        #占位变量,存储需要学习的理想中的图片
        self.y = tf.placeholder(tf.float32, shape=[batch_size, 192, 192, 3], name='train_target')
        self.init_fake_y = self.generator(self.x)  # 预训练时生成的假照片
        self.fake_y = self.generator(self.x, reuse=True)  # 全部训练时生成的假照片

         #占位变量,存储需要重构的测试图片
        self.test_x = tf.placeholder(tf.float32, shape=[1, None, None, 3], name='test_generator')
        #占位变量,存储重构后的测试图片
        self.test_fake_y = self.generator(self.test_x, reuse=True)  # 生成的假照片

        #占位变量,将生成图片resize
        self.fake_y_vgg = tf.image.resize_images(
            self.fake_y, size=[224, 224], method=0,
            align_corners=False)
         #占位变量,将理想图片resize
        self.real_y_vgg = tf.image.resize_images(
            self.y, size=[224, 224], method=0,
            align_corners=False)
        #提取伪造图片的特征
        self.fake_y_feature = self.vgg19(self.fake_y_vgg)  # 假照片的特征值
        #提取理想图片的特征
        self.real_y_feature = self.vgg19(self.real_y_vgg, reuse=True)  # 真照片的特征值

        # self.pre_dis_logits = self.discriminator(self.fake_y)  # 判别器生成的预测照片的判别值
        self.fake_dis_logits = self.discriminator(self.fake_y, reuse=False)  # 判别器生成的假照片的判别值
        self.real_dis_logits = self.discriminator(self.y, reuse=True)  # 判别器生成的假照片的判别值

        # 预训练时,判别器的优化根据值
        self.init_mse_loss = tf.losses.mean_squared_error(self.init_fake_y, self.y)

        # 关于判别器的优化根据值
        self.D_loos = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=self.real_dis_logits,
                                                                             labels=tf.ones_like(
                                                                                 self.real_dis_logits))) + \
                      tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=self.fake_dis_logits,
                                                                             labels=tf.zeros_like(
                                                                                 self.fake_dis_logits)))

        # 伪造数据判别器的判断情况,生成与目标图像的差距,生成特征与理想特征的差距
        self.D_loos_Ge = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=self.fake_dis_logits, labels=tf.ones_like( self.fake_dis_logits)))
        self.mse_loss = tf.losses.mean_squared_error(self.fake_y, self.y)
        self.loss_vgg = tf.losses.mean_squared_error(self.fake_y_feature, self.real_y_feature)

        #生成器的优化根据值,上面三个值的和
        self.G_loos = 1e-3 * self.D_loos_Ge + 2e-6 * self.loss_vgg + self.mse_loss

        #获取具体条件下的更新变量集合。
        t_vars = tf.trainable_variables()
        self.g_vars = [var for var in t_vars if var.name.startswith('trainGenerator')]
        self.d_vars = [var for var in t_vars if var.name.startswith('discriminator')]



    # 生成器,16层深度残差+1层初始的深度残差+2次2倍反卷积+1个卷积
    def generator(slef, input, reuse=False):
        with tf.variable_scope('trainGenerator') as scope:
            if reuse:
                scope.reuse_variables()
            n = tf.layers.conv2d(input, 64, 3, strides=1, padding='SAME', activation=None, use_bias=True,
                                 bias_initializer=None)
            prellu_param = tf.get_variable('p_alpha', n.get_shape()[-1], initializer=tf.constant_initializer(0.0),
                                           dtype=tf.float32)
            n = tf.nn.relu(n) + prellu_param * (n - abs(n)) * 0.02
            # n = tf.nn.relu(n)
            temp = n
            # 开始深度残差网络
            for i in range(16):
                nn = tf.layers.conv2d(n, 64, 3, strides=1, padding='SAME', activation=None, use_bias=True,
                                      bias_initializer=None)
                nn = tf.layers.batch_normalization(nn, training=True)
                prellu_param = tf.get_variable('p_alpha' + str(2 * i + 1), n.get_shape()[-1],
                                               initializer=tf.constant_initializer(0.0),
                                               dtype=tf.float32)
                nn = tf.nn.relu(nn) + prellu_param * (nn - abs(nn)) * 0.02

                nn = tf.layers.conv2d(nn, 64, 3, strides=1, padding='SAME', activation=None, use_bias=True,
                                      bias_initializer=None)
                nn = tf.layers.batch_normalization(nn, training=True)
                # prellu_param = tf.get_variable('p_alpha' + str(2 * i + 2), n.get_shape()[-1],
                #                                initializer=tf.constant_initializer(0.0),
                #                                dtype=tf.float32)
                # nn = tf.nn.relu(nn) + prellu_param * (nn - abs(nn)) * 0.02
                n = nn + n

            n = tf.layers.conv2d(n, 64, 3, strides=1, padding='SAME', activation=None, use_bias=True,
                                 bias_initializer=None)
            n = tf.layers.batch_normalization(n, training=True)
            # prellu_param = tf.get_variable('p_alpha_34', n.get_shape()[-1],
            #                                initializer=tf.constant_initializer(0.0),
            #                                dtype=tf.float32)
            # n = tf.nn.relu(n) + prellu_param * (n - abs(n)) * 0.02

            #注意这里的temp,看论文里面的生成器结构图
            n = temp + n

            # 将特征还原为图
            n = tf.layers.conv2d_transpose(n, 256, 3, strides=2, padding='SAME', activation=None, use_bias=True,
                                           bias_initializer=None)

            n = tf.layers.conv2d(n, 256, 3, strides=1, padding='SAME', activation=None, use_bias=True,
                                 bias_initializer=None)
            n = tf.layers.batch_normalization(n, training=True)
            n = tf.nn.relu(n)

            n = tf.layers.conv2d_transpose(n, 256, 3, strides=2, padding='SAME', activation=None, use_bias=True,
                                           bias_initializer=None)
            n = tf.layers.conv2d(n, 256, 3, strides=1, padding='SAME', activation=None, use_bias=True,
                                 bias_initializer=None)
            n = tf.layers.batch_normalization(n, training=True)
            n = tf.nn.relu(n)

            n = tf.layers.conv2d(n, 3, 1, strides=1, padding='SAME', activation=None, use_bias=True,
                                 bias_initializer=None)
            n = tf.nn.tanh(n)
            return n


    #判别器
    def discriminator(self, input, reuse=False):
        # input   size: 384x384
        with tf.variable_scope('discriminator') as scope:
            if reuse:
                scope.reuse_variables()
            # 1
            n = tf.layers.conv2d(input, 64, 3, strides=1, padding='SAME', activation=None, use_bias=True,
                                 bias_initializer=None)
            n = tf.maximum(0.01 * n, n)
            # 2
            n = tf.layers.conv2d(n, 64, 3, strides=2, padding='SAME', activation=None, use_bias=True,
                                 bias_initializer=None)
            n = tf.layers.batch_normalization(n, training=True)
            n = tf.maximum(0.01 * n, n)

            # 3
            n = tf.layers.conv2d(n, 128, 3, strides=1, padding='SAME', activation=None, use_bias=True,
                                 bias_initializer=None)
            n = tf.layers.batch_normalization(n, training=True)
            n = tf.maximum(0.01 * n, n)

            # 4
            n = tf.layers.conv2d(n, 128, 3, strides=2, padding='SAME', activation=None, use_bias=True,
                                 bias_initializer=None)
            n = tf.layers.batch_normalization(n, training=True)
            n = tf.maximum(0.01 * n, n)

            # 5
            n = tf.layers.conv2d(n, 256, 3, strides=1, padding='SAME', activation=None, use_bias=True,
                                 bias_initializer=None)
            n = tf.layers.batch_normalization(n, training=True)
            n = tf.maximum(0.01 * n, n)

            # 6
            n = tf.layers.conv2d(n, 256, 3, strides=2, padding='SAME', activation=None, use_bias=True,
                                 bias_initializer=None)
            n = tf.layers.batch_normalization(n, training=True)
            n = tf.maximum(0.01 * n, n)

            # 7
            n = tf.layers.conv2d(n, 512, 3, strides=1, padding='SAME', activation=None, use_bias=True,
                                 bias_initializer=None)
            n = tf.layers.batch_normalization(n, training=True)
            n = tf.maximum(0.01 * n, n)

            # 8
            n = tf.layers.conv2d(n, 512, 3, strides=2, padding='SAME', activation=None, use_bias=True,
                                 bias_initializer=None)
            n = tf.layers.batch_normalization(n, training=True)
            n = tf.maximum(0.01 * n, n)

            flatten = tf.reshape(n, (input.get_shape()[0], -1))
            # 内存不够,减小全链接数量
            # f = tf.layers.dense(flatten, 1024)
            # 论文里面这里时leaky relu,这我用的dense里面自带的
            f = tf.layers.dense(flatten, 1, bias_initializer=tf.contrib.layers.xavier_initializer())

            return f
    #vgg19特征提取
    def vgg19(self, input, reuse=False):
        VGG_MEAN = [103.939, 116.779, 123.68]
        with tf.variable_scope('vgg19') as scope:
            # if reuse:
            #     scope.reuse_variables()
            # ====================
            print("build model started")
            rgb_scaled = (input + 1) * (255.0 / 2)
            # Convert RGB to BGR
            red, green, blue = tf.split(rgb_scaled, 3, 3)
            assert red.get_shape().as_list()[1:] == [224, 224, 1]
            assert green.get_shape().as_list()[1:] == [224, 224, 1]
            assert blue.get_shape().as_list()[1:] == [224, 224, 1]
            bgr = tf.concat(
                [
                    blue - VGG_MEAN[0],
                    green - VGG_MEAN[1],
                    red - VGG_MEAN[2],
                ], axis=3)
            assert bgr.get_shape().as_list()[1:] == [224, 224, 3]

            # --------------------

            n = tf.nn.conv2d(bgr, w_params[0], name='conv2_1', strides=(1, 1, 1, 1), padding='SAME')
            n = tf.add(n, b_params[0])
            n = tf.nn.relu(n)
            n = tf.nn.conv2d(n, w_params[1], name='conv2_2', strides=(1, 1, 1, 1), padding='SAME')
            n = tf.add(n, b_params[1])
            n = tf.nn.relu(n)
            n = tf.nn.max_pool(n, ksize=(1, 2, 2, 1), strides=(1, 2, 2, 1), padding='SAME')

            # return n

            # two
            n = tf.nn.conv2d(n, w_params[2], strides=(1, 1, 1, 1), padding='SAME')
            n = tf.add(n, b_params[2])
            n = tf.nn.relu(n)
            n = tf.nn.conv2d(n, w_params[3], strides=(1, 1, 1, 1), padding='SAME')
            n = tf.add(n, b_params[3])
            n = tf.nn.relu(n)
            n = tf.nn.max_pool(n, ksize=(1, 2, 2, 1), strides=(1, 2, 2, 1), padding='SAME')
            # three
            n = tf.nn.conv2d(n, w_params[4], strides=(1, 1, 1, 1), padding='SAME')
            n = tf.add(n, b_params[4])
            n = tf.nn.relu(n)
            n = tf.nn.conv2d(n, w_params[5], strides=(1, 1, 1, 1), padding='SAME')
            n = tf.add(n, b_params[5])
            n = tf.nn.relu(n)
            n = tf.nn.conv2d(n, w_params[6], strides=(1, 1, 1, 1), padding='SAME')
            n = tf.add(n, b_params[6])
            n = tf.nn.relu(n)
            n = tf.nn.conv2d(n, w_params[7], strides=(1, 1, 1, 1), padding='SAME')
            n = tf.add(n, b_params[7])
            n = tf.nn.relu(n)
            n = tf.nn.max_pool(n, ksize=(1, 2, 2, 1), strides=(1, 2, 2, 1), padding='SAME')
            # four
            n = tf.nn.conv2d(n, w_params[8], strides=(1, 1, 1, 1), padding='SAME')
            n = tf.add(n, b_params[8])
            n = tf.nn.relu(n)
            n = tf.nn.conv2d(n, w_params[9], strides=(1, 1, 1, 1), padding='SAME')
            n = tf.add(n, b_params[9])
            n = tf.nn.relu(n)

            n = tf.nn.conv2d(n, w_params[10], strides=(1, 1, 1, 1), padding='SAME')
            n = tf.add(n, b_params[10])
            n = tf.nn.relu(n)
            n = tf.nn.conv2d(n, w_params[11], strides=(1, 1, 1, 1), padding='SAME')
            n = tf.add(n, b_params[11])
            n = tf.nn.relu(n)
            n = tf.nn.max_pool(n, ksize=(1, 2, 2, 1), strides=(1, 2, 2, 1), padding='SAME')
            return n

            # # five
            # n = tf.nn.conv2d(n, w_params[12], strides=(1, 1, 1, 1), padding='SAME')
            # n = tf.add(n, b_params[12])
            # n = tf.nn.relu(n)
            # n = tf.nn.conv2d(n, w_params[13], strides=(1, 1, 1, 1), padding='SAME')
            # n = tf.add(n, b_params[13])
            # n = tf.nn.relu(n)
            #
            # n = tf.nn.conv2d(n, w_params[14], strides=(1, 1, 1, 1), padding='SAME')
            # n = tf.add(n, b_params[14])
            # n = tf.nn.relu(n)
            # n = tf.nn.conv2d(n, w_params[15], strides=(1, 1, 1, 1), padding='SAME')
            # n = tf.add(n, b_params[15])
            # n = tf.nn.relu(n)
            # n = tf.nn.max_pool(n, ksize=(1, 2, 2, 1), strides=(1, 2, 2, 1), padding='SAME')
            # return n

            # 这里拿特征进行mse对比,不需要后面的全连接
            # flatten = tf.reshape(n, (input.get_shape()[0], -1))
            # f = tf.layers.dense(flatten, 4096)
            # f = tf.layers.dense(f, 4096)
            # f = tf.layers.dense(f, 1)
            # return n


gan = SRGAN()
G_OPTIM_init = tf.train.AdamOptimizer(learning_rate=0.001, beta1=0.4).minimize(gan.init_mse_loss, var_list=gan.g_vars)
D_OPTIM = tf.train.AdamOptimizer(learning_rate=0.001, beta1=0.4).minimize(gan.D_loos, var_list=gan.d_vars)
G_OPTIM = tf.train.AdamOptimizer(learning_rate=0.001, beta1=0.4).minimize(gan.G_loos, var_list=gan.g_vars)

saver = tf.train.Saver(max_to_keep=3)

init = tf.global_variables_initializer()


#加载路径文件夹中的训练图片,这里加载的只是图片目录。防止内存中加载太多图片,内存不够   
train_hr_img_list = sorted(tl.files.load_file_list(path='F:\\theRoleOfCOde\深度学习\SRGAN_PF\gaoqing', regx='.*.png', printable=False))[:100]
#加载图片  
train_hr_imgs = tl.vis.read_images(train_hr_img_list, path='F:\\theRoleOfCOde\深度学习\SRGAN_PF\gaoqing', n_threads=1)

#加载路径文件夹中的测试图片目录
test_img_list = sorted( tl.files.load_file_list(path='F:\\theRoleOfCOde\深度学习\SRGAN_PF\SRGAN_PF\img\\test', regx='.*.png', printable=False))[ :6]
test_img = tl.vis.read_images(test_img_list, path='F:\\theRoleOfCOde\深度学习\SRGAN_PF\SRGAN_PF\img\\test', n_threads=1)



#分三种运行方式,
#pre,预训练判别器
#restore,回复训练好的模型,继续训练


#训练一会儿,就测试一下效果。将生成的图片矩阵,保存为numpy矩阵
#通过工具函数,变化为图片查看
#第三种,从零开始训练
with tf.Session() as sess:
    type = 'go'
    if type == 'restore':
        saver.restore(sess, "./save/nets/ckpt-0-80")
        print('---------------------恢复以前的训练数据,继续训练-----------------------')
        for epoch in range(0):
            for idx in range(0, (len(train_hr_imgs) // 10), batch_size):
                # print(type(train_hr_imgs[idx:idx + batch_size]))
                b_imgs_384 = tl.prepro.threading_data(train_hr_imgs[idx:idx + batch_size], fn=crop_sub_imgs_fn,
                                                      is_random=True)
                b_imgs_96 = tl.prepro.threading_data(b_imgs_384, fn=downsample_fn)
                print('-------------pre_generator:' + str(epoch) + '_' + str(idx) + '----------------')
                for i in range(40):
                    init_mse_loss, _ = sess.run([gan.init_mse_loss, G_OPTIM_init],
                                                feed_dict={
                                                    gan.x: b_imgs_96,
                                                    gan.y: b_imgs_384
                                                })
                    print('init_mse_loss:' + str(init_mse_loss))
            saver.save(sess, "save/nets/better_ge.ckpt")
        for epoch in range(config["epoch"]):
            for idx in range(0, len(train_hr_imgs), batch_size):
                # print(type(train_hr_imgs[idx:idx + batch_size]))
                b_imgs_384 = tl.prepro.threading_data(train_hr_imgs[idx:idx + batch_size], fn=crop_sub_imgs_fn,
                                                      is_random=True)
                b_imgs_96 = tl.prepro.threading_data(b_imgs_384, fn=downsample_fn)
                print('-------------' + str(epoch) + '_' + str(idx) + '----------')
                for i in range(25):
                    loss_D, _ = sess.run([gan.D_loos, D_OPTIM],
                                         feed_dict={
                                             gan.x: b_imgs_96,
                                             gan.y: b_imgs_384
                                         })
                    loss_G, _ = sess.run([gan.G_loos, G_OPTIM],
                                         feed_dict={
                                             gan.x: b_imgs_96,
                                             gan.y: b_imgs_384
                                         })
                    print(loss_D, loss_G)
                if idx % 20 == 0:
                    saver.save(sess, "./save/nets/better_all_" + str(epoch) + "_" + str(idx) + '.ckpt')

                    _imgs = (np.asanyarray(test_img[0:1]) / (255. / 2.)) - 1
                    _imgs = _imgs[:, :, :, 0:3]
                    result_fake_y = sess.run([gan.test_fake_y], feed_dict={
                        gan.test_x: _imgs
                    })  # 生成的假照片
                    # result=sess.run(result_fake_y)
                    strpath = './preImg/result_' + str(epoch) + '_' + str(idx) + '_1.npy'
                    np.save(strpath, result_fake_y)

                    _imgs2 = (np.asanyarray(test_img[1:2]) / (255. / 2.)) - 1
                    _imgs2 = _imgs2[:, :, :, 0:3]
                    result_fake_y = sess.run([gan.test_fake_y], feed_dict={
                        gan.test_x: _imgs2
                    })  # 生成的假照片
                    # result=sess.run(result_fake_y)
                    strpath = './preImg/result_' + str(epoch) + '_' + str(idx) + '_2.npy'
                    np.save(strpath, result_fake_y)
                    # print(type(result_fake_y))
    elif type == 'pre':
        saver.restore(sess, "save/nets/better_all_1_28.ckpt")
        print('---------------------恢复训练好的模型,开始预测-----------------------')
        for num in range(6):
            _imgs = (np.asanyarray(test_img[num:(num + 1)]) / (255. / 2.)) - 1
            print(_imgs.shape)
            _imgs = _imgs[:, :, :, 0:3]
            # time.sleep(1)
            result_fake_y = sess.run([gan.test_fake_y], feed_dict={
                gan.test_x: _imgs
            })  # 生成的假照片
            strpath = './preImg/pre_result_' + str(num) + '.npy'
            np.save(strpath, result_fake_y)
            print('ok')
    else:
        sess.run(init)
        print('---------------------开始新的训练-----------------------')
        for epoch in range(2):
            for idx in range(0, len(train_hr_imgs), batch_size):
                # print(type(train_hr_imgs[idx:idx + batch_size]))
                b_imgs_384 = tl.prepro.threading_data(train_hr_imgs[idx:idx + batch_size], fn=crop_sub_imgs_fn,
                                                      is_random=True)
                b_imgs_96 = tl.prepro.threading_data(b_imgs_384, fn=downsample_fn)
                print('-------------pre_generator:' + str(epoch) + '_' + str(idx) + '----------------')
                for i in range(25):
                    init_mse_loss, _ = sess.run([gan.init_mse_loss, G_OPTIM_init],
                                                feed_dict={
                                                    gan.x: b_imgs_96,
                                                    gan.y: b_imgs_384
                                                })
                    print('init_mse_loss:' + str(init_mse_loss))
        saver.save(sess, "save/nets/cnn_mnist_basic_generator.ckpt")
        for epoch in range(config["epoch"]):
            for idx in range(0, len(train_hr_imgs), batch_size):
                # print(type(train_hr_imgs[idx:idx + batch_size]))
                b_imgs_384 = tl.prepro.threading_data(train_hr_imgs[idx:idx + batch_size], fn=crop_sub_imgs_fn,
                                                      is_random=True)
                b_imgs_96 = tl.prepro.threading_data(b_imgs_384, fn=downsample_fn)
                print('-------------' + str(epoch) + '_' + str(idx) + '----------')
                for i in range(25):
                    loss_D, _ = sess.run([gan.D_loos, D_OPTIM],
                                         feed_dict={
                                             gan.x: b_imgs_96,
                                             gan.y: b_imgs_384
                                         })
                    loss_G, _ = sess.run([gan.G_loos, G_OPTIM],
                                         feed_dict={
                                             gan.x: b_imgs_96,
                                             gan.y: b_imgs_384
                                         })
                    print(loss_D, loss_G)
                if idx % 20 == 0:
                    _imgs = (np.asanyarray(test_img[0:1]) / (255. / 2.)) - 1
                    _imgs = _imgs[:, :, :, 0:3]
                    result_fake_y = sess.run([gan.test_fake_y], feed_dict={
                        gan.test_x: _imgs
                    })  # 生成的假照片
                    # result=sess.run(result_fake_y)
                    strpath = './preImg/result_' + str(epoch) + '_' + str(idx) + '_1.npy'
                    np.save(strpath, result_fake_y)

                    _imgs2 = (np.asanyarray(test_img[1:2]) / (255. / 2.)) - 1
                    _imgs2 = _imgs2[:, :, :, 0:3]
                    result_fake_y = sess.run([gan.test_fake_y], feed_dict={
                        gan.test_x: _imgs2
                    })  # 生成的假照片
                    # result=sess.run(result_fake_y)
                    strpath = './preImg/result_' + str(epoch) + '_' + str(idx) + '_2.npy'
                    np.save(strpath, result_fake_y)
                    saver.save(sess, "save/nets/ckpt-" + str(epoch) + '-' + str(idx))
                    # print(type(result_fake_y))

查看效果的工具函数

将numpy矩阵转换为图片

import numpy as np
import matplotlib.pyplot as plt
from PIL import Image

npz = np.load('../preImg/pre_result_5.npy', encoding='latin1')
print(npz.shape)
data = ((npz[0][0]) + 1) * (255. / 2.)
print(data)

new_im = Image.fromarray(data.astype(np.uint8))
new_im.show()
new_im.save('result.png')

作者: 洪卫
版权声明: 本博客所有文章除特別声明外,均采用 CC BY 4.0 许可协议。转载请注明来源 洪卫 !
 本篇
GAN生成对抗网络Tutorial GAN生成对抗网络Tutorial
本文转载于博客园,感觉写得比较清晰,保存一下供以后查看。 GAN生成式对抗网络(一)——原理生成式对抗网络(GAN, Generative Adversarial Networks )是一种深度学习模型 GAN包括两个核心模块 1.生成
2019-12-10
下一篇 
研究生新手刚开始应该怎么看英文文献? 研究生新手刚开始应该怎么看英文文献?
本文原文转载于知乎,感觉总结得不错。 前言千万!千万!不要从头到尾按顺序看! 强烈推荐美国公立常青藤明尼苏达大学Peter W. Carr教授传授的阅读顺序 明尼苏达大学是世界著名公立研究型大学,在2019年USNews世界大学排名中位
2019-12-09
  目录