CPM:A large-scale generative chinese pre-trained lanuage model

2023-10-09 23:59

本文主要是介绍CPM:A large-scale generative chinese pre-trained lanuage model,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

GitHub - yangjianxin1/CPM: Easy-to-use CPM for Chinese text generation(基于CPM的中文文本生成)Easy-to-use CPM for Chinese text generation(基于CPM的中文文本生成) - GitHub - yangjianxin1/CPM: Easy-to-use CPM for Chinese text generation(基于CPM的中文文本生成)https://github.com/yangjianxin1/CPM论文《CPM: A Large-scale Generative Chinese Pre-trained Language Model》_陈欢伯的博客-CSDN博客1. IntroductionGPT-3含有175B参数使用了570GB的数据进行训练。但大多数语料是基于英文(93%),并且GPT-3的参数没有分布,所以提出了CPM(Chinese Pretrained language Model):包含2.6B参数,使用100GB中文训练数据。CPM可以对接下游任务:对话、文章生成、完形填空、语言理解。随着参数规模的增加,CPM在一些数据集上表现更好,表示大模型在语言生成和理解上面更有效。文章的主要贡献发布了一个CPM:2.6B参数,100GB中文训练https://blog.csdn.net/mark_technology/article/details/118680728https://github.com/leeguandong/CPMhttps://github.com/leeguandong/CPM文章本身写的非常简单,至于模型结构这块,可以看一下放出来的代码,还挺好用的,我跑一个电商场景的推荐文章生成模型,效果也不错。在生成模型上还是很建议尝试一下CPM,整体采用transformer中的代码实现,比较简洁。第三个链接是我在电商数据上训练的cpm,提供了权重。

中文版GPT-3来了?智源、清华发布清源 CPM——以中文为核心的大规模预训练模型

上面计算时间为使用单块NVIDIA V100 GPU训练的估计时间。

1.Approach

1.1 Chinese PLM(pretrained lanuage model)

上面是CPM的模型参数版本,其中small版本至少我是可以在gtx1080ti上训练的,后面我会添加我的具体训练参数。

稍微过一下CPM的模型结构,其实就是gpt2的模型:

