代码编织梦想

目录

一、克隆 ChatGLM-6B 源码到本地

二、下载数据集ADGEN 数据集

三、代码修改

四、Kaggle代码运行

五、最后将模型加载到本地用CPU进行效果测试


一、克隆 ChatGLM-6B 源码到本地

        (1)下载压缩包到本地

        (2)Git 克隆

git clone https://github.com/THUDM/ChatGLM-6B.git

二、下载数据集ADGEN 数据集

    数据集下载地址:

# Google Drive:
https://drive.google.com/file/d/13_vf0xRTQsyneRKdD1bZIr93vBGOczrk/viewTsinghua 
# Cloud:
https://cloud.tsinghua.edu.cn/f/b3f119a008264b1cabd1/?dl=1 

  下载处理好的 ADGEN 数据集,将解压后的AdvertiseGen目录放到本目录下(也就是放在\ChatGLM-6B\ptuning文件夹下)。

三、代码修改

        (1)配置文件修改(ChatGLM-6B文件夹下的requirements.txt),删除torch>=1.10(kaggle自带相对应的cuda版本),并添加 rouge_chinese nltk jieba datasets 这几个依赖包。

         处理后的文件

        (2) 切换到ChatGLM-6B\ptuning目录下,修改训练的配置文件train.sh。需要修改的参数有:

2.1 CUDA_VISIBLE_DEVICES=0修改为CUDA_VISIBLE_DEVICES=0,1

     因为我们用的kaggle的T4*2单机双卡训练,这里把CUDA_VISIBLE_DEVICES可用的GPU设置为2张。

     注意:目前不知道什么原因,用kaggle的P100进行单卡训练时,出现以下报错(RuntimeError: CUDA error: no kernel image is available for execution on the device),可能是GPU版本问题,切换为T4*2双卡下训练就不会报错了。

2.2 --model_name_or_path参数:THUDM/chatglm-6b修改为THUDM/chatglm-6b-int4。

      原因很简单,显存有限,只能使用-int4量化的版本来训练,总的参数也有30亿多了。

2.3 --output_dir参数output/adgen-chatglm-6b-pt-$PRE_SEQ_LEN-$LR修改为/kaggle/working/output/adgen-chatglm-6b-pt-$PRE_SEQ_LEN-$LR

2.4 (可选)--max_steps,--logging_steps,--save_steps,为了演示方便,这里把这几个参数设置小一点

最终的文件设置如下:

代码如下:

PRE_SEQ_LEN=128
LR=2e-2

CUDA_VISIBLE_DEVICES=0,1 python3 main.py \
    --do_train \
    --train_file AdvertiseGen/train.json \
    --validation_file AdvertiseGen/dev.json \
    --prompt_column content \
    --response_column summary \
    --overwrite_cache \
    --model_name_or_path THUDM/chatglm-6b-int4 \
    --output_dir /kaggle/working/output/adgen-chatglm-6b-pt-$PRE_SEQ_LEN-$LR \
    --overwrite_output_dir \
    --max_source_length 64 \
    --max_target_length 64 \
    --per_device_train_batch_size 1 \
    --per_device_eval_batch_size 1 \
    --gradient_accumulation_steps 16 \
    --predict_with_generate \
    --max_steps 20 \
    --logging_steps 10 \
    --save_steps 10 \
    --learning_rate $LR \
    --pre_seq_len $PRE_SEQ_LEN \
    --quantization_bit 4

        (3)修改测试的配置文件evaluate.sh。

 

 代码如下:

PRE_SEQ_LEN=128
LR=2e-2
CHECKPOINT=adgen-chatglm-6b-pt-128-2e-2
STEP=20

CUDA_VISIBLE_DEVICES=0,1 python3 main.py \
    --do_predict \
    --validation_file AdvertiseGen/dev.json \
    --test_file AdvertiseGen/dev.json \
    --overwrite_cache \
    --prompt_column content \
    --response_column summary \
    --model_name_or_path THUDM/chatglm-6b-int4 \
    --ptuning_checkpoint /kaggle/working/output/adgen-chatglm-6b-pt-$PRE_SEQ_LEN-$LR/checkpoint-$STEP \
    --output_dir /kaggle/working/output/$CHECKPOINT \
    --overwrite_output_dir \
    --max_source_length 64 \
    --max_target_length 64 \
    --per_device_eval_batch_size 1 \
    --predict_with_generate \
    --pre_seq_len $PRE_SEQ_LEN \
    --quantization_bit 4

四、Kaggle代码运行

代码如下:

pip install -r /kaggle/input/datas-chat-glm-xiaolvs/ChatGLM-6B/requirements.txt
# 修改工作路劲
import os, sys

path = "/kaggle/input/datas-chat-glm-xiaolvs/ChatGLM-6B/ptuning"

# 查看当前工作目录
retval = os.getcwd()
print("当前工作目录为 %s" % retval)

# 修改当前工作目录
os.chdir( path )

# 查看修改后的工作目录
retval = os.getcwd()

print("目录修改成功 %s" % retval)
 # 关闭 wandb ,不然训练是会报错。
