当前位置:   article > 正文

NLP任务样本数据不均衡问题解决方案的总结和数据增强回译的实战展示_nlp回译csdn

nlp回译csdn

目录

一、数据层面

1、欠采样(under-sampling)

2、过采样

二、算法层面

1、权重设置

2、新的损失函数——Focal Loss

三、评价方式

四、数据增强实战——回译(back translate)

1、Translator

2、TextBlob

3、百度翻译API


        在做NLP分类标注等任务的时候,避免不了会遇到样本不均衡的情况,那么我们就需要处理这个问题,这样才能使模型有良好的表现。为此,在收集了一些资料以后,做了一个简单总结,方便以后回顾(怕跳槽面试的时候问道答不上来)。主要是从数据、算法和模型评价标准这个三个方面,来减少数据不平衡对模型性能的影响。

一、数据层面

      当数据极度不平衡的时候,最容易相到的解决方案,就是从数据层面出发,小类数据太少了,那么就增加小类数据;大类样本太多了就删除一些样本。不管是2分类还是多分类,样本不均衡的表现都是样本数据数目之间存在着很大的差异。为了克服这个问题,实质上就要把数据经过一定的处理,变得不那么不均衡,比例适当一些。有实验表明,只要数据之间的比例超过了1:4,就会对算法造成偏差影响。针对数据比重失调,就可以对原始数据集进行采样调整,这里主要是欠采样和过采样。

1、欠采样(under-sampling)

           对大类的数据样本进行采样来减少该类数据的样本个数。使用的一般经验规则,一般而言是对样本数目超过1W,10W 甚至更多,进行欠采样。一般简单的做法,就是随机的删除部分样本。注意的是,一般很少使用欠采样,标注数据的成本比较高,而深度学习的方法是数据量越高越好,所以一般都是使用过采样。

2、过采样

        对小类数据的样本进行采样来增加小类样本数据的个数。Smote算法(它就是在少数类样本中用KNN方法合成了新样本)一般用来进行过采样的操作,这里有一点不方便的地方就是NLP任务中,不好使用Smote算法,我们的样本一般都是文本数据,不是直接的数字数据,只有把文本数据转化为数字数据才能进行smote操作。另外现在一般都是基于预训练模型做微调的,文本的向量表示也是变化的,所有不能进行smote算法来增加小类数据。那么针对NLP进行过采样的一些方法有那些呢?

  1. 最简单的就是直接复制小类样本,从而达到增加小类样本数据的目的。这样的方法缺点也是很明显的,实际上样本中并没有加入新的特征,特征还是很少,那么就会出现过拟合的问题。
  2. 对小类样本数据经过一定的处理,做一些小的改变。例如随机的打乱词的顺序,句子的顺序;随机的删除一些词,一些句子;裁剪文本的开头或者结尾等。我认为这些小方法至合适对语序不是特别重要的任务,像一些对语序特征特别重要的序列任务这种操做就不太恰当。
  3. 复述生成:这个就属性seq2seq任务,根据原始问题成成格式更好的问题,然后把新问题替换到问答系统中。
  4. EDA:同义词替换、随机插入和随机交换
  5. 回译(back translation) 把中文——英文(其他的语言)——中文
  6. 生成对抗网络——GAN

个人认为使用复述生成和回译以及生成对抗网络应该是最有效的,因为它们在做数据增强的时候,对原始数据做的处理使得语义发生了变化,但同时又保证了整个语义的完整性。随机删除的词,打乱顺序的方式,我认为对数据的整个语义破坏太大了。当然,这些技巧都值得在具体的数据集下做对应的实验,说不定它恰好就在这个数据集上起很重要的作用。

另外我自己做过的一些实践,回译是比较不错的,在百度翻译API免费的前提下,几乎没有成本。另外的复述生成和生成对抗网络不知道,听说生成对抗网络很难也很麻烦。

二、算法层面

1、权重设置

        在训练的时候给损失函数直接设定一定的比例,使得算法能够对小类数据更多的注意力。例如在深度学习中,做一个3分类任务,标签a、b、c的样本比例为1:1:8。在我们的交叉熵损失函数中就可以用类似这样的权重设置:

