Tensorboard追踪不到网络模型图

代码如下

import keras
import numpy as np
import tensorflow as tf
from tensorflow.keras.layers import *
from tensorflow.keras import *


class Mnist_Model(Model):
    def __init__(self):
        super(Mnist_Model, self).__init__()
        self.flatten = Flatten()
        self.d1 = Dense(128, activation='relu')
        self.d2 = Dense(10, activation='softmax')

    def call(self, inputs, training=None, mask=None):
        x = self.flatten(inputs)
        x = self.d1(x)
        y = self.d2(x)
        return y


#  初始化模型
mnist = Mnist_Model()
#  定义优化器
optimizer = tf.keras.optimizers.Adam(learning_rate=1e-3)
#  定义损失函数
scc_loss = tf.keras.losses.SparseCategoricalCrossentropy()
#  定义损失评价标准
loss_metric = tf.keras.metrics.SparseCategoricalAccuracy()
#  加载数据集

(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()

# x_train = x_train.reshape(60000, 784).astype("float32") / 255
x_train = x_train / 255
train = tf.data.Dataset.from_tensor_slices((x_train, y_train))
train = train.batch(64)

#  可视化日志目录
log_dir = "./logs"
summary_writer = tf.summary.create_file_writer(log_dir)
tf.summary.trace_on(graph=True, profiler=True)
# tf.summary.trace_on(graph=True, profiler=False)
#  定义训练总论数
epochs = 2
#  打开每一轮
x_step = 0
for epoch in range(epochs):
    print("第 %d 开始" % (epoch,))
    #  打开每一个块
    for step, x_batch_train in enumerate(train):
        x_tr, y_gt = x_batch_train
        with tf.GradientTape() as tape:
            y_pre = mnist(x_tr)
            loss = scc_loss(y_true=y_gt, y_pred=y_pre)

        grads = tape.gradient(loss, mnist.trainable_variables)
        optimizer.apply_gradients(zip(grads, mnist.trainable_variables))

        loss_metric(y_true=y_gt, y_pred=y_pre)

        if step % 100 == 0:
            print("步骤 %d: 平均准确度 = %.4f" % (step, loss_metric.result()))
        x_step = x_step + 1
        with summary_writer.as_default():
            tf.summary.scalar('loss', loss, step=x_step)
            tf.summary.scalar('loss_metric', loss_metric.result(), step=x_step)

with summary_writer.as_default():
    tf.summary.trace_export(name="model_trace", step=0, profiler_outdir=log_dir)
mnist.save('./model/Mnist_model0')

要看模型的图结构的话,需要使用 tf.function 把模型转换成TensorFlow图模型。参考 https://tf.wiki/zh_hans/basic/tools.html#graphprofileTensorFlow常用模块 — 简单粗暴 TensorFlow 2 0.4 beta 文档