!wandb off
# 训练
! bash train.sh
# 测试
! evaluate.sh

运行过程如图:

         !!成功跑通!!

        代码在kaggle上已经公开,可以直接拷贝下来自己去运行试试。

        链接:https://www.kaggle.com/lvcuiyan/chatglm-6b-pt-xiaolvExplore and run machine learning code with Kaggle Notebooks | Using data from datas_chat_glm_xiaolvshttps://www.kaggle.com/lvcuiyan/chatglm-6b-pt-xiaolv

五、最后将模型加载到本地用CPU进行效果测试

        这里使用本地CPU进行推理。        

# !/usr/bin/env python
# -*-coding:utf-8 -*-

"""
# File       : pre.py
# Time       :2023/5/26 16:53
# Author     :小吕同学
"""
import torch
from transformers import AutoConfig, AutoModel, AutoTokenizer

# 原始模型的路径:THUDM/chatglm-6b-int4,本地文件路径也可。
model_path = r"THUDM/chatglm-6b-int4"
# 经过微调后保存的模型,这里选择最后一个模型 output\adgen-chatglm-6b-pt-128-2e-2\checkpoint-20\pytorch_model.bin
prefix_state_dict = r"\output\adgen-chatglm-6b-pt-128-2e-2\checkpoint-20\pytorch_model.bin"
# 载入Tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
config = AutoConfig.from_pretrained(model_path, trust_remote_code=True, pre_seq_len=128)
model = AutoModel.from_pretrained(model_path, config=config, trust_remote_code=True)


# ------------ vvvv 测试原本模型可以注释掉 vvvvv ------------
# prefix_state_dict = torch.load(prefix_state_dict,map_location=torch.device('cpu'))
# new_prefix_state_dict = {}
# for k, v in prefix_state_dict.items():
#     if k.startswith("transformer.prefix_encoder."):
#         new_prefix_state_dict[k[len("transformer.prefix_encoder."):]] = v
# model.transformer.prefix_encoder.load_state_dict(new_prefix_state_dict)
# ------------ ^^^^^^^^^^^^^^^^^^^^^  ------------

# 模型量化
model = model.quantize(4)
model = model.float()
model.transformer.prefix_encoder.float()
model = model.eval()
print('模型总的参数:', sum(p.numel() for p in model.parameters()))

""" 
通过设置top_p=0.01, temperature=0.01,来控制模型输出的模型,数值越小,模型越准确。
"""
response, history = model.chat(tokenizer, "类型#上衣*材质#牛仔布*颜色#白色*风格#简约*图案#刺绣*衣样式#外套*衣款式#破洞", history=[],do_sample=True, top_p=0.01, temperature=0.01)
print(response)
print(history)

# 原始数据
# {"content": "类型#上衣*材质#牛仔布*颜色#白色*风格#简约*图案#刺绣*衣样式#外套*衣款式#破洞",
# "summary": "简约而不简单的牛仔外套,白色的衣身十分百搭。衣身多处有做旧破洞设计,打破单调乏味,增加一丝造型看点。衣身后背处有趣味刺绣装饰,丰富层次感,彰显别样时尚。"}

# 没微调前的模型输出
# [('类型#上衣*材质#牛仔布*颜色#白色*风格#简约*图案#刺绣*衣样式#外套*衣款式#破洞', '类型:\n\n上衣材质:牛仔布\n颜色:白色\n风格:简约\n图案:刺绣\n衣款式:外套\n\n破洞')]

# 微调后模型的输出
# {"labels": "<image_-100> 简约而不简单的牛仔外套,白色的衣身十分百搭。衣身多处有做旧破洞设计,打破单调乏味,增加一丝造型看点。衣身后背处有趣味刺绣装饰,丰富层次感,彰显别样时尚。",
# "predict": "小清新风格的牛仔外套,简约又无味,轻松搭配牛仔裤,搭配高跟鞋,完美诠释了甜美小清新的气息。拼接关系的刺绣元素,凸显美感,破洞的设计,播放出时尚感,可爱甜美,可爱本人了。"}



代码地址:

openi:

xiaolv/chatglm_kaggle: 利用Kaggle的GPU T4*2资源进行单机双卡训练微调Chatglm模型,最后进行本地CPU运行验证。 - chatglm_kaggle - OpenI - 启智AI开源社区提供普惠算力!chatglm_kaggle - 利用Kaggle的GPU T4*2资源进行单机双卡训练微调Chatglm模型,最后进行本地CPU运行验证。https://openi.pcl.ac.cn/xiaolv/chatglm_kaggle.gitkaggle:

ChatGLM-6B-PT-xiaolv | Kaggle 代码Explore and run machine learning code with Kaggle Notebooks | Using data from datas_chat_glm_xiaolvshttps://www.kaggle.com/lvcuiyan/chatglm-6b-pt-xiaolv


参考链接:

ChatGLM-6B源码https://github.com/THUDM/ChatGLM-6B
官方微调代码https://github.com/THUDM/ChatGLM-6B/blob/main/ptuning/README.md

版权声明:本文为博主原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。
本文链接:https://blog.csdn.net/qq_72632426/article/details/130898002