当前位置:   article > 正文

LSTM文本情感分类在Deeplearning4j+spark环境中的scala实践_deeplearning4j 情感分类

deeplearning4j 情感分类

         小编近来要做用户推荐相关的一些工作,根据调研,目前用于搜索、推荐等的算法以阿里妈妈、美团等团队的DIN等算法为主,因此需要祭出神经网络这样的武器,但奈何公司的大数据基础建设仍有待提高,直接使用高维度的item变量及用户变量进行模型开发可能会面临一定的工程难度,因此考虑使用循环神经网络的方法,将用户的点击页、浏览页作为一个个词,并用来评估用户下一步对产品、活动的兴趣点,而且LSTM的实现相对来讲也比较简单,不管是python的tensorflow还是keras都有现成的到layer的封装。笔者考虑到团队内部主要是Java为主的开发团队,只能先把Python那一套完整的东西放到一边,试试看spark+scala是不是能够将LSTM跑通。

        目前与spark、scala/JAVA集成最好的应该是skymind公司开源的Deeplearning4j(https://deeplearning4j.org/cn/spark ),其中的scalanet称得上是为scala工程师构建的keras,deeplearning4j是一款非常优秀的开源软件依赖,非常适合java工程师进行神经网络模型开发,与此同时,与spark和scala的集成也堪称优秀,笔者在进行代码开发时重点参考了https://blog.csdn.net/wangongxi/article/details/60775940的一系列文章。

     同时,为了使用scala跑通LSTM模型,更是参考了wangongxi博客里面的LSTM文本分类的java代码,文本进行分词后再进行one-hot,进入lstm的识别层并进行embeding处理成低维向量,再做循环神经网络的训练,对于笔者后续要开展的用户推荐,只需要将分词替换为用户的历史浏览和历史点击,再将文本分类label替换成用户未来一段时间对产品和活动页面的点击和浏览评分即可形成一套用户产品推荐的LSTM系统,因此,考虑到数据的安全性,本篇仍以文本的情感分类对deeplearning4j系统进行说明,java同学可以直接移步wangongxi博客,笔者采用的文本与其完全一致,均为http://spaces.ac.cn/archives/3414/的文本,在此感谢两位大神的付出。

0、首先给出模型建立使用的依赖,依赖共有三个部分,一部分是针对spark的,主要有建立spark的Session和JavaRDD,需要说明的是,deeplearning4j在模型训练时,需要使用java的sparkcontext和JAVARDD;另一部分是用于分词的结巴分词依赖,com.huaban.analysis.jieba.JiebaSegmenter,关于java分词的内容,在网上有很多资料,本处就不赘述;第3部分是关于Deeplearning4j的依赖,包括其本身以及nd4j的依赖。

  1. import org.apache.spark.sql.SparkSession
  2. import org.apache.spark.api.java.{JavaRDD, JavaSparkContext}
  3. import org.deeplearning4j.nn.api.OptimizationAlgorithm
  4. import org.deeplearning4j.nn.conf.{BackpropType, NeuralNetConfiguration}
  5. import org.deeplearning4j.nn.conf.layers.{RnnOutputLayer,LSTM,EmbeddingLayer,GravesLSTM}
  6. import org.deeplearning4j.optimize.listeners.ScoreIterationListener
  7. import org.deeplearning4j.spark.impl.multilayer.SparkDl4jMultiLayer
  8. import org.deeplearning4j.spark.impl.paramavg.ParameterAveragingTrainingMaster
  9. import org.nd4j.linalg.activations.Activation
  10. import org.nd4j.linalg.api.ndarray.INDArray
  11. import org.nd4j.linalg.dataset.DataSet
  12. import org.nd4j.linalg.factory.Nd4j
  13. import com.huaban.analysis.jieba.JiebaSegmenter
  14. import org.nd4j.linalg.lossfunctions.LossFunctions
  15. import org.deeplearning4j.nn.conf.inputs.InputType
  16. import org.nd4j.linalg.learning.config.Adam
  17. import org.nd4j.evaluation.classification.Evaluation
  18. import scala.collection.mutable.ArrayBuffer

1、建立sparksession,需要说明的是给出了一个JavaSC,还有一个spark本身的SparkSeesion,这两个可以同时使用,spark主要便于读取csv,javasc则用于后面的模型训练,同时导入了分词的对象

  1. val spark:SparkSession = {SparkSession
  2. .builder()
  3. .master("local")
  4. .appName("Spark LSTM Emotion Analysis")
  5. .getOrCreate()
  6. }
  7. import spark.implicits._
  8. val JavaSC = JavaSparkContext.fromSparkContext(spark.sparkContext)
  9. val segmenter = new JiebaSegmenter

2、读取csv并进行分词形成需要的JAVARDD

  1. def getTrainingData(spark: SparkSession): JavaRDD[DataSet] = {
  2. //Get data. For the sake of this example, we are doing the following operations:
  3. // File -> String -> List<String> (split into length "sequenceLength" characters) -> JavaRDD<String> -> JavaRDD<DataSet>
  4. //较好的评价
  5. val goodCSV = spark.read.format("csv").load("/home/lorry/下载/pos.csv").toDF("describe")
  6. //较差的评价
  7. val badCSV = spark.read.format("csv").load("/home/lorry/下载/neg.csv").toDF("describe")
  8. //设置正则
  9. val regEx = """[`~!@#$%^&*()+=|{}':;',『』[\-]《》\\[\\][\"][ ]\[\][0123456789].<>/?~!@#¥%……&*()——+|{}【】‘;:”“’。,、?]"""
  10. val signPattern = regEx.r
  11. val targetColumns = goodCSV.columns
  12. //在rdd中对评论进行分词并给出标签
  13. val goodRDD = goodCSV.select(targetColumns.head, targetColumns.tail: _*).rdd.map(x => {
  14. val word: String = x(0).asInstanceOf[String]
  15. val wordSplit = segmenter.sentenceProcess(signPattern.replaceAllIn(word.trim, "")).toArray().mkString(" ")
  16. ("正面", wordSplit)
  17. }
  18. ).filter(row => (row._2.size > 0))
  19. //在rdd中对评论进行分词并给出标签
  20. val badRDD = badCSV.select(targetColumns.head, targetColumns.tail: _*).rdd.map(x => {
  21. val word: String = x(0).asInstanceOf[String]
  22. val wordSplit = segmenter.sentenceProcess(signPattern.replaceAllIn(word.trim, "")).toArray().mkString(" ")
  23. ("负面", wordSplit)
  24. }
  25. ).filter(row => (row._2.size > 0))
  26. //汇总
  27. val totalRDD = goodRDD.union(badRDD)
  28. //count一下
  29. totalRDD.count()
  30. //得到词的统计并按照个数进行降序排列用于生成一个词对index的map
  31. val WORD_TO_INT: Map[String, Int] = {
  32. val VOCAB = totalRDD.flatMap(x => x._2.split(" "))
  33. .map(word => (word, 1))
  34. .reduceByKey(_ + _)
  35. .sortBy(_._2, false)
  36. .map(row => row._1)
  37. .collect()
  38. VOCAB.zipWithIndex.toMap
  39. }
  40. //将最大的数目设为200,此处需要说明一下,本来最大数目约为1500,但笔者在尝试时一直报内存溢出的错误,分析下来确实是1500的词长度超出了内存的限制,所以考虑到大部分词句的长度都在100左右,因此设置了两百作为最大长度
  41. val maxCorpusLength = 200//totalRDD.map(row => row._2.split(" ").length).collect().max
  42. //词袋的大小,58000左右
  43. val VOCAB_SIZE = totalRDD.flatMap(x => x._2.split(" ")).collect.distinct.length
  44. //标签,只有两个
  45. val labelWord = totalRDD.flatMap(x => x._1.split(" ")).collect.distinct
  46. //生成最终的JAVARDD[DataSet],此处的DataSet并不是Spark的DataSet,而是nd4j的DataSet,实际上是向量形式,有两种格式,一种是只含input和output两个张量,另一个则加上了labermask和featuresmask,用于说明对应位置的label和features是否考虑,具体可见wangongxi的文章解释,也可以直接goto到类下面看函数解释
  47. //此处参考wangongxi的文章设置,给了features和label的mask表示,但笔者的词语最大长度为200
  48. val totalDataSet = totalRDD.map(row => {
  49. val listWords = if (row._2.split(" ").length>=200) row._2.split(" ").take(200) else row._2.split(" ")
  50. // val listWords = totalRDD.take(1)(0)._2.split(" ")
  51. val label = row._1
  52. // val label = totalRDD.take(1)(0)._1
  53. val features: INDArray = Nd4j.create(1, 1, maxCorpusLength)
  54. val labels = Nd4j.create(1, labelWord.length, maxCorpusLength)
  55. val featuresMask = Nd4j.zeros(1.toLong, maxCorpusLength.toLong)
  56. val labelsMask = Nd4j.zeros(1.toLong, maxCorpusLength.toLong)
  57. labelsMask.shape()
  58. val origin = new Array[Int](3)
  59. val mask = new Array[Int](2)
  60. var i: Int = 0
  61. for (word <- listWords) {
  62. features.putScalar(Array(1, 1, i), WORD_TO_INT(word))
  63. featuresMask.putScalar(Array(0, i), 1)
  64. i += 1
  65. }
  66. val lastIdx: Int = listWords.size
  67. val idx = labelWord.indexOf(label)
  68. labels.putScalar(Array[Int](0, idx, lastIdx - 1), 1.0)
  69. labelsMask.putScalar(Array[Int](0, lastIdx - 1), 1.0)
  70. new DataSet(features, labels, featuresMask, labelsMask)
  71. })
  72. totalDataSet.toJavaRDD()
  73. }

3、设置LSTM的网络层,词袋尺寸为50758,每个词在第2步会被替换成对应的序号,例如最多的"的",会变成0,依次类推,然后再进入LSTM的layer0,即embeding层,进行embeding,将每个词由序号变为256维的向量,接着进入GravesLSTM层和RNN输出层,不做过多解释,由于本处使用的是Deeplearning4j的1.0.0-beta3包,所以与wangongxi的语法略有差别,体现在updater处,将learningrate等参数放在此处设置;接着设置参数平均的频率、每个batch进入的数据条目等,并生成一个参数平均控制器,这是使用Deeplearning4j进行神经网络训练必须要的;最后再将javaSC、conf、参数平均控制器集合在一起,形成训练用的网络控制器。

  1. val VOCAB_SIZE = 50758
  2. val conf = {
  3. new NeuralNetConfiguration.Builder()
  4. .seed(1234)
  5. .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
  6. .updater(Adam.builder().learningRate(1e-4).beta1(0.9).beta2(0.999).build())
  7. .l2(5 * 1e-4)
  8. .list()
  9. .layer(0, new EmbeddingLayer.Builder().nIn(VOCAB_SIZE).nOut(256).activation(Activation.IDENTITY).build())
  10. .layer(1, new GravesLSTM.Builder().nIn(256).nOut(256).activation(Activation.SOFTSIGN).build())
  11. .layer(2, new RnnOutputLayer.Builder(LossFunctions.LossFunction.MCXENT)
  12. .activation(Activation.SOFTMAX).nIn(256).nOut(2).build())
  13. .pretrain(false).backprop(true)
  14. .setInputType(InputType.recurrent(VOCAB_SIZE))
  15. .build()
  16. }
  17. val examplesPerDataSetObject = 1
  18. val averagingFrequency: Int = 5
  19. val batchSizePerWorker: Int = 20
  20. val trainMaster = {
  21. new ParameterAveragingTrainingMaster.Builder(examplesPerDataSetObject)
  22. .workerPrefetchNumBatches(0)
  23. .saveUpdater(true)
  24. .averagingFrequency(averagingFrequency)
  25. .batchSizePerWorker(batchSizePerWorker)
  26. .build()
  27. }
  28. val sparkNetwork: SparkDl4jMultiLayer = new SparkDl4jMultiLayer(JavaSC, conf, trainMaster)

4、进行训练,训练10个循环,将训练数据和测试数据按照0.5/0.5分配,每个训练循环结束给出对训练数据和测试数据的预测准确率。

  1. var numEpoch=0
  2. val emotionWordData = getTrainingData(spark)
  3. val Array(trainingData,testingData) = emotionWordData.randomSplit(Array(0.5,0.5))
  4. testingData.count()
  5. val resultArray = new ArrayBuffer[Array[String]](0)
  6. for (numEpoch <- 1 to 10){
  7. sparkNetwork.fit(trainingData)
  8. val trainEvaluation:Evaluation = sparkNetwork.evaluate(trainingData)
  9. val trainAccuracy = trainEvaluation.accuracy()
  10. val testEvaluation:Evaluation = sparkNetwork.evaluate(testingData)
  11. val testAccuracy = testEvaluation.accuracy()
  12. System.out.println("====================================================================")
  13. System.out.println("Epoch " + numEpoch + " Has Finished")
  14. System.out.println("Train Accuracy: " + trainAccuracy)
  15. System.out.println("Test Accuracy: " + testAccuracy)
  16. System.out.println("====================================================================")
  17. resultArray.append(Array(trainAccuracy.toString, testAccuracy.toString))
  18. }

5、代码结束,为了方便大家参考,给出笔者使用的依赖包,deeplearning4j发展的很快,因此不同版本之间的代码设计差别比较大,在使用时一定要注意添加的依赖包版本。下面有部分包是笔者进行其他模型开发使用的,大家可视情况进行删减,spark是我使用本地的local模式进行调试,到集群上使用时需要修改为yarn模式,其余代码直接参考上文,在idea中使用Object对代码进行封装即可。再次感谢wangongxi同学的java版本以及http://spaces.ac.cn/archives/3414/作者提供的文本。

  1. <properties>
  2. <maven.compiler.source>1.7</maven.compiler.source>
  3. <maven.compiler.target>1.7</maven.compiler.target>
  4. <encoding>UTF-8</encoding>
  5. <scala.version>2.11.8</scala.version>
  6. <scala.compat.version>2.11</scala.compat.version>
  7. <nd4j.backend>nd4j-native-platform</nd4j.backend>
  8. <project.build.sourceEncoding>UTF-8</project.build.sourceEncoding>
  9. <shadedClassifier>bin</shadedClassifier>
  10. <java.version>1.8</java.version>
  11. <nd4j.version>1.0.0-beta3</nd4j.version>
  12. <dl4j.version>1.0.0-beta3</dl4j.version>
  13. <datavec.version>1.0.0-beta3</datavec.version>
  14. <arbiter.version>1.0.0-beta3</arbiter.version>
  15. <guava.version>25.1-jre</guava.version>
  16. <jfreechart.version>1.5.0</jfreechart.version>
  17. <dl4j.spark.version>1.0.0-beta3</dl4j.spark.version>
  18. <aws.sdk.version>1.11.109</aws.sdk.version>
  19. <jcommander.version>1.72</jcommander.version>
  20. <scala.binary.version>2.11</scala.binary.version>
  21. <hadoop.version>2.7.4</hadoop.version>
  22. </properties>
  23. <dependencyManagement>
  24. <dependencies>
  25. <dependency>
  26. <groupId>org.nd4j</groupId>
  27. <artifactId>nd4j-api</artifactId>
  28. <version>${nd4j.version}</version>
  29. </dependency>
  30. <dependency>
  31. <groupId>org.nd4j</groupId>
  32. <artifactId>nd4j-common</artifactId>
  33. <version>${nd4j.version}</version>
  34. </dependency>
  35. <dependency>
  36. <groupId>org.nd4j</groupId>
  37. <artifactId>nd4j-cuda-9.2-platform</artifactId>
  38. <version>${nd4j.version}</version>
  39. <scope>test</scope>
  40. </dependency>
  41. </dependencies>
  42. </dependencyManagement>
  43. <dependencies>
  44. <dependency>
  45. <groupId>org.apache.spark</groupId>
  46. <artifactId>spark-core_2.11</artifactId>
  47. <version>2.3.1</version>
  48. </dependency>
  49. <dependency>
  50. <groupId>org.apache.spark</groupId>
  51. <artifactId>spark-sql_2.11</artifactId>
  52. <version>2.3.1</version>
  53. </dependency>
  54. <dependency>
  55. <groupId>org.apache.spark</groupId>
  56. <artifactId>spark-mllib_2.11</artifactId>
  57. <version>2.3.1</version>
  58. </dependency>
  59. <dependency>
  60. <groupId>ml.dmlc</groupId>
  61. <artifactId>xgboost4j</artifactId>
  62. <version>0.80</version>
  63. </dependency>
  64. <dependency>
  65. <groupId>ml.dmlc</groupId>
  66. <artifactId>xgboost4j-spark</artifactId>
  67. <version>0.80</version>
  68. </dependency>
  69. <dependency>
  70. <groupId>org.apache.predictionio</groupId>
  71. <artifactId>apache-predictionio-core_2.11</artifactId>
  72. <version>0.13.0</version>
  73. <scope>provided</scope>
  74. </dependency>
  75. <!-- ND4J后端。每个DL4J项目都需要一个。一般将artifactId指定为"nd4j-native-platform"或者"nd4j-cuda-7.5-platform" -->
  76. <dependency>
  77. <groupId>org.nd4j</groupId>
  78. <artifactId>${nd4j.backend}</artifactId>
  79. <version>${dl4j.version}</version>
  80. </dependency>
  81. <!-- DL4J核心功能 -->
  82. <dependency>
  83. <groupId>org.deeplearning4j</groupId>
  84. <artifactId>deeplearning4j-core</artifactId>
  85. <version>${dl4j.version}</version>
  86. </dependency>
  87. <dependency>
  88. <groupId>org.deeplearning4j</groupId>
  89. <artifactId>deeplearning4j-nlp</artifactId>
  90. <version>${dl4j.version}</version>
  91. </dependency>
  92. <!-- 强制指定使用UI/HistogramIterationListener时的guava版本 -->
  93. <dependency>
  94. <groupId>com.google.guava</groupId>
  95. <artifactId>guava</artifactId>
  96. <version>${guava.version}</version>
  97. </dependency>
  98. <!-- datavec-data-codec:仅用于在视频处理示例中加载视频数据 -->
  99. <dependency>
  100. <artifactId>datavec-data-codec</artifactId>
  101. <groupId>org.datavec</groupId>
  102. <version>${datavec.version}</version>
  103. </dependency>
  104. <!-- 用于前馈/分类/MLP*和前馈/回归/RegressionMathFunctions示例 -->
  105. <dependency>
  106. <groupId>org.jfree</groupId>
  107. <artifactId>jfreechart</artifactId>
  108. <version>${jfreechart.version}</version>
  109. </dependency>
  110. <!-- Arbiter:用于超参数优化示例 -->
  111. <dependency>
  112. <groupId>org.deeplearning4j</groupId>
  113. <artifactId>arbiter-deeplearning4j</artifactId>
  114. <version>${arbiter.version}</version>
  115. </dependency>
  116. <!-- https://mvnrepository.com/artifact/org.datavec/datavec-api -->
  117. <dependency>
  118. <groupId>org.datavec</groupId>
  119. <artifactId>datavec-hadoop</artifactId>
  120. <version>${dl4j.version}</version>
  121. </dependency>
  122. <!-- Logging Dependencies -->
  123. <dependency>
  124. <groupId>com.typesafe.scala-logging</groupId>
  125. <artifactId>scala-logging_2.11</artifactId>
  126. <version>3.5.0</version>
  127. </dependency>
  128. <!-- ND4J -->
  129. <dependency>
  130. <groupId>org.nd4j</groupId>
  131. <artifactId>nd4j-native-platform</artifactId>
  132. <version>${nd4j.version}</version>
  133. </dependency>
  134. <dependency>
  135. <groupId>junit</groupId>
  136. <artifactId>junit</artifactId>
  137. <version>4.12</version>
  138. <scope>test</scope>
  139. </dependency>
  140. <dependency>
  141. <groupId>org.apache.commons</groupId>
  142. <artifactId>commons-collections4</artifactId>
  143. <version>4.1</version>
  144. </dependency>
  145. <!-- ND4J -->
  146. <dependency>
  147. <groupId>org.nd4j</groupId>
  148. <artifactId>nd4j-api</artifactId>
  149. <version>${nd4j.version}</version>
  150. </dependency>
  151. <dependency>
  152. <groupId>org.deeplearning4j</groupId>
  153. <artifactId>scalnet_2.11</artifactId>
  154. <version>1.0.0-beta2</version>
  155. </dependency>
  156. <dependency>
  157. <groupId>org.slf4j</groupId>
  158. <artifactId>slf4j-log4j12</artifactId>
  159. <version>1.7.25</version>
  160. <scope>runtime</scope>
  161. <exclusions>
  162. <exclusion>
  163. <groupId>org.slf4j</groupId>
  164. <artifactId>slf4j-log4j12</artifactId>
  165. </exclusion>
  166. <exclusion>
  167. <groupId>log4j</groupId>
  168. <artifactId>log4j</artifactId>
  169. </exclusion>
  170. <exclusion>
  171. <groupId>ch.qos.logback</groupId>
  172. <artifactId>logback-classic</artifactId>
  173. </exclusion>
  174. </exclusions>
  175. </dependency>
  176. <dependency>
  177. <groupId>com.amazonaws</groupId>
  178. <artifactId>aws-java-sdk-emr</artifactId>
  179. <version>${aws.sdk.version}</version>
  180. <scope>provided</scope>
  181. </dependency>
  182. <dependency>
  183. <groupId>com.amazonaws</groupId>
  184. <artifactId>aws-java-sdk-s3</artifactId>
  185. <version>${aws.sdk.version}</version>
  186. <scope>provided</scope>
  187. </dependency>
  188. <dependency>
  189. <groupId>com.beust</groupId>
  190. <artifactId>jcommander</artifactId>
  191. <version>${jcommander.version}</version>
  192. </dependency>
  193. <dependency>
  194. <groupId>org.deeplearning4j</groupId>
  195. <artifactId>dl4j-spark_${scala.binary.version}</artifactId>
  196. <version>${dl4j.spark.version}_spark_2</version>
  197. </dependency>
  198. <dependency>
  199. <groupId>org.deeplearning4j</groupId>
  200. <artifactId>dl4j-spark-parameterserver_${scala.binary.version}</artifactId>
  201. <version>${dl4j.spark.version}_spark_2</version>
  202. </dependency>
  203. <dependency>
  204. <groupId>org.jsoup</groupId>
  205. <artifactId>jsoup</artifactId>
  206. <version>1.11.3</version>
  207. </dependency>
  208. <dependency>
  209. <groupId>com.huaban</groupId>
  210. <artifactId>jieba-analysis</artifactId>
  211. <version>1.0.2</version>
  212. </dependency>
  213. <dependency>
  214. <groupId>org.apache.poi</groupId>
  215. <artifactId>poi</artifactId>
  216. <version>3.14</version>
  217. </dependency>
  218. <dependency>
  219. <groupId>org.apache.poi</groupId>
  220. <artifactId>poi-ooxml</artifactId>
  221. <version>3.14</version>
  222. </dependency>
  223. <!-- 处理excel和上面功能是一样的-->
  224. <dependency>
  225. <groupId>net.sourceforge.jexcelapi</groupId>
  226. <artifactId>jxl</artifactId>
  227. <version>2.6.12</version>
  228. </dependency>
  229. </dependencies>

 

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/繁依Fanyi0/article/detail/384938
推荐阅读
相关标签
  

闽ICP备14008679号