I have fixed your code.
There were 2 issues.
The main one was your return from the body function.
You need to understand that at every iteration of the "loop", the output of the previous iteration while be inject to the body function.
In your code, you were return the sim value as the first column in the tupple. So, in the second "iteration", you don't have t1 there anymore. You have just a single value. That was the reason why you were getting that error. You were trying to index the "sim" value and not t1.
The other issue was when calling tf.image.ssim. It expects a batch of images, it basically needs [1,64,64,1] but you were passing [64,64,1]
def body(t1, t2, i, iters, sim):
a = np.expand_dims(t1[i],axis = 2)
a = tf.expand_dims(a, axis=0) #make a batch
b = tf.expand_dims(t2, axis=0) #make a batch
sim = tf.image.ssim(a, b, 255)
return [t1, t2, tf.add(i,1), iters, sim]
Here is the whole code:
import numpy as np
import tensorflow as tf
#Importing a generic dataset from Keras
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data(
path='mnist.npz'
)
#This is my reference image
x_reference = np.expand_dims(x_train[0],axis = 2)
t1 = tf.constant(x_train)
t2 = tf.constant(x_reference)
iters = tf.constant(60000)
def cond(t1, t2, i, iters, sim):
return tf.less(i, iters)
def body(t1, t2, i, iters, sim):
a = np.expand_dims(t1[i],axis = 2)
a = tf.expand_dims(a, axis=0) #make is a batch
b = tf.expand_dims(t2, axis=0) #make is a batch
sim = tf.image.ssim(a, b, 255)
return [t1, t2, tf.add(i,1), iters, sim]
res = tf.while_loop(cond, body, [t1, t2, 0 , iters, 0], parallel_iterations=60000)