前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >在数据增强、蒸馏剪枝下ERNIE3.0分类模型性能提升

在数据增强、蒸馏剪枝下ERNIE3.0分类模型性能提升

原创
作者头像
汀丶人工智能
修改2022-11-13 23:25:15
2810
修改2022-11-13 23:25:15
举报
文章被收录于专栏:NLP/KGNLP/KG

在数据增强、蒸馏剪枝下ERNIE3.0模型性能提升

项目链接:

https://aistudio.baidu.com/aistudio/projectdetail/4436131?contributionType=1

以CBLUE数据集中医疗搜索检索词意图分类为例:

本项目首先讲解了数据增强和数据蒸馏的方案,并在后面章节进行效果展示,结果预览:

| 模型 | ACC |Precision |Recall| F1 |average_of_acc_and_f1 |

| -------- | -------- | -------- | -------- | -------- |-------- |

| ERNIE 3.0 Base| 0.80255 | 0.9317147 |0.908284 | 0.919850 |0.86120 |

| ERNIE 3.0 Base+数据增强 | 0.7979539 | 0.901004 |0.92899 | 0.91478 |0.8563 |

| ERNIE 3.0 Base+剪裁保留比0.5 | 0.79846 | 0.951257 |0.89497 | 0.92225 |0.8603 |

| ERNIE 3.0 Base +剪裁保留比2/3 | 0.8092071 | 0.9415384 |0.905325 | 0.923076 |0.86614 |

gensim安装最新版本:pip install gensim

tqdm安装:pip install tqdm

LAC安装最新版本:pip install lac


**Gensim库介绍**

Gensim是在做自然语言处理时较为经常用到的一个工具库,主要用来以无监督的方式从原始的非结构化文本当中来学习到文本隐藏层的主题向量表达。

主要包括TF-IDF,LSA,LDA,word2vec,doc2vec等多种模型。

**Tqdm**

是一个快速,可扩展的Python进度条,可以在 Python 长循环中添加一个进度提示信息,用户只需要封装任意的迭代器 tqdm(iterator)。目的为了程序显示的美观

**中文词法分析-LAC**

LAC是一个联合的词法分析模型,整体性地完成中文分词、词性标注、专名识别任务。LAC既可以认为是Lexical Analysis of Chinese的首字母缩写,也可以认为是LAC Analyzes Chinese的递归缩写。

LAC基于一个堆叠的双向GRU结构,在长文本上准确复刻了百度AI开放平台上的词法分析算法。效果方面,分词、词性、专名识别的整体准确率95.5%;单独评估专名识别任务,F值87.1%(准确90.3,召回85.4%),总体略优于开放平台版本。在效果优化的基础上,LAC的模型简洁高效,内存开销不到100M,而速度则比百度AI开放平台提高了57%

LAC链接:https://www.paddlepaddle.org.cn/modelbasedetail/lac

!pip install --upgrade paddlenlp

!pip install gensim

!pip install tqdm

!pip install lac

2.数据增强方案介绍

数据增强工具提供4种增强策略:遮盖、删除、同词性词替换、词向量近义词替换

在这里插入图片描述
在这里插入图片描述

!unzip ERNIE-.zip -d ./ERNIE

#添加ERNIE工具包

代码语言:txt
复制
如果程序报错:

可以发现提示有一个.ipynb\_checkpoints的文件。但当我去对应的文件夹找时根本看不到这个文件,所以猜测是一个隐藏文件。所以通过终端进入对应的目录:输入cd coco进入对应目录,输入ls -a显示所有文件。然后输入rm -rf .ipynb\_checkpoints删除该文件。再次输入ls -a查看文件是否被删除。

下载词表,词表有1.7G会花点时间。下面以情感分析数据样例展示demo,看看数据增强的效果。

代码语言:python
复制
!wget -q --no-check-certificate http://bj.bcebos.com/wenxin-models/vec2.txt

**python data_aug.py "输入文件夹的目录" "输出文件夹的目录"**

* **data_aug.py脚本传参说明**

代码语言:txt
复制
shell输入:

    python data\_aug.py -h



