当前位置:   article > 正文

【论文复现】使用fastText进行文本分类_github上面的fasttext-main复现

github上面的fasttext-main复现

写在前面

今天是补笔记的一天。。。

今天的论文是来自Facebook AI Research的Bag of Tricks for Efficient Text Classification

也就是我们常用的fastText

最让人欣喜的这篇论文配套提供了fasttext工具包。这个工具包代码质量非常高,论文结果一键还原,目前已经是包装地非常专业了,这是fastText官网和其github代码库,以及提供了python接口,可以直接通过pip安装。这样准确率高又快的模型绝对是实战利器。

为了更好地理解fasttext原理,直接复现了一遍,但是代码中紧紧实现了最简单的基于单词的词向量求平均,并未使用b-gram的词向量,所以自己实现的文本分类效果会低于facebook开源的库:https://github.com/KaiyuanGao/text_claasification/tree/master/fastText

 

1*u8bD51OCnQEOSeQUa4woUQ.pnguploading.4e448015.gif转存失败重新上传取消

 论文概览

We can train fastText on more than one billion words in less than ten minutes using a standard multicore CPU, and classify half a million sentences among 312K classes in less than a minute.

首先引用论文中的一段话来看看作者们是怎么评价fasttext模型的表现的。

这篇论文的模型非常之简单,之前了解过word2vec的同学可以发现这跟CBOW的模型框架非常相似。

对应上面这个模型,比如输入是一句话,x1到xn就是这句话的单词或者是n-gram。每一个都对应一个向量,然后对这些向量取平均就得到了文本向量,然后用这个平均向量取预测标签。当类别不多的时候,就是最简单的softmax;当标签数量巨大的时候,就要用到hierarchical softmax了。

模型真的很简单,也没什么可以说的了。下面提一下论文中的两个tricks:

  • hierarchical softmax
    • 类别数较多时,通过构建一个霍夫曼编码树来加速softmax layer的计算,和之前word2vec中的trick相同
  • N-gram features
    • 只用unigram的话会丢掉word order信息,所以通过加入N-gram features进行补充
    • 用hashing来减少N-gram的存储

模型表现

来看一下fasttext的试验结果,如此简单的模型竟然能取得这么好的效果 !

但是也有人指出论文中选取的数据集都是对句子词序不是很敏感的数据集,所以得到文中的试验结果并不奇怪。

总结

这篇论文在模型创新方面明显新意不足。首先网络结构照搬Word2vec,只是把单词换成了label。其他的一些创新也在之前的工作中有人做过了。但是fasttext依然产生了巨大的影响,我觉得最主要的就是其良心开源代码,在github上收到超高的人气。

代码实现