torch.nn.CrossEntropyLoss(weight=torch.from_numpy(np.array([8,8,1])).float().to(device))

2、新的损失函数——Focal Loss

         另外还可以从设计新的损失函数的角度来出发,有大牛设计这样的一个损失函数,专门用来解决多分类或者二分类中样本不均衡的问题——Focal Loss。focal loss,这个损失函数是在标准交叉熵损失基础上修改得到的。这个函数可以通过减少易分类样本的权重,使得模型在训练时更专注于难分类的样本。焦点损失函数旨在通过降低内部加权(简单样本)来解决类别不平衡问题,这样即使简单样本的数量很大,但它们对总损失的贡献却很小。也就是说,该函数侧重于用困难样本稀疏的数据集来训练。本文不做原理分析,大牛们都分析的很清楚了。这里仅仅改方法展示出来,知道有这么个方法,知道怎么用就OK。借用大牛实现的代码,直接给出Focal Loss的实现:https://github.com/yatengLG/Focal-Loss-Pytorch/blob/master/Focal_Loss.py,该作者的github上有使用例子。

  1. import torch
  2. from torch import nn
  3. from torch.nn import functional as F
  4. import time
  5. class focal_loss(nn.Module):
  6. """
  7. 需要保证每个batch的长度一样,不然会报错。
  8. """
  9. def __init__(self,alpha=0.25,gamma = 2, num_classes = 2, size_average =True):
  10. """
  11. focal_loss损失函数, -α(1-yi)**γ *ce_loss(xi,yi) = -α(1-yi)**γ * log(yi)
  12. :param alpha:
  13. :param gamma:
  14. :param num_classes:
  15. :param size_average:
  16. """
  17. super(focal_loss, self).__init__()
  18. self.size_average = size_average
  19. if isinstance(alpha,list):
  20. # α可以以list方式输入,size:[num_classes] 用于对不同类别精细地赋予权重
  21. assert len(alpha) == num_classes
  22. print("Focal_loss alpha = {},对每一类权重进行精细化赋值".format(alpha))
  23. self.alpha = torch.tensor(alpha)
  24. else:
  25. assert alpha<1 #如果α为一个常数,则降低第一类的影响
  26. print("--- Focal_loss alpha = {},将对背景类或者大类负样本进行权重衰减".format(alpha))
  27. self.alpha = torch.zeros(num_classes)
  28. self.alpha[0] += alpha
  29. self.alpha[1:] += (1-alpha)
  30. self.gamma = gamma
  31. def forward(self, preds,labels):
  32. """
  33. focal_loss损失计算
  34. :param preds: 预测类别. size:[B,N,C] or [B,C] B:batch N:检测框数目 C:类别数
  35. :param labels: 实际类别. size:[B,N] or [B]
  36. :return:
  37. """
  38. preds = preds.view(-1, preds.size(-1))
  39. self.alpha = self.alpha.to(preds.device)
  40. # 这里并没有直接使用log_softmax, 因为后面会用到softmax的结果(当然你也可以使用log_softmax,然后进行exp操作)
  41. preds_softmax = F.softmax(preds,dim=1)
  42. preds_logsoft = torch.log(preds_softmax)
  43. # 这部分实现nll_loss ( crossempty = log_softmax + nll )
  44. preds_softmax = preds_softmax.gather(1,labels.view(-1,1))
  45. preds_logsoft = preds_logsoft.gather(1,labels.view(-1,1))
  46. self.alpha = self.alpha.gather(0,labels.view(-1))
  47. loss = -torch.mul(torch.pow((1-preds_softmax),self.gamma),preds_logsoft)
  48. loss = torch.mul(self.alpha,loss.t())
  49. if self.size_average:
  50. loss = loss.mean()
  51. else:
  52. loss = loss.sum()
  53. return loss

有一个要注意的是

self.alpha = self.alpha.gather(0,labels.view(-1))