shell输出:

    usage: data\_aug.py [-h] [-n AUG\_TIMES] [-c COLUMN\_NUMBER] [-u UNK]

                       [-t TRUNCATE] [-r POS\_REPLACE] [-w W2V\_REPLACE]

                       [-e ERNIE\_REPLACE] [--unk\_token UNK\_TOKEN]

                       input output

    

    main

    

    positional arguments:

      input                                                #原始待增强数据文件所在文件夹,带label的,一个或多个文本列

      output                                               #输出文件路径

    

    optional arguments:

      -h, --help            show this help message and exit

      -n AUG\_TIMES, --aug\_times AUG\_TIMES                  #数据集数目放大n倍,output行数为input的n+1倍      

      -c COLUMN\_NUMBER, --column\_number COLUMN\_NUMBER      #明文文件中所要增强列的列序号,多列用逗号分割,如:1,2

      -u UNK, --unk UNK                                    #unk 增强策略的概率

      -t TRUNCATE, --truncate TRUNCATE                     #truncate 增强策略的概率

      -r POS\_REPLACE, --pos\_replace POS\_REPLACE            #pos\_replace 增强策略的概率

      -w W2V\_REPLACE, --w2v\_replace W2V\_REPLACE            #w2v\_replace 增强策略的概率

      --unk\_token UNK\_TOKEN                    

分类问题中:推荐使用前三种即可,w2v词向量近义词替换可以不用,花费时间太长。

代码语言:python
复制
!python data\_aug.py --unk 0.25 --truncate 0.25 --pos 0.5 --w2v 0 ./train ./output
代码语言:python
复制
demo结果展示:



机器 背面 似乎 被 撕 了 张 什么 标签 , 残 胶 还在 。 但是 又 看 不 出 是 什么 标签 不见 了 , 该 有 的 都 在 , 怪    0

机器 背面 似乎 被 撕 了 张 什么 标签 , 胶 还在 。 但是 又 看 不 出 是 什么 标签 不见 了 , 该 有 的 都 在 , 怪    0

机器 背面 了 张 什么 标签 , 残 胶 还在 。 但是 又 看 不 出 是 什么 标签  了 , 该在 , 怪    0

呵呵 , 虽然 表皮 看上去 不错 很 精致 , 但是 我 还是 能 看得出来 是 盗 的 。 但是 里面 的 内容 真 的 不错 , 我 妈 爱 看 , 我自己 也 学 着 找 一些 穴位 。    0

呵呵 , 虽然 表皮 看上去 不错 很 精致 , 但是 我 还是 能 看得出来 是 盗 的 。 但是 里面 的 内容 真 的 不错 , 我?妈 爱 看 , 我自己 也 学 着 找 一些 穴位 ?    0

呵呵 , 虽然 表皮 看上去 不错 很 精致 , 但是 我 还? 能 看得出来 是 盗???。 但是 里面 的 内容 真 的 不错 , 我 妈 爱 看 ,???????学 着 找 ???????    0

?????虽然 表皮 看上去 不错 很 精致 , 但是 我 还是 能 看得出来 是 盗 的 。 但是 里面 的 内容 真 的 不错 , 我 妈 爱 看 , 我自己 也 学 着 找 一些 穴位 。    0

??????? 表皮 看上去 不错 很 精致 , 但是 我 还是 能 看得出来 是 盗 的 。 但是 里面 的 内容 真 的 不错 , 我 妈 爱 看 , 我自己 也 学 着 找 一些 穴位 。    0

地理 位置 佳 , 在 市中心 。 酒店 服务 好 、 早餐 品种 丰富 。 我 住 的 商务 数码 房 电脑 宽带 速度 满意 , 房间 还算 干净 , 离 湖南路小吃街 近 。    1

地理 位置 佳 , 在 市中心 。 酒店 服务 好 、 早餐 品种 丰富 。 我 住 的 商务 数码 房 电脑 宽带 速度 满意 , 房间 还算 干净 , 离 湖南路小吃街 近。。    1

地理 位置 佳 , 在 市中心 。 酒店 服务 好 、 早餐 品种 丰富 。 我 住 的 商务 数码 房 电脑 宽带 速度 满意 , 机器 还算 干净 , 离 湖南路小吃街 近 。    1

地理 位置 佳 , 在 市中心 。 酒店 服务 好 、 早餐 品种 丰富 。 我 住 的 商务 数码 房 电脑 宽带 速度 满意 , 房间 还算 干净 , 离 湖南路小吃街 近 。    1

地理 位置 佳 , 在 市中心 。 酒店 服务 好 、 早餐 品种 丰富 。 我 住 的 商务 数码 房 电脑 宽

我 看 是 书 的 还 可以 , 但是 我 订 的 书 迟迟 还 到 能 半个月 , 都 没有 收到 打电话 也 没

2.0 补充nlpcda一键中文数据增强工具(NLP Chinese Data Augmentation )

一键中文数据增强工具,支持:

1.随机实体替换

2.近义词

3.近义近音字替换

4.随机字删除(内部细节:数字时间日期片段,内容不会删)

5.NER类 BIO 数据增强

