当前位置:   article > 正文

bert模型后接TextCNN、LSTM_bert+cnn

bert+cnn

使用keras_bert来搭建模型

#bert
def get_model():
    bert_model = load_trained_model_from_checkpoint(config_path, checkpoint_path)
    for l in bert_model.layers:
        l.trainable = True
    T1 = Input(shape=(None,))
    T2 = Input(shape=(None,))
    T = bert_model([T1, T2])
    T = Lambda(lambda x: x[:, 0])(T)
    output = Dense(4, activation='softmax')(T)
    model = Model([T1, T2], output)
    model.compile(
        loss='categorical_crossentropy',
        optimizer=Adam(1e-5),  
        metrics=['accuracy']
    )
    model.summary()
    return model
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17

上面为bert模型代码,做一个四分类的任务,如果在T后面直接接TextCNN,会报错

def get_model():
    bert_model = load_trained_model_from_checkpoint(config_path, checkpoint_path)
    for l in bert_model.layers:
        l.trainable = True
    T1 = Input(shape=(None,))
    T2 = Input(shape=(None,))
    T = bert_model([T1, T2])
    convs = []
    for kernel_size in [3, 4, 5]:
        c = Conv1D(128, kernel_size, activation='relu')(T)
        c = GlobalMaxPooling1D()(c)
        convs.append(c)
    x = Concatenate()(convs)
    output = Dense(4, activation='softmax')(x)
    model = Model([T1, T2], output)
    model.compile(
        loss='categorical_crossentropy',
        optimizer=Adam(1e-5),  
        metrics=['accuracy']
    )
    model.summary()
    return model
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21

上面为bert接TextCNN模型代码,报错如下:

TypeError: Layer conv1d_1 does not support masking, but was passed an input_mask: Tensor("model_2/Encoder-12-FeedForward-Add/All:0", shape=(?, ?), dtype=bool)

    报错原因是CNN层不支持masking的输入,因此自己定义一个Nonmasking层,加入到CNN层之前:

    class NonMasking(Layer):
        def __init__(self, **kwargs):
            self.supports_masking = True
            super(NonMasking, self).__init__(**kwargs)
        def build(self, input_shape):
            input_shape = input_shape
        def compute_mask(self, input, input_mask=None):
            # do not pass the mask to the next layers
            return None
        def call(self, x, mask=None):
            return x
        def get_output_shape_for(self, input_shape):
            return input_shape
    
    def get_model():
        bert_model = load_trained_model_from_checkpoint(config_path, checkpoint_path)
        for l in bert_model.layers:
            l.trainable = True
        T1 = Input(shape=(None,))
        T2 = Input(shape=(None,))
        T = bert_model([T1, T2])
        T = NonMasking()(T)
        convs = []
        for kernel_size in [3, 4, 5]:
            c = Conv1D(128, kernel_size, activation='relu')(T)
            c = GlobalMaxPooling1D()(c)
            convs.append(c)
        x = Concatenate()(convs)
        output = Dense(4, activation='softmax')(x)
        model = Model([T1, T2], output)
        model.compile(
            loss='categorical_crossentropy',
            optimizer=Adam(1e-5),  
            metrics=['accuracy']
        )
        model.summary()
        return model
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36

    上面为bert接TextCNN模型,就不会报错了
    当然也可以接LSTM等其他模型

    def get_model():
        bert_model = load_trained_model_from_checkpoint(config_path, checkpoint_path)
        for l in bert_model.layers:
            l.trainable = True
        T1 = Input(shape=(None,))
        T2 = Input(shape=(None,))
        T = bert_model([T1, T2])
        x = LSTM(128, return_sequences=False)(T)
        output = Dense(4, activation='softmax')(x)
        model = Model([T1, T2], output)
        model.compile(
            loss='categorical_crossentropy',
            optimizer=Adam(1e-5),  
            metrics=['accuracy']
        )
        model.summary()
        return model
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/笔触狂放9/article/detail/342446
    推荐阅读
    相关标签
      

    闽ICP备14008679号