import numpy as np
import tensorflow as tf
import glob
%matplotlib inline

1. TFRecord Format

  • doesn't know anything about image formats
  • can save both dense arrays or image formats
  • in contrast to imread and imsave TF decouples reading/decoding and encoding/writting

2. Reading Unknown Data

raw_records = tf.data.TFRecordDataset("images/TFRecords/my-tfR.tfrecords")
for raw_record in raw_records.take(1):
    print("")
    #example = tf.train.Example()
    #example.ParseFromString(raw_record.numpy())
#example

3. TFRecord format (PNG raw file)

raw_image_dataset = tf.data.TFRecordDataset("images/TFRecords/my-tfR.tfrecords")

image_feature_description = {
    'height': tf.io.FixedLenFeature([],   tf.int64),
    'width':  tf.io.FixedLenFeature([],   tf.int64),
    'no_c':   tf.io.FixedLenFeature([],   tf.int64),
    'raw_image': tf.io.FixedLenFeature([], tf.string),
}

def _parse_image_function(example_proto):
  # Parse the input tf.train.Example proto using the dictionary above.
   example = tf.io.parse_single_example(example_proto,image_feature_description)
   raw_image = example["raw_image"]#.numpy() #this is a tensor  with bytes
   raw_image = tf.io.decode_png(raw_image,3) # this is a tensor with float32
   shape_h = example["height"]
   shape_w = example["width"]
   no_c =  example["no_c"]
   #raw_image = tf.reshape(raw_image, [400, 600,3])
   #raw_image = tf.cast(raw_image, tf.float32)
   return raw_image 

parsed_image_dataset = raw_image_dataset.map(_parse_image_function)
parsed_image_dataset
<MapDataset shapes: (None, None, 3), types: tf.uint8>
for image in parsed_image_dataset.take(1):
    print(image.shape)
(600, 400, 3)

4. TFRecord format (JPEG raw file)

def _parse_image_fct(example_proto):
   image_feature_description = {
                        'height': tf.io.FixedLenFeature([],   tf.int64),
                        'width':  tf.io.FixedLenFeature([],   tf.int64),
                        'no_c':   tf.io.FixedLenFeature([],   tf.int64),
                        'raw_image': tf.io.FixedLenFeature([], tf.string),
                                }
  # Parse the input tf.train.Example proto using the dictionary above.
   example = tf.io.parse_single_example(example_proto,image_feature_description)
   raw_image = example["raw_image"] #this is a tensor  with bytes
   raw_image = tf.io.decode_jpeg(contents = raw_image, channels = 0)
   shape_h = example["height"]
   shape_w = example["width"]
   no_c =  example["no_c"]
   #print(shape_h, shape_w, no_c)
   #raw_image = tf.reshape(raw_image, [shape_h, shape_w, no_c])
   raw_image = tf.cast(raw_image, tf.float32)
   return raw_image

#parsed_image_dataset = raw_image_dataset.map(_parse_image_function)

def load_dataset(filename):
    dataset = tf.data.TFRecordDataset(filename)
    dataset = dataset.map(_parse_image_fct)
    return dataset

def get_dataset(filename, BATCH_SIZE):
    dataset = load_dataset(filename)
    dataset = dataset.shuffle(2048)
    #dataset = dataset.prefetch()
    dataset = dataset.batch(BATCH_SIZE)
    return dataset


BATCH_SIZE = 2
filename = "images/TFRecords/my-tfR-JPEG.tfrecords"
dataset = get_dataset(filename, BATCH_SIZE)
tfR_image = next(iter(dataset))
for i in range(2):
    fig = plt.figure(figsize = (10, 10))
    ax1 = fig.add_subplot(212)
    ax1 = ax1.imshow(tfR_image[i, :, :, :]/255.)
    plt.colorbar(ax1)