当传入的labels长度不一致,就会使得self.alpha的长度不一样,进而报错。所有要保证训练的时候每个bath传入的数据长度要一致。

三、评价方式

在模型评价的时候,我们一般简单的采用accuracy就可以了。但是在样本数据极度不平衡,特别是那种重点关注小类识别准确率的时候,就不能使用accuracy来评价模型了。要使用precision和recall来综合考虑模型的性能,降低小类分错的几率。在pytorch中,一般使用tensor来计算,下面给出关于tensor计算precision和recall的代码,主要是熟悉tensor的操作——孰能生巧。

  1. correct += (predict == label).sum().item()
  2. total += label.size(0)
  3. train_acc = correct / total
  4. #精确率、recall和F1的计算
  5. for i in range(self.number_of_classes):
  6. if i == self.none_label:
  7. continue
  8. #TP和FP
  9. self._true_positives += ((predictions==i)*(gold_labels==i)*mask.bool()).sum()
  10. self._false_positives += ((predictions==i)*(gold_labels!=i)*mask.bool()).sum()
  11. #TN和FN
  12. self._true_negatives += ((predictions!=i)*(gold_labels!=i)*mask.bool()).sum()
  13. self._false_negatives += ((predictions!=i)*(gold_labels==i)*mask.bool()).sum()
  14. #精确率、
  15. precision = float(self._true_positives) / (float(self._true_positives + self._false_positives) + 1e-13)
  16. #recall
  17. recall = float(self._true_positives) / (float(self._true_positives + self._false_negatives) + 1e-13)
  18. #F1
  19. f1_measure = 2. * ((precision * recall) / (precision + recall + 1e-13))

 

四、数据增强实战——回译(back translate)

    尝试过的库或者API分别是Translator、TextBlob 和百度翻译的API。其实这些方法都是在网上都有,这里我做一个总结吧。

1、Translator

首选看看Translator,这个翻译的库用的是MyMeory的API,免费的限制是每天1000words。安装Translator

from translate import Translator

直接看示例:

  1. from translate import Translator
  2. def translation_translate(text):
  3. print(text)
  4. translator = Translator(from_lang="chinese", to_lang="english")
  5. translation = translator.translate(text)
  6. print(translation)
  7. print(len(translation))
  8. if len(translation)> 500:
  9. translation = translation[0:500]
  10. print(translation)
  11. translator = Translator(from_lang="english", to_lang="chinese")
  12. translation = translator.translate(translation)
  13. print(len(translation))
  14. print(translation)
  15. return translation
  16. if __name__ == '__main__':
  17. text = '国家“十五”重大专项“创新药物和中药现代化”(863计划2004AA2Z3380)。基因工程药物注射给药存在着:血浆半衰期较短,生物利用度不高;抗原性较强,易引起过敏等不良反应。'
  18. translation_translate(text)

    注意的是text的长度不能超过500。所以这个做回译还是有一定的限制的,要是text翻译到中间语言,中间语言的长度超过了500,要做截取处理,语义就会丢失很多。而且每天的字数也有限制,一天1000字太少了。但是翻译效果还不错,如下:

  1. 国家“十五”重大专项“创新药物和中药现代化”(863计划2004AA2Z3380)。基因工程药物注射给药存在着:血浆半衰期较短,生物利用度不高;抗原性较强,易引起过敏等不良反应。
  2. The national "Fifth Five-Year Plan" major special "innovative drugs and modernization of Chinese medicine" (863 plan 2004A2Z380). The injection of genetically engineered drugs exists: plasma half-life is short, bioavailability is not high, antigen is strong, easy to cause allergic reactions and other adverse reactions.
  3. 320
  4. The national "Fifth Five-Year Plan" major special "innovative drugs and modernization of Chinese medicine" (863 plan 2004A2Z380). The injection of genetically engineered drugs exists: plasma half-life is short, bioavailability is not high, antigen is strong, easy to cause allergic reactions and other adverse reactions.
  5. 85
  6. 国家"十五"重大专项"创新药物与中药现代化"863计划2004A2Z380)。基因工程药物的注射存在:血浆半寿命短,生物利用度不高,抗原强,容易引起过敏反应等不良反应。

