from transformers import AutoModel, AutoTokenizer
import gradio as gr
import mdtex2html
tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True)
model = AutoModel.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True).half().cuda()
model = model.eval()
"""Override Chatbot.postprocess"""
def postprocess(self, y):
if y is None:
return []
for i, (message, response) in enumerate(y):
y[i] = (
None if message is None else mdtex2html.convert((message)),
None if response is None else mdtex2html.convert(response),
)
return y
gr.Chatbot.postprocess = postprocess
def parse_text(text):
"""copy from https://github.com/GaiZhenbiao/ChuanhuChatGPT/"""
lines = text.split("\n")
lines = [line for line in lines if line != ""]
count = 0
for i, line in enumerate(lines):
if "```" in line:
count += 1
items = line.split('`')
if count % 2 == 1:
lines[i] = f'<pre><code class="language-{items[-1]}">'
else:
lines[i] = f'<br></code></pre>'
else:
if i > 0:
if count % 2 == 1:
line = line.replace("`", "\`")
line = line.replace("<", "<")
line = line.replace(">", ">")
line = line.replace(" ", " ")
line = line.replace("*", "*")
line = line.replace("_", "_")
line = line.replace("-", "-")
line = line.replace(".", ".")
line = line.replace("!", "!")
line = line.replace("(", "(")
line = line.replace(")", ")")
line = line.replace("$", "$")
lines[i] = "<br>"+line
text = "".join(lines)
return text
def predict(input, chatbot, max_length, top_p, temperature, history):
chatbot.append((parse_text(input), ""))
for response, history in model.stream_chat(tokenizer, input, history, max_length=max_length, top_p=top_p,
temperature=temperature):
chatbot[-1] = (parse_text(input), parse_text(response))
yield chatbot, history
def reset_user_input():
return gr.update(value='')
def reset_state():
return [], []
with gr.Blocks() as demo:
gr.HTML("""<h1 align="center">ChatGLM</h1>""")
chatbot = gr.Chatbot()
with gr.Row():
with gr.Column(scale=4):
with gr.Column(scale=12):
user_input = gr.Textbox(show_label=False, placeholder="Input...", lines=10).style(
container=False)
with gr.Column(min_width=32, scale=1):
submitBtn = gr.Button("Submit", variant="primary")
with gr.Column(scale=1):
emptyBtn = gr.Button("Clear History")
max_length = gr.Slider(0, 4096, value=2048, step=1.0, label="Maximum length", interactive=True)
top_p = gr.Slider(0, 1, value=0.7, step=0.01, label="Top P", interactive=True)
temperature = gr.Slider(0, 1, value=0.95, step=0.01, label="Temperature", interactive=True)
history = gr.State([])
submitBtn.click(predict, [user_input, chatbot, max_length, top_p, temperature, history], [chatbot, history],
show_progress=True)
submitBtn.click(reset_user_input, [], [user_input])
emptyBtn.click(reset_state, outputs=[chatbot, history], show_progress=True)
demo.queue().launch(share=False, inbrowser=True)
このコードの主な目的は、事前トレーニングされた ChatGPT モデル (「THUDM/chatglm-6b」) を使用して、対話型の Web ベースのチャットボットを構築することです。コード内の各行またはコード ブロックの説明は次のとおりです。
-
必要なライブラリをインポートします。
pythonコードをコピー
from transformers import AutoModel, AutoTokenizer
-
import gradio as gr
-
import mdtex2html
transformers
は、事前トレーニングされたモデルを多数提供する NLP ライブラリです。AutoModel
と は、対応するモデルとトークナイザーをAutoTokenizer
自動的にロードするために使用されるメソッドです。インタラクティブな UI を構築するためのライブラリですが、マークダウンを HTML に変換するためのツールです。gradio
mdtex2html
-
事前トレーニングされたモデルとトークナイザーをロードします。
pythonコードをコピー
tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True) model = AutoModel.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True).half().cuda() model = model.eval()
これらのコード行は、「THUDM/chatglm-6b」という名前の事前トレーニング済みモデルとそのトークナイザーを読み込みます。
half()
このメソッドは、モデルのデータ型を半精度浮動小数点に変換し、GPU での計算を高速化します。cuda()
モデルを GPU に移動するメソッド。model.eval()
モデルを評価モードに設定します。
-
書き換え
Chatbot.postprocess
方法:pythonコードをコピー
def postprocess(self, y): if y is None: return [] for i, (message, response) in enumerate(y): y[i] = ( None if message is None else mdtex2html.convert((message)), None if response is None else mdtex2html.convert(response), ) return y gr.Chatbot.postprocess = postprocess
コード
gr.Chatbot
のこの部分postprocess
は、入力メッセージとロボットの応答をマークダウン形式から HTML 形式に変換するメソッドを書き換えます。 -
parse_text
関数を定義します。pythonコードをコピー
def parse_text(text): ... return text
特定の文字の置換やHTMLタグの追加など、入力テキストを処理する機能です。
-
predict
関数を定義します。pythonコードをコピー
def predict(input, chatbot, max_length, top_p, temperature, history): ...
この関数は、事前トレーニングされたモデルを使用してチャットボット応答を生成します。入力テキスト、チャット履歴、生成の最大長、生成確率のしきい値、生成の「温度」などのパラメーターを受け取ります。
-
定義
reset_user_input
とreset_state
機能:pythonコードをコピー
def reset_user_input(): return gr.update(value='') def reset_state(): return [], []
これらの関数は、ユーザーの入力とチャットボットの状態をリセットするために使用されます。
-
Gradio UI インターフェイスを作成します。
pythonコードをコピー
with gr.Blocks() as demo: ... demo.queue().launch(share=False, inbrowser
-
Gradio UI インターフェイスを作成します (続き):
pythonコードをコピー
with gr.Blocks() as demo: gr.HTML("""<h1 align="center">ChatGLM</h1>""") chatbot = gr.Chatbot() with gr.Row(): with gr.Column(scale=4): with gr.Column(scale=12): user_input = gr.Textbox(show_label=False, placeholder="Input...", lines=10).style(container=False) with gr.Column(min_width=32, scale=1): submitBtn = gr.Button("Submit", variant="primary") with gr.Column(scale=1): emptyBtn = gr.Button("Clear History") max_length = gr.Slider(0, 4096, value=2048, step=1.0, label="Maximum length", interactive=True) top_p = gr.Slider(0, 1, value=0.7, step=0.01, label="Top P", interactive=True) temperature = gr.Slider(0, 1, value=0.95, step=0.01, label="Temperature", interactive=True) history = gr.State([]) submitBtn.click(predict, [user_input, chatbot, max_length, top_p, temperature, history], [chatbot, history], show_progress=True) submitBtn.click(reset_user_input, [], [user_input]) emptyBtn.click(reset_state, outputs=[chatbot, history], show_progress=True) demo.queue().launch(share=False, inbrowser=True)
コードのこの部分では、Gradio ライブラリを使用して、入力ボックス (
user_input
)、送信ボタン (submitBtn
)、クリア ボタン (emptyBtn
)、およびモデルを制御するためのいくつかのスライダーを含むWeb インターフェイスを作成します。インターフェイスでテキストを入力し、送信ボタンをクリックして、ボットに応答を生成させることができます。[クリア] ボタンは、チャット履歴とモデルのステータスをクリアするために使用します。history = gr.State([])
このコード行は、チャット履歴を保存するためのState オブジェクトを作成します。submitBtn.click(predict, [user_input, chatbot, max_length, top_p, temperature, history], [chatbot, history], show_progress=True)
このコード行は、送信ボタンのクリック イベントをpredict
関数に接続します。送信ボタンをクリックすると、predict
関数が呼び出され、入力ボックス内のテキスト、チャットボット オブジェクト、スライダーの値、チャット履歴がパラメーターとして渡されます。submitBtn.click(reset_user_input, [], [user_input])
このコード行は、クリック イベントとreset_user_input
送信ボタンの機能を接続するため、送信ボタンをクリックすると、入力ボックス内のテキストがクリアされます。emptyBtn.click(reset_state, outputs=[chatbot, history], show_progress=True)
このコード行は、クリア ボタンのクリック イベントをreset_state
関数に接続するため、クリア ボタンをクリックすると、チャットボットの履歴とステータスがクリアされます。demo.queue().launch(share=False, inbrowser=True)
このコード行は、インターフェースをローカル Web サーバーにデプロイし、ブラウザーで開きます。share=False
これは、このインターフェイスを他のユーザーと共有しないことを意味し、このインターフェイスを現在のブラウザで開くことをinbrowser=True
意味します。
コード全体を説明したので、よりよく理解していただくために、このコードの主な機能をもう一度説明します。
このスクリプトは、ユーザーが Web インターフェイス上で操作できる事前トレーニング済みモデル「THUDM/chatglm-6b」を使用してチャットボットを作成します。ユーザーがメッセージを入力すると、ボットがメッセージの内容に基づいて応答を生成します。
具体的には、ユーザーがインターフェイス上で実行できる操作は次のとおりです。
-
メッセージを入力する: ユーザーはテキスト ボックスにメッセージを入力し、[送信] ボタンをクリックしてメッセージをボットに送信できます。
-
モデルの制御: ユーザーは、スライダーを使用して、世代の最大長 (「最大長」)、世代の確率しきい値 (「トップ P」 )、「温度」などのモデルのいくつかのパラメーターを制御できます。世代(「温度」)の。
-
履歴のクリア: ユーザーは「履歴のクリア」ボタンをクリックして、チャット履歴とモデルのステータスをクリアできます。
さらに、スクリプトは、入出力テキストをマークダウン形式から HTML形式に変換したり、応答の生成後に入力ボックスを自動的にクリアしたりするなど、追加の処理も実行します。