这样做对吗?
加载数据集:
train_ds = keras.utils.image_dataset_from_directory(
directory='Abstract_gallery/',
labels='inferred',
label_mode='categorical',
image_size=(64, 64),
batch_size=128)
使用数据集训练GAN:
y_fake = np.zeros(batch_size)
for epoch in range(epochs):
for batch in train_ds:
###### train D
z = np.random.normal(0, 1, [batch_size,100])
generated_images = generator(z)
dloss = discriminator.train_on_batch(generated_images, y_fake)
y_real = np.ones(batch.shape[0])
dloss += discriminator.train_on_batch(batch, y_real)
Dloss.append(dloss)
###### train G
y_real = np.ones(batch_size)
z = np.random.normal(0, 1, [batch_size,100])
gloss = GAN.train_on_batch(z, y_real)
Gloss.append(gloss)