import tensorflow as tf
import os
# Dataset参数
BATCH_SIZE = 32
IMG_HEIGHT = 224
IMG_WIDTH = 224
NUM_CLASSES = 10
# 图像文件路径
data_path = "/path/to/image/data"
# 获取所有图像文件路径和其对应标签
image_paths = []
labels = []
for i in range(NUM_CLASSES):
folder_path = os.path.join(data_path, str(i))
for img_name in os.listdir(folder_path):
img_path = os.path.join(folder_path, img_name)
image_paths.append(img_path)
labels.append(i)
# 创建tf.data.Dataset
def parse_image(img_path, label):
img_str = tf.io.read_file(img_path)
img = tf.image.decode_png(img_str, channels=3)
img = tf.image.resize(img, [IMG_HEIGHT, IMG_WIDTH])
img = tf.cast(img, tf.float32) / 255.
return img, label
image_dataset = tf.data.Dataset.from_tensor_slices((image_paths, labels))
image_dataset = image_dataset.map(parse_image).batch(BATCH_SIZE)
# 运行模型
autoencoder = tf.keras.models.Sequential([ ... ])
autoencoder.compile(optimizer='adam', loss='mse')
autoencoder.fit(image_dataset, epochs=5)
在上述代码示例中,parse_image
函数定义了如何解析每个图像文件,并将其转换为合适的形式。tf.data.Dataset.from_tensor_slices
方法可接受一个元组作为输入,其中包含所有图像文件路径和对应标签。parse_image
函数将这些路径和标签转换为图像数据集,并使用batch
方法创建相同大小的数据批次。autoencoder
模型