GPT2LMHeadModel((transformer): GPT2Model((wte): Embedding(30000, 768)(wpe): Embedding(1024, 768)(drop): Dropout(p=0.1, inplace=False)(h): ModuleList((0): GPT2Block((ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)(attn): GPT2Attention((c_attn): Conv1D()(c_proj): Conv1D()(attn_dropout): Dropout(p=0.1, inplace=False)(resid_dropout): Dropout(p=0.1, inplace=False))(ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)(mlp): GPT2MLP((c_fc): Conv1D()(c_proj): Conv1D()(dropout): Dropout(p=0.1, inplace=False)))(1): GPT2Block((ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)(attn): GPT2Attention((c_attn): Conv1D()(c_proj): Conv1D()(attn_dropout): Dropout(p=0.1, inplace=False)(resid_dropout): Dropout(p=0.1, inplace=False))(ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)(mlp): GPT2MLP((c_fc): Conv1D()(c_proj): Conv1D()(dropout): Dropout(p=0.1, inplace=False)))(2): GPT2Block((ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)(attn): GPT2Attention((c_attn): Conv1D()(c_proj): Conv1D()(attn_dropout): Dropout(p=0.1, inplace=False)(resid_dropout): Dropout(p=0.1, inplace=False))(ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)(mlp): GPT2MLP((c_fc): Conv1D()(c_proj): Conv1D()(dropout): Dropout(p=0.1, inplace=False)))(3): GPT2Block((ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)(attn): GPT2Attention((c_attn): Conv1D()(c_proj): Conv1D()(attn_dropout): Dropout(p=0.1, inplace=False)(resid_dropout): Dropout(p=0.1, inplace=False))(ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)(mlp): GPT2MLP((c_fc): Conv1D()(c_proj): Conv1D()(dropout): Dropout(p=0.1, inplace=False)))(4): GPT2Block((ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)(attn): GPT2Attention((c_attn): Conv1D()(c_proj): Conv1D()(attn_dropout): Dropout(p=0.1, inplace=False)(resid_dropout): Dropout(p=0.1, inplace=False))(ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)(mlp): GPT2MLP((c_fc): Conv1D()(c_proj): Conv1D()(dropout): Dropout(p=0.1, inplace=False)))(5): GPT2Block((ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)(attn): GPT2Attention((c_attn): Conv1D()(c_proj): Conv1D()(attn_dropout): Dropout(p=0.1, inplace=False)(resid_dropout): Dropout(p=0.1, inplace=False))(ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)(mlp): GPT2MLP((c_fc): Conv1D()(c_proj): Conv1D()(dropout): Dropout(p=0.1, inplace=False)))(6): GPT2Block((ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)(attn): GPT2Attention((c_attn): Conv1D()(c_proj): Conv1D()(attn_dropout): Dropout(p=0.1, inplace=False)(resid_dropout): Dropout(p=0.1, inplace=False))(ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)(mlp): GPT2MLP((c_fc): Conv1D()(c_proj): Conv1D()(dropout): Dropout(p=0.1, inplace=False)))(7): GPT2Block((ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)(attn): GPT2Attention((c_attn): Conv1D()(c_proj): Conv1D()(attn_dropout): Dropout(p=0.1, inplace=False)(resid_dropout): Dropout(p=0.1, inplace=False))(ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)(mlp): GPT2MLP((c_fc): Conv1D()(c_proj): Conv1D()(dropout): Dropout(p=0.1, inplace=False)))(8): GPT2Block((ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)(attn): GPT2Attention((c_attn): Conv1D()(c_proj): Conv1D()(attn_dropout): Dropout(p=0.1, inplace=False)(resid_dropout): Dropout(p=0.1, inplace=False))(ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)(mlp): GPT2MLP((c_fc): Conv1D()(c_proj): Conv1D()(dropout): Dropout(p=0.1, inplace=False)))(9): GPT2Block((ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)(attn): GPT2Attention((c_attn): Conv1D()(c_proj): Conv1D()(attn_dropout): Dropout(p=0.1, inplace=False)(resid_dropout): Dropout(p=0.1, inplace=False))(ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)(mlp): GPT2MLP((c_fc): Conv1D()(c_proj): Conv1D()(dropout): Dropout(p=0.1, inplace=False)))(10): GPT2Block((ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)(attn): GPT2Attention((c_attn): Conv1D()(c_proj): Conv1D()(attn_dropout): Dropout(p=0.1, inplace=False)(resid_dropout): Dropout(p=0.1, inplace=False))(ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)(mlp): GPT2MLP((c_fc): Conv1D()(c_proj): Conv1D()(dropout): Dropout(p=0.1, inplace=False)))(11): GPT2Block((ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)(attn): GPT2Attention((c_attn): Conv1D()(c_proj): Conv1D()(attn_dropout): Dropout(p=0.1, inplace=False)(resid_dropout): Dropout(p=0.1, inplace=False))(ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)(mlp): GPT2MLP((c_fc): Conv1D()(c_proj): Conv1D()(dropout): Dropout(p=0.1, inplace=False))))(ln_f): LayerNorm((768,), eps=1e-05, elementwise_affine=True))(lm_head): Linear(in_features=768, out_features=30000, bias=False)
)Process finished with exit code 0
transformer.wte.weight   [30000, 768]
transformer.wpe.weight   [1024, 768]
transformer.h.0.ln_1.weight   [768]
transformer.h.0.ln_1.bias   [768]
transformer.h.0.attn.bias   [1, 1, 1024, 1024]
transformer.h.0.attn.masked_bias   []
transformer.h.0.attn.c_attn.weight   [768, 2304]
transformer.h.0.attn.c_attn.bias   [2304]
transformer.h.0.attn.c_proj.weight   [768, 768]
transformer.h.0.attn.c_proj.bias   [768]
transformer.h.0.ln_2.weight   [768]
transformer.h.0.ln_2.bias   [768]
transformer.h.0.mlp.c_fc.weight   [768, 3072]
transformer.h.0.mlp.c_fc.bias   [3072]
transformer.h.0.mlp.c_proj.weight   [3072, 768]
transformer.h.0.mlp.c_proj.bias   [768]
transformer.h.1.ln_1.weight   [768]
transformer.h.1.ln_1.bias   [768]
transformer.h.1.attn.bias   [1, 1, 1024, 1024]
transformer.h.1.attn.masked_bias   []
transformer.h.1.attn.c_attn.weight   [768, 2304]
transformer.h.1.attn.c_attn.bias   [2304]
transformer.h.1.attn.c_proj.weight   [768, 768]
transformer.h.1.attn.c_proj.bias   [768]
transformer.h.1.ln_2.weight   [768]
transformer.h.1.ln_2.bias   [768]
transformer.h.1.mlp.c_fc.weight   [768, 3072]
transformer.h.1.mlp.c_fc.bias   [3072]
transformer.h.1.mlp.c_proj.weight   [3072, 768]
transformer.h.1.mlp.c_proj.bias   [768]
transformer.h.2.ln_1.weight   [768]
transformer.h.2.ln_1.bias   [768]
transformer.h.2.attn.bias   [1, 1, 1024, 1024]
transformer.h.2.attn.masked_bias   []
transformer.h.2.attn.c_attn.weight   [768, 2304]
transformer.h.2.attn.c_attn.bias   [2304]
transformer.h.2.attn.c_proj.weight   [768, 768]
transformer.h.2.attn.c_proj.bias   [768]
transformer.h.2.ln_2.weight   [768]
transformer.h.2.ln_2.bias   [768]
transformer.h.2.mlp.c_fc.weight   [768, 3072]
transformer.h.2.mlp.c_fc.bias   [3072]
transformer.h.2.mlp.c_proj.weight   [3072, 768]
transformer.h.2.mlp.c_proj.bias   [768]
transformer.h.3.ln_1.weight   [768]
transformer.h.3.ln_1.bias   [768]
transformer.h.3.attn.bias   [1, 1, 1024, 1024]
transformer.h.3.attn.masked_bias   []
transformer.h.3.attn.c_attn.weight   [768, 2304]
transformer.h.3.attn.c_attn.bias   [2304]
transformer.h.3.attn.c_proj.weight   [768, 768]
transformer.h.3.attn.c_proj.bias   [768]
transformer.h.3.ln_2.weight   [768]
transformer.h.3.ln_2.bias   [768]
transformer.h.3.mlp.c_fc.weight   [768, 3072]
transformer.h.3.mlp.c_fc.bias   [3072]
transformer.h.3.mlp.c_proj.weight   [3072, 768]
transformer.h.3.mlp.c_proj.bias   [768]
transformer.h.4.ln_1.weight   [768]
transformer.h.4.ln_1.bias   [768]
transformer.h.4.attn.bias   [1, 1, 1024, 1024]
transformer.h.4.attn.masked_bias   []
transformer.h.4.attn.c_attn.weight   [768, 2304]
transformer.h.4.attn.c_attn.bias   [2304]
transformer.h.4.attn.c_proj.weight   [768, 768]
transformer.h.4.attn.c_proj.bias   [768]
transformer.h.4.ln_2.weight   [768]
transformer.h.4.ln_2.bias   [768]
transformer.h.4.mlp.c_fc.weight   [768, 3072]
transformer.h.4.mlp.c_fc.bias   [3072]
transformer.h.4.mlp.c_proj.weight   [3072, 768]
transformer.h.4.mlp.c_proj.bias   [768]
transformer.h.5.ln_1.weight   [768]
transformer.h.5.ln_1.bias   [768]
transformer.h.5.attn.bias   [1, 1, 1024, 1024]
transformer.h.5.attn.masked_bias   []
transformer.h.5.attn.c_attn.weight   [768, 2304]
transformer.h.5.attn.c_attn.bias   [2304]
transformer.h.5.attn.c_proj.weight   [768, 768]
transformer.h.5.attn.c_proj.bias   [768]
transformer.h.5.ln_2.weight   [768]
transformer.h.5.ln_2.bias   [768]
transformer.h.5.mlp.c_fc.weight   [768, 3072]
transformer.h.5.mlp.c_fc.bias   [3072]
transformer.h.5.mlp.c_proj.weight   [3072, 768]
transformer.h.5.mlp.c_proj.bias   [768]
transformer.h.6.ln_1.weight   [768]
transformer.h.6.ln_1.bias   [768]
transformer.h.6.attn.bias   [1, 1, 1024, 1024]
transformer.h.6.attn.masked_bias   []
transformer.h.6.attn.c_attn.weight   [768, 2304]
transformer.h.6.attn.c_attn.bias   [2304]
transformer.h.6.attn.c_proj.weight   [768, 768]
transformer.h.6.attn.c_proj.bias   [768]
transformer.h.6.ln_2.weight   [768]
transformer.h.6.ln_2.bias   [768]
transformer.h.6.mlp.c_fc.weight   [768, 3072]
transformer.h.6.mlp.c_fc.bias   [3072]
transformer.h.6.mlp.c_proj.weight   [3072, 768]
transformer.h.6.mlp.c_proj.bias   [768]
transformer.h.7.ln_1.weight   [768]
transformer.h.7.ln_1.bias   [768]
transformer.h.7.attn.bias   [1, 1, 1024, 1024]
transformer.h.7.attn.masked_bias   []
transformer.h.7.attn.c_attn.weight   [768, 2304]
transformer.h.7.attn.c_attn.bias   [2304]
transformer.h.7.attn.c_proj.weight   [768, 768]
transformer.h.7.attn.c_proj.bias   [768]
transformer.h.7.ln_2.weight   [768]
transformer.h.7.ln_2.bias   [768]
transformer.h.7.mlp.c_fc.weight   [768, 3072]
transformer.h.7.mlp.c_fc.bias   [3072]
transformer.h.7.mlp.c_proj.weight   [3072, 768]
transformer.h.7.mlp.c_proj.bias   [768]
transformer.h.8.ln_1.weight   [768]
transformer.h.8.ln_1.bias   [768]
transformer.h.8.attn.bias   [1, 1, 1024, 1024]
transformer.h.8.attn.masked_bias   []
transformer.h.8.attn.c_attn.weight   [768, 2304]
transformer.h.8.attn.c_attn.bias   [2304]
transformer.h.8.attn.c_proj.weight   [768, 768]
transformer.h.8.attn.c_proj.bias   [768]
transformer.h.8.ln_2.weight   [768]
transformer.h.8.ln_2.bias   [768]
transformer.h.8.mlp.c_fc.weight   [768, 3072]
transformer.h.8.mlp.c_fc.bias   [3072]
transformer.h.8.mlp.c_proj.weight   [3072, 768]
transformer.h.8.mlp.c_proj.bias   [768]
transformer.h.9.ln_1.weight   [768]
transformer.h.9.ln_1.bias   [768]
transformer.h.9.attn.bias   [1, 1, 1024, 1024]
transformer.h.9.attn.masked_bias   []
transformer.h.9.attn.c_attn.weight   [768, 2304]
transformer.h.9.attn.c_attn.bias   [2304]
transformer.h.9.attn.c_proj.weight   [768, 768]
transformer.h.9.attn.c_proj.bias   [768]
transformer.h.9.ln_2.weight   [768]
transformer.h.9.ln_2.bias   [768]
transformer.h.9.mlp.c_fc.weight   [768, 3072]
transformer.h.9.mlp.c_fc.bias   [3072]
transformer.h.9.mlp.c_proj.weight   [3072, 768]
transformer.h.9.mlp.c_proj.bias   [768]
transformer.h.10.ln_1.weight   [768]
transformer.h.10.ln_1.bias   [768]
transformer.h.10.attn.bias   [1, 1, 1024, 1024]
transformer.h.10.attn.masked_bias   []
transformer.h.10.attn.c_attn.weight   [768, 2304]
transformer.h.10.attn.c_attn.bias   [2304]
transformer.h.10.attn.c_proj.weight   [768, 768]
transformer.h.10.attn.c_proj.bias   [768]
transformer.h.10.ln_2.weight   [768]
transformer.h.10.ln_2.bias   [768]
transformer.h.10.mlp.c_fc.weight   [768, 3072]
transformer.h.10.mlp.c_fc.bias   [3072]
transformer.h.10.mlp.c_proj.weight   [3072, 768]
transformer.h.10.mlp.c_proj.bias   [768]
transformer.h.11.ln_1.weight   [768]
transformer.h.11.ln_1.bias   [768]
transformer.h.11.attn.bias   [1, 1, 1024, 1024]
transformer.h.11.attn.masked_bias   []
transformer.h.11.attn.c_attn.weight   [768, 2304]
transformer.h.11.attn.c_attn.bias   [2304]
transformer.h.11.attn.c_proj.weight   [768, 768]
transformer.h.11.attn.c_proj.bias   [768]
transformer.h.11.ln_2.weight   [768]
transformer.h.11.ln_2.bias   [768]
transformer.h.11.mlp.c_fc.weight   [768, 3072]
transformer.h.11.mlp.c_fc.bias   [3072]
transformer.h.11.mlp.c_proj.weight   [3072, 768]
transformer.h.11.mlp.c_proj.bias   [768]
transformer.ln_f.weight   [768]
transformer.ln_f.bias   [768]
lm_head.weight   [30000, 768]