6.随机置换邻近的字:研表究明,汉字序顺并不定一影响文字的阅读理解<<是乱序的

7.中文等价字替换(1 一 壹 ①,2 二 贰 ②)

8.翻译互转实现的增强

9.使用simbert做生成式相似句生成

参考链接:

一键中文数据增强包 ; NLP数据增强、bert数据增强、EDA:pip install nlpcda

nlpcda一键中文数据增强工具

3.数据蒸馏技术

ERNIE数据蒸馏三步

**Step 1.** 使用ERNIE模型对输入标注数据对进行fine-tune,得到Teacher Model

**Step 2**. 使用ERNIE Service对以下无监督数据进行预测:

* 用户提供的大规模无标注数据,需与标注数据同源

* 对标注数据进行数据增强,具体增强策略

* 对无标注数据和数据增强数据进行一定比例混合

**Step 3.** 使用步骤2的数据训练出Student Model

数据增强

目前采用三种数据增强策略策略,对于不用的任务可以特定的比例混合。三种数据增强策略包括:

添加噪声:对原始样本中的词,以一定的概率(如0.1)替换为”UNK”标签

同词性词替换:对原始样本中的所有词,以一定的概率(如0.1)替换为本数据集钟随机一个同词性的词

N-sampling:从原始样本中,随机选取位置截取长度为m的片段作为新的样本,其中片段的长度m为0到原始样本长度之间的随机值

在这里插入图片描述
在这里插入图片描述

具体效果在下一节展现,先安装好paddleslim库

4.基于ERNIR3.0文本模型微调

加载已有数据集:CBLUE数据集中医疗搜索检索词意图分类(训练)

数据集定义:

以公开数据集CBLUE数据集中医疗搜索检索词意图分类(KUAKE-QIC)任务为示例,在训练集上进行模型微调,并在开发集上使用准确率Accuracy评估模型表现。

数据集默认为:默认为"cblue"。

**save_dir**:保存训练模型的目录;默认保存在当前目录checkpoint文件夹下。

**dataset**:训练数据集;默认为"cblue"。

<font color="red">**dataset_dir**:本地数据集路径,数据集路径中应包含train.txt,dev.txt和label.txt文件;默认为None。</font>

**task_name**:训练数据集;默认为"KUAKE-QIC"。

max_seq_length:ERNIE模型使用的最大序列长度,最大不能超过512, 若出现显存不足,请适当调低这一参数;默认为128。

<font color="red">**model_name**:选择预训练模型;默认为"ernie-3.0-base-zh"。</font>

<font color="red">**device**: 选用什么设备进行训练,可选cpu、gpu、xpu、npu。如使用gpu训练,可使用参数gpus指定GPU卡号。</font>

**batch_size**:批处理大小,请结合显存情况进行调整,若出现显存不足,请适当调低这一参数;默认为32。

**learning_rate**:Fine-tune的最大学习率;默认为6e-5。

**weight_decay**:控制正则项力度的参数,用于防止过拟合,默认为0.01。

**early_stop**:选择是否使用早停法(EarlyStopping);默认为False。

<font color="red">**early_stop_nums**:在设定的早停训练轮次内,模型在开发集上表现不再上升,训练终止;默认为4。

epochs: 训练轮次,默认为100。</font>

**warmup**:是否使用学习率warmup策略;默认为False。

**warmup_proportion**:学习率warmup策略的比例数,如果设为0.1,则学习率会在前10%steps数从0慢慢增长到learning_rate, 而后再缓慢衰减;默认为0.1。

**logging_steps**: 日志打印的间隔steps数,默认5。

**init_from_ckpt**: 模型初始checkpoint参数地址,默认None。

**seed**:随机种子,默认为3。

代码语言:python
复制
#修改后的训练文件train\_new2.py ,主要使用了paddlenlp.metrics.glue的AccuracyAndF1:准确率及F1-score,可用于GLUE中的MRPC 和QQP任务

#不过吐槽一下:    return (acc,precision,recall,f1,(acc + f1) / 2,) 最后一个指标竟然是加权平均.....

!python train\_new2.py --warmup --early\_stop --epochs 10 --save\_dir "./checkpoint2" --batch\_size 16 --model\_name ernie-3.0-base-zh

训练结果部分展示:

代码语言:txt
复制
[2022-08-16 19:58:36,834] [    INFO] - global step 1280, epoch: 3, batch: 412, loss: 0.23292, acc: 0.87106, speed: 16.54 step/s

[2022-08-16 19:58:37,392] [    INFO] - global step 1290, epoch: 3, batch: 422, loss: 0.22339, acc: 0.87130, speed: 17.94 step/s

