Read Images from TFRecords Format
Read Images from TFRecords Format
- 1. TFRecord Format
- 2. Reading Unknown Data
- 3. TFRecord format (PNG raw file)
- 4. TFRecord format (JPEG raw file)
import numpy as np
import tensorflow as tf
import glob
%matplotlib inline
raw_records = tf.data.TFRecordDataset("images/TFRecords/my-tfR.tfrecords")
for raw_record in raw_records.take(1):
print("")
#example = tf.train.Example()
#example.ParseFromString(raw_record.numpy())
#example
raw_image_dataset = tf.data.TFRecordDataset("images/TFRecords/my-tfR.tfrecords")
image_feature_description = {
'height': tf.io.FixedLenFeature([], tf.int64),
'width': tf.io.FixedLenFeature([], tf.int64),
'no_c': tf.io.FixedLenFeature([], tf.int64),
'raw_image': tf.io.FixedLenFeature([], tf.string),
}
def _parse_image_function(example_proto):
# Parse the input tf.train.Example proto using the dictionary above.
example = tf.io.parse_single_example(example_proto,image_feature_description)
raw_image = example["raw_image"]#.numpy() #this is a tensor with bytes
raw_image = tf.io.decode_png(raw_image,3) # this is a tensor with float32
shape_h = example["height"]
shape_w = example["width"]
no_c = example["no_c"]
#raw_image = tf.reshape(raw_image, [400, 600,3])
#raw_image = tf.cast(raw_image, tf.float32)
return raw_image
parsed_image_dataset = raw_image_dataset.map(_parse_image_function)
parsed_image_dataset
for image in parsed_image_dataset.take(1):
print(image.shape)
def _parse_image_fct(example_proto):
image_feature_description = {
'height': tf.io.FixedLenFeature([], tf.int64),
'width': tf.io.FixedLenFeature([], tf.int64),
'no_c': tf.io.FixedLenFeature([], tf.int64),
'raw_image': tf.io.FixedLenFeature([], tf.string),
}
# Parse the input tf.train.Example proto using the dictionary above.
example = tf.io.parse_single_example(example_proto,image_feature_description)
raw_image = example["raw_image"] #this is a tensor with bytes
raw_image = tf.io.decode_jpeg(contents = raw_image, channels = 0)
shape_h = example["height"]
shape_w = example["width"]
no_c = example["no_c"]
#print(shape_h, shape_w, no_c)
#raw_image = tf.reshape(raw_image, [shape_h, shape_w, no_c])
raw_image = tf.cast(raw_image, tf.float32)
return raw_image
#parsed_image_dataset = raw_image_dataset.map(_parse_image_function)
def load_dataset(filename):
dataset = tf.data.TFRecordDataset(filename)
dataset = dataset.map(_parse_image_fct)
return dataset
def get_dataset(filename, BATCH_SIZE):
dataset = load_dataset(filename)
dataset = dataset.shuffle(2048)
#dataset = dataset.prefetch()
dataset = dataset.batch(BATCH_SIZE)
return dataset
BATCH_SIZE = 2
filename = "images/TFRecords/my-tfR-JPEG.tfrecords"
dataset = get_dataset(filename, BATCH_SIZE)
tfR_image = next(iter(dataset))
for i in range(2):
fig = plt.figure(figsize = (10, 10))
ax1 = fig.add_subplot(212)
ax1 = ax1.imshow(tfR_image[i, :, :, :]/255.)
plt.colorbar(ax1)