Graduate School/Neural Network
Face Recognition
- -
728x90
반응형
Implement CNN Model¶
Import Library¶
In [1]:
import scipy.io as sio
import matplotlib.pyplot as plt
import tensorflow as tf
import numpy as np
%matplotlib inline
2022-12-07 17:52:14.633420: I tensorflow/core/util/util.cc:169] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
Load Dataset¶
In [2]:
d = sio.loadmat("./face.mat")
In [3]:
images = d["images"]
landmarks = d["landmarks"]
print(images.shape, landmarks.shape)
print(images.max())
(2000, 56, 56, 3) (2000, 68, 2) 1.1094794840691424
Visualization Data¶
In [4]:
plt.imshow(images[1])
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Out[4]:
<matplotlib.image.AxesImage at 0x7fcbd6a85970>
In [5]:
plt.imshow(images[1])
for point in landmarks[1]:
plt.plot(point[0]*56, point[1]*56, "r+")
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Create Dataset(train and test)¶
In [6]:
landmarks = np.reshape(d["landmarks"], [-1, 68*2])
x_train = images[:100]
y_train = landmarks[:100]
x_test = images[-100:]
y_test = landmarks[-100:]
print(x_train.shape, x_test.shape)
(100, 56, 56, 3) (100, 56, 56, 3)
Create Model¶
In [7]:
def conv(x, k, f): # 2 conv layer
for _ in range(2):
x = tf.keras.layers.Conv2D(kernel_size=k, filters=f, activation=tf.nn.relu, padding="same")(x)
return x
def pool(x): # pool layer
return tf.keras.layers.MaxPool2D(pool_size=2, strides=2)(x)
x_in = x = tf.keras.Input(shape=[56, 56, 3])
x = conv(x, 3, 16) # output should be [56, 56, 16]
x = pool(x) # output should be [28, 28, 16]
x = conv(x, 3, 16) # output should be [28, 28, 16]
x = pool(x) # output should be [14, 14, 16]
x = conv(x, 3, 16) # output should be [14, 14, 16]
x = pool(x) # output should be [7, 7, 16] => 7*7*16=784 values
x = tf.keras.layers.Flatten()(x)
x = tf.keras.layers.Dense(units=256, activation=tf.nn.relu)(x)
y = tf.keras.layers.Dense(units=136, activation=tf.nn.sigmoid)(x) # 136=68*2
model = tf.keras.Model(inputs=x_in, outputs=y)
2022-12-07 17:52:16.724291: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations: AVX2 AVX512F AVX512_VNNI FMA To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags. 2022-12-07 17:52:17.290186: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1532] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 22121 MB memory: -> device: 0, name: NVIDIA GeForce RTX 3090, pci bus id: 0000:65:00.0, compute capability: 8.6
Show Model Information¶
In [8]:
model.summary()
Model: "model" _________________________________________________________________ Layer (type) Output Shape Param # ================================================================= input_1 (InputLayer) [(None, 56, 56, 3)] 0 conv2d (Conv2D) (None, 56, 56, 16) 448 conv2d_1 (Conv2D) (None, 56, 56, 16) 2320 max_pooling2d (MaxPooling2D (None, 28, 28, 16) 0 ) conv2d_2 (Conv2D) (None, 28, 28, 16) 2320 conv2d_3 (Conv2D) (None, 28, 28, 16) 2320 max_pooling2d_1 (MaxPooling (None, 14, 14, 16) 0 2D) conv2d_4 (Conv2D) (None, 14, 14, 16) 2320 conv2d_5 (Conv2D) (None, 14, 14, 16) 2320 max_pooling2d_2 (MaxPooling (None, 7, 7, 16) 0 2D) flatten (Flatten) (None, 784) 0 dense (Dense) (None, 256) 200960 dense_1 (Dense) (None, 136) 34952 ================================================================= Total params: 247,960 Trainable params: 247,960 Non-trainable params: 0 _________________________________________________________________
Training Model¶
In [9]:
def loss_fn(y_true, y_pred):
# MSE Loss
mse = tf.reduce_mean(tf.reduce_mean(tf.square(y_true-y_pred), axis=1))
# Cosine Loss
norm_y_true = tf.sqrt(tf.reduce_sum(tf.square(y_true), axis=1))
norm_y_pred = tf.sqrt(tf.reduce_sum(tf.square(y_pred), axis=1))
y_true_y_pred = tf.reduce_sum(tf.multiply(y_true, y_pred), axis=1)
cos_loss = tf.reduce_mean(1-y_true_y_pred/(norm_y_true*norm_y_pred))
return mse+0.5*cos_loss
model.compile(optimizer=tf.keras.optimizers.Adam(lr=1e-3), loss=loss_fn, metrics=tf.keras.metrics.MSE)
history = model.fit(x_train, y_train, batch_size=32, epochs=100, validation_data=(x_test, y_test))
/home/pmi-minos-3090-single/anaconda3/lib/python3.9/site-packages/keras/optimizers/optimizer_v2/adam.py:110: UserWarning: The `lr` argument is deprecated, use `learning_rate` instead. super(Adam, self).__init__(name, **kwargs)
Epoch 1/100
2022-12-07 17:52:19.777359: I tensorflow/stream_executor/cuda/cuda_dnn.cc:384] Loaded cuDNN version 8401 2022-12-07 17:52:21.233090: I tensorflow/stream_executor/cuda/cuda_blas.cc:1786] TensorFloat-32 will be used for the matrix multiplication. This will only be logged once.
4/4 [==============================] - 4s 128ms/step - loss: 0.0291 - mean_squared_error: 0.0175 - val_loss: 0.0236 - val_mean_squared_error: 0.0142 Epoch 2/100 4/4 [==============================] - 0s 31ms/step - loss: 0.0197 - mean_squared_error: 0.0118 - val_loss: 0.0139 - val_mean_squared_error: 0.0080 Epoch 3/100 4/4 [==============================] - 0s 32ms/step - loss: 0.0117 - mean_squared_error: 0.0068 - val_loss: 0.0089 - val_mean_squared_error: 0.0052 Epoch 4/100 4/4 [==============================] - 0s 32ms/step - loss: 0.0077 - mean_squared_error: 0.0045 - val_loss: 0.0069 - val_mean_squared_error: 0.0040 Epoch 5/100 4/4 [==============================] - 0s 50ms/step - loss: 0.0064 - mean_squared_error: 0.0037 - val_loss: 0.0059 - val_mean_squared_error: 0.0034 Epoch 6/100 4/4 [==============================] - 0s 30ms/step - loss: 0.0058 - mean_squared_error: 0.0034 - val_loss: 0.0057 - val_mean_squared_error: 0.0033 Epoch 7/100 4/4 [==============================] - 0s 35ms/step - loss: 0.0055 - mean_squared_error: 0.0032 - val_loss: 0.0053 - val_mean_squared_error: 0.0030 Epoch 8/100 4/4 [==============================] - 0s 31ms/step - loss: 0.0051 - mean_squared_error: 0.0029 - val_loss: 0.0051 - val_mean_squared_error: 0.0029 Epoch 9/100 4/4 [==============================] - 0s 29ms/step - loss: 0.0048 - mean_squared_error: 0.0028 - val_loss: 0.0049 - val_mean_squared_error: 0.0028 Epoch 10/100 4/4 [==============================] - 0s 31ms/step - loss: 0.0047 - mean_squared_error: 0.0027 - val_loss: 0.0049 - val_mean_squared_error: 0.0028 Epoch 11/100 4/4 [==============================] - 0s 31ms/step - loss: 0.0049 - mean_squared_error: 0.0028 - val_loss: 0.0049 - val_mean_squared_error: 0.0028 Epoch 12/100 4/4 [==============================] - 0s 29ms/step - loss: 0.0046 - mean_squared_error: 0.0027 - val_loss: 0.0047 - val_mean_squared_error: 0.0027 Epoch 13/100 4/4 [==============================] - 0s 32ms/step - loss: 0.0046 - mean_squared_error: 0.0026 - val_loss: 0.0047 - val_mean_squared_error: 0.0027 Epoch 14/100 4/4 [==============================] - 0s 49ms/step - loss: 0.0045 - mean_squared_error: 0.0026 - val_loss: 0.0046 - val_mean_squared_error: 0.0026 Epoch 15/100 4/4 [==============================] - 0s 31ms/step - loss: 0.0044 - mean_squared_error: 0.0025 - val_loss: 0.0046 - val_mean_squared_error: 0.0026 Epoch 16/100 4/4 [==============================] - 0s 31ms/step - loss: 0.0044 - mean_squared_error: 0.0025 - val_loss: 0.0046 - val_mean_squared_error: 0.0026 Epoch 17/100 4/4 [==============================] - 0s 30ms/step - loss: 0.0045 - mean_squared_error: 0.0025 - val_loss: 0.0046 - val_mean_squared_error: 0.0026 Epoch 18/100 4/4 [==============================] - 0s 30ms/step - loss: 0.0044 - mean_squared_error: 0.0025 - val_loss: 0.0046 - val_mean_squared_error: 0.0026 Epoch 19/100 4/4 [==============================] - 0s 27ms/step - loss: 0.0044 - mean_squared_error: 0.0025 - val_loss: 0.0046 - val_mean_squared_error: 0.0026 Epoch 20/100 4/4 [==============================] - 0s 32ms/step - loss: 0.0044 - mean_squared_error: 0.0025 - val_loss: 0.0046 - val_mean_squared_error: 0.0026 Epoch 21/100 4/4 [==============================] - 0s 32ms/step - loss: 0.0043 - mean_squared_error: 0.0025 - val_loss: 0.0046 - val_mean_squared_error: 0.0026 Epoch 22/100 4/4 [==============================] - 0s 31ms/step - loss: 0.0043 - mean_squared_error: 0.0025 - val_loss: 0.0047 - val_mean_squared_error: 0.0027 Epoch 23/100 4/4 [==============================] - 0s 31ms/step - loss: 0.0043 - mean_squared_error: 0.0025 - val_loss: 0.0047 - val_mean_squared_error: 0.0027 Epoch 24/100 4/4 [==============================] - 0s 31ms/step - loss: 0.0044 - mean_squared_error: 0.0025 - val_loss: 0.0046 - val_mean_squared_error: 0.0026 Epoch 25/100 4/4 [==============================] - 0s 32ms/step - loss: 0.0043 - mean_squared_error: 0.0025 - val_loss: 0.0046 - val_mean_squared_error: 0.0026 Epoch 26/100 4/4 [==============================] - 0s 31ms/step - loss: 0.0043 - mean_squared_error: 0.0024 - val_loss: 0.0045 - val_mean_squared_error: 0.0026 Epoch 27/100 4/4 [==============================] - 0s 32ms/step - loss: 0.0042 - mean_squared_error: 0.0024 - val_loss: 0.0045 - val_mean_squared_error: 0.0026 Epoch 28/100 4/4 [==============================] - 0s 29ms/step - loss: 0.0042 - mean_squared_error: 0.0024 - val_loss: 0.0045 - val_mean_squared_error: 0.0026 Epoch 29/100 4/4 [==============================] - 0s 31ms/step - loss: 0.0041 - mean_squared_error: 0.0024 - val_loss: 0.0045 - val_mean_squared_error: 0.0026 Epoch 30/100 4/4 [==============================] - 0s 31ms/step - loss: 0.0041 - mean_squared_error: 0.0023 - val_loss: 0.0046 - val_mean_squared_error: 0.0026 Epoch 31/100 4/4 [==============================] - 0s 29ms/step - loss: 0.0041 - mean_squared_error: 0.0024 - val_loss: 0.0045 - val_mean_squared_error: 0.0026 Epoch 32/100 4/4 [==============================] - 0s 31ms/step - loss: 0.0040 - mean_squared_error: 0.0023 - val_loss: 0.0045 - val_mean_squared_error: 0.0025 Epoch 33/100 4/4 [==============================] - 0s 28ms/step - loss: 0.0040 - mean_squared_error: 0.0023 - val_loss: 0.0045 - val_mean_squared_error: 0.0026 Epoch 34/100 4/4 [==============================] - 0s 31ms/step - loss: 0.0040 - mean_squared_error: 0.0023 - val_loss: 0.0046 - val_mean_squared_error: 0.0026 Epoch 35/100 4/4 [==============================] - 0s 31ms/step - loss: 0.0040 - mean_squared_error: 0.0023 - val_loss: 0.0045 - val_mean_squared_error: 0.0026 Epoch 36/100 4/4 [==============================] - 0s 30ms/step - loss: 0.0041 - mean_squared_error: 0.0024 - val_loss: 0.0045 - val_mean_squared_error: 0.0026 Epoch 37/100 4/4 [==============================] - 0s 29ms/step - loss: 0.0039 - mean_squared_error: 0.0022 - val_loss: 0.0042 - val_mean_squared_error: 0.0024 Epoch 38/100 4/4 [==============================] - 0s 27ms/step - loss: 0.0036 - mean_squared_error: 0.0021 - val_loss: 0.0042 - val_mean_squared_error: 0.0024 Epoch 39/100 4/4 [==============================] - 0s 26ms/step - loss: 0.0036 - mean_squared_error: 0.0020 - val_loss: 0.0041 - val_mean_squared_error: 0.0024 Epoch 40/100 4/4 [==============================] - 0s 30ms/step - loss: 0.0034 - mean_squared_error: 0.0020 - val_loss: 0.0043 - val_mean_squared_error: 0.0024 Epoch 41/100 4/4 [==============================] - 0s 29ms/step - loss: 0.0033 - mean_squared_error: 0.0019 - val_loss: 0.0041 - val_mean_squared_error: 0.0023 Epoch 42/100 4/4 [==============================] - 0s 32ms/step - loss: 0.0034 - mean_squared_error: 0.0020 - val_loss: 0.0041 - val_mean_squared_error: 0.0024 Epoch 43/100 4/4 [==============================] - 0s 30ms/step - loss: 0.0032 - mean_squared_error: 0.0018 - val_loss: 0.0042 - val_mean_squared_error: 0.0024 Epoch 44/100 4/4 [==============================] - 0s 32ms/step - loss: 0.0031 - mean_squared_error: 0.0018 - val_loss: 0.0040 - val_mean_squared_error: 0.0023 Epoch 45/100 4/4 [==============================] - 0s 31ms/step - loss: 0.0028 - mean_squared_error: 0.0016 - val_loss: 0.0040 - val_mean_squared_error: 0.0023 Epoch 46/100 4/4 [==============================] - 0s 32ms/step - loss: 0.0029 - mean_squared_error: 0.0016 - val_loss: 0.0038 - val_mean_squared_error: 0.0022 Epoch 47/100 4/4 [==============================] - 0s 32ms/step - loss: 0.0027 - mean_squared_error: 0.0015 - val_loss: 0.0037 - val_mean_squared_error: 0.0021 Epoch 48/100 4/4 [==============================] - 0s 32ms/step - loss: 0.0025 - mean_squared_error: 0.0014 - val_loss: 0.0037 - val_mean_squared_error: 0.0021 Epoch 49/100 4/4 [==============================] - 0s 31ms/step - loss: 0.0023 - mean_squared_error: 0.0013 - val_loss: 0.0037 - val_mean_squared_error: 0.0021 Epoch 50/100 4/4 [==============================] - 0s 32ms/step - loss: 0.0023 - mean_squared_error: 0.0013 - val_loss: 0.0037 - val_mean_squared_error: 0.0021 Epoch 51/100 4/4 [==============================] - 0s 31ms/step - loss: 0.0021 - mean_squared_error: 0.0012 - val_loss: 0.0037 - val_mean_squared_error: 0.0021 Epoch 52/100 4/4 [==============================] - 0s 34ms/step - loss: 0.0020 - mean_squared_error: 0.0012 - val_loss: 0.0036 - val_mean_squared_error: 0.0021 Epoch 53/100 4/4 [==============================] - 0s 46ms/step - loss: 0.0019 - mean_squared_error: 0.0011 - val_loss: 0.0037 - val_mean_squared_error: 0.0021 Epoch 54/100 4/4 [==============================] - 0s 32ms/step - loss: 0.0019 - mean_squared_error: 0.0011 - val_loss: 0.0037 - val_mean_squared_error: 0.0021 Epoch 55/100 4/4 [==============================] - 0s 31ms/step - loss: 0.0017 - mean_squared_error: 9.6311e-04 - val_loss: 0.0036 - val_mean_squared_error: 0.0021 Epoch 56/100 4/4 [==============================] - 0s 33ms/step - loss: 0.0017 - mean_squared_error: 9.9180e-04 - val_loss: 0.0036 - val_mean_squared_error: 0.0021 Epoch 57/100 4/4 [==============================] - 0s 31ms/step - loss: 0.0016 - mean_squared_error: 9.2352e-04 - val_loss: 0.0037 - val_mean_squared_error: 0.0021 Epoch 58/100 4/4 [==============================] - 0s 32ms/step - loss: 0.0016 - mean_squared_error: 9.0020e-04 - val_loss: 0.0035 - val_mean_squared_error: 0.0020 Epoch 59/100 4/4 [==============================] - 0s 32ms/step - loss: 0.0014 - mean_squared_error: 8.2510e-04 - val_loss: 0.0036 - val_mean_squared_error: 0.0020 Epoch 60/100 4/4 [==============================] - 0s 33ms/step - loss: 0.0014 - mean_squared_error: 8.2793e-04 - val_loss: 0.0036 - val_mean_squared_error: 0.0021 Epoch 61/100 4/4 [==============================] - 0s 34ms/step - loss: 0.0013 - mean_squared_error: 7.4150e-04 - val_loss: 0.0037 - val_mean_squared_error: 0.0021 Epoch 62/100 4/4 [==============================] - 0s 41ms/step - loss: 0.0012 - mean_squared_error: 7.1525e-04 - val_loss: 0.0036 - val_mean_squared_error: 0.0021 Epoch 63/100 4/4 [==============================] - 0s 28ms/step - loss: 0.0012 - mean_squared_error: 7.0862e-04 - val_loss: 0.0037 - val_mean_squared_error: 0.0021 Epoch 64/100 4/4 [==============================] - 0s 29ms/step - loss: 0.0011 - mean_squared_error: 6.4688e-04 - val_loss: 0.0036 - val_mean_squared_error: 0.0020 Epoch 65/100 4/4 [==============================] - 0s 31ms/step - loss: 0.0010 - mean_squared_error: 6.0541e-04 - val_loss: 0.0035 - val_mean_squared_error: 0.0020 Epoch 66/100 4/4 [==============================] - 0s 31ms/step - loss: 0.0010 - mean_squared_error: 5.9362e-04 - val_loss: 0.0035 - val_mean_squared_error: 0.0020 Epoch 67/100 4/4 [==============================] - 0s 31ms/step - loss: 9.2938e-04 - mean_squared_error: 5.3726e-04 - val_loss: 0.0035 - val_mean_squared_error: 0.0020 Epoch 68/100 4/4 [==============================] - 0s 31ms/step - loss: 9.0071e-04 - mean_squared_error: 5.2258e-04 - val_loss: 0.0035 - val_mean_squared_error: 0.0020 Epoch 69/100 4/4 [==============================] - 0s 32ms/step - loss: 9.2059e-04 - mean_squared_error: 5.3523e-04 - val_loss: 0.0036 - val_mean_squared_error: 0.0021 Epoch 70/100 4/4 [==============================] - 0s 32ms/step - loss: 9.1256e-04 - mean_squared_error: 5.2840e-04 - val_loss: 0.0035 - val_mean_squared_error: 0.0020 Epoch 71/100 4/4 [==============================] - 0s 29ms/step - loss: 9.1447e-04 - mean_squared_error: 5.3608e-04 - val_loss: 0.0034 - val_mean_squared_error: 0.0020 Epoch 72/100 4/4 [==============================] - 0s 30ms/step - loss: 8.4509e-04 - mean_squared_error: 4.9052e-04 - val_loss: 0.0036 - val_mean_squared_error: 0.0020 Epoch 73/100 4/4 [==============================] - 0s 31ms/step - loss: 7.9418e-04 - mean_squared_error: 4.6561e-04 - val_loss: 0.0037 - val_mean_squared_error: 0.0021 Epoch 74/100 4/4 [==============================] - 0s 31ms/step - loss: 8.1259e-04 - mean_squared_error: 4.7772e-04 - val_loss: 0.0036 - val_mean_squared_error: 0.0021 Epoch 75/100 4/4 [==============================] - 0s 28ms/step - loss: 7.6772e-04 - mean_squared_error: 4.4993e-04 - val_loss: 0.0036 - val_mean_squared_error: 0.0020 Epoch 76/100 4/4 [==============================] - 0s 31ms/step - loss: 6.8751e-04 - mean_squared_error: 3.9961e-04 - val_loss: 0.0036 - val_mean_squared_error: 0.0020 Epoch 77/100 4/4 [==============================] - 0s 31ms/step - loss: 7.2701e-04 - mean_squared_error: 4.2143e-04 - val_loss: 0.0035 - val_mean_squared_error: 0.0020 Epoch 78/100 4/4 [==============================] - 0s 32ms/step - loss: 6.3485e-04 - mean_squared_error: 3.7062e-04 - val_loss: 0.0036 - val_mean_squared_error: 0.0021 Epoch 79/100 4/4 [==============================] - 0s 32ms/step - loss: 6.2501e-04 - mean_squared_error: 3.6738e-04 - val_loss: 0.0035 - val_mean_squared_error: 0.0020 Epoch 80/100 4/4 [==============================] - 0s 31ms/step - loss: 6.0892e-04 - mean_squared_error: 3.5873e-04 - val_loss: 0.0036 - val_mean_squared_error: 0.0020 Epoch 81/100 4/4 [==============================] - 0s 45ms/step - loss: 5.7113e-04 - mean_squared_error: 3.3603e-04 - val_loss: 0.0036 - val_mean_squared_error: 0.0020 Epoch 82/100 4/4 [==============================] - 0s 32ms/step - loss: 5.2376e-04 - mean_squared_error: 3.1004e-04 - val_loss: 0.0036 - val_mean_squared_error: 0.0020 Epoch 83/100 4/4 [==============================] - 0s 32ms/step - loss: 5.2685e-04 - mean_squared_error: 3.1033e-04 - val_loss: 0.0034 - val_mean_squared_error: 0.0019 Epoch 84/100 4/4 [==============================] - 0s 30ms/step - loss: 5.2161e-04 - mean_squared_error: 3.0623e-04 - val_loss: 0.0035 - val_mean_squared_error: 0.0020 Epoch 85/100 4/4 [==============================] - 0s 32ms/step - loss: 4.6667e-04 - mean_squared_error: 2.7376e-04 - val_loss: 0.0035 - val_mean_squared_error: 0.0020 Epoch 86/100 4/4 [==============================] - 0s 31ms/step - loss: 4.3922e-04 - mean_squared_error: 2.5937e-04 - val_loss: 0.0035 - val_mean_squared_error: 0.0020 Epoch 87/100 4/4 [==============================] - 0s 32ms/step - loss: 4.3373e-04 - mean_squared_error: 2.5466e-04 - val_loss: 0.0035 - val_mean_squared_error: 0.0020 Epoch 88/100 4/4 [==============================] - 0s 30ms/step - loss: 4.1350e-04 - mean_squared_error: 2.4425e-04 - val_loss: 0.0035 - val_mean_squared_error: 0.0020 Epoch 89/100 4/4 [==============================] - 0s 30ms/step - loss: 3.8058e-04 - mean_squared_error: 2.2427e-04 - val_loss: 0.0036 - val_mean_squared_error: 0.0021 Epoch 90/100 4/4 [==============================] - 0s 32ms/step - loss: 3.6596e-04 - mean_squared_error: 2.1618e-04 - val_loss: 0.0037 - val_mean_squared_error: 0.0021 Epoch 91/100 4/4 [==============================] - 0s 30ms/step - loss: 3.8842e-04 - mean_squared_error: 2.2904e-04 - val_loss: 0.0036 - val_mean_squared_error: 0.0021 Epoch 92/100 4/4 [==============================] - 0s 31ms/step - loss: 3.5222e-04 - mean_squared_error: 2.0731e-04 - val_loss: 0.0037 - val_mean_squared_error: 0.0021 Epoch 93/100 4/4 [==============================] - 0s 31ms/step - loss: 3.6023e-04 - mean_squared_error: 2.1189e-04 - val_loss: 0.0036 - val_mean_squared_error: 0.0020 Epoch 94/100 4/4 [==============================] - 0s 31ms/step - loss: 3.2864e-04 - mean_squared_error: 1.9343e-04 - val_loss: 0.0036 - val_mean_squared_error: 0.0021 Epoch 95/100 4/4 [==============================] - 0s 34ms/step - loss: 3.3554e-04 - mean_squared_error: 2.0020e-04 - val_loss: 0.0035 - val_mean_squared_error: 0.0020 Epoch 96/100 4/4 [==============================] - 0s 31ms/step - loss: 3.1564e-04 - mean_squared_error: 1.8651e-04 - val_loss: 0.0036 - val_mean_squared_error: 0.0020 Epoch 97/100 4/4 [==============================] - 0s 32ms/step - loss: 2.9386e-04 - mean_squared_error: 1.7345e-04 - val_loss: 0.0036 - val_mean_squared_error: 0.0020 Epoch 98/100 4/4 [==============================] - 0s 32ms/step - loss: 2.8144e-04 - mean_squared_error: 1.6661e-04 - val_loss: 0.0036 - val_mean_squared_error: 0.0021 Epoch 99/100 4/4 [==============================] - 0s 31ms/step - loss: 2.7749e-04 - mean_squared_error: 1.6380e-04 - val_loss: 0.0037 - val_mean_squared_error: 0.0021 Epoch 100/100 4/4 [==============================] - 0s 47ms/step - loss: 2.9252e-04 - mean_squared_error: 1.7434e-04 - val_loss: 0.0036 - val_mean_squared_error: 0.0020
Predict¶
In [10]:
p = model.predict(x_test[2:3])
print(p.shape)
1/1 [==============================] - 0s 138ms/step (1, 136)
Visualization Prediction¶
In [11]:
p = np.reshape(p, [68, 2])
plt.imshow(x_test[2])
for point in p:
plt.plot(point[0]*56, point[1]*56, "r+")
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
728x90
반응형
'Graduate School > Neural Network' 카테고리의 다른 글
LSTM을 이용한 주식 가격 예측 (6) | 2024.09.10 |
---|---|
Image Prediction (0) | 2024.09.10 |
Linear Classifier 02 (0) | 2024.09.10 |
Linear Classifier 01 (0) | 2024.09.10 |
Contents
소중한 공감 감사합니다