1.2 data processing

CPM的词汇表有3w个。丰富的中文训练数据,中文数据其实比较好搞,直接网上爬就可以,git上作为提供了一个作文预训练的模型,在这个预训练模型上finetune效果也不错,我的训练数据大概有7-8w的标题-文本对数据。

1.3 pr-training details

 lr=1.5x10-4,batch_size=3072,max_len:1024(训练时,输入数据的最大长度),steps=2000(前500轮warmup),optimizer=adam,64*v100训了2周。

2x1080ti:cpm-small版本,max_len:200,lr=0.00015,batch_size:16,steps:100,adamw。

transformer=4.6.0

2.后面是cpm在一些任务上的实验。

这篇关于CPM:A large-scale generative chinese pre-trained lanuage model的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/176486

相关文章

论文翻译:arxiv-2024 Benchmark Data Contamination of Large Language Models: A Survey

Benchmark Data Contamination of Large Language Models: A Survey https://arxiv.org/abs/2406.04244 大规模语言模型的基准数据污染:一项综述 文章目录 大规模语言模型的基准数据污染:一项综述摘要1 引言 摘要 大规模语言模型(LLMs),如GPT-4、Claude-3和Gemini的快

MVC(Model-View-Controller)和MVVM(Model-View-ViewModel)

1、MVC MVC(Model-View-Controller) 是一种常用的架构模式,用于分离应用程序的逻辑、数据和展示。它通过三个核心组件(模型、视图和控制器)将应用程序的业务逻辑与用户界面隔离,促进代码的可维护性、可扩展性和模块化。在 MVC 模式中,各组件可以与多种设计模式结合使用,以增强灵活性和可维护性。以下是 MVC 各组件与常见设计模式的关系和作用: 1. Model(模型)