[2022-08-16 19:58:37,960] [    INFO] - global step 1300, epoch: 3, batch: 432, loss: 0.22791, acc: 0.87182, speed: 17.68 step/s

(acc, precision, recall, f1, average\_of\_acc\_and\_f1):(0.8025575447570332, 0.9317147192716236, 0.908284023668639, 0.9198501872659175, 0.8612038660114754)

2022-08-16 20:01:36,060 - Early stop!

2022-08-16 20:01:36,060 - Save best accuracy text classification model in ./checkpoint2

4.1 加载自定义数据集(并通过数据增强训练)

**从本地文件创建数据集**

**使用本地数据集来训练我们的文本分类模型,本项目支持使用固定格式本地数据集文件进行训练**

如果需要对本地数据集进行数据标注,可以参考文本分类任务doccano数据标注使用指南进行文本分类数据标注。**这个放到下个项目讲解**

本项目将以CBLUE数据集中医疗搜索检索词意图分类(KUAKE-QIC)任务为例进行介绍如何加载本地固定格式数据集进行训练:

**本地数据集目录结构如下:**

代码语言:txt
复制
data/

├── train.txt # 训练数据集文件

├── dev.txt # 开发数据集文件

├── label.txt # 分类标签文件

└── data.txt # 可选,待预测数据文件

部分结果展示

代码语言:txt
复制
[2022-08-16 23:43:18,093] [    INFO] - global step 2400, epoch: 2, batch: 234, loss: 0.60859, acc: 0.84437, speed: 19.27 step/s

(acc, precision, recall, f1, average\_of\_acc\_and\_f1):(0.7979539641943734, 0.9010043041606887, 0.9289940828402367, 0.9147851420247632, 0.8563695531095683)

[2022-08-16 23:43:24,522] [    INFO] - Save best F1 text classification model in ./checkpoint3

[2022-08-16 23:43:24,523] [    INFO] - best F1 performence has been updated: 0.91450 --> 0.91479

4.2 数据蒸馏

代码语言:python
复制
!unset CUDA\_VISIBLE\_DEVICES

!python -m paddle.distributed.launch --gpus "0" prune.py \

    --device "gpu" \

    --output\_dir "./prune" \

    --per\_device\_train\_batch\_size 32 \

    --per\_device\_eval\_batch\_size 32 \

    --learning\_rate 3e-5 \

    --num\_train\_epochs 5 \

    --logging\_steps 10 \

    --save\_steps 50 \

    --seed 3 \

    --dataset\_dir "KUAKE\_QIC" \

    --max\_seq\_length 128 \

    --params\_dir "./checkpoint3" \

    --width\_mult '0.5'

部分结果展示:

代码语言:txt
复制
[2022-08-17 14:22:30,954] [    INFO] - width\_mult: 0.5, eval loss: 0.63535, acc: 0.79847

(acc, precision, recall, f1, average\_of\_acc\_and\_f1):(0.7984654731457801, 0.9512578616352201, 0.8949704142011834, 0.9222560975609755, 0.8603607853533778)

[2022-08-17 14:22:35,870] [    INFO] - Save best F1 text classification model in ./prune/0.5

[2022-08-17 14:22:35,870] [    INFO] - best F1 performence has been updated: 0.92226 --> 0.92226
代码语言:python
复制
!unset CUDA\_VISIBLE\_DEVICES

!python -m paddle.distributed.launch --gpus "0" prune.py \

    --device "gpu" \

    --output\_dir "./prune" \

    --per\_device\_train\_batch\_size 32 \

    --per\_device\_eval\_batch\_size 32 \

    --learning\_rate 3e-5 \

    --num\_train\_epochs 5 \

    --logging\_steps 10 \

    --save\_steps 50 \

    --seed 3 \

    --dataset\_dir "KUAKE\_QIC" \

    --max\_seq\_length 128 \

    --params\_dir "./checkpoint3" \

    --width\_mult '2/3'
代码语言:txt
复制
2022-08-17 14:53:45,544] [    INFO] - global step 3070, epoch: 2, batch: 904, loss: 0.709566, speed: 9.93 step/s

[2022-08-17 14:53:46,550] [    INFO] - global step 3080, epoch: 2, batch: 914, loss: 0.607238, speed: 9.94 step/s

[2022-08-17 14:53:47,558] [    INFO] - global step 3090, epoch: 2, batch: 924, loss: 0.718484, speed: 9.93 step/s

[2022-08-17 14:53:48,563] [    INFO] - global step 3100, epoch: 2, batch: 934, loss: 0.546288, speed: 9.95 step/s

