import tensorflow as tf

def count_nums(true_labels, num_classes):
    initial_value = 0
    list_length = num_classes
    list_data = [ initial_value for i in range(list_length)]
    for i in range(0, num_classes):
        list_data[i] = true_labels.count(i)
    return list_data

def accuracy(confusion_matrix, true_labels, num_classes):
    # 各个类别的测试样本的个数
    list_data = count_nums(true_labels, num_classes)

    # 各个类别正确分类的个数
    initial_value = 0
    list_length = num_classes
    true_pred = [ initial_value for i in range(list_length)]
    for i in range(0,5):
        true_pred[i] = confusion_matrix[i][i]

    # 计算各个样本被正确分类的正确率
    acc = []
    for i in range(0, 5):
        acc.append(0)

    for i in range(0,5):
        acc[i] = true_pred[i] / list_data[i]

    return acc

# 测试数据
y_true = [0, 1, 2, 3, 1, 2, 3, 4, 1] # 真实的标签
y_pred = [1, 1, 2, 3, 1, 2, 3, 4, 2] # 预测的标签

# Build graph with tf.confusion_matrix operation
sess = tf.InteractiveSession()
op = tf.confusion_matrix(y_true, y_pred)
# Execute the graph
print ("confusion matrix in tensorflow: ")
confusion_matrix = sess.run(op)
print(confusion_matrix)
sess.close()

# 计算各个类别的正确率
acc = accuracy(confusion_matrix, y_true, num_classes = 5)
print(acc)