72 lines
2.4 KiB
Python
72 lines
2.4 KiB
Python
import numpy as np
|
|
import matplotlib.pyplot as plt
|
|
|
|
import tensorflow as tf
|
|
from tensorflow.keras.datasets import fashion_mnist
|
|
from tensorflow.keras.models import Model
|
|
from tensorflow.keras.layers import Input, Conv2D, MaxPooling2D, UpSampling2D
|
|
from keras.layers import Add
|
|
|
|
(train_data, _), (test_data, _) = fashion_mnist.load_data()
|
|
|
|
# Daten normalisieren nach [0, 1]
|
|
train_data = train_data.astype('float32') / 255.0
|
|
test_data = test_data.astype('float32') / 255.0
|
|
|
|
# Reshape, um nur eine Dimension zu haben
|
|
train_data = train_data.reshape(-1, 28, 28, 1)
|
|
test_data = test_data.reshape(-1, 28, 28, 1)
|
|
|
|
# Zu Trainingsdaten und Testdaten Rauschen hinzufügen
|
|
noise_factor = 0.5
|
|
train_noisy = train_data + noise_factor * np.random.normal(size=train_data.shape)
|
|
test_noisy = test_data + noise_factor * np.random.normal(size=test_data.shape)
|
|
# ... und auf Werte zwischen 0 und 1 begrenzen
|
|
train_noisy = np.clip(train_noisy, 0.0, 1.0)
|
|
test_noisy = np.clip(test_noisy, 0.0, 1.0)
|
|
|
|
# die U-Net-Modelldefinition
|
|
def unet_model(input_shape):
|
|
input_img = Input(shape=input_shape)
|
|
|
|
# Encoder
|
|
x1 = Conv2D(32, (3, 3), activation='relu', padding='same')(input_img)
|
|
x2 = MaxPooling2D((2, 2), padding='same')(x1)
|
|
x3 = Conv2D(64, (3, 3), activation='relu', padding='same')(x2)
|
|
x4 = MaxPooling2D((2, 2), padding='same')(x3)
|
|
|
|
# Bottleneck
|
|
bn = Conv2D(128, (3, 3), activation='relu', padding='same')(x4)
|
|
|
|
# Decoder
|
|
x5 = Conv2D(64, (3, 3), activation='relu', padding='same')(bn)
|
|
x6 = UpSampling2D((2, 2))(x5)
|
|
# Hinzufügen von Residual-Verbindungen von x3
|
|
x6 = Add()([x6, x3])
|
|
|
|
x7 = Conv2D(32, (3, 3), activation='relu', padding='same')(x6)
|
|
x8 = UpSampling2D((2, 2))(x7)
|
|
# Hinzufügen von Residual-Verbindungen von x1
|
|
x8 = Add()([x8, x1])
|
|
|
|
x9 = Conv2D(1, (3, 3), activation='sigmoid', padding='same')(x8)
|
|
|
|
return Model(input_img, x9)
|
|
|
|
model = unet_model((28, 28, 1))
|
|
model.compile(optimizer='adam', loss='mean_squared_error')
|
|
model.fit(train_noisy, train_data, epochs=10, batch_size=128, validation_data=(test_noisy, test_data))
|
|
|
|
denoised_images = model.predict(test_noisy)
|
|
# Plot
|
|
n = 10
|
|
plt.figure(figsize=(20, 6))
|
|
for i in range(1, n + 1):
|
|
ax = plt.subplot(2, n, i)
|
|
plt.imshow(test_noisy[i].reshape(28, 28), cmap='gray')
|
|
ax.set_title('verrauscht')
|
|
|
|
ax = plt.subplot(2, n, i + n)
|
|
plt.imshow(denoised_images[i].reshape(28, 28), cmap='gray')
|
|
ax.set_title('entrauscht')
|
|
plt.show() |