※この記事はChatアプリをPythonだけで作る記事です, Open AI(Chat GPT)等の生成AIそのものの記事では有りません.
いつかは野球AIと一緒に漫才風にLTをしたいと思っている人です, ネタじゃなくてマジです*1.
以前Open AIでやるチャットアプリのバックエンドどうしよう?的な記事を書いたのですが,
- AIで何かをやるからにはバックエンドのAPIだけじゃどうしようもできない(そら、そうよ).
- Open AIやGoogle CloudのAI(まだ使ったこと無いが)で何かやるんだったらチャットもしくは検索エンジン風のUIが無いと成立しない.
- まずやりたいのはお遊びで作って動かす(コンサル風に言うとPoC)程度なので本格的なUI(フロントエンド)を作るのはちょっと🤔
と思っていた矢先, 久しぶりに彼が私の前に現れました.
彼の名はStremalit, 以前の発表やブログ(PyCon JP 2021, このブログの解説記事など)や以前の仕事(アレとコレ)で大変お世話になったStremlitです*2.
最近はDashをメインで使っていて*3何気に出番が減っていたStreamlitさんですが,
ワイ「え, Streamlitいつの間にかChatアプリ作れるようになってるやん!?」
と気が付き変な声出ました.
先週今週と勉強と個人タスクをサボる技術的なキャッチアップ新しいネタを仕入れたい一心で試してみた所,
150行ちょいのコードで意外といい感じにできてしまった(コードは公開しています)ので,
- Streamlitで作るChatアプリケーションの簡単な紹介
- 実装時のポイント
- 実際に運用する時に気をつけること
Tipsを残す観点で以上を紹介したいと思います.
TL;DR
数人程度で使うAI Chatのお試し(PoC)だったらStreamlitで十分行ける.
この記事を読む前に
前提条件として,
- Pythonを書く・動かす環境と知識があること. 知識は「独学プログラマー」に書いてあることが理解できれば問題なし.
- Jupyter Lab(notebook), PandasおよびPythonのグラフライブラリ「Plotly」*4を使ったことがあるとなお良い.
- 余談ですが, コード内のパッケージ管理はpoetryを使っています.
この程度の知識を必要とします(初心者から脱出手前の中級者ぐらいのイメージ).
ちなみにStreamlit, Pandas, Plotlyそのものは解説しません.
StreamlitのChatに集中した解説が中心となります.
本記事のコード
MITライセンスで公開しています, 好きに触ってください.
StreamlitでChat開発
事前準備およびやり方はめっちゃ簡単です.
- 必要なライブラリを入れる(自分のコードだと
poetry install
で一発です). app.py
とか適当なファイルでStreamlitのアプリを作る.streamlit run app.py
で動かして後はチャットでお遊び.
これで行けます.
なお, 私のコードはStreamlit公式のチュートリアル「Build conversational apps 」内にある「Build a simple chatbot GUI with streaming」を元に作りました.
Chatアプリの流れ
StreamlitはJupyter notebookをWebアプリとしていい感じに動かすためのFWです(雑な要約).
純然たるJupyter notebookにはChatの機能は無い為, ちょっとした設定と一工夫が必要です.
初期設定
streamlit(大抵の場合Aliasとしてst
とつける)他, 色々importしたりデータを設定しています(ソースだとここ).
import streamlit as st import time import pandas as pd import plotly.graph_objs as go from graph import bar # Dataset df_team_batting: pd.DataFrame = pd.read_csv('./data/20230930_172103_team_stats_batting.csv') df_team_pitching: pd.DataFrame = pd.read_csv('./data/20230930_172103_team_stats_pitcing.csv') # データ集計した日(固定値) STATS_DATE: str = '2023/9/30' st.set_page_config(layout="wide") st.title("阪神タイガースの優勝を知るChat AI(もどき)") # Initialize chat history if "messages" not in st.session_state: st.session_state.messages = []
importの後はDataFrame作ったり, タイトルやらページ設定やらを埋め込んでいるおまじない的なものです.
今回着目してほしいのはこちらです.
# Initialize chat history if "messages" not in st.session_state: st.session_state.messages = []
どうやらChatアプリを作る際はStreamlitのSession Stateを必ず用いる必要があるみたいです*5.
上記のコードではSession Stateを初期化しています, 何故こんな事しているかは後ほど判明します.
プロンプトを受け取る
プロンプト(自分が出した命令)の受け取りはこちらです.
prompt: str = "" # Accept user input if prompt := st.chat_input("阪神タイガースは今年優勝しましたか?"): st.session_state.messages.append({"role": "user", "content": prompt}) with st.chat_message("user"): st.markdown(prompt)
st.chat_input
(APIドキュメントはこちら)でプロンプト(prompt
)を受け取り, Session Stateに追加しています.
追加しているメッセージのスキーマは思いっきりOpen AIっぽさあります(この謎は後半に判明します).
その後, st.chat_message
(APIドキュメント)で表示しているような動きになっています.
ちなみに入力値のチェック・Validationをする場合はpromptを受け取ってすぐ, st.session_state.messages.append({"role": "user", "content": prompt})
の前でやる必要がありそうです(試していませんが).
会話する
肝心の会話ですが...私のサンプルではif文で場合分けしています(コードだとここ).
# Display assistant response in chat message container with st.chat_message("assistant"): message_placeholder = st.empty() full_response = "" assistant_response = "" fig = None if prompt.startswith("阪神タイガースは今年"): assistant_response = "優勝しましたなー" elif prompt.startswith("本当ですか"): assistant_response = "本当やで" elif prompt.startswith("どうして優勝したの"): assistant_response = "たくさんお散歩して、相手のお散歩の邪魔をしたからや" elif prompt.startswith("証拠見せてよ"): assistant_response = "これは打者のデータなんだけど、どのチームより沢山散歩しているのがわかるかな" fig = bar( df=df_team_batting.sort_values(['pa_bb'], ascending=False), x=['pa_k', 'pa_bb'], y='team', x_dtick=2, x_range=[0, 18], title=f'【チーム打撃成績】三振および四球獲得までの平均打席数 ※{STATS_DATE}集計', x_title='pa_k(三振するまでの平均打席数), ab_bb(四球獲得までの平均打席数)', y_title='チーム名(pa_bbの昇順)' ) elif prompt.startswith("すごいなー、ピッチャーは"): assistant_response = "投手は逆に相手の打者に四球、すなわちお散歩を許していないんだ" fig = bar( df=df_team_pitching.sort_values(['bb_p'], ascending=False), x=['so_p', 'bb_p'], y='team', x_dtick=1, x_range=[0, 9], title=f"【チーム投手成績】奪三振率と与四球率の比較 ※{STATS_DATE}集計", x_title='so_p(奪三振率), bb_p(与四球率)', y_title='チーム名(与四球率の昇順)' ) elif prompt.startswith("CSと日本シリーズも期待していいかな?"): assistant_response = "知らんがな" elif prompt.startswith("もうええわ"): assistant_response = "どうも、ありがとうございましたー" else: assistant_response = "なんでや!阪神!!関係ないやろ!!!" for chunk in assistant_response.split(): full_response += chunk + " " time.sleep(0.05) message_placeholder.markdown(full_response + "▌") message_placeholder.markdown(full_response) if fig: # Display chart message_placeholder = st.empty() message_placeholder.plotly_chart(fig, use_container_width=True)
やってることは超シンプルで, prompt
の先頭一致(String startwith
)に合わせて喋りたい事を変数に突っ込み,
message_placeholder.markdown(full_response)
このタイミングでつぶやかせています(コードだとここ), ちなみにmessage_placeholder
はst.emptyで作られた空のコンテナ(入れ物)です.
グラフを描画する
テキストはシンプルな仕組みかつ, 「Build a simple chatbot GUI with streaming」の通りにやったらできましたが, グラフは案外難産でした(めっちゃハマった*6).
書き方のお作法から言うと,
- st.emptyで空のコンテナを作る(テキスト会話と同じ).
- グラフオブジェクト(Plotlyの
go.Figure
)をコンテナに突っ込む.
elif prompt.startswith("証拠見せてよ"): assistant_response = "これは打者のデータなんだけど、どのチームより沢山散歩しているのがわかるかな" fig = bar( df=df_team_batting.sort_values(['pa_bb'], ascending=False), x=['pa_k', 'pa_bb'], y='team', x_dtick=2, x_range=[0, 18], title=f'【チーム打撃成績】三振および四球獲得までの平均打席数 ※{STATS_DATE}集計', x_title='pa_k(三振するまでの平均打席数), ab_bb(四球獲得までの平均打席数)', y_title='チーム名(pa_bbの昇順)' ) # (中略) if fig: # Display chart message_placeholder = st.empty() message_placeholder.plotly_chart(fig, use_container_width=True)
コンテナに突っ込む以外は通常のstreamlitアプリを作る時と同じといえば同じです.
セッションに保存する
実は上記だけだと会話は保存されず消えちゃうので, Session Stateに保存します(コードだとここ).
# Add session state st.session_state.messages.append({"role": "assistant", "content": full_response}) if fig: st.session_state.messages.append({"role": "assistant", "content": fig})
最初に初期化したSession Stateに足すことで(ブラウザのセッションが有効な間)保存がされます.
再描画
どうやらStreamlitの動き的に, Session Stateを使うと再描画が必要っぽいです(デバッグしてたらわかった).
以下のコードでSession Stateの中身を描画しています(ここです).
# Display chat messages from history on app rerun for message in st.session_state.messages: with st.chat_message(message["role"]): if type(message["content"]) == go.Figure: # Display chart st.plotly_chart(message["content"], use_container_width=True) else: # Display chat st.markdown(message["content"])
Streamlitのアプリ領域(多分グローバル変数っぽい所, 理解違ってたらすいません)に書き込む動きをしないと行けない模様です.
ここでもグラフ描画. 具体的にはPlotlyのgo.Figure
の描画でハマりました(st.markdown
だと変数の中身を書き出してしまう).
使いたい命令はst.plotly_chart
だったので,
if type(message["content"]) == go.Figure: # Display chart st.plotly_chart(message["content"], use_container_width=True)
オブジェクトの型チェックで切り抜けました, ここが一番時間が溶けたかも...
とりあえずこれでグラフ描画付きのチャットは作れました.
応用的な使い方
先に言っておくとまだ試していません.
が, 公式ドキュメントに乗っていたのでちょこっと応用的な内容として残します.
Open AIを使う
Open AIを使う場合は...なんと, 公式がチュートリアルを用意してくれているのでこれに従うと良さそうです.
やり方はシンプルで, Open AIのAPIキーを定義, Session Stateに保存したメッセージを丸投げで終わるみたいです.
for response in openai.ChatCompletion.create( model=st.session_state["openai_model"], messages=[{"role": m["role"], "content": m["content"]} for m in st.session_state.messages], stream=True, ): full_response += response.choices[0].delta.get("content", "") message_placeholder.markdown(full_response + "▌") message_placeholder.markdown(full_response) st.session_state.messages.append({"role": "assistant", "content": full_response})
Session Stateのメッセージ定義をOpen AIに揃えているのはこのためだったか...
Chatの記録を保存する
Streamlitのチュートリアルも私のコードもそうですが, ブラウザをリロードするとChatの会話歴が消えます.
Chatの会話歴はSession Stateのmessagesに残っているので,
- Session Stateそのものをシリアライズして適当に保存する(Amazon S3等のストレージやFirebase Store等のDBなど). Session Stateのシリアライズ/デシリアライズはできるので, そこを保存して読み込めば良さげ.
- ブラウザのSession Storage等に残す. ただ, 公式ではこの手段は用意されていないっぽい?探しても出てこなかった&ご存じの方教えてください🙏
なお, どっちもそこそこリスクが有るやり方(どっちもコード書きにくそう)なので,
- 外部で何かしらのユーザー認証の手段を作る(入れる).
- ユーザーごとに会話歴を残せる&取得できるようなバックエンドを別に用意する.
- Streamlit Chatアプリを認証付きで入れさせて裏で会話歴を保存/読み込むバックエンドを呼ぶ(Session Stateの値をコントロールする).
のが急がば回れ, 的な方法で確実かなと思いました.
ここまでやるならStremlitでお茶を濁すのではなく, ガチでChatアプリ作っても良い気がしますが笑.
結び
というわけで, 「Streamlitでチャットアプリを作りつつ, 一緒に漫才をする」方法を紹介しました.
TL;DRに書いた通り,
数人程度で使うAI Chatのお試し(PoC)だったらStreamlitで十分行ける.
これは手応えとして感じました, 書き方とかロジックはちょっと独特ですが実用には十分耐えそうですPoCなら*7.
せっかく作ったのでこのブログとコードの公開という形で残しましたが, 今度どこかでこれを使ったLTしたいと思いました*8.
最後までお読み頂きありがとうございました.
Appendix
コード全体像
app.py
import streamlit as st import time import pandas as pd import plotly.graph_objs as go from graph import bar # Dataset df_team_batting: pd.DataFrame = pd.read_csv('./data/20230930_172103_team_stats_batting.csv') df_team_pitching: pd.DataFrame = pd.read_csv('./data/20230930_172103_team_stats_pitcing.csv') # データ集計した日(固定値) STATS_DATE: str = '2023/9/30' st.set_page_config(layout="wide") st.title("阪神タイガースの優勝を知るChat AI(もどき)") # Initialize chat history if "messages" not in st.session_state: st.session_state.messages = [] # Display chat messages from history on app rerun for message in st.session_state.messages: with st.chat_message(message["role"]): if type(message["content"]) == go.Figure: # Display chart st.plotly_chart(message["content"], use_container_width=True) else: # Display chat st.markdown(message["content"]) prompt: str = "" # Accept user input if prompt := st.chat_input("阪神タイガースは今年優勝しましたか?"): st.session_state.messages.append({"role": "user", "content": prompt}) with st.chat_message("user"): st.markdown(prompt) # Display assistant response in chat message container with st.chat_message("assistant"): message_placeholder = st.empty() full_response = "" assistant_response = "" fig = None if prompt.startswith("阪神タイガースは今年"): assistant_response = "優勝しましたなー" elif prompt.startswith("本当ですか"): assistant_response = "本当やで" elif prompt.startswith("どうして優勝したの"): assistant_response = "たくさんお散歩して、相手のお散歩の邪魔をしたからや" elif prompt.startswith("証拠見せてよ"): assistant_response = "これは打者のデータなんだけど、どのチームより沢山散歩しているのがわかるかな" fig = bar( df=df_team_batting.sort_values(['pa_bb'], ascending=False), x=['pa_k', 'pa_bb'], y='team', x_dtick=2, x_range=[0, 18], title=f'【チーム打撃成績】三振および四球獲得までの平均打席数 ※{STATS_DATE}集計', x_title='pa_k(三振するまでの平均打席数), ab_bb(四球獲得までの平均打席数)', y_title='チーム名(pa_bbの昇順)' ) elif prompt.startswith("すごいなー、ピッチャーは"): assistant_response = "投手は逆に相手の打者に四球、すなわちお散歩を許していないんだ" fig = bar( df=df_team_pitching.sort_values(['bb_p'], ascending=False), x=['so_p', 'bb_p'], y='team', x_dtick=1, x_range=[0, 9], title=f"【チーム投手成績】奪三振率と与四球率の比較 ※{STATS_DATE}集計", x_title='so_p(奪三振率), bb_p(与四球率)', y_title='チーム名(与四球率の昇順)' ) elif prompt.startswith("CSと日本シリーズも期待していいかな?"): assistant_response = "知らんがな" elif prompt.startswith("もうええわ"): assistant_response = "どうも、ありがとうございましたー" else: assistant_response = "なんでや!阪神!!関係ないやろ!!!" for chunk in assistant_response.split(): full_response += chunk + " " time.sleep(0.05) message_placeholder.markdown(full_response + "▌") message_placeholder.markdown(full_response) if fig: # Display chart message_placeholder = st.empty() message_placeholder.plotly_chart(fig, use_container_width=True) # Add session state st.session_state.messages.append({"role": "assistant", "content": full_response}) if fig: st.session_state.messages.append({"role": "assistant", "content": fig})
graph.py(グラフ描画)
from typing import List, Tuple import plotly.express as px import plotly.graph_objs as go import pandas as pd WIDTH = 1400 HEIGHT = 700 def bar(df: pd.DataFrame, x: List[str], y: str, title: str, x_title: str, y_title: str, x_dtick: int, x_range: Tuple[int], autosize: bool = True) -> go.Figure: fig = px.bar(df, x=x, y=y, barmode='group', ) if autosize: # 画面に合わせるパターン fig.update_layout( title=title, autosize=True, legend_title=None ) else: # 固定長 fig.update_layout( title=title, width=WIDTH, height=HEIGHT, legend_title=None ) fig.update_yaxes(title=y_title) fig.update_xaxes(title=x_title, dtick=x_dtick, range=x_range) return fig
参考文献
StreamlitのChatAPI公式document
公式のチュートリアル
*1:余談ですがPyCon JPで一度だけ二人登壇はやったことあります, 漫才っていうよりジョイントコント風の出来でしたが笑
*2:前職における仕事では間違いなくMVP級に活躍してくれたライブラリでした, まじで感謝しています.
*3:なお, DashとStreamlitはどっちもLow CodeなWebアプリ風ダッシュボードを指向しており被りそうに見えますが, 自分としては微妙に用途が異なる認識でして...って記事はいつか書きたい(使い分けしてます私は).
*4:Plotly以外のライブラリでもいいかもしれません, 私はPlotlyが好きなのでそうしました.
*5:結構Streamlitを使ってた私もこの機能は初めて知りました&読んだり書いたりするうちにこれは結構キモとなる機能だと改めて理解.
*6:なんせ, らしい解説が無かったので. 故にこのブログを書いています.
*7:なお, 長く使ってるとブラウザのメモリを結構食う感じあるので常時利用は無理かもしれません.
*8:先日のPyLadies Tokyo 9周年パーティーに間に合うように頑張りましたが, 間に合わなかったのでこのブログで供養しています.