gnn/beispiele/12.3_U-Net.py

72 lines
2.4 KiB
Python

import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
from tensorflow.keras.datasets import fashion_mnist
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Input, Conv2D, MaxPooling2D, UpSampling2D
from keras.layers import Add
(train_data, _), (test_data, _) = fashion_mnist.load_data()
# Daten normalisieren nach [0, 1]
train_data = train_data.astype('float32') / 255.0
test_data = test_data.astype('float32') / 255.0
# Reshape, um nur eine Dimension zu haben
train_data = train_data.reshape(-1, 28, 28, 1)
test_data = test_data.reshape(-1, 28, 28, 1)
# Zu Trainingsdaten und Testdaten Rauschen hinzufügen
noise_factor = 0.5
train_noisy = train_data + noise_factor * np.random.normal(size=train_data.shape)
test_noisy = test_data + noise_factor * np.random.normal(size=test_data.shape)
# ... und auf Werte zwischen 0 und 1 begrenzen
train_noisy = np.clip(train_noisy, 0.0, 1.0)
test_noisy = np.clip(test_noisy, 0.0, 1.0)
# die U-Net-Modelldefinition
def unet_model(input_shape):
input_img = Input(shape=input_shape)
# Encoder
x1 = Conv2D(32, (3, 3), activation='relu', padding='same')(input_img)
x2 = MaxPooling2D((2, 2), padding='same')(x1)
x3 = Conv2D(64, (3, 3), activation='relu', padding='same')(x2)
x4 = MaxPooling2D((2, 2), padding='same')(x3)
# Bottleneck
bn = Conv2D(128, (3, 3), activation='relu', padding='same')(x4)
# Decoder
x5 = Conv2D(64, (3, 3), activation='relu', padding='same')(bn)
x6 = UpSampling2D((2, 2))(x5)
# Hinzufügen von Residual-Verbindungen von x3
x6 = Add()([x6, x3])
x7 = Conv2D(32, (3, 3), activation='relu', padding='same')(x6)
x8 = UpSampling2D((2, 2))(x7)
# Hinzufügen von Residual-Verbindungen von x1
x8 = Add()([x8, x1])
x9 = Conv2D(1, (3, 3), activation='sigmoid', padding='same')(x8)
return Model(input_img, x9)
model = unet_model((28, 28, 1))
model.compile(optimizer='adam', loss='mean_squared_error')
model.fit(train_noisy, train_data, epochs=10, batch_size=128, validation_data=(test_noisy, test_data))
denoised_images = model.predict(test_noisy)
# Plot
n = 10
plt.figure(figsize=(20, 6))
for i in range(1, n + 1):
ax = plt.subplot(2, n, i)
plt.imshow(test_noisy[i].reshape(28, 28), cmap='gray')
ax.set_title('verrauscht')
ax = plt.subplot(2, n, i + n)
plt.imshow(denoised_images[i].reshape(28, 28), cmap='gray')
ax.set_title('entrauscht')
plt.show()