import numpy as np
import sys
import tensorflow as tf
import matplotlib.pyplot as plt
from tensorflow import keras
from functools import partial
import IPython.display as display
print("finish")
tf.__version__
finish
'2.3.0'
from google.colab import drive
drive.mount('/content/drive')
Mounted at /content/drive
import glob
from random import shuffle

train_file_path = "/content/drive/My Drive/datasets/cfd_tum/train/*.npz"
train_files_npz = glob.glob(train_file_path)
test_file_path = "/content/drive/My Drive/datasets/cfd_tum/test/*.npz"
test_files_npz = glob.glob(test_file_path)
list_filenames = test_files_npz
#list_filenames = list_filenames[:10]
tfrecords_train_name = "/content/drive/My Drive/datasets/cfd_tum/cfdTUM-Training.tfrecords"
tfrecords_test_name  = "/content/drive/My Drive/datasets/cfd_tum/cfdTUM-Test.tfrecords"
tfrecords_filename   = tfrecords_test_name

def _create_byte_feature(value):
    value = value.numpy()
    return tf.train.Feature(bytes_list = tf.train.BytesList(value = [value]))


with tf.io.TFRecordWriter(tfrecords_filename) as writer:
  for filename in list_filenames:
    data = np.load(filename)
    raw_data = data["a"]
    b_vel_x =  np.reshape(raw_data[0], [-1, 1]).squeeze()
    b_vel_y =  np.reshape(raw_data[1], [-1, 1]).squeeze()
    b_geo   =  np.reshape(raw_data[2], [-1, 1]).squeeze()
    p_vel_x =  np.reshape(raw_data[3], [-1, 1]).squeeze()
    p_vel_y =  np.reshape(raw_data[4], [-1, 1]).squeeze()
    p_press =  np.reshape(raw_data[5], [-1, 1]).squeeze()

    b_vel_x = tf.io.serialize_tensor(b_vel_x)
    b_vel_y = tf.io.serialize_tensor(b_vel_y)
    b_geo   = tf.io.serialize_tensor(b_geo)
    p_vel_x = tf.io.serialize_tensor(p_vel_x)
    p_vel_y = tf.io.serialize_tensor(p_vel_y)
    p_press = tf.io.serialize_tensor(p_press) 

    feature = {"b_vel_x": _create_byte_feature(b_vel_x),
               "b_vel_y": _create_byte_feature(b_vel_y),
               "b_geo":   _create_byte_feature(b_geo),
               "p_vel_x": _create_byte_feature(p_vel_x),
               "p_vel_y": _create_byte_feature(p_vel_y),
               "p_press": _create_byte_feature(p_press)
                       }
    example_message = tf.train.Example(features=tf.train.Features(feature=feature))
    writer.write(example_message.SerializeToString())
writer.close()                       
print("Finish")                       
Finish
#feature_description = {'b_vel_x': tf.io.FixedLenFeature([], tf.string),
#                       'b_vel_y': tf.io.FixedLenFeature([], tf.string),
#                       'b_geo': tf.io.FixedLenFeature([], tf.string),
#                       'p_vel_x': tf.io.FixedLenFeature([], tf.string),
#                       'p_vel_y': tf.io.FixedLenFeature([], tf.string),
#                       'pressure': tf.io.FixedLenFeature([], tf.string)
#                      }
def _parse_tensor(exp):
    exp  = tf.io.parse_tensor(exp, out_type=tf.float64)
    return (tf.reshape(exp, (128, 128, 1)))

def _parse_image(example):
    feature_description = { 'b_vel_x': tf.io.FixedLenFeature([], tf.string),
                            'b_vel_y': tf.io.FixedLenFeature([], tf.string),
                            'b_geo': tf.io.FixedLenFeature([], tf.string),
                            'p_vel_x': tf.io.FixedLenFeature([], tf.string),
                            'p_vel_y': tf.io.FixedLenFeature([], tf.string),
                            'p_press': tf.io.FixedLenFeature([], tf.string)
                      }
    example = tf.io.parse_single_example(example, feature_description)
    #list_exp = list(example["b_geo"], example["b_vel_x"], example["b_vel_y"], example["p_vel_x"], example["p_vel_y"],  example["pressure"])
    #for j in list_exp:
    #    geo = _parse_tensor(j)
    b_geo = example["b_geo"]
    b_geo = tf.io.parse_tensor(b_geo, out_type=tf.float64)
    b_geo = tf.reshape(b_geo, (128, 128, 1))
    
    b_velx = example["b_vel_x"]
    b_velx = tf.io.parse_tensor(b_velx, out_type=tf.float64)
    b_velx = tf.reshape(b_velx, (128, 128, 1))

    
    b_vely = example["b_vel_y"]
    b_vely = tf.io.parse_tensor(b_vely, out_type=tf.float64)
    b_vely = tf.reshape(b_vely, (128, 128, 1))
    
    p_velx = example["p_vel_x"]
    p_velx = tf.io.parse_tensor(p_velx, out_type=tf.float64)
    p_velx = tf.reshape(p_velx, (128, 128, 1))
    
    p_vely = example["p_vel_y"]
    p_vely = tf.io.parse_tensor(p_vely, out_type=tf.float64)
    p_vely = tf.reshape(p_vely, (128, 128, 1))
    
    p_press = example["p_press"]
    p_press = tf.io.parse_tensor(p_press, out_type=tf.float64)
    p_press = tf.reshape(p_vely, (128, 128, 1))
    
    return b_geo, b_velx, b_vely, p_velx, p_vely, p_press


def load_dataset(filename):
    dataset = tf.data.TFRecordDataset(tfrecords_filename)
    dataset = dataset.map(_parse_image)
    return dataset

def get_dataset(filename, BATCH_SIZE):
    dataset = load_dataset(filename)
    dataset = dataset.shuffle(2048)
    #dataset = dataset.prefetch()
    dataset = dataset.batch(BATCH_SIZE)
    return dataset

BATCH_SIZE = 12
dataset = get_dataset(filename, BATCH_SIZE)
b_geo, b_velx, b_vely, p_velx, p_vely, p_pressure= next(iter(dataset))
#b_velx
print("Finish")
Finish
for i in range(3):
    fig = plt.figure(figsize = (10, 10))
    ax1 = fig.add_subplot(231)
    ax1 = ax1.imshow(b_geo[i, :, :, 0])
    plt.colorbar(ax1)
    ax2 = fig.add_subplot(232)
    ax2 = ax2.imshow(b_velx[i, :, :, 0])
    plt.colorbar(ax2)
    ax3 = fig.add_subplot(233)
    ax3 = ax3.imshow(b_vely[i, :, :, 0])
    plt.colorbar(ax3)   
    ax4 = fig.add_subplot(234)    
    ax4 = ax4.imshow(p_velx[i, :, :, 0]) 
    plt.colorbar(ax4)
    ax5 = fig.add_subplot(235)
    ax5 = ax5.imshow(p_vely[i, :, :, 0])
    plt.colorbar(ax5)
    ax6 = fig.add_subplot(236)
    ax6 = ax6.imshow(p_pressure[i, :, :, 0])
    plt.colorbar(ax6)