2、TextBlob

      类似Translator的使用,但是这个是调用Google翻译的API,内网用不了。

3、百度翻译API

      使用这个来做翻译的话,需要使用import http.client模块儿来实现,百度也给出了详细的教程。我这里的一个需求是需要做数据增强,每条数据需要,使用6种语言来做回译,才能配平样本比例。直接上代码:

     核心函数:

  1. def baidu_translate(content,from_lang,to_lang):
  2. appid = '×××××××××'
  3. secretKey = '××××××××××××××'
  4. httpClient = None
  5. myurl = '/api/trans/vip/translate'
  6. q = content
  7. fromLang = from_lang # 源语言
  8. toLang = to_lang # 翻译后的语言
  9. salt = random.randint(32768, 65536)
  10. sign = appid + q + str(salt) + secretKey
  11. sign = hashlib.md5(sign.encode()).hexdigest()
  12. myurl = myurl + '?appid=' + appid + '&q=' + urllib.parse.quote(
  13. q) + '&from=' + fromLang + '&to=' + toLang + '&salt=' + str(
  14. salt) + '&sign=' + sign
  15. try:
  16. httpClient = http.client.HTTPConnection('api.fanyi.baidu.com')
  17. httpClient.request('GET', myurl)
  18. # response是HTTPResponse对象
  19. response = httpClient.getresponse()
  20. jsonResponse = response.read().decode("utf-8") # 获得返回的结果,结果为json格式
  21. js = json.loads(jsonResponse) # 将json格式的结果转换字典结构
  22. dst = str(js["trans_result"][0]["dst"]) # 取得翻译后的文本结果
  23. # print(dst) # 打印结果
  24. return dst
  25. except Exception as e:
  26. print('err:',e)
  27. finally:
  28. if httpClient:
  29. httpClient.close()
  30. def do_translate(content,from_lang,to_lang):
  31. if len(content)>= 260:
  32. content = content[0:260]
  33. temp = baidu_translate(content,from_lang,to_lang)
  34. time.sleep(1)#百度API免费调用的QPS=1,所以要1s以后才能调用
  35. if temp is None:
  36. temp = 0
  37. if len(temp) >= 1500:
  38. temp = temp[0:1500]
  39. res = baidu_translate(temp,to_lang,from_lang)
  40. return res

遇到的一些坑:

注意到,这里使用的是标准版,没有收费,目前是免费的,但是以后说不定就不会开放免费的版本了。另外QPS=1,也就是1秒内并发能力只有1,所有这个在代码中,用了time.sleep(1),保证API被及时调用,而不会报错。最后由于我的中文预料长度很长大都在100-500之间,翻译成其他语言,字符数就有1500-2000多,虽然百度API对字符数长度放宽了,但是不做长度处理还是会报错,这个就需要自己有针对性的调整了。

