From 794502df696300d4f9cc6b0d7bd4eedeb6cef750 Mon Sep 17 00:00:00 2001 From: wangjiacai Date: Mon, 3 Apr 2023 01:00:24 +0800 Subject: [PATCH] implement chat --- config.ini | 5 +- project/__init__.py | 4 ++ project/main.py | 57 ++++++++++++++++++-- project/models.py | 11 ++++ project/templates/base.html | 8 ++- project/templates/chat.html | 101 ++++++++++++++++++++++++++++++++++++ requirements.txt | 2 + 7 files changed, 178 insertions(+), 10 deletions(-) create mode 100644 project/templates/chat.html diff --git a/config.ini b/config.ini index c892d48..27a2d63 100644 --- a/config.ini +++ b/config.ini @@ -7,6 +7,7 @@ SQLALCHEMY_DATABASE_URI=sqlite:///sqlite.db [network] PROXY=http://127.0.0.1:7890 -[gpt] -SECRET_KEY= +[openai] +API_KEY= MODEL_NAME=gpt-3.5-turbo +PROMPT=You are a helpful assistant diff --git a/project/__init__.py b/project/__init__.py index 82fb7d4..a82327b 100644 --- a/project/__init__.py +++ b/project/__init__.py @@ -14,6 +14,10 @@ def create_app(): app.config['SECRET_KEY'] = conf['app']['SECRET_KEY'] app.config['SQLALCHEMY_DATABASE_URI'] = conf['app']['SQLALCHEMY_DATABASE_URI'] + app.config['NETWORK_PROXY'] = conf['network']['PROXY'] + app.config['OPENAI_API_KEY'] = conf['openai']['API_KEY'] + app.config['OPENAI_MODEL_NAME'] = conf['openai']['MODEL_NAME'] + app.config['OPENAI_PROMPT'] = conf['openai']['PROMPT'] db.init_app(app) login_manager = LoginManager() diff --git a/project/main.py b/project/main.py index 500fc1f..047677f 100644 --- a/project/main.py +++ b/project/main.py @@ -1,6 +1,8 @@ -from flask import Blueprint, render_template, request, flash, redirect, url_for +from flask import Blueprint, render_template, current_app, request, flash, redirect, url_for from flask_login import login_required, current_user, login_manager -from .models import User +from .models import User, Conversation +from . import db +import openai main = Blueprint('main', __name__) @@ -40,4 +42,53 @@ def manage(): @main.route('/chat') @login_required def chat(): - return "暂未实现" + if current_user.isActivated: + return render_template('chat.html', user=current_user) + else: + flash("您的账户暂未被激活") + return redirect(url_for('main.index')) + return redirect(url_for('main.index')) + + +@main.route('/chat', methods=['POST']) +@login_required +def chat_post(): + openai.api_key = current_app.config['OPENAI_API_KEY'] + openai.proxy = current_app.config['NETWORK_PROXY'] + + msg = request.form.get("msg") + + new_conversation = Conversation(userid=current_user.id, + useremail=current_user.email, + username=current_user.name, + request=msg, + response="") + db.session.add(new_conversation) + db.session.commit() + + if current_user.is_authenticated and current_user.isActivated: + openai_resp = openai.ChatCompletion.create( + model=current_app.config['OPENAI_MODEL_NAME'], + messages=[ + {"role": "system", + "content": current_app.config['OPENAI_PROMPT']}, + {"role": "user", "content": msg} + ] + ) + msg_resp = openai_resp['choices'][0]['message']['content'] + if msg_resp: + response = {"message": msg_resp, "status": "success"} + else: + response = {"message": "请求错误", "status": "success"} + else: + response = {"message": "请先登录/激活", "status": "fail"} + + new_conversation = Conversation(userid=current_user.id, + useremail=current_user.email, + username=current_user.name, + request=msg, + response=msg_resp) + db.session.add(new_conversation) + db.session.commit() + + return response diff --git a/project/models.py b/project/models.py index c2e1dc0..ea04e47 100644 --- a/project/models.py +++ b/project/models.py @@ -1,5 +1,6 @@ from flask_login import UserMixin from . import db +from sqlalchemy.sql import func class User(UserMixin, db.Model): @@ -10,3 +11,13 @@ class User(UserMixin, db.Model): name = db.Column(db.String(100), nullable=False) role = db.Column(db.String(100), nullable=False) isActivated = db.Column(db.Boolean, nullable=False) + + +class Conversation(db.Model): + id = db.Column(db.Integer, primary_key=True) + userid = db.Column(db.Integer) + useremail = db.Column(db.String(100), nullable=False) + username = db.Column(db.String(100), nullable=False) + request = db.Column(db.String(10000)) + response = db.Column(db.String(10000)) + datetime = db.Column(db.DateTime, server_default=func.now()) diff --git a/project/templates/base.html b/project/templates/base.html index 4cf96bd..c8813a9 100644 --- a/project/templates/base.html +++ b/project/templates/base.html @@ -49,11 +49,9 @@ -
-
- {% block content %} - {% endblock content %} -
+
+ {% block content %} + {% endblock content %}
+ + {% else %} +

您的账号暂未激活,请等待管理员激活此账号。

+ {% endif %} + {% else %} + + + + 或 + + + + {% endif %} + {% with messages = get_flashed_messages() %} + {% if messages %}
{{ messages[0] }}
{% endif %} + {% endwith %} +{% endblock content %} diff --git a/requirements.txt b/requirements.txt index 1520d25..7205a12 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,6 @@ Flask==2.2.3 Flask_Login==0.6.2 flask_sqlalchemy==3.0.3 +openai==0.27.2 +SQLAlchemy==2.0.7 Werkzeug==2.2.3