TensorFlow 分布式训练

老师你好,想请问下 strategy = tf.distribute.MirroredStrategy () 支持 Model 子类化创建自定义模型使用多 gpu 分布式训练么,我自定义了一个 Resnet18 模型进行 cifar100 的多分布训练,但训练效果很差,测试集 acc 才 1%左右,好像参数无法更新,而在改变 Sequential 按层顺序创建模型,其他不变,效果就提升很多,
构建模型的代码如下,麻烦看下有什么问题:

with strategy.scope ():
    # 定义一个 3x3 卷积
    def regularized_padded_conv (*args,**kwargs):
        return layers.Conv2D (*args,**kwargs,padding='same',kernel_regularizer=regularizers.l2 (5e-5),
                             use_bias=False,kernel_initializer='glorot_normal')

    # 定义 Basic Block 模块。对应 Resnet18 和 Resnet34
    class BasecBlock (layers.Layer):
        expansion=1
    
        def __init__(self,in_channels,out_channels,stride=1):
            super (BasecBlock,self).__init__()
            #1
            self.conv1=regularized_padded_conv (out_channels,kernel_size=3,strides=stride)
            self.bn1=layers.BatchNormalization ()
        
            #2
            self.conv2=regularized_padded_conv (out_channels,kernel_size=3,strides=1)
            self.bn2=layers.BatchNormalization ()
        
            #3
            if stride!=1 or in_channels!=self.expansion * out_channels:
                self.shortcut= Sequential ([regularized_padded_conv (self.expansion * out_channels,kernel_size=3,strides=stride),
                                      layers.BatchNormalization ()])
            else :
                self.shortcut= lambda x ,_ : x 
        
        @tf.function
        def call (self,inputs,training=False):
        
            x=self.conv1 (inputs)
            x=self.bn1 (x,training=training)
            x=tf.nn.relu (x)
        
            x=self.conv2 (x)
            x=self.bn2 (x,training=training)
        
            x_short=self.shortcut (inputs,training)
        
            x=x+x_short
            out=tf.nn.relu (x)
        
            return out 

    # 自定义模型,ResBlock 模块。继承 keras.Model 或者 keras.Layer 都可以
    class ResNet (tf.keras.Model):
    
        # 第 1 个参数 blocks:对应 2 个自定义模块 BasicBlock 和 Bottleneck, 其中 BasicBlock 对应 res18 和 res34,Bottleneck 对应 res50,res101 和 res152,
        # 第 2 个参数 layer_dims:[2, 2, 2, 2] 4 个 Res Block,每个包含 2 个 Basic Block
        # 第 3 个参数 num_classes:我们的全连接输出,取决于输出有多少类
        def __init__(self,blocks,layer_dims,initial_filters=64,num_classes=100):
            super (ResNet,self).__init__()
        
            self.in_channels=initial_filters
        
            #
            self.stem=Sequential ([regularized_padded_conv (initial_filters,kernel_size=3,strides=1),
                                layers.BatchNormalization ()])
        
            #
            self.layer1=self.build_resblock (blocks,initial_filters,layer_dims [0],stride=1)
            self.layer2=self.build_resblock (blocks,initial_filters*2,layer_dims [1],stride=2)
            self.layer3=self.build_resblock (blocks,initial_filters*4,layer_dims [2],stride=2)
            self.layer4=self.build_resblock (blocks,initial_filters*8,layer_dims [3],stride=2)
        
            self.final_bn=layers.BatchNormalization ()
        
            self.avg_pool=layers.GlobalAveragePooling2D ()
            self.dense=layers.Dense (num_classes)
        
        def build_resblock (self,blocks,out_channels,num_blocks,stride):
            strides=[stride]+[1]*(num_blocks-1)
            res_block=Sequential ()
        
            for stride in strides:
                res_block.add (blocks (self.in_channels,out_channels,stride))
                self.in_channels=out_channels
            return res_block
    
        @tf.function
        def call (self,inputs,training):
        
            x=self.stem (inputs,training)
            x=tf.nn.relu (x)
        
            x=self.layer1 (x,training=training)
            x=self.layer2 (x,training=training)
            x=self.layer3 (x,training=training)
            x=self.layer4 (x,training=training)
        
            x=self.final_bn (x,training=training)
            x=tf.nn.relu (x)
            x=self.avg_pool (x)
            out=self.dense (x)
        
            return out
        
    def resnet18 ():
        return ResNet (BasecBlock,[2,2,2,2])
    model=resnet18 ()
    model.build (input_shape=(None,32,32,3))