看一看下面的回译的结果,原始的中文就没有展示出来,这个6种不同语言,回译的文本。信息都有缺失,但是整体都还在,做数据增强就很不错了。

  1. 本项目属于有色金属材料制备加工技术领域。通过系统研究,证明了铜合金纳米强化相的形核、长大机理和强化机理,突破了引入纳米强化相、控制弥散分布等常见技术难题,开发了高强高导铜合金的关键制备和加工技术,提高了我国关键领域铜合金的综合性能。它满足了我国特高压/特高压电器对高强度、高导电性铜合金的迫切需求,打破了国外对集成电路引线框架带的市场和价格垄断,为我国航天和武器装备的关键部件提供了物质支持。该成果已在全国12家企业应用,近3年新增销售金额45.51亿元。
  2. 这个项目属于有色金属材料的调制加工技术领域。通过系统研究,明确了在铜合金中纳米强化相核增长机构和强化机构,突破了纳米强化相的引进和扩散分布控制的共同技术课题,开发了高强导铜合金的重要调制加工技术,实现了我国重点领域铜合金的综合性能提高。我国超高压电器产品满足了对高强度铜合金的紧急需要,打破了对国外集成电路领先框架的领先市场和价格垄断,为我国航天飞行、武器装备的关键部件提供了材料保障。其成果是,在全国12家企业中运用,在这3年间增加了45.51亿元的营业额。
  3. 该项目属于有色金属材料制品加工技术领域。通过系统研究,通过联合合金中纳米通过象形核成长机制和加强机制的引进纳米强化奖和分布控制的共同技术难题,开发了高强度铜合金的关键技术加工技术,开发韩国重点重点在领域提高了同合金综合性能。我国初/特高压电器对高强度铜合金的需求,满足国外集成电路界面框架的市场和价格垄断,对我国航空航空航空航空航空航空航空航天设备的核心部件提供了材料保障。成果是全国1 2家企业的应用,最近3年新销售额为45.51亿韩元。
  4. 该项目是有色金属材料制造技术领域的一部分。通过系统研究,确定了在铜合金中生长和加强强化纳米芯的机制。1.克服引入和控制增强纳米散射的共同技术困难;开发高导电铜合金的临界加工技术,并改善铜合金在我国优先领域的组合性能。打破外国在集成电路市场上的垄断和集成电路的价格,确保我国航空器和武器设备的关键部件的物质安全。新增营业额45.51亿元。
  5. 该项目是有色金属制备和加工技术领域的一部分。通过引入纳米增强相、控制分散分布等常见技术问题,突破了铜合金中纳米增强相的核与生长机理和强化机理,发展高强高导铜合金的关键制备和加工技术,提高我国关键领域铜合金的综合性能,满足我国特高压、特高压电气设备对高强高导铜合金的迫切需求,打破了国外对集成电路引线框架带的市场和价格垄断,为我国航空武器装备的关键部件提供了物质支持。近三年新销售额45.51亿元。
  6. 本项目属于有色金属的加工和加工技术领域。通过系统的研究,已经确定了纳米热的增长和加强机制。(a)在铜合金中强化核,克服了采用纳米热相和控制弥散分布的一般技术困难;开发了处理高强度铜合金的关键技术,我国重点地区铜合金的综合性能实现了现代化对于高强度铜合金,打破了对市场的外部垄断和集成电路导线范围内材料的价格,为航空航天技术关键部件提供材料。全国12家企业的应用成果,近3年的销售额增加了45.51亿美元。

 

