Home > AI > Uncategorized

Tensorflow – Preserve intermediate results

# API
# tf.add_to_collection()
# tf.get_collection()
# Functions:
# to preserve intermediate results to check network problem, such as same loss.


import tensorflow as tf

a = tf.get_variable(shape=[2, 3, 3, 3], dtype=tf.float32, initializer=tf.truncated_normal_initializer(stddev=0.1), name='a')

b = tf.get_variable(shape=[2, 2], dtype=tf.float32, initializer=tf.truncated_normal_initializer(stddev=0.1), name='b')

tf.add_to_collection('network', a)
tf.add_to_collection('network', b)

def fc(input, category):
    batch = input.shape.as_list()[0]
    height =  input.shape.as_list()[1]
    width = input.shape.as_list()[2]
    channel = input.shape.as_list()[3]
    
    input_size = height * width * channel
    output_size = category
    w_init = tf.truncated_normal([input_size, output_size], stddev=0.1)
    w = tf.Variable(w_init, name='w')

    b = tf.Variable(tf.ones([output_size]))
    
    batch = input.shape.as_list()[0]
    input = tf.reshape(input, [batch, -1])
    
    output = tf.matmul(input, w) + b
    return output


with tf.Session() as se:
    
    output = fc(a, 2)
    
    se.run(tf.global_variables_initializer())
    
    for i in tf.get_collection('network'):
        print('name: {}, shape: {}'.format(i.name, i.shape))
        print(se.run(i))
        print('--------------')
    print(se.run(output))

 

Related posts:

Leave a Reply