1039 lines
38 KiB
Python
1039 lines
38 KiB
Python
from channels.layers import get_channel_layer
|
|
from channels.db import database_sync_to_async
|
|
from rest_framework_simplejwt.views import TokenObtainPairView
|
|
from rest_framework_simplejwt.tokens import RefreshToken
|
|
from rest_framework import permissions, status
|
|
from .serializers import (
|
|
MyTokenObtainPairSerializer,
|
|
CustomUserSerializer,
|
|
BasicUserSerializer,
|
|
AnnouncmentSerializer,
|
|
CompanySerializer,
|
|
ConversationSerializer,
|
|
PromptSerializer,
|
|
FeedbackSerializer,
|
|
DocumentWorkspaceSerializer,
|
|
DocumentSerializer
|
|
)
|
|
from rest_framework.views import APIView
|
|
from rest_framework.response import Response
|
|
from .models import (
|
|
CustomUser,
|
|
Announcement,
|
|
Conversation,
|
|
Prompt,
|
|
Feedback,
|
|
PromptMetric,
|
|
DocumentWorkspace,
|
|
Document
|
|
)
|
|
from django.views.decorators.cache import never_cache
|
|
from django.http import JsonResponse
|
|
from datetime import datetime
|
|
from .client import LlamaClient
|
|
from asgiref.sync import sync_to_async, async_to_sync
|
|
from channels.generic.websocket import AsyncWebsocketConsumer
|
|
from langchain_ollama.llms import OllamaLLM
|
|
from langchain_core.prompts import ChatPromptTemplate
|
|
from langchain_core.messages import HumanMessage, SystemMessage
|
|
from langchain.chains import RetrievalQA
|
|
import re
|
|
import os
|
|
from django.conf import settings
|
|
import json
|
|
import base64
|
|
import pandas as pd
|
|
|
|
# For email support
|
|
from django.core.mail import EmailMultiAlternatives
|
|
from django.template.loader import render_to_string
|
|
from django.utils.html import strip_tags
|
|
from django.template.loader import get_template
|
|
from django.template import Context
|
|
from django.utils import timezone
|
|
from django.core.files import File
|
|
from django.core.files.base import ContentFile
|
|
import math
|
|
import datetime
|
|
import pytz
|
|
from langchain_community.embeddings import OllamaEmbeddings
|
|
|
|
from dateutil.relativedelta import relativedelta
|
|
from django.views.decorators.csrf import csrf_exempt
|
|
|
|
from .utils import last_day_of_month
|
|
from .services.llm_service import AsyncLLMService
|
|
from .services.rag_services import AsyncRAGService
|
|
from .services.title_generator import title_generator
|
|
from .services.moderation_classifier import moderation_classifier, ModerationLabel
|
|
from .services.prompt_classifier import prompt_classifier, PromptType
|
|
|
|
|
|
from langchain.chains import create_retrieval_chain
|
|
from langchain.chains.combine_documents import create_stuff_documents_chain
|
|
from langchain_ollama import ChatOllama
|
|
|
|
CHANNEL_NAME: str = "llm_messages"
|
|
MODEL_NAME: str = "llama3"
|
|
|
|
|
|
# Create your views here.
|
|
class CustomObtainTokenView(TokenObtainPairView):
|
|
permission_classes = (permissions.AllowAny,)
|
|
serializer_class = MyTokenObtainPairSerializer
|
|
|
|
|
|
class CustomUserCreate(APIView):
|
|
permission_classes = (permissions.AllowAny,)
|
|
authentication_classes = ()
|
|
|
|
def post(self, request, format="json"):
|
|
serializer = CustomUserSerializer(data=request.data)
|
|
if serializer.is_valid():
|
|
user = serializer.save()
|
|
if user:
|
|
json = serializer.data
|
|
return Response(json, status=status.HTTP_201_CREATED)
|
|
return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST)
|
|
|
|
|
|
def send_invite_email(slug, email_to_invite):
|
|
print("Sending invite email")
|
|
print(f"url : https://www.chat.aimloperations.com/set_password?slug={slug}")
|
|
url = f"https://www.chat.aimloperations.com/set_password?slug={slug}"
|
|
subject = "Welcome to AI ML Operations, LLC Chat Services"
|
|
from_email = "ryan@aimloperations.com"
|
|
to = email_to_invite
|
|
d = {"url": url}
|
|
html_content = get_template(r"emails/invite_email.html").render(d)
|
|
text_content = get_template(r"emails/invite_email.txt").render(d)
|
|
|
|
msg = EmailMultiAlternatives(subject, text_content, from_email, [to])
|
|
msg.attach_alternative(html_content, "text/html")
|
|
msg.send(fail_silently=True)
|
|
|
|
|
|
def send_feedback_email(feedback_obj):
|
|
print("Sending feedback email")
|
|
subject = "New Feedback for Chat by AI ML Operations, LLC"
|
|
from_email = "ryan@aimloperations.com"
|
|
to = "ryan@aimloperations.com"
|
|
d = {"title": feedback_obj.title, "feedback_text": feedback_obj.text}
|
|
html_content = get_template(r"emails/feedback_email.html").render(d)
|
|
text_content = get_template(r"emails/feedback_email.txt").render(d)
|
|
|
|
msg = EmailMultiAlternatives(subject, text_content, from_email, [to])
|
|
msg.attach_alternative(html_content, "text/html")
|
|
msg.send(fail_silently=True)
|
|
|
|
|
|
def send_password_reset_email(slug, email_to_invite):
|
|
print("Sending Password reset email")
|
|
url = f"https://www.chat.aimloperations.com/set_password?slug={slug}"
|
|
subject = "Password reset for Chat by AI ML Operations, LLC"
|
|
from_email = "ryan@aimloperations.com"
|
|
to = email_to_invite
|
|
d = {"url": url}
|
|
html_content = get_template(r"emails/reset_email.html").render(d)
|
|
text_content = get_template(r"emails/reset_email.txt").render(d)
|
|
msg = EmailMultiAlternatives(subject, text_content, from_email, [to])
|
|
msg.attach_alternative(html_content, "text/html")
|
|
msg.send(fail_silently=True)
|
|
|
|
|
|
class CustomUserInvite(APIView):
|
|
http_method_names = ["post"]
|
|
|
|
def post(self, request, format="json"):
|
|
def valid_email(email_string):
|
|
regex = r"^[a-z0-9]+[\._]?[a-z0-9]+[@]\w+[.]\w+$"
|
|
if re.match(regex, email_string):
|
|
return True
|
|
else:
|
|
return False
|
|
|
|
email_to_invite = request.data["email"]
|
|
|
|
if (
|
|
len(email_to_invite) == 0
|
|
or not valid_email(email_to_invite)
|
|
or not request.user.is_company_manager
|
|
):
|
|
return Response(status=status.HTTP_400_BAD_REQUEST)
|
|
# make sure there isn't a user with this email already
|
|
existing_users = CustomUser.objects.filter(email=email_to_invite)
|
|
if len(existing_users) > 0:
|
|
return Response(status=status.HTTP_400_BAD_REQUEST)
|
|
# create the object and send the email
|
|
user = CustomUser.objects.create(
|
|
email=email_to_invite,
|
|
username=email_to_invite,
|
|
company=request.user.company,
|
|
)
|
|
|
|
# send an email
|
|
send_invite_email(user.slug, email_to_invite)
|
|
|
|
return Response(status=status.HTTP_201_CREATED)
|
|
|
|
@csrf_exempt
|
|
def reset_password(request):
|
|
if request.method == "POST":
|
|
data = json.loads(request.body)
|
|
token = data.get('recaptchaToken')
|
|
payload = {
|
|
'secret': settings.CAPTCHA_SECRET_KEY,
|
|
'response': token,
|
|
}
|
|
response = requests.post('https://www.google.com/recaptcha/api/siteverify', data=payload)
|
|
result = response.json()
|
|
if result.get('success') and result.get('score') >= 0.5:
|
|
email = data.get('email')
|
|
user = CustomUser.objects.filter(email=email).first()
|
|
if user:
|
|
user.set_unusable_password()
|
|
user.save()
|
|
|
|
# send the email
|
|
send_password_reset_email(user.slug, email)
|
|
JsonResponse(status=200)
|
|
|
|
|
|
JsonResponse(status=400)
|
|
|
|
class ResetUserPassword(APIView):
|
|
http_method_names = [
|
|
"post",
|
|
]
|
|
permission_classes = (permissions.AllowAny,)
|
|
authentication_classes = ()
|
|
|
|
def post(self, request, format="json"):
|
|
"""
|
|
Send an email with a set password link to the set password page
|
|
Also disable the account
|
|
"""
|
|
print(f"Password reset for requests. {request.data}")
|
|
token = request.data.get('recaptchaToken')
|
|
payload = {
|
|
'secret': settings.CAPTCHA_SECRET_KEY,
|
|
'response': recaptchaToken,
|
|
}
|
|
response = requests.post('https://www.google.com/recaptcha/api/siteverify', data=payload)
|
|
result = response.json()
|
|
if result.get('success') and result.get('score') >= 0.5:
|
|
user = CustomUser.objects.filter(email=email).first()
|
|
if user:
|
|
user.set_unusable_password()
|
|
user.save()
|
|
|
|
# send the email
|
|
send_password_reset_email(user.slug, email)
|
|
else:
|
|
print('Captcha secret failed')
|
|
|
|
return Response(status=status.HTTP_200_OK)
|
|
|
|
|
|
class SetUserPassword(APIView):
|
|
http_method_names = ["post", "get"]
|
|
permission_classes = (permissions.AllowAny,)
|
|
authentication_classes = ()
|
|
|
|
def get(self, request, slug):
|
|
user = CustomUser.objects.get(slug=slug)
|
|
if user.has_usable_password():
|
|
return Response(status=status.HTTP_401_UNAUTHORIZED)
|
|
else:
|
|
return Response(status=status.HTTP_200_OK)
|
|
|
|
def post(self, request, slug, format="json"):
|
|
user = CustomUser.objects.get(slug=slug)
|
|
user.set_password(request.data["password"])
|
|
user.save()
|
|
return Response(status=status.HTTP_200_OK)
|
|
|
|
|
|
class CustomUserGet(APIView):
|
|
http_method_names = ["get", "head", "post"]
|
|
|
|
def get(self, request, format="json"):
|
|
|
|
email = request.user.email
|
|
username = request.user.username
|
|
user = CustomUser.objects.get(email=email)
|
|
serializer = CustomUserSerializer(user)
|
|
|
|
return Response(serializer.data, status=status.HTTP_200_OK)
|
|
|
|
|
|
class FeedbackView(APIView):
|
|
http_method_names = ["post", "get"]
|
|
|
|
def post(self, request, format="json"):
|
|
serializer = FeedbackSerializer(data=request.data)
|
|
print(request.data)
|
|
if serializer.is_valid():
|
|
|
|
feedback_obj = serializer.save()
|
|
feedback_obj.user = request.user
|
|
|
|
feedback_obj.save()
|
|
send_feedback_email(feedback_obj)
|
|
return Response(serializer.data, status=status.HTTP_201_CREATED)
|
|
else:
|
|
print(serializer.errors)
|
|
return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST)
|
|
|
|
def get(self, request, format="json"):
|
|
feedback_objs = Feedback.objects.filter(user=request.user)
|
|
serializer = FeedbackSerializer(feedback_objs, many=True)
|
|
return Response(serializer.data, status=status.HTTP_200_OK)
|
|
|
|
|
|
class AcknowledgeTermsOfService(APIView):
|
|
http_method_names = ["post"]
|
|
|
|
def post(self, request, format="json"):
|
|
request.user.has_signed_tos = True
|
|
request.user.save()
|
|
return Response(status=status.HTTP_200_OK)
|
|
|
|
|
|
class CompanyUsersView(APIView):
|
|
def get(self, request, format="json"):
|
|
# TODO: make sure you are a manager of that company
|
|
if request.user.is_company_manager:
|
|
users = CustomUser.objects.filter(company_id=request.user.company.id)
|
|
serializer = BasicUserSerializer(users, many=True)
|
|
return Response(serializer.data, status=status.HTTP_200_OK)
|
|
else:
|
|
return Response(status=status.HTTP_401_UNAUTHORIZED)
|
|
|
|
def post(self, request, format="json"):
|
|
if request.user.is_company_manager:
|
|
user = CustomUser.objects.get(email=request.data.get("email"))
|
|
if request.user.company_id == user.company_id:
|
|
field = request.data.get("field")
|
|
data = {}
|
|
if field == "is_active":
|
|
data.update({"is_active": not user.is_active})
|
|
elif field == "company_manager":
|
|
data.update({"is_company_manager": not user.is_company_manager})
|
|
elif field == "has_password":
|
|
if user.has_usable_password():
|
|
user.set_unusable_password()
|
|
serializer = CustomUserSerializer(user, data, partial=True)
|
|
if serializer.is_valid():
|
|
serializer.save()
|
|
return Response(status=status.HTTP_200_OK)
|
|
return Response(status=status.HTTP_400_BAD_REQUEST)
|
|
return Response(status=status.HTTP_401_UNAUTHORIZED)
|
|
|
|
def delete(self, request, format="json"):
|
|
if request.user.is_company_manager:
|
|
user = CustomUser.objects.get(email=request.data.get("email"))
|
|
if request.user.company_id == user.company_id:
|
|
user.delete()
|
|
return Response(status=status.HTTP_200_OK)
|
|
return Response(status=status.HTTP_401_UNAUTHORIZED)
|
|
|
|
|
|
class AnnouncmentView(APIView):
|
|
permission_classes = (permissions.AllowAny,)
|
|
serializer_class = AnnouncmentSerializer
|
|
|
|
def get(self, request, format="json"):
|
|
announcements = Announcement.objects.all()
|
|
serializer = AnnouncmentSerializer(announcements, many=True)
|
|
|
|
return Response(serializer.data, status=status.HTTP_200_OK)
|
|
|
|
|
|
class LogoutAndBlacklistRefreshTokenForUserView(APIView):
|
|
permission_classes = (permissions.AllowAny,)
|
|
authentication_classes = ()
|
|
|
|
def post(self, request):
|
|
try:
|
|
refresh_token = request.data["refresh_token"]
|
|
token = RefreshToken(refresh_token)
|
|
token.blacklist()
|
|
return Response(status=status.HTTP_205_RESET_CONTENT)
|
|
except Exception as e:
|
|
return Response(status=status.HTTP_400_BAD_REQUEST)
|
|
|
|
|
|
@never_cache
|
|
def is_authenticated(request):
|
|
if request.user.is_authenticated:
|
|
return JsonResponse({}, status=status.HTTP_200_OK)
|
|
return JsonResponse({}, status=status.HTTP_401_UNAUTHORIZED)
|
|
|
|
|
|
class ConversationsView(APIView):
|
|
def get(self, request, format="json"):
|
|
order = "created" if request.user.conversation_order else "-created"
|
|
conversations = Conversation.objects.filter(
|
|
user=request.user, deleted=False
|
|
).order_by(order)
|
|
serializer = ConversationSerializer(conversations, many=True)
|
|
return Response(serializer.data, status=status.HTTP_200_OK)
|
|
|
|
def post(self, request, format="json"):
|
|
"""
|
|
Create a blank conversation and return the title and id number
|
|
"""
|
|
title = request.data.get("name")
|
|
conversation = Conversation.objects.create(title=title)
|
|
conversation.save()
|
|
conversation.user_id = request.user.id
|
|
conversation.save()
|
|
|
|
# TODO: when we are smart enough to create a conversation when a prompt is sent
|
|
# client = LlamaClient()
|
|
# inital_message = request.data.get("prompt")
|
|
# title = client.generate_conversation_title(inital_message)
|
|
# title = title if title else "New Conversation"
|
|
# conversation = Conversation.objects.create(title=title)
|
|
# conversation.save()
|
|
# conversation.user_id = request.user.id
|
|
# conversation.save()
|
|
|
|
return Response(
|
|
{"title": title, "id": conversation.id}, status=status.HTTP_201_CREATED
|
|
)
|
|
|
|
|
|
class ConversationPreferences(APIView):
|
|
def get(self, request, format="json"):
|
|
user = request.user
|
|
return Response({"order": user.conversation_order}, status=status.HTTP_200_OK)
|
|
|
|
def post(self, request, format="json"):
|
|
user = request.user
|
|
user.conversation_order = not user.conversation_order
|
|
user.save()
|
|
return Response({"order": user.conversation_order}, status=status.HTTP_200_OK)
|
|
|
|
|
|
class ConversationDetailView(APIView):
|
|
def get(self, request, format="json"):
|
|
conversation_id = request.query_params.get("conversation_id")
|
|
prompts = Prompt.objects.filter(conversation__id=conversation_id)
|
|
serailzer = PromptSerializer(prompts, many=True)
|
|
return Response(serailzer.data, status=status.HTTP_200_OK)
|
|
|
|
def post(self, request, format="json"):
|
|
print("In the post")
|
|
# Add the prompt to the database
|
|
# make sure there is a conversation for it
|
|
# if there is not a conversation create a title for it
|
|
|
|
# make sure that our model exists and it is running
|
|
prompt = request.data.get("prompt")
|
|
|
|
conversation_id = request.data.get("conversation_id")
|
|
is_user = bool(request.data.get("is_user"))
|
|
|
|
try:
|
|
conversation = Conversation.objects.get(id=conversation_id)
|
|
|
|
# add the prompt to the conversation
|
|
serializer = PromptSerializer(
|
|
data={
|
|
"message": prompt,
|
|
"user_created": is_user,
|
|
"created": datetime.now(),
|
|
}
|
|
)
|
|
if serializer.is_valid():
|
|
prompt_instance = serializer.save()
|
|
prompt_instance.conversation_id = conversation.id
|
|
prompt_instance = serializer.save()
|
|
|
|
# set up the streaming response if it is from the user
|
|
print(f"Do we have a valid user? {is_user}")
|
|
if is_user:
|
|
messages = []
|
|
for prompt_obj in Prompt.objects.filter(
|
|
conversation__id=conversation_id
|
|
):
|
|
messages.append(
|
|
{
|
|
"content": prompt_obj.message,
|
|
"role": "user" if prompt_obj.user_created else "assistant",
|
|
}
|
|
)
|
|
|
|
channel_layer = get_channel_layer()
|
|
print(f"Sending to the channel: {CHANNEL_NAME}")
|
|
async_to_sync(channel_layer.group_send)(
|
|
CHANNEL_NAME, {"type": "receive", "content": messages}
|
|
)
|
|
except:
|
|
print(
|
|
f"Error trying to submit to conversation_id: {conversation_id} with request.data: {request.data}"
|
|
)
|
|
pass
|
|
|
|
return Response(status=status.HTTP_200_OK)
|
|
|
|
def delete(self, request, format="json"):
|
|
conversation_id = request.data.get("conversation_id")
|
|
conversation = Conversation.objects.get(id=conversation_id, user=request.user)
|
|
conversation.deleted = True
|
|
conversation.save()
|
|
return Response(status=status.HTTP_202_ACCEPTED)
|
|
|
|
|
|
class UserPromptAnalytics(APIView):
|
|
def get(self, request, format="json"):
|
|
now = timezone.now()
|
|
result = []
|
|
|
|
number_of_months = 3
|
|
company_user_ids = CustomUser.objects.filter(
|
|
company=request.user.company
|
|
).values_list("id", flat=True)
|
|
for i in range(number_of_months):
|
|
next_year = now.year
|
|
next_month = now.month - i
|
|
while next_month < 1:
|
|
next_year -= 1
|
|
next_month += 12
|
|
|
|
start_date = datetime.datetime(next_year, next_month, 1)
|
|
end_date = last_day_of_month(start_date)
|
|
total_conversations = Conversation.objects.filter(
|
|
created__gte=start_date, created__lte=end_date
|
|
)
|
|
total_prompts = Prompt.objects.filter(
|
|
conversation__id__in=total_conversations,
|
|
created__gte=start_date,
|
|
created__lte=end_date,
|
|
)
|
|
total_users = len(CustomUser.objects.all())
|
|
my_conversations = Conversation.objects.filter(user=request.user)
|
|
my_prompts = Prompt.objects.filter(
|
|
conversation__in=my_conversations,
|
|
created__gte=start_date,
|
|
created__lte=end_date,
|
|
)
|
|
company_conversations = Conversation.objects.filter(
|
|
user__id__in=company_user_ids
|
|
)
|
|
company_prompts = Prompt.objects.filter(
|
|
conversation__in=company_conversations,
|
|
created__gte=start_date,
|
|
created__lte=end_date,
|
|
)
|
|
|
|
result.append(
|
|
{
|
|
"month": start_date.strftime("%B"),
|
|
"you": len(my_prompts),
|
|
"others": len(company_prompts) / len(company_user_ids),
|
|
"all": len(total_prompts) / total_users,
|
|
}
|
|
)
|
|
|
|
return Response(result[::-1], status=status.HTTP_200_OK)
|
|
|
|
|
|
class UserConversationAnalytics(APIView):
|
|
def get(self, request, format="json"):
|
|
now = timezone.now()
|
|
result = []
|
|
|
|
number_of_months = 3
|
|
company_user_ids = CustomUser.objects.filter(
|
|
company=request.user.company
|
|
).values_list("id", flat=True)
|
|
for i in range(number_of_months):
|
|
next_year = now.year
|
|
next_month = now.month - i
|
|
while next_month < 1:
|
|
next_year -= 1
|
|
next_month += 12
|
|
|
|
start_date = datetime.datetime(next_year, next_month, 1)
|
|
end_date = last_day_of_month(start_date)
|
|
total_conversations = len(
|
|
Conversation.objects.filter(
|
|
created__gte=start_date, created__lte=end_date
|
|
)
|
|
)
|
|
total_users = len(CustomUser.objects.all())
|
|
company_conversations = len(
|
|
Conversation.objects.filter(
|
|
user__id__in=company_user_ids,
|
|
created__gte=start_date,
|
|
created__lte=end_date,
|
|
)
|
|
)
|
|
|
|
result.append(
|
|
{
|
|
"month": start_date.strftime("%B"),
|
|
"you": len(
|
|
Conversation.objects.filter(
|
|
user=request.user,
|
|
created__gte=start_date,
|
|
created__lte=end_date,
|
|
)
|
|
),
|
|
"others": company_conversations / len(company_user_ids),
|
|
"all": total_conversations / total_users,
|
|
}
|
|
)
|
|
|
|
return Response(result[::-1], status=status.HTTP_200_OK)
|
|
|
|
|
|
class CompanyUsageAnalytics(APIView):
|
|
def get(self, request, format="json"):
|
|
now = timezone.now()
|
|
result = []
|
|
|
|
number_of_months = 3
|
|
company_user_ids = CustomUser.objects.filter(
|
|
company=request.user.company
|
|
).values_list("id", flat=True)
|
|
|
|
for i in range(number_of_months):
|
|
next_year = now.year
|
|
next_month = now.month - i
|
|
while next_month < 1:
|
|
next_year -= 1
|
|
next_month += 12
|
|
|
|
start_date = datetime.datetime(next_year, next_month, 1)
|
|
end_date = last_day_of_month(start_date)
|
|
conversations = Conversation.objects.filter(
|
|
user__id__in=company_user_ids,
|
|
created__gte=start_date,
|
|
created__lte=end_date,
|
|
)
|
|
|
|
conversation_user_ids = conversations.values_list(
|
|
"user__id", flat=True
|
|
).distinct()
|
|
result.append(
|
|
{
|
|
"month": start_date.strftime("%B"),
|
|
"used": len(conversation_user_ids),
|
|
"not_used": len(company_user_ids) - len(conversation_user_ids),
|
|
}
|
|
)
|
|
return Response(result[::-1], status=status.HTTP_200_OK)
|
|
|
|
|
|
class AdminAnalytics(APIView):
|
|
def get(self, request, format="json"):
|
|
number_of_months = 3
|
|
result = []
|
|
now = timezone.now()
|
|
|
|
for i in range(number_of_months):
|
|
next_year = now.year
|
|
next_month = now.month - i
|
|
while next_month < 1:
|
|
next_year -= 1
|
|
next_month += 12
|
|
|
|
start_date = datetime.datetime(next_year, next_month, 1)
|
|
end_date = last_day_of_month(start_date)
|
|
durations = [
|
|
item.get_duration()
|
|
for item in PromptMetric.objects.filter(
|
|
created__gte=start_date, created__lte=end_date
|
|
)
|
|
]
|
|
if len(durations) == 0:
|
|
result.append(
|
|
{
|
|
"month": start_date.strftime("%B"),
|
|
"range": [0, 0],
|
|
"avg": 0,
|
|
"median": 0,
|
|
}
|
|
)
|
|
continue
|
|
|
|
average = sum(durations) / len(durations)
|
|
min_value = min(durations)
|
|
max_value = max(durations)
|
|
durations.sort()
|
|
median = durations[len(durations) // 2]
|
|
result.append(
|
|
{
|
|
"month": start_date.strftime("%B"),
|
|
"range": [min_value, max_value],
|
|
"avg": average,
|
|
"median": median,
|
|
}
|
|
)
|
|
|
|
return Response(result[::-1], status=status.HTTP_200_OK)
|
|
|
|
|
|
prompt = ChatPromptTemplate.from_messages(
|
|
[("system", "You are a helpful assistant."), ("user", "{input}")]
|
|
)
|
|
|
|
llm = OllamaLLM(model=MODEL_NAME)
|
|
|
|
# output_parser = StrOutputParser()
|
|
# # Chain
|
|
# chain = prompt | llm.with_config({"run_name": "model"}) | output_parser.with_config({"run_name": "Assistant"})
|
|
|
|
|
|
@database_sync_to_async
|
|
def create_conversation(prompt, email, title):
|
|
# return the conversation id
|
|
conversation = Conversation.objects.create(title=title)
|
|
conversation.save()
|
|
|
|
user = CustomUser.objects.get(email=email)
|
|
conversation.user_id = user.id
|
|
conversation.save()
|
|
return conversation.id
|
|
|
|
|
|
@database_sync_to_async
|
|
def get_workspace(conversation_id):
|
|
conversation = Conversation.objects.get(id=conversation_id)
|
|
return DocumentWorkspace.objects.get(company=conversation.user.company)
|
|
|
|
@database_sync_to_async
|
|
def get_messages(conversation_id, prompt, file_string: str = None, file_type: str = ""):
|
|
messages = []
|
|
|
|
conversation = Conversation.objects.get(id=conversation_id)
|
|
print(file_string)
|
|
|
|
# add the prompt to the conversation
|
|
serializer = PromptSerializer(
|
|
data={
|
|
"message": prompt,
|
|
"user_created": True,
|
|
"created": timezone.now(),
|
|
}
|
|
)
|
|
if serializer.is_valid(raise_exception=True):
|
|
prompt_instance = serializer.save()
|
|
prompt_instance.conversation_id = conversation.id
|
|
prompt_instance.save()
|
|
if file_string:
|
|
file_name = f"prompt_{prompt_instance.id}_data.{file_type}"
|
|
f = ContentFile(file_string, name=file_name)
|
|
prompt_instance.file.save(file_name, f)
|
|
prompt_instance.file_type = file_type
|
|
prompt_instance.save()
|
|
|
|
for prompt_obj in Prompt.objects.filter(conversation__id=conversation_id):
|
|
messages.append(
|
|
{
|
|
"content": prompt_obj.message,
|
|
"role": "user" if prompt_obj.user_created else "assistant",
|
|
"has_file": prompt_obj.file_exists(),
|
|
"file": prompt_obj.file if prompt_obj.file_exists() else None,
|
|
"file_type": prompt_obj.file_type if prompt_obj.file_exists() else None,
|
|
}
|
|
)
|
|
|
|
# now transform the messages
|
|
transformed_messages = []
|
|
for message in messages:
|
|
|
|
if message["has_file"] and message["file_type"] != None:
|
|
if "csv" in message["file_type"]:
|
|
file_type = "csv"
|
|
altered_message = f"{message['content']}\n The file type is csv and the file contents are: {message['file'].read()}"
|
|
elif "xlsx" in message["file_type"]:
|
|
file_type = "xlsx"
|
|
df = pd.read_excel(message["file"].read())
|
|
altered_message = f"{message['content']}\n The file type is xlsx and the file contents are: {df}"
|
|
elif "txt" in message["file_type"]:
|
|
file_type = "txt"
|
|
altered_message = f"{message['content']}\n The file type is csv and the file contents are: {message['file'].read()}"
|
|
else:
|
|
altered_message = message["content"]
|
|
|
|
transformed_message = (
|
|
SystemMessage(content=altered_message)
|
|
if message["role"] == "assistant"
|
|
else HumanMessage(content=altered_message)
|
|
)
|
|
transformed_messages.append(transformed_message)
|
|
|
|
return transformed_messages, prompt_instance
|
|
|
|
|
|
@database_sync_to_async
|
|
def save_generated_message(conversation_id, message):
|
|
conversation = Conversation.objects.get(id=conversation_id)
|
|
|
|
# add the prompt to the conversation
|
|
serializer = PromptSerializer(
|
|
data={
|
|
"message": message,
|
|
"user_created": False,
|
|
"created": timezone.now(),
|
|
}
|
|
)
|
|
if serializer.is_valid():
|
|
prompt_instance = serializer.save()
|
|
prompt_instance.conversation_id = conversation.id
|
|
prompt_instance = serializer.save()
|
|
|
|
|
|
@database_sync_to_async
|
|
def create_prompt_metric(
|
|
prompt_id, prompt, has_file, file_type, model_name, conversation_id
|
|
):
|
|
prompt_metric = PromptMetric.objects.create(
|
|
prompt_id=prompt_id,
|
|
start_time=timezone.now(),
|
|
prompt_length=len(prompt),
|
|
has_file=has_file,
|
|
file_type=file_type,
|
|
model_name=model_name,
|
|
conversation_id=conversation_id,
|
|
)
|
|
prompt_metric.save()
|
|
return prompt_metric
|
|
|
|
|
|
@database_sync_to_async
|
|
def update_prompt_metric(prompt_metric, status):
|
|
prompt_metric.event = status
|
|
prompt_metric.save()
|
|
|
|
|
|
@database_sync_to_async
|
|
def finish_prompt_metric(prompt_metric, response_length):
|
|
print(f"finish_prompt_metric: {response_length}")
|
|
prompt_metric.end_time = timezone.now()
|
|
prompt_metric.reponse_length = response_length
|
|
prompt_metric.event = "FINISHED"
|
|
prompt_metric.save(update_fields=["end_time", "reponse_length", "event"])
|
|
print("finish_prompt_metric saved")
|
|
|
|
@database_sync_to_async
|
|
def get_retriever(conversation_id):
|
|
print(f'getting workspace from conversation: {conversation_id}')
|
|
conversation = Conversation.objects.get(id=conversation_id)
|
|
print(f'Got conversation: {conversation}')
|
|
workspace = DocumentWorkspace.objects.get(company=conversation.user.company)
|
|
print(f'Got workspace: {conversation}')
|
|
vectorstore = Chroma(
|
|
persist_directory=f"./chroma_db/",
|
|
embedding=OllamaEmbeddings(model="llama3.2"),
|
|
)
|
|
return vectorstore.as_retriever()
|
|
|
|
class ChatConsumerAgain(AsyncWebsocketConsumer):
|
|
async def connect(self):
|
|
await self.accept()
|
|
|
|
async def disconnect(self, close_code):
|
|
await self.close()
|
|
|
|
async def receive(self, text_data=None, bytes_data=None):
|
|
print(f"Text Data: {text_data}")
|
|
print(f"Bytes Data: {bytes_data}")
|
|
if text_data:
|
|
data = json.loads(text_data)
|
|
message = data.get("message", None)
|
|
conversation_id = data.get("conversation_id", None)
|
|
email = data.get("email", None)
|
|
file = data.get("file", None)
|
|
file_type = data.get("fileType", "")
|
|
model = data.get("modelName", "Turbo")
|
|
|
|
if not conversation_id:
|
|
# we need to create a new conversation
|
|
# we will generate a name for it too
|
|
title = await title_generator.generate_async(message)
|
|
conversation_id = await create_conversation(message, email, title)
|
|
|
|
if conversation_id:
|
|
decoded_file = None
|
|
|
|
if file:
|
|
decoded_file = base64.b64decode(file)
|
|
print(decoded_file)
|
|
if "csv" in file_type:
|
|
file_type = "csv"
|
|
altered_message = f"{message}\n The file type is csv and the file contents are: {decoded_file}"
|
|
elif "xmlformats-officedocument" in file_type:
|
|
file_type = "xlsx"
|
|
df = pd.read_excel(decoded_file)
|
|
altered_message = f"{message}\n The file type is xlsx and the file contents are: {df}"
|
|
elif "text" in file_type:
|
|
file_type = "txt"
|
|
altered_message = f"{message}\n The file type is txt and the file contents are: {decoded_file}"
|
|
else:
|
|
file_type = "Not Sure"
|
|
|
|
print(f'received: "{message}" for conversation {conversation_id}')
|
|
|
|
# check the moderation here
|
|
if await moderation_classifier.classify_async(message) == ModerationLabel.NSFW:
|
|
response = "Prompt has been marked as NSFW. If this is in error, submit a feedback with the prompt text."
|
|
print("this prompt has been marked as NSFW")
|
|
await self.send("CONVERSATION_ID")
|
|
await self.send(str(conversation_id))
|
|
await self.send("START_OF_THE_STREAM_ENDER_GAME_42")
|
|
await self.send(response)
|
|
await self.send("END_OF_THE_STREAM_ENDER_GAME_42")
|
|
await save_generated_message(conversation_id, response)
|
|
return
|
|
|
|
# TODO: add the message to the database
|
|
# get the new conversation
|
|
# TODO: get the messages here
|
|
|
|
messages, prompt = await get_messages(
|
|
conversation_id, message, decoded_file, file_type
|
|
)
|
|
|
|
prompt_type = await prompt_classifier.classify_async(message)
|
|
print(f"prompt_type: {prompt_type} for {message}")
|
|
|
|
|
|
|
|
|
|
prompt_metric = await create_prompt_metric(
|
|
prompt.id,
|
|
prompt.message,
|
|
True if file else False,
|
|
file_type,
|
|
MODEL_NAME,
|
|
conversation_id,
|
|
)
|
|
if file:
|
|
# udpate with the altered_message
|
|
messages = messages[:-1] + [HumanMessage(content=altered_message)]
|
|
print(messages)
|
|
# send it to the LLM
|
|
|
|
# stream the response back
|
|
response = ""
|
|
# start of the message
|
|
await self.send("CONVERSATION_ID")
|
|
await self.send(str(conversation_id))
|
|
await self.send("START_OF_THE_STREAM_ENDER_GAME_42")
|
|
if prompt_type == PromptType.RAG:
|
|
service = AsyncRAGService()
|
|
#await service.ingest_documents()
|
|
workspace = await get_workspace(conversation_id)
|
|
print('Time to get the rag response')
|
|
|
|
async for chunk in service.generate_response(messages, prompt.message, workspace):
|
|
print(f"chunk: {chunk}")
|
|
response += chunk
|
|
await self.send(chunk)
|
|
elif prompt_type == PromptType.IMAGE_GENERATION:
|
|
response = "Image Generation is not supported at this time, but it will be soon."
|
|
await self.send(response)
|
|
|
|
else:
|
|
service = AsyncLLMService()
|
|
async for chunk in service.generate_response(messages, prompt.message):
|
|
print(f"chunk: {chunk}")
|
|
response += chunk
|
|
await self.send(chunk)
|
|
await self.send("END_OF_THE_STREAM_ENDER_GAME_42")
|
|
await save_generated_message(conversation_id, response)
|
|
|
|
await finish_prompt_metric(prompt_metric, len(response))
|
|
|
|
if bytes_data:
|
|
print("we have byte data")
|
|
|
|
# Document Views
|
|
class DocumentWorkspaceView(APIView):
|
|
#permission_classes = [permissions.IsAuthenticated]
|
|
|
|
def get(self, request):
|
|
workspaces = DocumentWorkspace.objects.filter(company=request.user.company)
|
|
serializer = DocumentWorkspaceSerializer(workspaces, many=True)
|
|
return Response(serializer.data)
|
|
|
|
def post(self, request):
|
|
serializer = DocumentWorkspaceSerializer(data=request.data)
|
|
if serializer.is_valid():
|
|
serializer.save(company=request.user.company)
|
|
return Response(serializer.data, status=status.HTTP_201_CREATED)
|
|
return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST)
|
|
|
|
class DocumentUploadView(APIView):
|
|
#permission_classes = [permissions.IsAuthenticated]Z
|
|
|
|
def get(self, request):
|
|
print(f'request_3: {request}')
|
|
try:
|
|
workspace = DocumentWorkspace.objects.get(company=request.user.company)
|
|
serializer = DocumentSerializer(Document.objects.filter(workspace=workspace), many=True)
|
|
return Response(serializer.data, status=status.HTTP_200_OK)
|
|
|
|
except:
|
|
return Response({'error': "Workspace not found"}, status=status.HTTP_404_NOT_FOUND)
|
|
|
|
def post(self, request):
|
|
print(f'request: {request}')
|
|
|
|
try:
|
|
workspace = DocumentWorkspace.objects.get(company=request.user.company)
|
|
|
|
except:
|
|
return Response({'error': "Workspace not found"}, status=status.HTTP_404_NOT_FOUND)
|
|
|
|
print(request.FILES)
|
|
file = request.FILES.get('file')
|
|
if not file:
|
|
return Response({"error":"No file provided"}, status=status.HTTP_400_BAD_REQUEST)
|
|
|
|
print("have the workspace and the file")
|
|
|
|
document = Document.objects.create(
|
|
workspace=workspace,
|
|
file=file
|
|
)
|
|
|
|
# process the document inthe background
|
|
self.process_document(document)
|
|
|
|
serializer = DocumentSerializer(document)
|
|
return Response(serializer.data, status=status.HTTP_201_CREATED)
|
|
|
|
|
|
def process_document(self, document):
|
|
file_path = os.path.join(settings.MEDIA_ROOT, document.file.name)
|
|
|
|
document.processed = True
|
|
document.active = True
|
|
document.save()
|
|
service = AsyncRAGService()
|
|
service.add_files_to_store([(file_path, document.file.name, document.workspace_id)], workspace_id=document.workspace_id)
|
|
|
|
class DocumentDetailView(APIView):
|
|
#permission_classes = [permissions.IsAuthenticated]
|
|
|
|
def get(self, request, document_id):
|
|
print(f'request: {request}')
|
|
try:
|
|
workspace = DocumentWorkspace.objects.get(company=request.user.company)
|
|
|
|
document = Document.objects.get(
|
|
workspace=workspace,
|
|
id=document_id
|
|
)
|
|
except:
|
|
return Response({'error': "Document not found"}, status=status.HTTP_404_NOT_FOUND)
|
|
|
|
serializer = DocumentWorkspaceSerializer(workspaces, many=True)
|
|
return Response(serializer.data) |