赞
踩
RNN用来处理序列数据
具有记忆能力
- LSTM可以避免梯度消失的问题
- LSTM的记忆力要比SimpleRNN强
结构图:
文本分类任务中,CNN可以用来提取句子中类似N-Gram的关键信息,适合短句子文本。TextRNN擅长捕获更长的序列信息。具体到文本分类任务中,从某种意义上可以理解为可以捕获变长、单向的N-Gram信息(Bi-LSTM可以是双向)。
一句话简介:textRNN指的是利用RNN循环神经网络解决文本分类问题,通常使用LSTM和GRU这种变形的RNN,而且使用双向,两层架构居多。
基本处理步骤:
流程:embedding—>BiLSTM—>concat final output/average all output—–>softmax layer
两种形式:
上述结构也可以添加dropout/L2正则化或BatchNormalization 来防止过拟合以及加速模型训练。
任务:输入前两个此,预测下一个词
import torch
import numpy as np
import torch.nn as nn
import torch.optim as optim
import torch.utils.data as Data
dtype = torch.FloatTensor
sentences = [ "i like dog", "i love coffee", "i hate milk"]
word_list = " ".join(sentences).split()
vocab = list(set(word_list))
word2idx = {
w: i for i, w in enumerate(vocab)}
idx2word = {
i: w for i, w in enumerate(vocab)}
n_class = len(vocab)
# TextRNN Parameter batch_size = 2 n_step = 2 # number of cells(= number of Step) 输入有多少个单词 n_hidden = 5 # number of hidden units in one cell def make_data(sentences): input_batch = [] target_batch = [] for sen in sentences: word = sen.split() input = [word2idx[n] for n in word[:-1]] target = word2idx[word[-1]] input_batch.append(np.eye(n_class)[input]) # one-hot编码 target_batch.append(target) return input_batch, target_batch input_batch, target_batch = make_data(sentences) input_batch, target_batch = torch.Tensor(input_batch), torch.LongTensor(target_batch) dataset = Data.TensorDataset(input_batch, target_batch) loader = Data.DataLoader
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。