[论文笔记]LLM.int8(): 8-bit Matrix Multiplication for Transformers at Scale

引言 今天带来第一篇量化论文LLM.int8(): 8-bit Matrix Multiplication for Transformers at Scale笔记。 为了简单,下文中以翻译的口吻记录,比如替换"作者"为"我们"。 大语言模型已被广泛采用,但推理时需要大量的GPU内存。我们开发了一种Int8矩阵乘法的过程,用于Transformer中的前馈和注意力投影层,这可以将推理所需

深度学习--对抗生成网络(GAN, Generative Adversarial Network)

对抗生成网络(GAN, Generative Adversarial Network)是一种深度学习模型,由Ian Goodfellow等人在2014年提出。GAN主要用于生成数据,通过两个神经网络相互对抗,来生成以假乱真的新数据。以下是对GAN的详细阐述,包括其概念、作用、核心要点、实现过程、代码实现和适用场景。 1. 概念 GAN由两个神经网络组成:生成器(Generator)和判别器(D

高精度打表-Factoring Large Numbers

求斐波那契数,不打表的话会超时,打表的话普通的高精度开不出来那么大的数组,不如一个int存8位,特殊处理一下,具体看代码 #include<stdio.h>#include<string.h>#define MAX_SIZE 5005#define LEN 150#define to 100000000/*一个int存8位*/int num[MAX_SIZE][LEN];void

diffusion model 合集

diffusion model 整理 DDPM: 前向一步到位,从数据集里的图片加噪声,根据随机到的 t t t 决定混合的比例,反向要慢慢迭代,DDPM是用了1000步迭代。模型的输入是带噪声图和 t,t 先生成embedding后,用通道和的方式加到每一层中间去: 训练过程是对每个样本分配一个随机的t,采样一个高斯噪声 ϵ \epsilon ϵ,然后根据 t 对图片和噪声进行混合,将加噪

android xml之动画篇 alpha、scale、translate、rotate、set的属性及用法 和

1.简介 Android的补间动画TweenAnimation由四种类型组成:alpha、scale、translate、rotate,对应android官方文档地址:《Animation Resources》 逐帧动画 FrameAnimation(也称 Drawable Animation  ):animation-list alpha 渐变透明度动画效果 scale 渐变

[论文笔记]Making Large Language Models A Better Foundation For Dense Retrieval

引言 今天带来北京智源研究院(BAAI)团队带来的一篇关于如何微调LLM变成密集检索器的论文笔记——Making Large Language Models A Better Foundation For Dense Retrieval。 为了简单,下文中以翻译的口吻记录,比如替换"作者"为"我们"。 密集检索需要学习具有区分性的文本嵌入,以表示查询和文档之间的语义关系。考虑到大语言模

Android AnimationDrawable资源 set[translate,alpha,scale,rotate]

本文内容摘自《疯狂Android讲义 第三版-李刚著作》 xml <?xml version="1.0" encoding="utf-8"?><set xmlns:android="http://schemas.android.com/apk/res/android"android:duration="1000"android:fillAfter="true"android:f

【机器学习】生成对抗网络(Generative Adversarial Networks, GANs)详解

🌈个人主页: 鑫宝Code 🔥热门专栏: 闲话杂谈| 炫酷HTML | JavaScript基础 ​💫个人格言: "如无必要,勿增实体" 文章目录 生成对抗网络(Generative Adversarial Networks, GANs)详解GANs的基本原理GANs的训练过程GANs的发展历程GANs在实际任务中的应用小结 生成对