我觉得当时可能直接去读他们开源的代码会比较好.....

  1. class fastTextModel(BaseModel):
  2. """
  3. A simple implementation of fasttext for text classification
  4. """
  5. def __init__(self, sequence_length, num_classes, vocab_size,
  6. embedding_size, learning_rate, decay_steps, decay_rate,
  7. l2_reg_lambda, is_training=True,
  8. initializer=tf.random_normal_initializer(stddev=0.1)):
  9. self.vocab_size = vocab_size
  10. self.embedding_size = embedding_size
  11. self.num_classes = num_classes
  12. self.sequence_length = sequence_length
  13. self.learning_rate = learning_rate
  14. self.decay_steps = decay_steps
  15. self.decay_rate = decay_rate
  16. self.is_training = is_training
  17. self.l2_reg_lambda = l2_reg_lambda
  18. self.initializer = initializer
  19. self.input_x = tf.placeholder(tf.int32, [None, self.sequence_length], name='input_x')
  20. self.input_y = tf.placeholder(tf.int32, [None, self.num_classes], name='input_y')
  21. self.global_step = tf.Variable(0, trainable=False, name='global_step')
  22. self.instantiate_weight()
  23. self.logits = self.inference()
  24. self.loss_val = self.loss()
  25. self.train_op = self.train()
  26. self.predictions = tf.argmax(self.logits, axis=1, name='predictions')
  27. correct_prediction = tf.equal(self.predictions, tf.argmax(self.input_y, 1))
  28. self.accuracy = tf.reduce_mean(tf.cast(correct_prediction, 'float'), name='accuracy')
  29. def instantiate_weight(self):
  30. with tf.name_scope('weights'):
  31. self.Embedding = tf.get_variable('Embedding', shape=[self.vocab_size, self.embedding_size],
  32. initializer=self.initializer)
  33. self.W_projection = tf.get_variable('W_projection', shape=[self.embedding_size, self.num_classes],
  34. initializer=self.initializer)
  35. self.b_projection = tf.get_variable('b_projection', shape=[self.num_classes])
  36. def inference(self):
  37. """
  38. 1. word embedding
  39. 2. average embedding
  40. 3. linear classifier
  41. :return:
  42. """
  43. # embedding layer
  44. with tf.name_scope('embedding'):
  45. words_embedding = tf.nn.embedding_lookup(self.Embedding, self.input_x)
  46. self.average_embedding = tf.reduce_mean(words_embedding, axis=1)
  47. logits = tf.matmul(self.average_embedding, self.W_projection) +self.b_projection
  48. return logits
  49. def loss(self):
  50. # loss
  51. with tf.name_scope('loss'):
  52. losses = tf.nn.softmax_cross_entropy_with_logits(labels=self.input_y, logits=self.logits)
  53. data_loss = tf.reduce_mean(losses)
  54. l2_loss = tf.add_n([tf.nn.l2_loss(cand_var) for cand_var in tf.trainable_variables()
  55. if 'bias' not in cand_var.name]) * self.l2_reg_lambda
  56. data_loss += l2_loss * self.l2_reg_lambda
  57. return data_loss
  58. def train(self):
  59. with tf.name_scope('train'):
  60. learning_rate = tf.train.exponential_decay(self.learning_rate, self.global_step,
  61. self.decay_steps, self.decay_rate,
  62. staircase=True)
  63. train_op = tf.contrib.layers.optimize_loss(self.loss_val, global_step=self.global_step,
  64. learning_rate=learning_rate, optimizer='Adam')
  65. return train_op
  1. import tensorflow as tf
  2. import numpy as np
  3. import os
  4. import time
  5. import datetime
  6. from cnn_classification import data_process
  7. from fastText import fastTextModel
  8. from tensorflow.contrib import learn
  9. # define parameters
  10. #data load params
  11. tf.flags.DEFINE_string("positive_data_file", "../cnn_classification/data/rt-polarity.pos", "Data source for the positive data.")
  12. tf.flags.DEFINE_string("negative_data_file", "../cnn_classification/data/rt-polarity.neg", "Data source for the negative data.")
  13. #configuration
  14. tf.flags.DEFINE_float("learning_rate", 0.01, "learning rate")
  15. tf.flags.DEFINE_integer("num_epochs", 60, "embedding size")
  16. tf.flags.DEFINE_integer("batch_size", 100, "Batch size for training/evaluating.") #批处理的大小 32-->128
  17. tf.flags.DEFINE_integer("decay_steps", 12000, "how many steps before decay learning rate.")
  18. tf.flags.DEFINE_float("decay_rate", 0.9, "Rate of decay for learning rate.") # 0.5一次衰减多少
  19. tf.flags.DEFINE_string("ckpt_dir", "text_fastText_checkpoint/", "checkpoint location for the model")
  20. tf.flags.DEFINE_integer('num_checkpoints', 10, 'save checkpoints count')
  21. tf.flags.DEFINE_integer("sequence_length", 300, "max sentence length")
  22. tf.flags.DEFINE_integer("embedding_size", 128, "embedding size")
  23. tf.flags.DEFINE_boolean("is_training", True, "is traning.true:tranining,false:testing/inference")
  24. tf.flags.DEFINE_integer("validate_every", 1, "Validate every validate_every epochs.") #每10轮做一次验证
  25. tf.flags.DEFINE_float("dev_sample_percentage", .1, "Percentage of the training data to use for validation")
  26. tf.flags.DEFINE_integer('dev_sample_max_cnt', 1000, 'max cnt of validation samples, dev samples cnt too large will case high loader')
  27. tf.flags.DEFINE_float("dropout_keep_prob", 0.5, "Dropout keep probability (default: 0.5)")
  28. tf.flags.DEFINE_float("l2_reg_lambda", 0.0001, "L2 regularization lambda (default: 0.0)")
  29. tf.flags.DEFINE_boolean("allow_soft_placement", True, "Allow device soft device placement")
  30. tf.flags.DEFINE_boolean("log_device_placement", False, "Log placement of ops on devices")
  31. FLAGS = tf.flags.FLAGS
  32. def prepocess():
  33. """
  34. For load and process data
  35. :return:
  36. """
  37. print("Loading data...")
  38. x_text, y = data_process.load_data_and_labels(FLAGS.positive_data_file, FLAGS.negative_data_file)
  39. # bulid vocabulary
  40. max_document_length = max(len(x.split(' ')) for x in x_text)
  41. vocab_processor = learn.preprocessing.VocabularyProcessor(max_document_length)
  42. x = np.array(list(vocab_processor.fit_transform(x_text)))
  43. # shuffle
  44. np.random.seed(10)
  45. shuffle_indices = np.random.permutation(np.arange(len(y)))
  46. x_shuffled = x[shuffle_indices]
  47. y_shuffled = y[shuffle_indices]
  48. # split train/test dataset
  49. dev_sample_index = -1 * int(FLAGS.dev_sample_percentage * float(len(y)))
  50. x_train, x_dev = x_shuffled[:dev_sample_index], x_shuffled[dev_sample_index:]
  51. y_train, y_dev = y_shuffled[:dev_sample_index], y_shuffled[dev_sample_index:]
  52. del x, y, x_shuffled, y_shuffled
  53. print('Vocabulary Size: {:d}'.format(len(vocab_processor.vocabulary_)))
  54. print('Train/Dev split: {:d}/{:d}'.format(len(y_train), len(y_dev)))
  55. return x_train, y_train, vocab_processor, x_dev, y_dev
  56. def train(x_train, y_train, vocab_processor, x_dev, y_dev):
  57. with tf.Graph().as_default():
  58. session_conf = tf.ConfigProto(
  59. # allows TensorFlow to fall back on a device with a certain operation implemented
  60. allow_soft_placement= FLAGS.allow_soft_placement,
  61. # allows TensorFlow log on which devices (CPU or GPU) it places operations
  62. log_device_placement=FLAGS.log_device_placement
  63. )
  64. sess = tf.Session(config=session_conf)
  65. with sess.as_default():
  66. # initialize cnn
  67. fasttext = fastTextModel(sequence_length=x_train.shape[1],
  68. num_classes=y_train.shape[1],
  69. vocab_size=len(vocab_processor.vocabulary_),
  70. embedding_size=FLAGS.embedding_size,
  71. l2_reg_lambda=FLAGS.l2_reg_lambda,
  72. is_training=True,
  73. learning_rate=FLAGS.learning_rate,
  74. decay_steps=FLAGS.decay_steps,
  75. decay_rate=FLAGS.decay_rate
  76. )
  77. # output dir for models and summaries
  78. timestamp = str(time.time())
  79. out_dir = os.path.abspath(os.path.join(os.path.curdir, 'run', timestamp))
  80. if not os.path.exists(out_dir):
  81. os.makedirs(out_dir)
  82. print('Writing to {} \n'.format(out_dir))
  83. # checkpoint dir. checkpointing – saving the parameters of your model to restore them later on.
  84. checkpoint_dir = os.path.abspath(os.path.join(out_dir, FLAGS.ckpt_dir))
  85. checkpoint_prefix = os.path.join(checkpoint_dir, 'model')
  86. if not os.path.exists(checkpoint_dir):
  87. os.makedirs(checkpoint_dir)
  88. saver = tf.train.Saver(tf.global_variables(), max_to_keep=FLAGS.num_checkpoints)
  89. # Write vocabulary
  90. vocab_processor.save(os.path.join(out_dir, 'vocab'))
  91. # Initialize all
  92. sess.run(tf.global_variables_initializer())
  93. def train_step(x_batch, y_batch):
  94. """
  95. A single training step
  96. :param x_batch:
  97. :param y_batch:
  98. :return:
  99. """
  100. feed_dict = {
  101. fasttext.input_x: x_batch,
  102. fasttext.input_y: y_batch,
  103. }
  104. _, step, loss, accuracy = sess.run(
  105. [fasttext.train_op, fasttext.global_step, fasttext.loss_val, fasttext.accuracy],
  106. feed_dict=feed_dict
  107. )
  108. time_str = datetime.datetime.now().isoformat()
  109. print("{}: step {}, loss {:g}, acc {:g}".format(time_str, step, loss, accuracy))
  110. def dev_step(x_batch, y_batch):
  111. """
  112. Evaluate model on a dev set
  113. Disable dropout
  114. :param x_batch:
  115. :param y_batch:
  116. :param writer:
  117. :return:
  118. """
  119. feed_dict = {
  120. fasttext.input_x: x_batch,
  121. fasttext.input_y: y_batch,
  122. }
  123. step, loss, accuracy = sess.run(
  124. [fasttext.global_step, fasttext.loss_val, fasttext.accuracy],
  125. feed_dict=feed_dict
  126. )
  127. time_str = datetime.datetime.now().isoformat()
  128. print("dev results:{}: step {}, loss {:g}, acc {:g}".format(time_str, step, loss, accuracy))
  129. # generate batches
  130. batches = data_process.batch_iter(list(zip(x_train, y_train)), FLAGS.batch_size, FLAGS.num_epochs)
  131. # training loop
  132. for batch in batches:
  133. x_batch, y_batch = zip(*batch)
  134. train_step(x_batch, y_batch)
  135. current_step = tf.train.global_step(sess, fasttext.global_step)
  136. if current_step % FLAGS.validate_every == 0:
  137. print('\n Evaluation:')
  138. dev_step(x_dev, y_dev)
  139. print('')
  140. path = saver.save(sess, checkpoint_prefix, global_step=current_step)
  141. print('Save model checkpoint to {} \n'.format(path))
  142. def main(argv=None):
  143. x_train, y_train, vocab_processor, x_dev, y_dev = prepocess()
  144. train(x_train, y_train, vocab_processor, x_dev, y_dev)
  145. if __name__ == '__main__':
  146. tf.app.run()

对啦,我这里使用的数据集还是之前训练CNN时的那一份喔

以上~

 

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/小桥流水78/article/detail/967870?site=
推荐阅读
相关标签
  

闽ICP备14008679号