Files
chat_backend/llm_be/chat_backend/views.py

1039 lines
38 KiB
Python

from channels.layers import get_channel_layer
from channels.db import database_sync_to_async
from rest_framework_simplejwt.views import TokenObtainPairView
from rest_framework_simplejwt.tokens import RefreshToken
from rest_framework import permissions, status
from .serializers import (
MyTokenObtainPairSerializer,
CustomUserSerializer,
BasicUserSerializer,
AnnouncmentSerializer,
CompanySerializer,
ConversationSerializer,
PromptSerializer,
FeedbackSerializer,
DocumentWorkspaceSerializer,
DocumentSerializer
)
from rest_framework.views import APIView
from rest_framework.response import Response
from .models import (
CustomUser,
Announcement,
Conversation,
Prompt,
Feedback,
PromptMetric,
DocumentWorkspace,
Document
)
from django.views.decorators.cache import never_cache
from django.http import JsonResponse
from datetime import datetime
from .client import LlamaClient
from asgiref.sync import sync_to_async, async_to_sync
from channels.generic.websocket import AsyncWebsocketConsumer
from langchain_ollama.llms import OllamaLLM
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.messages import HumanMessage, SystemMessage
from langchain.chains import RetrievalQA
import re
import os
from django.conf import settings
import json
import base64
import pandas as pd
# For email support
from django.core.mail import EmailMultiAlternatives
from django.template.loader import render_to_string
from django.utils.html import strip_tags
from django.template.loader import get_template
from django.template import Context
from django.utils import timezone
from django.core.files import File
from django.core.files.base import ContentFile
import math
import datetime
import pytz
from langchain_community.embeddings import OllamaEmbeddings
from dateutil.relativedelta import relativedelta
from django.views.decorators.csrf import csrf_exempt
from .utils import last_day_of_month
from .services.llm_service import AsyncLLMService
from .services.rag_services import AsyncRAGService
from .services.title_generator import title_generator
from .services.moderation_classifier import moderation_classifier, ModerationLabel
from .services.prompt_classifier import prompt_classifier, PromptType
from langchain.chains import create_retrieval_chain
from langchain.chains.combine_documents import create_stuff_documents_chain
from langchain_ollama import ChatOllama
CHANNEL_NAME: str = "llm_messages"
MODEL_NAME: str = "llama3"
# Create your views here.
class CustomObtainTokenView(TokenObtainPairView):
permission_classes = (permissions.AllowAny,)
serializer_class = MyTokenObtainPairSerializer
class CustomUserCreate(APIView):
permission_classes = (permissions.AllowAny,)
authentication_classes = ()
def post(self, request, format="json"):
serializer = CustomUserSerializer(data=request.data)
if serializer.is_valid():
user = serializer.save()
if user:
json = serializer.data
return Response(json, status=status.HTTP_201_CREATED)
return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST)
def send_invite_email(slug, email_to_invite):
print("Sending invite email")
print(f"url : https://www.chat.aimloperations.com/set_password?slug={slug}")
url = f"https://www.chat.aimloperations.com/set_password?slug={slug}"
subject = "Welcome to AI ML Operations, LLC Chat Services"
from_email = "ryan@aimloperations.com"
to = email_to_invite
d = {"url": url}
html_content = get_template(r"emails/invite_email.html").render(d)
text_content = get_template(r"emails/invite_email.txt").render(d)
msg = EmailMultiAlternatives(subject, text_content, from_email, [to])
msg.attach_alternative(html_content, "text/html")
msg.send(fail_silently=True)
def send_feedback_email(feedback_obj):
print("Sending feedback email")
subject = "New Feedback for Chat by AI ML Operations, LLC"
from_email = "ryan@aimloperations.com"
to = "ryan@aimloperations.com"
d = {"title": feedback_obj.title, "feedback_text": feedback_obj.text}
html_content = get_template(r"emails/feedback_email.html").render(d)
text_content = get_template(r"emails/feedback_email.txt").render(d)
msg = EmailMultiAlternatives(subject, text_content, from_email, [to])
msg.attach_alternative(html_content, "text/html")
msg.send(fail_silently=True)
def send_password_reset_email(slug, email_to_invite):
print("Sending Password reset email")
url = f"https://www.chat.aimloperations.com/set_password?slug={slug}"
subject = "Password reset for Chat by AI ML Operations, LLC"
from_email = "ryan@aimloperations.com"
to = email_to_invite
d = {"url": url}
html_content = get_template(r"emails/reset_email.html").render(d)
text_content = get_template(r"emails/reset_email.txt").render(d)
msg = EmailMultiAlternatives(subject, text_content, from_email, [to])
msg.attach_alternative(html_content, "text/html")
msg.send(fail_silently=True)
class CustomUserInvite(APIView):
http_method_names = ["post"]
def post(self, request, format="json"):
def valid_email(email_string):
regex = r"^[a-z0-9]+[\._]?[a-z0-9]+[@]\w+[.]\w+$"
if re.match(regex, email_string):
return True
else:
return False
email_to_invite = request.data["email"]
if (
len(email_to_invite) == 0
or not valid_email(email_to_invite)
or not request.user.is_company_manager
):
return Response(status=status.HTTP_400_BAD_REQUEST)
# make sure there isn't a user with this email already
existing_users = CustomUser.objects.filter(email=email_to_invite)
if len(existing_users) > 0:
return Response(status=status.HTTP_400_BAD_REQUEST)
# create the object and send the email
user = CustomUser.objects.create(
email=email_to_invite,
username=email_to_invite,
company=request.user.company,
)
# send an email
send_invite_email(user.slug, email_to_invite)
return Response(status=status.HTTP_201_CREATED)
@csrf_exempt
def reset_password(request):
if request.method == "POST":
data = json.loads(request.body)
token = data.get('recaptchaToken')
payload = {
'secret': settings.CAPTCHA_SECRET_KEY,
'response': token,
}
response = requests.post('https://www.google.com/recaptcha/api/siteverify', data=payload)
result = response.json()
if result.get('success') and result.get('score') >= 0.5:
email = data.get('email')
user = CustomUser.objects.filter(email=email).first()
if user:
user.set_unusable_password()
user.save()
# send the email
send_password_reset_email(user.slug, email)
JsonResponse(status=200)
JsonResponse(status=400)
class ResetUserPassword(APIView):
http_method_names = [
"post",
]
permission_classes = (permissions.AllowAny,)
authentication_classes = ()
def post(self, request, format="json"):
"""
Send an email with a set password link to the set password page
Also disable the account
"""
print(f"Password reset for requests. {request.data}")
token = request.data.get('recaptchaToken')
payload = {
'secret': settings.CAPTCHA_SECRET_KEY,
'response': recaptchaToken,
}
response = requests.post('https://www.google.com/recaptcha/api/siteverify', data=payload)
result = response.json()
if result.get('success') and result.get('score') >= 0.5:
user = CustomUser.objects.filter(email=email).first()
if user:
user.set_unusable_password()
user.save()
# send the email
send_password_reset_email(user.slug, email)
else:
print('Captcha secret failed')
return Response(status=status.HTTP_200_OK)
class SetUserPassword(APIView):
http_method_names = ["post", "get"]
permission_classes = (permissions.AllowAny,)
authentication_classes = ()
def get(self, request, slug):
user = CustomUser.objects.get(slug=slug)
if user.has_usable_password():
return Response(status=status.HTTP_401_UNAUTHORIZED)
else:
return Response(status=status.HTTP_200_OK)
def post(self, request, slug, format="json"):
user = CustomUser.objects.get(slug=slug)
user.set_password(request.data["password"])
user.save()
return Response(status=status.HTTP_200_OK)
class CustomUserGet(APIView):
http_method_names = ["get", "head", "post"]
def get(self, request, format="json"):
email = request.user.email
username = request.user.username
user = CustomUser.objects.get(email=email)
serializer = CustomUserSerializer(user)
return Response(serializer.data, status=status.HTTP_200_OK)
class FeedbackView(APIView):
http_method_names = ["post", "get"]
def post(self, request, format="json"):
serializer = FeedbackSerializer(data=request.data)
print(request.data)
if serializer.is_valid():
feedback_obj = serializer.save()
feedback_obj.user = request.user
feedback_obj.save()
send_feedback_email(feedback_obj)
return Response(serializer.data, status=status.HTTP_201_CREATED)
else:
print(serializer.errors)
return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST)
def get(self, request, format="json"):
feedback_objs = Feedback.objects.filter(user=request.user)
serializer = FeedbackSerializer(feedback_objs, many=True)
return Response(serializer.data, status=status.HTTP_200_OK)
class AcknowledgeTermsOfService(APIView):
http_method_names = ["post"]
def post(self, request, format="json"):
request.user.has_signed_tos = True
request.user.save()
return Response(status=status.HTTP_200_OK)
class CompanyUsersView(APIView):
def get(self, request, format="json"):
# TODO: make sure you are a manager of that company
if request.user.is_company_manager:
users = CustomUser.objects.filter(company_id=request.user.company.id)
serializer = BasicUserSerializer(users, many=True)
return Response(serializer.data, status=status.HTTP_200_OK)
else:
return Response(status=status.HTTP_401_UNAUTHORIZED)
def post(self, request, format="json"):
if request.user.is_company_manager:
user = CustomUser.objects.get(email=request.data.get("email"))
if request.user.company_id == user.company_id:
field = request.data.get("field")
data = {}
if field == "is_active":
data.update({"is_active": not user.is_active})
elif field == "company_manager":
data.update({"is_company_manager": not user.is_company_manager})
elif field == "has_password":
if user.has_usable_password():
user.set_unusable_password()
serializer = CustomUserSerializer(user, data, partial=True)
if serializer.is_valid():
serializer.save()
return Response(status=status.HTTP_200_OK)
return Response(status=status.HTTP_400_BAD_REQUEST)
return Response(status=status.HTTP_401_UNAUTHORIZED)
def delete(self, request, format="json"):
if request.user.is_company_manager:
user = CustomUser.objects.get(email=request.data.get("email"))
if request.user.company_id == user.company_id:
user.delete()
return Response(status=status.HTTP_200_OK)
return Response(status=status.HTTP_401_UNAUTHORIZED)
class AnnouncmentView(APIView):
permission_classes = (permissions.AllowAny,)
serializer_class = AnnouncmentSerializer
def get(self, request, format="json"):
announcements = Announcement.objects.all()
serializer = AnnouncmentSerializer(announcements, many=True)
return Response(serializer.data, status=status.HTTP_200_OK)
class LogoutAndBlacklistRefreshTokenForUserView(APIView):
permission_classes = (permissions.AllowAny,)
authentication_classes = ()
def post(self, request):
try:
refresh_token = request.data["refresh_token"]
token = RefreshToken(refresh_token)
token.blacklist()
return Response(status=status.HTTP_205_RESET_CONTENT)
except Exception as e:
return Response(status=status.HTTP_400_BAD_REQUEST)
@never_cache
def is_authenticated(request):
if request.user.is_authenticated:
return JsonResponse({}, status=status.HTTP_200_OK)
return JsonResponse({}, status=status.HTTP_401_UNAUTHORIZED)
class ConversationsView(APIView):
def get(self, request, format="json"):
order = "created" if request.user.conversation_order else "-created"
conversations = Conversation.objects.filter(
user=request.user, deleted=False
).order_by(order)
serializer = ConversationSerializer(conversations, many=True)
return Response(serializer.data, status=status.HTTP_200_OK)
def post(self, request, format="json"):
"""
Create a blank conversation and return the title and id number
"""
title = request.data.get("name")
conversation = Conversation.objects.create(title=title)
conversation.save()
conversation.user_id = request.user.id
conversation.save()
# TODO: when we are smart enough to create a conversation when a prompt is sent
# client = LlamaClient()
# inital_message = request.data.get("prompt")
# title = client.generate_conversation_title(inital_message)
# title = title if title else "New Conversation"
# conversation = Conversation.objects.create(title=title)
# conversation.save()
# conversation.user_id = request.user.id
# conversation.save()
return Response(
{"title": title, "id": conversation.id}, status=status.HTTP_201_CREATED
)
class ConversationPreferences(APIView):
def get(self, request, format="json"):
user = request.user
return Response({"order": user.conversation_order}, status=status.HTTP_200_OK)
def post(self, request, format="json"):
user = request.user
user.conversation_order = not user.conversation_order
user.save()
return Response({"order": user.conversation_order}, status=status.HTTP_200_OK)
class ConversationDetailView(APIView):
def get(self, request, format="json"):
conversation_id = request.query_params.get("conversation_id")
prompts = Prompt.objects.filter(conversation__id=conversation_id)
serailzer = PromptSerializer(prompts, many=True)
return Response(serailzer.data, status=status.HTTP_200_OK)
def post(self, request, format="json"):
print("In the post")
# Add the prompt to the database
# make sure there is a conversation for it
# if there is not a conversation create a title for it
# make sure that our model exists and it is running
prompt = request.data.get("prompt")
conversation_id = request.data.get("conversation_id")
is_user = bool(request.data.get("is_user"))
try:
conversation = Conversation.objects.get(id=conversation_id)
# add the prompt to the conversation
serializer = PromptSerializer(
data={
"message": prompt,
"user_created": is_user,
"created": datetime.now(),
}
)
if serializer.is_valid():
prompt_instance = serializer.save()
prompt_instance.conversation_id = conversation.id
prompt_instance = serializer.save()
# set up the streaming response if it is from the user
print(f"Do we have a valid user? {is_user}")
if is_user:
messages = []
for prompt_obj in Prompt.objects.filter(
conversation__id=conversation_id
):
messages.append(
{
"content": prompt_obj.message,
"role": "user" if prompt_obj.user_created else "assistant",
}
)
channel_layer = get_channel_layer()
print(f"Sending to the channel: {CHANNEL_NAME}")
async_to_sync(channel_layer.group_send)(
CHANNEL_NAME, {"type": "receive", "content": messages}
)
except:
print(
f"Error trying to submit to conversation_id: {conversation_id} with request.data: {request.data}"
)
pass
return Response(status=status.HTTP_200_OK)
def delete(self, request, format="json"):
conversation_id = request.data.get("conversation_id")
conversation = Conversation.objects.get(id=conversation_id, user=request.user)
conversation.deleted = True
conversation.save()
return Response(status=status.HTTP_202_ACCEPTED)
class UserPromptAnalytics(APIView):
def get(self, request, format="json"):
now = timezone.now()
result = []
number_of_months = 3
company_user_ids = CustomUser.objects.filter(
company=request.user.company
).values_list("id", flat=True)
for i in range(number_of_months):
next_year = now.year
next_month = now.month - i
while next_month < 1:
next_year -= 1
next_month += 12
start_date = datetime.datetime(next_year, next_month, 1)
end_date = last_day_of_month(start_date)
total_conversations = Conversation.objects.filter(
created__gte=start_date, created__lte=end_date
)
total_prompts = Prompt.objects.filter(
conversation__id__in=total_conversations,
created__gte=start_date,
created__lte=end_date,
)
total_users = len(CustomUser.objects.all())
my_conversations = Conversation.objects.filter(user=request.user)
my_prompts = Prompt.objects.filter(
conversation__in=my_conversations,
created__gte=start_date,
created__lte=end_date,
)
company_conversations = Conversation.objects.filter(
user__id__in=company_user_ids
)
company_prompts = Prompt.objects.filter(
conversation__in=company_conversations,
created__gte=start_date,
created__lte=end_date,
)
result.append(
{
"month": start_date.strftime("%B"),
"you": len(my_prompts),
"others": len(company_prompts) / len(company_user_ids),
"all": len(total_prompts) / total_users,
}
)
return Response(result[::-1], status=status.HTTP_200_OK)
class UserConversationAnalytics(APIView):
def get(self, request, format="json"):
now = timezone.now()
result = []
number_of_months = 3
company_user_ids = CustomUser.objects.filter(
company=request.user.company
).values_list("id", flat=True)
for i in range(number_of_months):
next_year = now.year
next_month = now.month - i
while next_month < 1:
next_year -= 1
next_month += 12
start_date = datetime.datetime(next_year, next_month, 1)
end_date = last_day_of_month(start_date)
total_conversations = len(
Conversation.objects.filter(
created__gte=start_date, created__lte=end_date
)
)
total_users = len(CustomUser.objects.all())
company_conversations = len(
Conversation.objects.filter(
user__id__in=company_user_ids,
created__gte=start_date,
created__lte=end_date,
)
)
result.append(
{
"month": start_date.strftime("%B"),
"you": len(
Conversation.objects.filter(
user=request.user,
created__gte=start_date,
created__lte=end_date,
)
),
"others": company_conversations / len(company_user_ids),
"all": total_conversations / total_users,
}
)
return Response(result[::-1], status=status.HTTP_200_OK)
class CompanyUsageAnalytics(APIView):
def get(self, request, format="json"):
now = timezone.now()
result = []
number_of_months = 3
company_user_ids = CustomUser.objects.filter(
company=request.user.company
).values_list("id", flat=True)
for i in range(number_of_months):
next_year = now.year
next_month = now.month - i
while next_month < 1:
next_year -= 1
next_month += 12
start_date = datetime.datetime(next_year, next_month, 1)
end_date = last_day_of_month(start_date)
conversations = Conversation.objects.filter(
user__id__in=company_user_ids,
created__gte=start_date,
created__lte=end_date,
)
conversation_user_ids = conversations.values_list(
"user__id", flat=True
).distinct()
result.append(
{
"month": start_date.strftime("%B"),
"used": len(conversation_user_ids),
"not_used": len(company_user_ids) - len(conversation_user_ids),
}
)
return Response(result[::-1], status=status.HTTP_200_OK)
class AdminAnalytics(APIView):
def get(self, request, format="json"):
number_of_months = 3
result = []
now = timezone.now()
for i in range(number_of_months):
next_year = now.year
next_month = now.month - i
while next_month < 1:
next_year -= 1
next_month += 12
start_date = datetime.datetime(next_year, next_month, 1)
end_date = last_day_of_month(start_date)
durations = [
item.get_duration()
for item in PromptMetric.objects.filter(
created__gte=start_date, created__lte=end_date
)
]
if len(durations) == 0:
result.append(
{
"month": start_date.strftime("%B"),
"range": [0, 0],
"avg": 0,
"median": 0,
}
)
continue
average = sum(durations) / len(durations)
min_value = min(durations)
max_value = max(durations)
durations.sort()
median = durations[len(durations) // 2]
result.append(
{
"month": start_date.strftime("%B"),
"range": [min_value, max_value],
"avg": average,
"median": median,
}
)
return Response(result[::-1], status=status.HTTP_200_OK)
prompt = ChatPromptTemplate.from_messages(
[("system", "You are a helpful assistant."), ("user", "{input}")]
)
llm = OllamaLLM(model=MODEL_NAME)
# output_parser = StrOutputParser()
# # Chain
# chain = prompt | llm.with_config({"run_name": "model"}) | output_parser.with_config({"run_name": "Assistant"})
@database_sync_to_async
def create_conversation(prompt, email, title):
# return the conversation id
conversation = Conversation.objects.create(title=title)
conversation.save()
user = CustomUser.objects.get(email=email)
conversation.user_id = user.id
conversation.save()
return conversation.id
@database_sync_to_async
def get_workspace(conversation_id):
conversation = Conversation.objects.get(id=conversation_id)
return DocumentWorkspace.objects.get(company=conversation.user.company)
@database_sync_to_async
def get_messages(conversation_id, prompt, file_string: str = None, file_type: str = ""):
messages = []
conversation = Conversation.objects.get(id=conversation_id)
print(file_string)
# add the prompt to the conversation
serializer = PromptSerializer(
data={
"message": prompt,
"user_created": True,
"created": timezone.now(),
}
)
if serializer.is_valid(raise_exception=True):
prompt_instance = serializer.save()
prompt_instance.conversation_id = conversation.id
prompt_instance.save()
if file_string:
file_name = f"prompt_{prompt_instance.id}_data.{file_type}"
f = ContentFile(file_string, name=file_name)
prompt_instance.file.save(file_name, f)
prompt_instance.file_type = file_type
prompt_instance.save()
for prompt_obj in Prompt.objects.filter(conversation__id=conversation_id):
messages.append(
{
"content": prompt_obj.message,
"role": "user" if prompt_obj.user_created else "assistant",
"has_file": prompt_obj.file_exists(),
"file": prompt_obj.file if prompt_obj.file_exists() else None,
"file_type": prompt_obj.file_type if prompt_obj.file_exists() else None,
}
)
# now transform the messages
transformed_messages = []
for message in messages:
if message["has_file"] and message["file_type"] != None:
if "csv" in message["file_type"]:
file_type = "csv"
altered_message = f"{message['content']}\n The file type is csv and the file contents are: {message['file'].read()}"
elif "xlsx" in message["file_type"]:
file_type = "xlsx"
df = pd.read_excel(message["file"].read())
altered_message = f"{message['content']}\n The file type is xlsx and the file contents are: {df}"
elif "txt" in message["file_type"]:
file_type = "txt"
altered_message = f"{message['content']}\n The file type is csv and the file contents are: {message['file'].read()}"
else:
altered_message = message["content"]
transformed_message = (
SystemMessage(content=altered_message)
if message["role"] == "assistant"
else HumanMessage(content=altered_message)
)
transformed_messages.append(transformed_message)
return transformed_messages, prompt_instance
@database_sync_to_async
def save_generated_message(conversation_id, message):
conversation = Conversation.objects.get(id=conversation_id)
# add the prompt to the conversation
serializer = PromptSerializer(
data={
"message": message,
"user_created": False,
"created": timezone.now(),
}
)
if serializer.is_valid():
prompt_instance = serializer.save()
prompt_instance.conversation_id = conversation.id
prompt_instance = serializer.save()
@database_sync_to_async
def create_prompt_metric(
prompt_id, prompt, has_file, file_type, model_name, conversation_id
):
prompt_metric = PromptMetric.objects.create(
prompt_id=prompt_id,
start_time=timezone.now(),
prompt_length=len(prompt),
has_file=has_file,
file_type=file_type,
model_name=model_name,
conversation_id=conversation_id,
)
prompt_metric.save()
return prompt_metric
@database_sync_to_async
def update_prompt_metric(prompt_metric, status):
prompt_metric.event = status
prompt_metric.save()
@database_sync_to_async
def finish_prompt_metric(prompt_metric, response_length):
print(f"finish_prompt_metric: {response_length}")
prompt_metric.end_time = timezone.now()
prompt_metric.reponse_length = response_length
prompt_metric.event = "FINISHED"
prompt_metric.save(update_fields=["end_time", "reponse_length", "event"])
print("finish_prompt_metric saved")
@database_sync_to_async
def get_retriever(conversation_id):
print(f'getting workspace from conversation: {conversation_id}')
conversation = Conversation.objects.get(id=conversation_id)
print(f'Got conversation: {conversation}')
workspace = DocumentWorkspace.objects.get(company=conversation.user.company)
print(f'Got workspace: {conversation}')
vectorstore = Chroma(
persist_directory=f"./chroma_db/",
embedding=OllamaEmbeddings(model="llama3.2"),
)
return vectorstore.as_retriever()
class ChatConsumerAgain(AsyncWebsocketConsumer):
async def connect(self):
await self.accept()
async def disconnect(self, close_code):
await self.close()
async def receive(self, text_data=None, bytes_data=None):
print(f"Text Data: {text_data}")
print(f"Bytes Data: {bytes_data}")
if text_data:
data = json.loads(text_data)
message = data.get("message", None)
conversation_id = data.get("conversation_id", None)
email = data.get("email", None)
file = data.get("file", None)
file_type = data.get("fileType", "")
model = data.get("modelName", "Turbo")
if not conversation_id:
# we need to create a new conversation
# we will generate a name for it too
title = await title_generator.generate_async(message)
conversation_id = await create_conversation(message, email, title)
if conversation_id:
decoded_file = None
if file:
decoded_file = base64.b64decode(file)
print(decoded_file)
if "csv" in file_type:
file_type = "csv"
altered_message = f"{message}\n The file type is csv and the file contents are: {decoded_file}"
elif "xmlformats-officedocument" in file_type:
file_type = "xlsx"
df = pd.read_excel(decoded_file)
altered_message = f"{message}\n The file type is xlsx and the file contents are: {df}"
elif "text" in file_type:
file_type = "txt"
altered_message = f"{message}\n The file type is txt and the file contents are: {decoded_file}"
else:
file_type = "Not Sure"
print(f'received: "{message}" for conversation {conversation_id}')
# check the moderation here
if await moderation_classifier.classify_async(message) == ModerationLabel.NSFW:
response = "Prompt has been marked as NSFW. If this is in error, submit a feedback with the prompt text."
print("this prompt has been marked as NSFW")
await self.send("CONVERSATION_ID")
await self.send(str(conversation_id))
await self.send("START_OF_THE_STREAM_ENDER_GAME_42")
await self.send(response)
await self.send("END_OF_THE_STREAM_ENDER_GAME_42")
await save_generated_message(conversation_id, response)
return
# TODO: add the message to the database
# get the new conversation
# TODO: get the messages here
messages, prompt = await get_messages(
conversation_id, message, decoded_file, file_type
)
prompt_type = await prompt_classifier.classify_async(message)
print(f"prompt_type: {prompt_type} for {message}")
prompt_metric = await create_prompt_metric(
prompt.id,
prompt.message,
True if file else False,
file_type,
MODEL_NAME,
conversation_id,
)
if file:
# udpate with the altered_message
messages = messages[:-1] + [HumanMessage(content=altered_message)]
print(messages)
# send it to the LLM
# stream the response back
response = ""
# start of the message
await self.send("CONVERSATION_ID")
await self.send(str(conversation_id))
await self.send("START_OF_THE_STREAM_ENDER_GAME_42")
if prompt_type == PromptType.RAG:
service = AsyncRAGService()
#await service.ingest_documents()
workspace = await get_workspace(conversation_id)
print('Time to get the rag response')
async for chunk in service.generate_response(messages, prompt.message, workspace):
print(f"chunk: {chunk}")
response += chunk
await self.send(chunk)
elif prompt_type == PromptType.IMAGE_GENERATION:
response = "Image Generation is not supported at this time, but it will be soon."
await self.send(response)
else:
service = AsyncLLMService()
async for chunk in service.generate_response(messages, prompt.message):
print(f"chunk: {chunk}")
response += chunk
await self.send(chunk)
await self.send("END_OF_THE_STREAM_ENDER_GAME_42")
await save_generated_message(conversation_id, response)
await finish_prompt_metric(prompt_metric, len(response))
if bytes_data:
print("we have byte data")
# Document Views
class DocumentWorkspaceView(APIView):
#permission_classes = [permissions.IsAuthenticated]
def get(self, request):
workspaces = DocumentWorkspace.objects.filter(company=request.user.company)
serializer = DocumentWorkspaceSerializer(workspaces, many=True)
return Response(serializer.data)
def post(self, request):
serializer = DocumentWorkspaceSerializer(data=request.data)
if serializer.is_valid():
serializer.save(company=request.user.company)
return Response(serializer.data, status=status.HTTP_201_CREATED)
return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST)
class DocumentUploadView(APIView):
#permission_classes = [permissions.IsAuthenticated]Z
def get(self, request):
print(f'request_3: {request}')
try:
workspace = DocumentWorkspace.objects.get(company=request.user.company)
serializer = DocumentSerializer(Document.objects.filter(workspace=workspace), many=True)
return Response(serializer.data, status=status.HTTP_200_OK)
except:
return Response({'error': "Workspace not found"}, status=status.HTTP_404_NOT_FOUND)
def post(self, request):
print(f'request: {request}')
try:
workspace = DocumentWorkspace.objects.get(company=request.user.company)
except:
return Response({'error': "Workspace not found"}, status=status.HTTP_404_NOT_FOUND)
print(request.FILES)
file = request.FILES.get('file')
if not file:
return Response({"error":"No file provided"}, status=status.HTTP_400_BAD_REQUEST)
print("have the workspace and the file")
document = Document.objects.create(
workspace=workspace,
file=file
)
# process the document inthe background
self.process_document(document)
serializer = DocumentSerializer(document)
return Response(serializer.data, status=status.HTTP_201_CREATED)
def process_document(self, document):
file_path = os.path.join(settings.MEDIA_ROOT, document.file.name)
document.processed = True
document.active = True
document.save()
service = AsyncRAGService()
service.add_files_to_store([(file_path, document.file.name, document.workspace_id)], workspace_id=document.workspace_id)
class DocumentDetailView(APIView):
#permission_classes = [permissions.IsAuthenticated]
def get(self, request, document_id):
print(f'request: {request}')
try:
workspace = DocumentWorkspace.objects.get(company=request.user.company)
document = Document.objects.get(
workspace=workspace,
id=document_id
)
except:
return Response({'error': "Document not found"}, status=status.HTTP_404_NOT_FOUND)
serializer = DocumentWorkspaceSerializer(workspaces, many=True)
return Response(serializer.data)