from transformers import AutoModel, AutoTokenizer
import gradio as gr
import mdtex2html
tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm2-6b", trust_remote_code=True)
model = AutoModel.from_pretrained("THUDM/chatglm2-6b", trust_remote_code=True).cuda()
model = model.eval()
"""Override Chatbot.postprocess"""
def postprocess(self, y):
if y is None:
return []
for i, (message, response) in enumerate(y):
y[i] = (
None if message is None else mdtex2html.convert((message)),
None if response is None else mdtex2html.convert(response),
)
return y
gr.Chatbot.postprocess = postprocess
def parse_text(text):
"""copy from https://github.***/GaiZhenbiao/ChuanhuChatGPT/"""
lines = text.split("\n")
lines = [line for line in lines if line != ""]
count = 0
for i, line in enumerate(lines):
if "```" in line:
count += 1
items = line.split('`')
if count % 2 == 1:
lines[i] = f'<pre><code class="language-{items[-1]}">'
else:
lines[i] = f'<br></code></pre>'
else:
if i > 0:
if count % 2 == 1:
line = line.replace("`", "\`")
line = line.replace("<", "<")
line = line.replace(">", ">")
line = line.replace(" ", " ")
line = line.replace("*", "*")
line = line.replace("_", "_")
line = line.replace("-", "-")
line = line.replace(".", ".")
line = line.replace("!", "!")
line = line.replace("(", "(")
line = line.replace(")", ")")
line = line.replace("$", "$")
lines[i] = "<br>"+line
text = "".join(lines)
return text
def predict(input, chatbot, max_length, top_p, temperature, history, past_key_values):
chatbot.append((parse_text(input), ""))
for response, history, past_key_values in model.stream_chat(tokenizer, input, history, past_key_values=past_key_values,
return_past_key_values=True,
max_length=max_length, top_p=top_p,
temperature=temperature):
chatbot[-1] = (parse_text(input), parse_text(response))
yield chatbot, history, past_key_values
def reset_user_input():
return gr.update(value='')
def reset_state():
return [], [], None
with gr.Blocks() as demo:
gr.HTML("""<h1 align="center">ChatGLM2-6B</h1>""")
chatbot = gr.Chatbot()
with gr.Row():
with gr.Column(scale=4):
with gr.Column(scale=12):
user_input = gr.Textbox(show_label=False, placeholder="Input...", lines=10).style(
container=False)
with gr.Column(min_width=32, scale=1):
submitBtn = gr.Button("Submit", variant="primary")
with gr.Column(scale=1):
emptyBtn = gr.Button("Clear History")
max_length = gr.Slider(0, 32768, value=8192, step=1.0, label="Maximum length", interactive=True)
top_p = gr.Slider(0, 1, value=0.8, step=0.01, label="Top P", interactive=True)
temperature = gr.Slider(0, 1, value=0.95, step=0.01, label="Temperature", interactive=True)
history = gr.State([])
past_key_values = gr.State(None)
submitBtn.click(predict, [user_input, chatbot, max_length, top_p, temperature, history, past_key_values],
[chatbot, history, past_key_values], show_progress=True)
submitBtn.click(reset_user_input, [], [user_input])
emptyBtn.click(reset_state, outputs=[chatbot, history, past_key_values], show_progress=True)
demo.queue().launch(share=False, inbrowser=True)
这是一个使用Hugging Face的transformers库和gradio库搭建的机器人聊天程序的Python代码,模型预训练的版本为"THUDM/chatglm2-6b",而且使用了mdtex2html来把markdown格式转化为HTML格式。我会逐行解释它的含义:
1-3: 导入必要的模块。transformers
是Hugging Face发布的处理NLP任务的库,提供了大量的预训练模型和相关工具。gradio
是一个创建快速demo用户界面的库,可以让模型快速地在web页面上进行展示和交互。mdtex2html
用来将markdown格式的文本转化为HTML格式。
5-6: 从 Hugging Face 模型库中加载了一个预训练的模型及其相应的词条化器(tokenizer),模型的名字是 "THUDM/chatglm2-6b"。trust_remote_code=True
是指示允许从远程代码中加载特殊的tokenizer或model代码。
7: 把模型放在CUDA上,这会让模型运行在GPU上,加快运算速度。
8: 把模型设定为评估模式(eval mode),在这种模式下,模型中的某些层(如 Dropout 或 BatchNorm)的行为会与训练模式有所不同。
10-19: 定义了一个postprocess
函数,它是为了对模型的输出进行后处理。在这个函数中,它会将每一条聊天信息和相应的回复转化为HTML格式。它将此函数绑定到gradio的Chatbot类的后处理函数上。
21-42: parse_text
函数用于处理用户的输入。它主要是处理markdown的代码部分,并把它们转化为HTML格式。
44-52: predict
函数接受用户的输入和一些模型参数,它首先把用户的输入添加到聊天记录中,然后用模型生成回复,最后返回聊天记录和新的历史记录。
54-56: 定义了两个函数来重置用户的输入和模型的状态。
58-89: 在这部分中,它用gradio
创建了一个用户界面。这个界面包括一个聊天机器人,一个用户输入框,一个提交按钮,一个清空历史按钮,以及一些滑块用于控制模型的参数。
91: 启动这个用户界面。share=False
表示这个界面不会被分享,inbrowser=True
表示这个界面会在浏览器中打开。
总的来说,这是一个使用预训练的模型来生成对话的聊天机器人的程序,它有一个简单的用户界面,用户可以通过这个界面和机器人进行交流。
我会继续分析此代码中的一些关键部分:
AutoTokenizer.from_pretrained
和 AutoModel.from_pretrained
方法从预训练的模型库中加载了一个模型及其对应的词条化器。它们能够自动地识别模型的类型,并加载相应的模型和词条化器。
model.cuda()
将模型加载到GPU上进行计算。这个操作是在PyTorch中进行的,其目的是利用GPU进行更快的计算。这只有在你的机器上有可用的Nvidia GPU,并且你的PyTorch版本支持CUDA时才会有效。
gr.Chatbot.postprocess = postprocess
这行代码将定义的postprocess
函数绑定到gr.Chatbot
类的postprocess
方法。这使得我们可以修改gr.Chatbot
的行为,以便于使用mdtex2html.convert
函数处理消息和回应,从而将markdown格式的文本转化为HTML格式。
函数 parse_text
是用于处理输入的文本,主要是处理markdown的代码部分,并把它们转化为HTML格式。
函数 predict
是这个程序的核心部分。它使用了模型进行预测,生成聊天机器人的回应。
最后,这个程序使用gradio
库来创建一个用户界面。这个界面包括一个聊天机器人,一个用户输入框,一个提交按钮,一个清空历史按钮,以及一些滑块用于控制模型的参数。用户可以通过这个界面和聊天机器人进行交流。
整体来看,这个程序是一个基于预训练模型和gradio库的聊天机器人。它能够处理用户的输入,生成聊天机器人的回应,并通过用户界面和用户进行交互。
。
reset_user_input
和reset_state
函数被设计用来重置用户输入和清空模型的状态。它们主要被绑定到了前端的按钮上,当用户点击这些按钮时,会触发相应的函数。
接下来,这个脚本使用gr.Blocks()
来创建一个应用的界面。在这个界面中,gr.Chatbot()
实例化了一个聊天机器人,gr.Textbox()
实例化了一个输入框,用户可以在里面输入文本,然后点击gr.Button()
实例化的提交按钮,这会触发绑定到按钮上的函数,把用户的输入提交到聊天机器人,并接收聊天机器人的回应。另外,gr.Slider()
实例化了一些滑块,用于调整模型的参数。
gr.State([])
和gr.State(None)
实例化了两个状态对象,这些对象用于保存聊天的历史记录和模型的内部状态。这些状态在每次预测时都会更新。
然后,submitBtn.click()
和emptyBtn.click()
将前面定义的函数绑定到了按钮的点击事件上。当用户点击这些按钮时,会触发绑定到按钮上的函数。
最后,demo.queue().launch(share=False, inbrowser=True)
启动了这个应用。用户可以在浏览器中访问这个应用,并与聊天机器人进行交互。
这个脚本实现了一个完整的基于预训练模型的聊天机器人应用,包括用户界面和模型的预测逻辑。它展示了如何使用Hugging Face的transformers库和gradio库来实现一个可交互的机器学习应用。