完整代码:

  1. import pandas as pd
  2. import http.client
  3. import hashlib
  4. import json
  5. import urllib
  6. import random
  7. import time
  8. from tqdm import tqdm
  9. import csv
  10. def baidu_translate(content,from_lang,to_lang):
  11. appid = '××××××××××××××
  12. secretKey = '××××××××××××××××××××'
  13. httpClient = None
  14. myurl = '/api/trans/vip/translate'
  15. q = content
  16. fromLang = from_lang # 源语言
  17. toLang = to_lang # 翻译后的语言
  18. salt = random.randint(32768, 65536)
  19. sign = appid + q + str(salt) + secretKey
  20. sign = hashlib.md5(sign.encode()).hexdigest()
  21. myurl = myurl + '?appid=' + appid + '&q=' + urllib.parse.quote(
  22. q) + '&from=' + fromLang + '&to=' + toLang + '&salt=' + str(
  23. salt) + '&sign=' + sign
  24. try:
  25. httpClient = http.client.HTTPConnection('api.fanyi.baidu.com')
  26. httpClient.request('GET', myurl)
  27. # response是HTTPResponse对象
  28. response = httpClient.getresponse()
  29. jsonResponse = response.read().decode("utf-8") # 获得返回的结果,结果为json格式
  30. js = json.loads(jsonResponse) # 将json格式的结果转换字典结构
  31. dst = str(js["trans_result"][0]["dst"]) # 取得翻译后的文本结果
  32. # print(dst) # 打印结果
  33. return dst
  34. except Exception as e:
  35. print('err:',e)
  36. finally:
  37. if httpClient:
  38. httpClient.close()
  39. def do_translate(content,from_lang,to_lang):
  40. if len(content)>= 260:
  41. content = content[0:260]
  42. temp = baidu_translate(content,from_lang,to_lang)
  43. time.sleep(1)#百度API免费调用的QPS=1,所以要1s以后才能调用
  44. if temp is None:
  45. temp = 0
  46. if len(temp) >= 1500:
  47. temp = temp[0:1500]
  48. res = baidu_translate(temp,to_lang,from_lang)
  49. return res
  50. def back_translate(A_title,R_title,A_content,R_content,level,writer):
  51. new_A_titles = []
  52. new_R_titles = []
  53. new_A_contents = []
  54. new_R_contents = []
  55. new_levels = []
  56. fromlang_tolangs = [
  57. ('zh', 'en'),
  58. ('zh', 'jp'),
  59. ('zh', 'kor'),
  60. ('zh', 'fra'),
  61. ('zh', 'de'),
  62. ('zh', 'ru')
  63. ]
  64. for ele in fromlang_tolangs:
  65. from_lang = ele[0]
  66. to_lang = ele[1]
  67. A_content_new = do_translate(A_content,from_lang,to_lang)
  68. time.sleep(1)#百度API免费调用的QPS=1,所以要1s以后才能调用
  69. R_content_new = do_translate(R_content,from_lang,to_lang)
  70. time.sleep(1) # 百度API免费调用的QPS=1,所以要1s以后才能调用
  71. new_A_titles.append(A_title)
  72. new_R_titles.append(R_title)
  73. new_A_contents.append(A_content_new)
  74. new_R_contents.append(R_content_new)
  75. new_levels.append(level)
  76. writer.writerow([A_title,A_content_new,R_title,R_content_new,level])
  77. return new_A_titles,new_R_titles,new_A_contents,new_R_contents,new_levels
  78. if __name__ == '__main__':
  79. orginal_data = pd.read_csv('data/interrelation_final.csv',sep='\t')
  80. print(orginal_data.groupby(['Level']).size())
  81. A_titles = orginal_data[(orginal_data['Level']==3) | (orginal_data['Level']==4) ]['A_title'].values.tolist()
  82. R_titles = orginal_data[(orginal_data['Level']==3) | (orginal_data['Level']==4)]['R_title'].values.tolist()
  83. A_contents = orginal_data[(orginal_data['Level']==3) | (orginal_data['Level']==4)]['A_content'].values.tolist()
  84. R_contents = orginal_data[(orginal_data['Level']==3) | (orginal_data['Level']==4)]['R_content'].values.tolist()
  85. levels = orginal_data[(orginal_data['Level']==3) | (orginal_data['Level']==4)]['Level'].values.tolist()
  86. A_title_new = []
  87. R_title_new = []
  88. A_content_new = []
  89. R_content_new = []
  90. levels_new = []
  91. count = 0
  92. csv_header = ['A_title','A_content','R_title','R_content','Level']
  93. with open('data/final_augment_data_2.csv','w') as f:
  94. writer = csv.writer(f)
  95. writer.writerow(csv_header)
  96. for A_title,R_title,A_content,R_content,level in tqdm(list(zip(A_titles,R_titles,A_contents,R_contents,levels)),desc='回译执行:'):
  97. if count>= 311:
  98. new_A_titles,new_R_titles,new_A_contents,new_R_contents,new_levels = back_translate(A_title,R_title,A_content,R_content,level,writer)
  99. A_title_new.extend(new_A_titles)
  100. R_title_new.extend(new_R_titles)
  101. A_content_new.extend(new_A_contents)
  102. R_content_new.extend(new_R_contents)
  103. levels_new.extend(new_levels)
  104. count += 1

 

参考文章:

欠采样(undersampling)和过采样(oversampling)会对模型带来怎样的影响?

NLP数据增强方法

https://github.com/yatengLG/Focal-Loss-Pytorch/blob/master/Focal_Loss.py

Focal Loss

使用python调用百度翻译api进行翻译

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/2023面试高手/article/detail/349100
推荐阅读
相关标签
  

闽ICP备14008679号