[2022-08-17 14:53:50,206] [    INFO] - teacher model, eval loss: 0.66438, acc: 0.80358

[2022-08-17 14:53:50,207] [    INFO] - eval done total : 1.6434180736541748 s

[2022-08-17 14:53:53,568] [    INFO] - width\_mult: 0.6666666666666666, eval loss: 0.60219, acc: 0.80921

(acc, precision, recall, f1, average\_of\_acc\_and\_f1):(0.8092071611253197, 0.9415384615384615, 0.9053254437869822, 0.923076923076923, 0.8661420421011213)

[2022-08-17 14:53:58,489] [    INFO] - Save best F1 text classification model in ./prune/0.6666666666666666

[2022-08-17 14:53:58,489] [    INFO] - best F1 performence has been updated: 0.92308 --> 0.92308

4.3 模型预测

输入待预测数据和数据标签对照列表,模型预测数据对应的标签

使用默认数据进行预测:

代码语言:python
复制
#也可以选择使用本地数据文件data/data.txt进行预测:

!python predict.py --params\_path ./checkpoint3/ --dataset\_dir ./KUAKE\_QIC --device "cpu"
代码语言:python
复制
黑苦荞茶的功效与作用及食用方法 功效作用

交界痣会凸起吗 疾病表述

检查是否能怀孕挂什么科 就医建议

鱼油怎么吃咬破吃还是直接咽下去 其他

幼儿挑食的生理原因是 病因分析
代码语言:python
复制
!python predict.py \

    --device "cpu" \

    --dataset\_dir ./KUAKE\_QIC \

    --params\_path "./prune/0.5" \

5.总结

本项目首先讲解了数据增强和数据蒸馏的方案,并在后面章节进行效果展示,现在进行汇总

| 模型 | ACC |Precision |Recall| F1 |average_of_acc_and_f1 |

| -------- | -------- | -------- | -------- | -------- |-------- |

| ERNIE 3.0 Base| 0.80255 | 0.9317147 |0.908284 | 0.919850 |0.86120 |

| ERNIE 3.0 Base+数据增强 | 0.7979539 | 0.901004 |0.92899 | 0.91478 |0.8563 |

| ERNIE 3.0 Base+剪裁保留比0.5 | 0.79846 | 0.951257 |0.89497 | 0.92225 |0.8603 |

| ERNIE 3.0 Base +剪裁保留比2/3 | 0.8092071 | 0.9415384 |0.905325 | 0.923076 |0.86614 |

分析可得,

* 首先数据增强后导致性能部分下降部分和预期的原因:

随机mask、删除会产生过多噪声样本影响结果,推荐只使用同义词替换,本次样本数据量足够,且ERNIE性能本就优越,数据增强对结果提升在较大样本集可以忽略。

* 其次,可以看到通过数据蒸馏后,模型性能变化不大,甚至在剪裁1/3之后,性能有小幅度提升

本次主要对分类模型加入数据增强、数据蒸馏,已经对性能指标进行细化,不只是ACC,个人比较关注F1情况,并作为保存模型依据。

**展望:** 后续将完善动态图和静态图转化部分,让蒸馏下来模型可以继续线上加载使用;其次将会考虑小样本学习在分类模型应用情况;最后将完成模型融合环节提升性能,并做可解释性分析。

**本人博客:**https://blog.csdn.net/sinat_39620217?type=blog

原创声明:本文系作者授权腾讯云开发者社区发表,未经许可,不得转载。

如有侵权,请联系 cloudcommunity@tencent.com 删除。

原创声明:本文系作者授权腾讯云开发者社区发表,未经许可,不得转载。

如有侵权,请联系 cloudcommunity@tencent.com 删除。

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
目录
  • 在数据增强、蒸馏剪枝下ERNIE3.0模型性能提升
  • 2.数据增强方案介绍
  • 2.0 补充nlpcda一键中文数据增强工具(NLP Chinese Data Augmentation )
  • 3.数据蒸馏技术
    • ERNIE数据蒸馏三步
    • 4.基于ERNIR3.0文本模型微调
      • 4.1 加载自定义数据集(并通过数据增强训练)
        • 4.2 数据蒸馏
          • 4.3 模型预测
          • 5.总结
          相关产品与服务
          NLP 服务
          NLP 服务(Natural Language Process,NLP)深度整合了腾讯内部的 NLP 技术,提供多项智能文本处理和文本生成能力,包括词法分析、相似词召回、词相似度、句子相似度、文本润色、句子纠错、文本补全、句子生成等。满足各行业的文本智能需求。
          领券
          问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档
          http://www.vxiaotou.com