Compare commits
9 Commits
57695353d0
...
master
| Author | SHA1 | Date | |
|---|---|---|---|
| 77d7edd0dc | |||
| eed1abedc8 | |||
| 91bdb2fd2d | |||
| 8a259158c8 | |||
| 14d8211715 | |||
| 951a58f2fa | |||
| a85f1222eb | |||
| d8a912e2c3 | |||
| f5d29166a6 |
3
.gitignore
vendored
3
.gitignore
vendored
@@ -167,4 +167,5 @@ cython_debug/
|
||||
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
||||
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
||||
#.idea/
|
||||
|
||||
chroma_db/
|
||||
documents/
|
||||
|
||||
@@ -1,5 +1,16 @@
|
||||
from django.contrib import admin
|
||||
from .models import CustomUser, Announcement, Company, LLMModels, Conversation, Prompt, Feedback, PromptMetric
|
||||
from .models import (
|
||||
CustomUser,
|
||||
Announcement,
|
||||
Company,
|
||||
LLMModels,
|
||||
Conversation,
|
||||
Prompt,
|
||||
Feedback,
|
||||
PromptMetric,
|
||||
DocumentWorkspace,
|
||||
Document,
|
||||
)
|
||||
|
||||
# Register your models here.
|
||||
|
||||
@@ -27,16 +38,16 @@ class CustomUserAdmin(admin.ModelAdmin):
|
||||
"has_signed_tos",
|
||||
"last_login",
|
||||
"slug",
|
||||
"get_set_password_url"
|
||||
"get_set_password_url",
|
||||
)
|
||||
search_fields = ("fields", "username", "first_name", "last_name", "slug")
|
||||
|
||||
|
||||
class FeedbackAdmin(admin.ModelAdmin):
|
||||
model = Feedback
|
||||
search_fields = ("status", "text", "get_user_email")
|
||||
list_display= (
|
||||
"status", "get_user_email", "title", "category"
|
||||
)
|
||||
list_display = ("status", "get_user_email", "title", "category")
|
||||
|
||||
|
||||
class LLMModelsAdmin(admin.ModelAdmin):
|
||||
model = LLMModels
|
||||
@@ -44,20 +55,53 @@ class LLMModelsAdmin(admin.ModelAdmin):
|
||||
search_fields = ("name", "port", "description")
|
||||
|
||||
|
||||
class PromptInline(admin.TabularInline):
|
||||
model = Prompt
|
||||
|
||||
class ConversationAdmin(admin.ModelAdmin):
|
||||
model = Conversation
|
||||
list_display = ("title", "get_user_email","deleted")
|
||||
list_display = ("title", "get_user_email", "deleted")
|
||||
search_fields = ("title",)
|
||||
inlines = [PromptInline,]
|
||||
|
||||
|
||||
class PromptAdmin(admin.ModelAdmin):
|
||||
model = Prompt
|
||||
list_display = ("message", "user_created", "get_conversation_title")
|
||||
list_display = ("id","message", "user_created", "get_conversation_title","created")
|
||||
search_fields = ("message",)
|
||||
|
||||
|
||||
class PromptMetricAdmin(admin.ModelAdmin):
|
||||
model = PromptMetric
|
||||
list_display = ("event", "model_name", "prompt_length","reponse_length",'has_file','file_type', "get_duration")
|
||||
list_display = (
|
||||
"id",
|
||||
"event",
|
||||
"model_name",
|
||||
"prompt_length",
|
||||
"reponse_length",
|
||||
"has_file",
|
||||
"file_type",
|
||||
"get_duration",
|
||||
"created"
|
||||
)
|
||||
|
||||
|
||||
class DocumentWorkspaceAdmin(admin.ModelAdmin):
|
||||
model = DocumentWorkspace
|
||||
list_display = (
|
||||
"name",
|
||||
"company",
|
||||
)
|
||||
|
||||
|
||||
class DocumentAdmin(admin.ModelAdmin):
|
||||
model = Document
|
||||
list_display = (
|
||||
"file",
|
||||
"active",
|
||||
"created",
|
||||
"processed",
|
||||
)
|
||||
|
||||
|
||||
admin.site.register(Announcement, AnnouncmentAdmin)
|
||||
@@ -69,3 +113,6 @@ admin.site.register(Conversation, ConversationAdmin)
|
||||
admin.site.register(Prompt, PromptAdmin)
|
||||
admin.site.register(PromptMetric, PromptMetricAdmin)
|
||||
admin.site.register(Feedback, FeedbackAdmin)
|
||||
|
||||
admin.site.register(DocumentWorkspace, DocumentWorkspaceAdmin)
|
||||
admin.site.register(Document, DocumentAdmin)
|
||||
|
||||
@@ -1,6 +1,32 @@
|
||||
from django.apps import AppConfig
|
||||
from django.conf import settings
|
||||
from django.db import OperationalError
|
||||
|
||||
|
||||
class ChatBackendConfig(AppConfig):
|
||||
default_auto_field = "django.db.models.BigAutoField"
|
||||
name = "chat_backend"
|
||||
|
||||
def ready(self):
|
||||
import chat_backend.signals
|
||||
|
||||
FORCE_RELOAD = False
|
||||
|
||||
if True: # not settings.TESTING: # Don't run during tests
|
||||
try:
|
||||
from .services.rag_services import AsyncRAGService
|
||||
from chat_backend.models import Document
|
||||
|
||||
# Check if Chroma needs initialization
|
||||
if Document.objects.exists():
|
||||
rag_service = AsyncRAGService()
|
||||
|
||||
if rag_service.vector_store._collection.count() == 0:
|
||||
print("Initializing ChromaDB with existing documents...")
|
||||
rag_service.ingest_documents()
|
||||
if FORCE_RELOAD:
|
||||
print("Force Reload ChromaDB with existing documents...")
|
||||
rag_service.clear_vector_store()
|
||||
except OperationalError:
|
||||
# Database tables might not exist yet during migration
|
||||
pass
|
||||
|
||||
@@ -1,30 +1,32 @@
|
||||
"""
|
||||
llama client - Abstract this in the future
|
||||
"""
|
||||
|
||||
import ollama
|
||||
from typing import List, Dict
|
||||
|
||||
|
||||
class LlamaClient(object):
|
||||
def __init__(self, model: str='llama3'):
|
||||
def __init__(self, model: str = "llama3"):
|
||||
self.client = ollama.Client(host="http://127.0.0.1:11434")
|
||||
self.model = model
|
||||
|
||||
def check_if_model_exists(self) -> bool:
|
||||
raise NotImplementedError
|
||||
|
||||
def generate_conversation_title(self, message:str):
|
||||
response = self.generate_single_message("Summarise the phrase in one to for words\"%s\"" % message)
|
||||
|
||||
raw_response = response['response'].replace("\"","")
|
||||
def generate_conversation_title(self, message: str):
|
||||
response = self.generate_single_message(
|
||||
'Summarise the phrase in one to for words"%s"' % message
|
||||
)
|
||||
|
||||
raw_response = response["response"].replace('"', "")
|
||||
return " ".join(raw_response.split()[:4])
|
||||
|
||||
def generate_single_message(self, message: str):
|
||||
return ollama.generate(model=self.model, prompt=message)
|
||||
|
||||
def get_chat_response(self, messages: List[str]):
|
||||
return self.client.chat(model = self.model, messages=messages, stream=False)
|
||||
|
||||
|
||||
return self.client.chat(model=self.model, messages=messages, stream=False)
|
||||
|
||||
def get_streamed_chat_response(self, messages: List[str]):
|
||||
return self.client.chat(model = self.model, messages=messages, stream=True)
|
||||
|
||||
return self.client.chat(model=self.model, messages=messages, stream=True)
|
||||
|
||||
422
llm_be/chat_backend/consumers.py
Normal file
422
llm_be/chat_backend/consumers.py
Normal file
@@ -0,0 +1,422 @@
|
||||
import json
|
||||
import base64
|
||||
import logging
|
||||
import pandas as pd
|
||||
from datetime import datetime
|
||||
from django.utils import timezone
|
||||
from django.conf import settings
|
||||
from django.core.files.base import ContentFile
|
||||
from channels.generic.websocket import AsyncWebsocketConsumer
|
||||
from channels.db import database_sync_to_async
|
||||
from channels.layers import get_channel_layer
|
||||
from asgiref.sync import sync_to_async, async_to_sync
|
||||
from langchain_core.messages import HumanMessage, AIMessage
|
||||
from langchain_community.vectorstores import Chroma
|
||||
from langchain_community.embeddings import OllamaEmbeddings
|
||||
from langchain_community.tools import DuckDuckGoSearchRun
|
||||
from langchain_core.runnables import RunnableLambda, RunnableBranch, RunnablePassthrough
|
||||
from langchain_core.tracers.context import collect_runs
|
||||
|
||||
from .models import Conversation, Prompt, PromptMetric, DocumentWorkspace, Document, CustomUser
|
||||
from .serializers import PromptSerializer
|
||||
from .services.llm_service import AsyncLLMService
|
||||
from .services.rag_services import AsyncRAGService
|
||||
from .services.title_generator import title_generator
|
||||
from .services.moderation_classifier import moderation_classifier, ModerationLabel
|
||||
from .services.prompt_classifier.prompt_classifier import PromptClassifier, PromptType
|
||||
from .services.data_analysis_service import AsyncDataAnalysisService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
CHANNEL_NAME: str = "llm_messages"
|
||||
MODEL_NAME: str = "llama3.2"
|
||||
PROMPT_CLASSIFIER = PromptClassifier()
|
||||
|
||||
@database_sync_to_async
|
||||
def create_conversation(prompt, email, title):
|
||||
# return the conversation id
|
||||
conversation = Conversation.objects.create(title=title)
|
||||
conversation.save()
|
||||
|
||||
user = CustomUser.objects.get(email=email)
|
||||
conversation.user_id = user.id
|
||||
conversation.save()
|
||||
return conversation.id
|
||||
|
||||
|
||||
@database_sync_to_async
|
||||
def get_workspace(conversation_id):
|
||||
conversation = Conversation.objects.get(id=conversation_id)
|
||||
return DocumentWorkspace.objects.get(company=conversation.user.company)
|
||||
|
||||
|
||||
@database_sync_to_async
|
||||
def get_messages(conversation_id, prompt, file_string: str = None, file_type: str = ""):
|
||||
messages = []
|
||||
|
||||
conversation = Conversation.objects.get(id=conversation_id)
|
||||
logger.debug(file_string)
|
||||
|
||||
# add the prompt to the conversation
|
||||
serializer = PromptSerializer(
|
||||
data={
|
||||
"message": prompt,
|
||||
"user_created": True,
|
||||
"created": timezone.now(),
|
||||
}
|
||||
)
|
||||
if serializer.is_valid(raise_exception=True):
|
||||
prompt_instance = serializer.save()
|
||||
prompt_instance.conversation_id = conversation.id
|
||||
prompt_instance.save()
|
||||
if file_string:
|
||||
file_name = f"prompt_{prompt_instance.id}_data.{file_type}"
|
||||
f = ContentFile(file_string, name=file_name)
|
||||
prompt_instance.file.save(file_name, f)
|
||||
prompt_instance.file_type = file_type
|
||||
prompt_instance.save()
|
||||
|
||||
for prompt_obj in Prompt.objects.filter(conversation__id=conversation_id):
|
||||
messages.append(
|
||||
{
|
||||
"content": prompt_obj.message,
|
||||
"role": "user" if prompt_obj.user_created else "assistant",
|
||||
"has_file": prompt_obj.file_exists(),
|
||||
"file": prompt_obj.file if prompt_obj.file_exists() else None,
|
||||
"file_type": prompt_obj.file_type if prompt_obj.file_exists() else None,
|
||||
}
|
||||
)
|
||||
|
||||
# now transform the messages
|
||||
transformed_messages = []
|
||||
for message in messages:
|
||||
|
||||
if message["has_file"] and message["file_type"] != None:
|
||||
if "csv" in message["file_type"]:
|
||||
file_type = "csv"
|
||||
altered_message = f"{message['content']}\n The file type is csv and the file contents are: {message['file'].read()}"
|
||||
elif "xlsx" in message["file_type"]:
|
||||
file_type = "xlsx"
|
||||
df = pd.read_excel(message["file"].read())
|
||||
altered_message = f"{message['content']}\n The file type is xlsx and the file contents are: {df}"
|
||||
elif "txt" in message["file_type"]:
|
||||
file_type = "txt"
|
||||
altered_message = f"{message['content']}\n The file type is csv and the file contents are: {message['file'].read()}"
|
||||
else:
|
||||
altered_message = message["content"]
|
||||
else:
|
||||
altered_message = message["content"]
|
||||
|
||||
transformed_message = (
|
||||
AIMessage(content=altered_message)
|
||||
if message["role"] == "assistant"
|
||||
else HumanMessage(content=altered_message)
|
||||
)
|
||||
transformed_messages.append(transformed_message)
|
||||
|
||||
return transformed_messages, prompt_instance
|
||||
|
||||
|
||||
@database_sync_to_async
|
||||
def save_generated_message(conversation_id, message):
|
||||
conversation = Conversation.objects.get(id=conversation_id)
|
||||
|
||||
# add the prompt to the conversation
|
||||
serializer = PromptSerializer(
|
||||
data={
|
||||
"message": message,
|
||||
"user_created": False,
|
||||
"created": timezone.now(),
|
||||
}
|
||||
)
|
||||
if serializer.is_valid():
|
||||
prompt_instance = serializer.save()
|
||||
prompt_instance.conversation_id = conversation.id
|
||||
prompt_instance = serializer.save()
|
||||
else:
|
||||
print(serializer.errors)
|
||||
|
||||
|
||||
@database_sync_to_async
|
||||
def create_prompt_metric(
|
||||
prompt_id, prompt, has_file, file_type, model_name, conversation_id
|
||||
):
|
||||
prompt_metric = PromptMetric.objects.create(
|
||||
prompt_id=prompt_id,
|
||||
start_time=timezone.now(),
|
||||
prompt_length=len(prompt),
|
||||
has_file=has_file,
|
||||
file_type=file_type,
|
||||
model_name=model_name,
|
||||
conversation_id=conversation_id,
|
||||
)
|
||||
prompt_metric.save()
|
||||
return prompt_metric
|
||||
|
||||
|
||||
@database_sync_to_async
|
||||
def update_prompt_metric(prompt_metric, status):
|
||||
prompt_metric.event = status
|
||||
prompt_metric.save()
|
||||
|
||||
|
||||
@database_sync_to_async
|
||||
def finish_prompt_metric(prompt_metric, response_length):
|
||||
logger.info(f"finish_prompt_metric: {response_length}")
|
||||
prompt_metric.end_time = timezone.now()
|
||||
prompt_metric.reponse_length = response_length
|
||||
prompt_metric.event = "FINISHED"
|
||||
prompt_metric.save(update_fields=["end_time", "reponse_length", "event"])
|
||||
logger.info("finish_prompt_metric saved")
|
||||
|
||||
|
||||
@database_sync_to_async
|
||||
def get_retriever(conversation_id):
|
||||
logger.info(f"getting workspace from conversation: {conversation_id}")
|
||||
conversation = Conversation.objects.get(id=conversation_id)
|
||||
logger.info(f"Got conversation: {conversation}")
|
||||
workspace = DocumentWorkspace.objects.get(company=conversation.user.company)
|
||||
logger.info(f"Got workspace: {conversation}")
|
||||
vectorstore = Chroma(
|
||||
persist_directory=f"./chroma_db/",
|
||||
embedding=OllamaEmbeddings(model="llama3.2"),
|
||||
)
|
||||
return vectorstore.as_retriever()
|
||||
|
||||
async def get_conversation_file_async(conversation_id):
|
||||
try:
|
||||
# Get the very first prompt in the conversation that has a file
|
||||
prompt_with_file = await Prompt.objects.filter(
|
||||
conversation_id=conversation_id
|
||||
).exclude(file='').order_by('created').afirst()
|
||||
|
||||
if prompt_with_file and prompt_with_file.file:
|
||||
# You must use sync_to_async to access the file's binary content
|
||||
file_data = await sync_to_async(prompt_with_file.file.read)()
|
||||
file_type = prompt_with_file.file_type
|
||||
return file_data, file_type
|
||||
except Exception as e:
|
||||
logger.error(f"Error retrieving file from conversation history: {e}")
|
||||
return None, None
|
||||
|
||||
class ChatConsumerAgain(AsyncWebsocketConsumer):
|
||||
async def connect(self):
|
||||
await self.accept()
|
||||
|
||||
async def disconnect(self, close_code):
|
||||
await self.close()
|
||||
|
||||
async def send_json_message(self, data_str):
|
||||
"""
|
||||
Ensures that the message sent over the websocket is a valid JSON object.
|
||||
If data_str is a plain string, it wraps it in {"type": "text", "content": ...}.
|
||||
"""
|
||||
try:
|
||||
# Test if it's already a valid JSON object string
|
||||
json.loads(data_str)
|
||||
# If it is, send it as is
|
||||
await self.send(data_str)
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
# If it's a plain string or not JSON-decodable, wrap it
|
||||
await self.send(data_str)
|
||||
|
||||
async def receive(self, text_data=None, bytes_data=None):
|
||||
logger.debug(f"Text Data: {text_data}")
|
||||
logger.debug(f"Bytes Data: {bytes_data}")
|
||||
if text_data:
|
||||
data = json.loads(text_data)
|
||||
message = data.get("message", None)
|
||||
conversation_id = data.get("conversation_id", None)
|
||||
email = data.get("email", None)
|
||||
file = data.get("file", None)
|
||||
file_type = data.get("fileType", "")
|
||||
model = data.get("modelName", "Turbo")
|
||||
|
||||
if not conversation_id:
|
||||
# we need to create a new conversation
|
||||
# we will generate a name for it too
|
||||
title = await title_generator.generate_async(message)
|
||||
conversation_id = await create_conversation(message, email, title)
|
||||
|
||||
if conversation_id:
|
||||
decoded_file = None
|
||||
|
||||
if file:
|
||||
decoded_file = base64.b64decode(file)
|
||||
logger.debug(decoded_file)
|
||||
# The `altered_message` should only be created if a file exists
|
||||
# and you want to pass its content directly to the classifier.
|
||||
# Here, we'll let the classifier decide based on the user's prompt
|
||||
# and then handle the file content separately.
|
||||
altered_message = message
|
||||
if "csv" in file_type:
|
||||
file_type = "csv"
|
||||
#altered_message = f"{message}\n The file type is csv and the file contents are: {decoded_file}"
|
||||
elif "xmlformats-officedocument" in file_type:
|
||||
file_type = "xlsx"
|
||||
#df = pd.read_excel(decoded_file)
|
||||
#altered_message = f"{message}\n The file type is xlsx and the file contents are: {df}"
|
||||
elif "word" in file_type:
|
||||
file_type = "docx"
|
||||
elif "pdf" in file_type:
|
||||
file_type = "pdf"
|
||||
elif "text" in file_type:
|
||||
file_type = "txt"
|
||||
#altered_message = f"{message}\n The file type is txt and the file contents are: {decoded_file}"
|
||||
else:
|
||||
file_type = "Not Sure"
|
||||
|
||||
logger.info(f'received: "{message}" for conversation {conversation_id}')
|
||||
|
||||
# --- LangSmith Pipeline Construction ---
|
||||
|
||||
async def check_moderation(input_dict):
|
||||
msg = input_dict["message"]
|
||||
label = await moderation_classifier.classify_async(msg)
|
||||
return {**input_dict, "moderation_label": label}
|
||||
|
||||
async def classify_prompt_step(input_dict):
|
||||
if input_dict["moderation_label"] == ModerationLabel.NSFW:
|
||||
return {**input_dict, "prompt_type": None} # Skip classification
|
||||
|
||||
msg = input_dict["message"]
|
||||
decoded_file = input_dict.get("decoded_file")
|
||||
|
||||
prompt_type = await PROMPT_CLASSIFIER.classify_async(msg)
|
||||
|
||||
# Override logic
|
||||
if decoded_file and (prompt_type == PromptType.DATA_ANALYSIS or 'analyze' in msg.lower() or 'data' in msg.lower()):
|
||||
prompt_type = PromptType.DATA_ANALYSIS
|
||||
elif decoded_file:
|
||||
prompt_type = PromptType.GENERAL_CHAT
|
||||
|
||||
return {**input_dict, "prompt_type": prompt_type}
|
||||
|
||||
async def generate_response_step(input_dict):
|
||||
if input_dict["moderation_label"] == ModerationLabel.NSFW:
|
||||
response = "Prompt has been marked as NSFW. If this is in error, submit a feedback with the prompt text."
|
||||
return {"type": "error", "content": response}
|
||||
|
||||
prompt_type = input_dict["prompt_type"]
|
||||
messages = input_dict["messages"]
|
||||
prompt_instance = input_dict["prompt_instance"]
|
||||
conversation_id = input_dict["conversation_id"]
|
||||
decoded_file = input_dict.get("decoded_file")
|
||||
file_type = input_dict.get("file_type")
|
||||
|
||||
# Feature Flag: Image Generation
|
||||
if prompt_type == PromptType.IMAGE_GENERATION:
|
||||
if not getattr(settings, "ALLOW_IMAGE_GENERATION", False):
|
||||
return {"type": "text", "content": "Image Generation is disabled."}
|
||||
# If enabled, proceed (assuming implementation exists, but user said "have it set to false for now")
|
||||
return {"type": "text", "content": "Image Generation is not supported at this time, but it will be soon."}
|
||||
|
||||
if prompt_type == PromptType.SEARCH:
|
||||
# Check modelName first - if FAST, we skip search regardless of settings
|
||||
if input_dict.get("model_name") == "FAST":
|
||||
pass # Skip search
|
||||
elif getattr(settings, "ALLOW_INTERNET_ACCESS", False):
|
||||
try:
|
||||
search = DuckDuckGoSearchRun()
|
||||
search_results = search.run(input_dict["message"])
|
||||
messages.append(HumanMessage(content=f"Search Results: {search_results}"))
|
||||
except Exception as e:
|
||||
logger.error(f"Search failed: {e}")
|
||||
# If search fails, we proceed without it, essentially falling back to general chat
|
||||
pass
|
||||
else:
|
||||
# If search is disabled, we could notify the user, but for now we'll just proceed
|
||||
# potentially adding a system message or just letting the LLM handle it with its training data
|
||||
pass
|
||||
|
||||
if prompt_type == PromptType.RAG:
|
||||
service = AsyncRAGService()
|
||||
workspace = await get_workspace(conversation_id)
|
||||
return service.generate_response(messages, prompt_instance.message, workspace)
|
||||
|
||||
elif prompt_type == PromptType.DATA_ANALYSIS:
|
||||
service = AsyncDataAnalysisService()
|
||||
print(file_type)
|
||||
if not decoded_file:
|
||||
return {"type": "text", "content": "Please upload a file to perform data analysis."}
|
||||
return service.generate_response(prompt_instance.message, decoded_file, file_type)
|
||||
|
||||
else: # GENERAL_CHAT or others
|
||||
service = AsyncLLMService()
|
||||
return service.generate_response(messages, prompt_instance.message, conversation_id)
|
||||
|
||||
# --- Execution ---
|
||||
|
||||
# Pre-fetch messages and file
|
||||
messages, prompt_instance = await get_messages(
|
||||
conversation_id, message, decoded_file, file_type
|
||||
)
|
||||
if not decoded_file:
|
||||
decoded_file, file_type = await get_conversation_file_async(conversation_id)
|
||||
|
||||
if file:
|
||||
# udpate with the altered_message (logic from original)
|
||||
# Note: altered_message was defined in original but not fully used in the messages list construction in the same way
|
||||
# In original: messages = messages[:-1] + [HumanMessage(content=altered_message)]
|
||||
# I need to replicate that if I want exact behavior.
|
||||
# But altered_message was only set if file was present.
|
||||
pass # Logic is already in get_messages for the most part, but the original code had a specific override at the end.
|
||||
# Let's trust get_messages for now or add the override if needed.
|
||||
# Original:
|
||||
# if file:
|
||||
# messages = messages[:-1] + [HumanMessage(content=altered_message)]
|
||||
# I'll add it to the input_dict if needed.
|
||||
|
||||
prompt_metric = await create_prompt_metric(
|
||||
prompt_instance.id,
|
||||
prompt_instance.message,
|
||||
True if file else False,
|
||||
file_type,
|
||||
MODEL_NAME,
|
||||
conversation_id,
|
||||
)
|
||||
|
||||
pipeline_input = {
|
||||
"message": message,
|
||||
"conversation_id": conversation_id,
|
||||
"decoded_file": decoded_file,
|
||||
"file_type": file_type,
|
||||
"messages": messages,
|
||||
"prompt_instance": prompt_instance,
|
||||
"model_name": model
|
||||
}
|
||||
|
||||
# Run the pipeline steps manually to handle the async generator return type of generate_response_step
|
||||
# A pure RunnableSequence might struggle with the async generator return.
|
||||
# So I'll chain them in python but conceptually it's one pipeline.
|
||||
|
||||
step1 = await check_moderation(pipeline_input)
|
||||
step2 = await classify_prompt_step(step1)
|
||||
|
||||
# Send start markers
|
||||
await self.send("CONVERSATION_ID")
|
||||
await self.send(str(conversation_id))
|
||||
await self.send("START_OF_THE_STREAM_ENDER_GAME_42")
|
||||
|
||||
response_generator_or_dict = await generate_response_step(step2)
|
||||
|
||||
full_response = ""
|
||||
|
||||
if isinstance(response_generator_or_dict, dict):
|
||||
# It's an error or simple message
|
||||
content = response_generator_or_dict.get("content", "")
|
||||
await self.send_json_message(json.dumps(response_generator_or_dict))
|
||||
full_response = content
|
||||
else:
|
||||
# It's an async generator
|
||||
async for chunk in response_generator_or_dict:
|
||||
full_response += chunk
|
||||
await self.send_json_message(chunk)
|
||||
|
||||
await self.send("END_OF_THE_STREAM_ENDER_GAME_42")
|
||||
|
||||
await save_generated_message(conversation_id, full_response)
|
||||
await finish_prompt_metric(prompt_metric, len(full_response))
|
||||
|
||||
if bytes_data:
|
||||
logger.info("we have byte data")
|
||||
363
llm_be/chat_backend/consumers_graph.py
Normal file
363
llm_be/chat_backend/consumers_graph.py
Normal file
@@ -0,0 +1,363 @@
|
||||
import json
|
||||
import base64
|
||||
import logging
|
||||
import pandas as pd
|
||||
from datetime import datetime
|
||||
from typing import TypedDict, Annotated, List, Union, Dict, Any
|
||||
from django.utils import timezone
|
||||
from django.conf import settings
|
||||
from django.core.files.base import ContentFile
|
||||
from channels.generic.websocket import AsyncWebsocketConsumer
|
||||
from channels.db import database_sync_to_async
|
||||
from langchain_core.messages import HumanMessage, AIMessage, BaseMessage
|
||||
from langchain_community.tools import DuckDuckGoSearchRun
|
||||
from langgraph.graph import StateGraph, END
|
||||
|
||||
from .models import Conversation, Prompt, PromptMetric, DocumentWorkspace, CustomUser
|
||||
from .serializers import PromptSerializer
|
||||
from .services.llm_service import AsyncLLMService
|
||||
from .services.rag_services import AsyncRAGService
|
||||
from .services.title_generator import title_generator
|
||||
from .services.moderation_classifier import moderation_classifier, ModerationLabel
|
||||
from .services.prompt_classifier.prompt_classifier import PromptClassifier, PromptType
|
||||
from .services.data_analysis_service import AsyncDataAnalysisService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
CHANNEL_NAME: str = "llm_messages"
|
||||
MODEL_NAME: str = "llama3.2"
|
||||
PROMPT_CLASSIFIER = PromptClassifier()
|
||||
|
||||
# --- Database Helpers (Reused) ---
|
||||
|
||||
@database_sync_to_async
|
||||
def create_conversation(prompt, email, title):
|
||||
conversation = Conversation.objects.create(title=title)
|
||||
user = CustomUser.objects.get(email=email)
|
||||
conversation.user_id = user.id
|
||||
conversation.save()
|
||||
return conversation.id
|
||||
|
||||
@database_sync_to_async
|
||||
def get_workspace(conversation_id):
|
||||
conversation = Conversation.objects.get(id=conversation_id)
|
||||
return DocumentWorkspace.objects.get(company=conversation.user.company)
|
||||
|
||||
@database_sync_to_async
|
||||
def get_messages(conversation_id, prompt, file_string: str = None, file_type: str = ""):
|
||||
messages = []
|
||||
conversation = Conversation.objects.get(id=conversation_id)
|
||||
|
||||
serializer = PromptSerializer(
|
||||
data={
|
||||
"message": prompt,
|
||||
"user_created": True,
|
||||
"created": timezone.now(),
|
||||
}
|
||||
)
|
||||
if serializer.is_valid(raise_exception=True):
|
||||
prompt_instance = serializer.save()
|
||||
prompt_instance.conversation_id = conversation.id
|
||||
prompt_instance.save()
|
||||
if file_string:
|
||||
file_name = f"prompt_{prompt_instance.id}_data.{file_type}"
|
||||
f = ContentFile(file_string, name=file_name)
|
||||
prompt_instance.file.save(file_name, f)
|
||||
prompt_instance.file_type = file_type
|
||||
prompt_instance.save()
|
||||
|
||||
for prompt_obj in Prompt.objects.filter(conversation__id=conversation_id):
|
||||
messages.append(
|
||||
{
|
||||
"content": prompt_obj.message,
|
||||
"role": "user" if prompt_obj.user_created else "assistant",
|
||||
"has_file": prompt_obj.file_exists(),
|
||||
"file": prompt_obj.file if prompt_obj.file_exists() else None,
|
||||
"file_type": prompt_obj.file_type if prompt_obj.file_exists() else None,
|
||||
}
|
||||
)
|
||||
|
||||
transformed_messages = []
|
||||
for message in messages:
|
||||
if message["has_file"] and message["file_type"] != None:
|
||||
# Simplified handling compared to original, as we rely on services to handle files now
|
||||
# But we keep the structure for context
|
||||
altered_message = message["content"]
|
||||
else:
|
||||
altered_message = message["content"]
|
||||
|
||||
transformed_message = (
|
||||
AIMessage(content=altered_message)
|
||||
if message["role"] == "assistant"
|
||||
else HumanMessage(content=altered_message)
|
||||
)
|
||||
transformed_messages.append(transformed_message)
|
||||
|
||||
return transformed_messages, prompt_instance
|
||||
|
||||
@database_sync_to_async
|
||||
def save_generated_message(conversation_id, message):
|
||||
conversation = Conversation.objects.get(id=conversation_id)
|
||||
serializer = PromptSerializer(
|
||||
data={
|
||||
"message": message,
|
||||
"user_created": False,
|
||||
"created": timezone.now(),
|
||||
}
|
||||
)
|
||||
if serializer.is_valid():
|
||||
prompt_instance = serializer.save()
|
||||
prompt_instance.conversation_id = conversation.id
|
||||
prompt_instance.save()
|
||||
else:
|
||||
print(serializer.errors)
|
||||
|
||||
@database_sync_to_async
|
||||
def create_prompt_metric(prompt_id, prompt, has_file, file_type, model_name, conversation_id):
|
||||
prompt_metric = PromptMetric.objects.create(
|
||||
prompt_id=prompt_id,
|
||||
start_time=timezone.now(),
|
||||
prompt_length=len(prompt),
|
||||
has_file=has_file,
|
||||
file_type=file_type,
|
||||
model_name=model_name,
|
||||
conversation_id=conversation_id,
|
||||
)
|
||||
return prompt_metric
|
||||
|
||||
@database_sync_to_async
|
||||
def finish_prompt_metric(prompt_metric, response_length):
|
||||
prompt_metric.end_time = timezone.now()
|
||||
prompt_metric.reponse_length = response_length
|
||||
prompt_metric.event = "FINISHED"
|
||||
prompt_metric.save(update_fields=["end_time", "reponse_length", "event"])
|
||||
|
||||
async def get_conversation_file_async(conversation_id):
|
||||
try:
|
||||
prompt_with_file = await Prompt.objects.filter(
|
||||
conversation_id=conversation_id
|
||||
).exclude(file='').order_by('created').afirst()
|
||||
|
||||
if prompt_with_file and prompt_with_file.file:
|
||||
file_data = await sync_to_async(prompt_with_file.file.read)()
|
||||
file_type = prompt_with_file.file_type
|
||||
return file_data, file_type
|
||||
except Exception as e:
|
||||
logger.error(f"Error retrieving file from conversation history: {e}")
|
||||
return None, None
|
||||
|
||||
|
||||
# --- LangGraph State ---
|
||||
|
||||
class ChatState(TypedDict):
|
||||
message: str
|
||||
conversation_id: int
|
||||
decoded_file: Union[bytes, None]
|
||||
file_type: Union[str, None]
|
||||
messages: List[BaseMessage]
|
||||
prompt_instance: Any # Django model instance
|
||||
moderation_label: Union[ModerationLabel, None]
|
||||
prompt_type: Union[PromptType, None]
|
||||
response_generator: Any # AsyncGenerator or dict
|
||||
error: Union[str, None]
|
||||
model_name: str
|
||||
|
||||
|
||||
# --- LangGraph Nodes ---
|
||||
|
||||
async def moderation_node(state: ChatState) -> ChatState:
|
||||
msg = state["message"]
|
||||
label = await moderation_classifier.classify_async(msg)
|
||||
return {"moderation_label": label}
|
||||
|
||||
async def classification_node(state: ChatState) -> ChatState:
|
||||
if state.get("moderation_label") == ModerationLabel.NSFW:
|
||||
return {"prompt_type": None}
|
||||
|
||||
msg = state["message"]
|
||||
decoded_file = state.get("decoded_file")
|
||||
|
||||
prompt_type = await PROMPT_CLASSIFIER.classify_async(msg)
|
||||
|
||||
# Override logic
|
||||
if decoded_file and (prompt_type == PromptType.DATA_ANALYSIS or 'analyze' in msg.lower() or 'data' in msg.lower()):
|
||||
prompt_type = PromptType.DATA_ANALYSIS
|
||||
elif decoded_file:
|
||||
prompt_type = PromptType.GENERAL_CHAT
|
||||
|
||||
return {"prompt_type": prompt_type}
|
||||
|
||||
async def generation_node(state: ChatState) -> ChatState:
|
||||
if state.get("moderation_label") == ModerationLabel.NSFW:
|
||||
response = "Prompt has been marked as NSFW. If this is in error, submit a feedback with the prompt text."
|
||||
return {"response_generator": {"type": "error", "content": response}}
|
||||
|
||||
prompt_type = state["prompt_type"]
|
||||
messages = state["messages"]
|
||||
prompt_instance = state["prompt_instance"]
|
||||
conversation_id = state["conversation_id"]
|
||||
decoded_file = state.get("decoded_file")
|
||||
file_type = state.get("file_type")
|
||||
|
||||
# Feature Flag: Image Generation
|
||||
if prompt_type == PromptType.IMAGE_GENERATION:
|
||||
if not getattr(settings, "ALLOW_IMAGE_GENERATION", False):
|
||||
return {"response_generator": {"type": "text", "content": "Image Generation is disabled."}}
|
||||
return {"response_generator": {"type": "text", "content": "Image Generation is not supported at this time, but it will be soon."}}
|
||||
|
||||
# Feature Flag: Internet Access
|
||||
if prompt_type == PromptType.SEARCH:
|
||||
# Check modelName first - if FAST, we skip search regardless of settings
|
||||
if state.get("model_name") == "FAST":
|
||||
pass
|
||||
elif getattr(settings, "ALLOW_INTERNET_ACCESS", False):
|
||||
try:
|
||||
search = DuckDuckGoSearchRun()
|
||||
search_results = search.run(state["message"])
|
||||
messages.append(HumanMessage(content=f"Search Results: {search_results}"))
|
||||
except Exception as e:
|
||||
logger.error(f"Search failed: {e}")
|
||||
pass
|
||||
else:
|
||||
pass
|
||||
|
||||
if prompt_type == PromptType.RAG:
|
||||
service = AsyncRAGService()
|
||||
workspace = await get_workspace(conversation_id)
|
||||
generator = service.generate_response(messages, prompt_instance.message, workspace)
|
||||
return {"response_generator": generator}
|
||||
|
||||
elif prompt_type == PromptType.DATA_ANALYSIS:
|
||||
service = AsyncDataAnalysisService()
|
||||
if not decoded_file:
|
||||
return {"response_generator": {"type": "text", "content": "Please upload a file to perform data analysis."}}
|
||||
generator = service.generate_response(prompt_instance.message, decoded_file, file_type)
|
||||
return {"response_generator": generator}
|
||||
|
||||
else: # GENERAL_CHAT or others
|
||||
service = AsyncLLMService()
|
||||
generator = service.generate_response(messages, prompt_instance.message, conversation_id)
|
||||
return {"response_generator": generator}
|
||||
|
||||
|
||||
# --- LangGraph Definition ---
|
||||
|
||||
workflow = StateGraph(ChatState)
|
||||
|
||||
workflow.add_node("moderation", moderation_node)
|
||||
workflow.add_node("classification", classification_node)
|
||||
workflow.add_node("generation", generation_node)
|
||||
|
||||
workflow.set_entry_point("moderation")
|
||||
|
||||
workflow.add_edge("moderation", "classification")
|
||||
workflow.add_edge("classification", "generation")
|
||||
workflow.add_edge("generation", END)
|
||||
|
||||
app = workflow.compile()
|
||||
|
||||
|
||||
# --- Consumer ---
|
||||
|
||||
class ChatConsumerGraph(AsyncWebsocketConsumer):
|
||||
async def connect(self):
|
||||
await self.accept()
|
||||
|
||||
async def disconnect(self, close_code):
|
||||
await self.close()
|
||||
|
||||
async def send_json_message(self, data_str):
|
||||
try:
|
||||
json.loads(data_str)
|
||||
await self.send(data_str)
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
await self.send(data_str)
|
||||
|
||||
async def receive(self, text_data=None, bytes_data=None):
|
||||
logger.debug(f"Text Data: {text_data}")
|
||||
print("Text Data: ", text_data)
|
||||
if text_data:
|
||||
data = json.loads(text_data)
|
||||
model = data.get("modelName", "Turbo")
|
||||
message = data.get("message", None)
|
||||
conversation_id = data.get("conversation_id", None)
|
||||
email = data.get("email", None)
|
||||
file = data.get("file", None)
|
||||
file_type = data.get("fileType", "")
|
||||
|
||||
if not conversation_id:
|
||||
title = await title_generator.generate_async(message)
|
||||
conversation_id = await create_conversation(message, email, title)
|
||||
|
||||
if conversation_id:
|
||||
print("Conversation ID: ", conversation_id)
|
||||
decoded_file = None
|
||||
if file:
|
||||
decoded_file = base64.b64decode(file)
|
||||
if "csv" in file_type: file_type = "csv"
|
||||
elif "xmlformats-officedocument" in file_type: file_type = "xlsx"
|
||||
elif "word" in file_type: file_type = "docx"
|
||||
elif "pdf" in file_type: file_type = "pdf"
|
||||
elif "text" in file_type: file_type = "txt"
|
||||
else: file_type = "Not Sure"
|
||||
|
||||
# Pre-fetch messages and file
|
||||
messages, prompt_instance = await get_messages(
|
||||
conversation_id, message, decoded_file, file_type
|
||||
)
|
||||
print("Messages: ", messages)
|
||||
if not decoded_file:
|
||||
decoded_file, file_type = await get_conversation_file_async(conversation_id)
|
||||
|
||||
prompt_metric = await create_prompt_metric(
|
||||
prompt_instance.id,
|
||||
prompt_instance.message,
|
||||
True if file else False,
|
||||
file_type,
|
||||
MODEL_NAME,
|
||||
conversation_id,
|
||||
)
|
||||
|
||||
# Initialize State
|
||||
initial_state = {
|
||||
"message": message,
|
||||
"conversation_id": conversation_id,
|
||||
"decoded_file": decoded_file,
|
||||
"file_type": file_type,
|
||||
"messages": messages,
|
||||
"prompt_instance": prompt_instance,
|
||||
"moderation_label": None,
|
||||
"prompt_type": None,
|
||||
"response_generator": None,
|
||||
"error": None,
|
||||
"model_name": model
|
||||
}
|
||||
print("Initial State: ", initial_state)
|
||||
|
||||
# Run Graph
|
||||
final_state = await app.ainvoke(initial_state)
|
||||
print("Final State: ", final_state)
|
||||
|
||||
response_generator_or_dict = final_state["response_generator"]
|
||||
print("Response Generator: ", response_generator_or_dict)
|
||||
|
||||
# Send start markers
|
||||
await self.send("CONVERSATION_ID")
|
||||
await self.send(str(conversation_id))
|
||||
await self.send("START_OF_THE_STREAM_ENDER_GAME_42")
|
||||
|
||||
full_response = ""
|
||||
|
||||
if isinstance(response_generator_or_dict, dict):
|
||||
content = response_generator_or_dict.get("content", "")
|
||||
await self.send_json_message(json.dumps(response_generator_or_dict))
|
||||
full_response = content
|
||||
else:
|
||||
async for chunk in response_generator_or_dict:
|
||||
full_response += chunk
|
||||
await self.send_json_message(chunk)
|
||||
|
||||
await self.send("END_OF_THE_STREAM_ENDER_GAME_42")
|
||||
|
||||
await save_generated_message(conversation_id, full_response)
|
||||
await finish_prompt_metric(prompt_metric, len(full_response))
|
||||
@@ -0,0 +1,78 @@
|
||||
# Generated by Django 5.1.7 on 2025-04-30 18:58
|
||||
|
||||
import django.db.models.deletion
|
||||
import django.utils.timezone
|
||||
from django.db import migrations, models
|
||||
|
||||
|
||||
class Migration(migrations.Migration):
|
||||
|
||||
dependencies = [
|
||||
("chat_backend", "0019_customuser_conversation_order_and_more"),
|
||||
]
|
||||
|
||||
operations = [
|
||||
migrations.CreateModel(
|
||||
name="DocumentWorkspace",
|
||||
fields=[
|
||||
(
|
||||
"id",
|
||||
models.BigAutoField(
|
||||
auto_created=True,
|
||||
primary_key=True,
|
||||
serialize=False,
|
||||
verbose_name="ID",
|
||||
),
|
||||
),
|
||||
("created", models.DateTimeField(default=django.utils.timezone.now)),
|
||||
(
|
||||
"last_modified",
|
||||
models.DateTimeField(default=django.utils.timezone.now),
|
||||
),
|
||||
("name", models.CharField(max_length=255)),
|
||||
(
|
||||
"company",
|
||||
models.ForeignKey(
|
||||
on_delete=django.db.models.deletion.CASCADE,
|
||||
to="chat_backend.company",
|
||||
),
|
||||
),
|
||||
],
|
||||
options={
|
||||
"abstract": False,
|
||||
},
|
||||
),
|
||||
migrations.CreateModel(
|
||||
name="Document",
|
||||
fields=[
|
||||
(
|
||||
"id",
|
||||
models.BigAutoField(
|
||||
auto_created=True,
|
||||
primary_key=True,
|
||||
serialize=False,
|
||||
verbose_name="ID",
|
||||
),
|
||||
),
|
||||
("created", models.DateTimeField(default=django.utils.timezone.now)),
|
||||
(
|
||||
"last_modified",
|
||||
models.DateTimeField(default=django.utils.timezone.now),
|
||||
),
|
||||
("file", models.FileField(upload_to="documents/")),
|
||||
("uploaded_at", models.DateTimeField(auto_now_add=True)),
|
||||
("processed", models.BooleanField(default=False)),
|
||||
("active", models.BooleanField(default=False)),
|
||||
(
|
||||
"workspace",
|
||||
models.ForeignKey(
|
||||
on_delete=django.db.models.deletion.CASCADE,
|
||||
to="chat_backend.documentworkspace",
|
||||
),
|
||||
),
|
||||
],
|
||||
options={
|
||||
"abstract": False,
|
||||
},
|
||||
),
|
||||
]
|
||||
20
llm_be/chat_backend/migrations/0021_alter_prompt_message.py
Normal file
20
llm_be/chat_backend/migrations/0021_alter_prompt_message.py
Normal file
@@ -0,0 +1,20 @@
|
||||
# Generated by Django 5.1.7 on 2025-09-24 16:44
|
||||
|
||||
from django.db import migrations, models
|
||||
|
||||
|
||||
class Migration(migrations.Migration):
|
||||
|
||||
dependencies = [
|
||||
("chat_backend", "0020_documentworkspace_document"),
|
||||
]
|
||||
|
||||
operations = [
|
||||
migrations.AlterField(
|
||||
model_name="prompt",
|
||||
name="message",
|
||||
field=models.CharField(
|
||||
help_text="The text for a prompt", max_length=102400
|
||||
),
|
||||
),
|
||||
]
|
||||
@@ -3,9 +3,11 @@ from django.contrib.auth.models import AbstractUser
|
||||
from django.utils import timezone
|
||||
from autoslug import AutoSlugField
|
||||
from django.core.files.storage import FileSystemStorage
|
||||
|
||||
# Create your models here.
|
||||
|
||||
FILE_STORAGE = FileSystemStorage(location='prompt_files')
|
||||
FILE_STORAGE = FileSystemStorage(location="prompt_files")
|
||||
|
||||
|
||||
class TimeInfoBase(models.Model):
|
||||
|
||||
@@ -51,6 +53,9 @@ class Company(TimeInfoBase):
|
||||
help_text="A list of LLMs that company can use",
|
||||
)
|
||||
|
||||
def __str__(self):
|
||||
return self.name
|
||||
|
||||
|
||||
class CustomUser(AbstractUser):
|
||||
company = models.ForeignKey(
|
||||
@@ -60,11 +65,17 @@ class CustomUser(AbstractUser):
|
||||
help_text="Allows the edit/add/remove of users for a company", default=False
|
||||
)
|
||||
deleted = models.BooleanField(help_text="This is to hid accounts", default=False)
|
||||
has_signed_tos = models.BooleanField(default=False, help_text="If the user has signed the TOS")
|
||||
slug = AutoSlugField(populate_from='email')
|
||||
conversation_order = models.BooleanField(default=True, help_text='How the conversations should display')
|
||||
has_signed_tos = models.BooleanField(
|
||||
default=False, help_text="If the user has signed the TOS"
|
||||
)
|
||||
slug = AutoSlugField(populate_from="email")
|
||||
conversation_order = models.BooleanField(
|
||||
default=True, help_text="How the conversations should display"
|
||||
)
|
||||
|
||||
def get_set_password_url(self):
|
||||
return f"https://www.chat.aimloperations.com/set_password?slug={self.slug}"
|
||||
return f"https://chat.aimloperations.com/set_password?slug={self.slug}"
|
||||
|
||||
|
||||
FEEDBACK_CHOICE = (
|
||||
("SUBMITTED", "Submitted"),
|
||||
@@ -74,21 +85,26 @@ FEEDBACK_CHOICE = (
|
||||
)
|
||||
|
||||
FEEDBACK_CATEGORIES = (
|
||||
('NOT_DEFINED', 'Not defined'),
|
||||
('BUG', 'Bug'),
|
||||
('ENHANCEMENT', 'Enhancement'),
|
||||
('OTHER', 'Other'),
|
||||
('MAX_CATEGORIES', 'Max Categories'),
|
||||
("NOT_DEFINED", "Not defined"),
|
||||
("BUG", "Bug"),
|
||||
("ENHANCEMENT", "Enhancement"),
|
||||
("OTHER", "Other"),
|
||||
("MAX_CATEGORIES", "Max Categories"),
|
||||
)
|
||||
|
||||
|
||||
class Feedback(TimeInfoBase):
|
||||
title = models.TextField(max_length=64, default='')
|
||||
title = models.TextField(max_length=64, default="")
|
||||
user = models.ForeignKey(
|
||||
CustomUser, on_delete=models.CASCADE, blank=True, null=True
|
||||
)
|
||||
text = models.TextField(max_length=512)
|
||||
status = models.CharField(max_length=24, choices=FEEDBACK_CHOICE, default="SUBMITTED")
|
||||
category = models.CharField(max_length=24, choices=FEEDBACK_CATEGORIES, default="NOT_DEFINED")
|
||||
status = models.CharField(
|
||||
max_length=24, choices=FEEDBACK_CHOICE, default="SUBMITTED"
|
||||
)
|
||||
category = models.CharField(
|
||||
max_length=24, choices=FEEDBACK_CATEGORIES, default="NOT_DEFINED"
|
||||
)
|
||||
|
||||
def get_user_email(self):
|
||||
if self.user:
|
||||
@@ -105,9 +121,8 @@ MONTH_CHOICES = (
|
||||
("DECEMBER", "December"),
|
||||
)
|
||||
|
||||
month = models.CharField(max_length=9,
|
||||
choices=MONTH_CHOICES,
|
||||
default="JANUARY")
|
||||
month = models.CharField(max_length=9, choices=MONTH_CHOICES, default="JANUARY")
|
||||
|
||||
|
||||
class Announcement(TimeInfoBase):
|
||||
class Status(models.TextChoices):
|
||||
@@ -131,7 +146,9 @@ class Conversation(TimeInfoBase):
|
||||
title = models.CharField(
|
||||
max_length=64, help_text="The title for the conversation", default=""
|
||||
)
|
||||
deleted = models.BooleanField(help_text="This is to hide conversations", default=False)
|
||||
deleted = models.BooleanField(
|
||||
help_text="This is to hide conversations", default=False
|
||||
)
|
||||
|
||||
def get_user_email(self):
|
||||
if self.user:
|
||||
@@ -144,27 +161,33 @@ class Conversation(TimeInfoBase):
|
||||
|
||||
|
||||
class Prompt(TimeInfoBase):
|
||||
message = models.CharField(max_length=10 * 1024, help_text="The text for a prompt")
|
||||
message = models.CharField(max_length=100 * 1024, help_text="The text for a prompt")
|
||||
user_created = models.BooleanField(
|
||||
help_text="True if was created by the user. False if it was generate by the LLM"
|
||||
)
|
||||
conversation = models.ForeignKey(
|
||||
"Conversation", on_delete=models.CASCADE, blank=True, null=True
|
||||
)
|
||||
file =models.FileField(upload_to=FILE_STORAGE, blank=True, null=True, help_text="file for the prompt")
|
||||
file_type=models.CharField(max_length=16, blank=True, null=True, help_text='file type of the file for the prompt')
|
||||
file = models.FileField(
|
||||
upload_to=FILE_STORAGE, blank=True, null=True, help_text="file for the prompt"
|
||||
)
|
||||
file_type = models.CharField(
|
||||
max_length=16,
|
||||
blank=True,
|
||||
null=True,
|
||||
help_text="file type of the file for the prompt",
|
||||
)
|
||||
|
||||
def get_conversation_title(self):
|
||||
if self.conversation:
|
||||
return self.conversation.title
|
||||
else:
|
||||
return ""
|
||||
|
||||
|
||||
def file_exists(self):
|
||||
return self.file != None and self.file.storage.exists(self.file.name)
|
||||
|
||||
|
||||
|
||||
class PromptMetric(TimeInfoBase):
|
||||
PROMPT_METRIC_CHOICES = (
|
||||
("CREATED", "Created"),
|
||||
@@ -174,20 +197,42 @@ class PromptMetric(TimeInfoBase):
|
||||
("MAX_PROMPT_METRIC_CHOICES", "Max Prompt Metric Choices"),
|
||||
)
|
||||
prompt_id = models.IntegerField(help_text="The id of the prompt this matches to")
|
||||
conversation_id = models.IntegerField(help_text="The id of the conversation this matches to")
|
||||
conversation_id = models.IntegerField(
|
||||
help_text="The id of the conversation this matches to"
|
||||
)
|
||||
event = models.CharField(
|
||||
max_length=26, choices=PROMPT_METRIC_CHOICES, default='CREATED'
|
||||
max_length=26, choices=PROMPT_METRIC_CHOICES, default="CREATED"
|
||||
)
|
||||
model_name = models.CharField(max_length=215, help_text="The name of the model")
|
||||
start_time = models.DateTimeField()
|
||||
end_time = models.DateTimeField(blank=True, null=True)
|
||||
prompt_length = models.IntegerField( help_text="How many characters are in the prompt")
|
||||
reponse_length = models.IntegerField(blank=True, null=True, help_text="How many characters are in the response")
|
||||
prompt_length = models.IntegerField(
|
||||
help_text="How many characters are in the prompt"
|
||||
)
|
||||
reponse_length = models.IntegerField(
|
||||
blank=True, null=True, help_text="How many characters are in the response"
|
||||
)
|
||||
has_file = models.BooleanField(help_text="Is there a file")
|
||||
file_type = models.CharField(max_length=16, help_text='The file type, if any', blank=True, null=True)
|
||||
file_type = models.CharField(
|
||||
max_length=16, help_text="The file type, if any", blank=True, null=True
|
||||
)
|
||||
|
||||
def get_duration(self):
|
||||
if(self.start_time and self.end_time):
|
||||
difference =self.end_time - self.start_time
|
||||
if self.start_time and self.end_time:
|
||||
difference = self.end_time - self.start_time
|
||||
return difference.seconds
|
||||
return 0
|
||||
|
||||
|
||||
# Document Models
|
||||
class DocumentWorkspace(TimeInfoBase):
|
||||
name = models.CharField(max_length=255)
|
||||
company = models.ForeignKey(Company, on_delete=models.CASCADE)
|
||||
|
||||
|
||||
class Document(TimeInfoBase):
|
||||
workspace = models.ForeignKey(DocumentWorkspace, on_delete=models.CASCADE)
|
||||
file = models.FileField(upload_to="documents/")
|
||||
uploaded_at = models.DateTimeField(auto_now_add=True)
|
||||
processed = models.BooleanField(default=False)
|
||||
active = models.BooleanField(default=False)
|
||||
|
||||
@@ -1,8 +1,9 @@
|
||||
from rest_framework.renderers import BaseRenderer
|
||||
|
||||
|
||||
class ServerSentEventRenderer(BaseRenderer):
|
||||
media_type = 'text/event-stream'
|
||||
format = 'txt'
|
||||
media_type = "text/event-stream"
|
||||
format = "txt"
|
||||
|
||||
def render(self, data, accepted_media_type=None, renderer_context=None):
|
||||
return data
|
||||
return data
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
from django.urls import re_path
|
||||
from .views import ChatConsumerAgain
|
||||
|
||||
websocket_urlpatterns = [
|
||||
re_path(r'ws/chat_again/$', ChatConsumerAgain.as_asgi()),
|
||||
from django.urls import re_path
|
||||
from .consumers import ChatConsumerAgain
|
||||
from .consumers_graph import ChatConsumerGraph
|
||||
|
||||
]
|
||||
websocket_urlpatterns = [
|
||||
re_path(r"ws/chat_again/$", ChatConsumerAgain.as_asgi()),
|
||||
re_path(r"ws/conditional_chat/$", ChatConsumerGraph.as_asgi()),
|
||||
]
|
||||
|
||||
@@ -1,6 +1,16 @@
|
||||
from rest_framework_simplejwt.serializers import TokenObtainPairSerializer
|
||||
from rest_framework import serializers
|
||||
from .models import CustomUser, Announcement, Company, Conversation, Prompt, Feedback, FEEDBACK_CATEGORIES
|
||||
from .models import (
|
||||
CustomUser,
|
||||
Announcement,
|
||||
Company,
|
||||
Conversation,
|
||||
Prompt,
|
||||
Feedback,
|
||||
FEEDBACK_CATEGORIES,
|
||||
DocumentWorkspace,
|
||||
Document,
|
||||
)
|
||||
|
||||
|
||||
class MyTokenObtainPairSerializer(TokenObtainPairSerializer):
|
||||
@@ -25,11 +35,13 @@ class AnnouncmentSerializer(serializers.ModelSerializer):
|
||||
model = Announcement
|
||||
fields = "__all__"
|
||||
|
||||
|
||||
class FeedbackSerializer(serializers.ModelSerializer):
|
||||
class Meta:
|
||||
model = Feedback
|
||||
fields = "__all__"
|
||||
|
||||
|
||||
class CustomUserSerializer(serializers.ModelSerializer):
|
||||
email = serializers.EmailField(required=True)
|
||||
username = serializers.CharField()
|
||||
@@ -58,12 +70,49 @@ class ConversationSerializer(serializers.ModelSerializer):
|
||||
|
||||
|
||||
class PromptSerializer(serializers.ModelSerializer):
|
||||
|
||||
|
||||
class Meta:
|
||||
model = Prompt
|
||||
fields = ("message", "user_created", "created", "id", )
|
||||
fields = (
|
||||
"message",
|
||||
"user_created",
|
||||
"created",
|
||||
"id",
|
||||
)
|
||||
|
||||
|
||||
class BasicUserSerializer(serializers.ModelSerializer):
|
||||
class Meta:
|
||||
model = CustomUser
|
||||
fields = ("email", "first_name", "last_name", "is_active","has_usable_password","is_company_manager",'has_signed_tos')
|
||||
fields = (
|
||||
"email",
|
||||
"first_name",
|
||||
"last_name",
|
||||
"is_active",
|
||||
"has_usable_password",
|
||||
"is_company_manager",
|
||||
"has_signed_tos",
|
||||
)
|
||||
|
||||
|
||||
# document serializers
|
||||
class DocumentWorkspaceSerializer(serializers.ModelSerializer):
|
||||
class Meta:
|
||||
model = DocumentWorkspace
|
||||
fields = ["id", "name", "created"]
|
||||
read_only_fields = ["id", "created"]
|
||||
|
||||
|
||||
class DocumentSerializer(serializers.ModelSerializer):
|
||||
class Meta:
|
||||
model = Document
|
||||
fields = [
|
||||
"id",
|
||||
"workspace",
|
||||
"file",
|
||||
"uploaded_at",
|
||||
"processed",
|
||||
"created",
|
||||
"active",
|
||||
]
|
||||
read_only_fields = ["id", "uploaded_at", "processed", "created"]
|
||||
|
||||
0
llm_be/chat_backend/services/__init__.py
Normal file
0
llm_be/chat_backend/services/__init__.py
Normal file
18
llm_be/chat_backend/services/base_service.py
Normal file
18
llm_be/chat_backend/services/base_service.py
Normal file
@@ -0,0 +1,18 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from langchain_ollama import OllamaLLM
|
||||
from langchain_core.output_parsers import StrOutputParser
|
||||
from django.conf import settings
|
||||
|
||||
class BaseService(ABC):
|
||||
"""Abstract base class for LLM conversation services."""
|
||||
|
||||
def __init__(self, temperature=0.7):
|
||||
self.llm = OllamaLLM(
|
||||
model="llama3.2" if not settings.DEBUG else "gpt-oss:20b",
|
||||
temperature=0.7,
|
||||
top_k=50,
|
||||
top_p=0.9,
|
||||
repeat_penalty=1.1,
|
||||
num_ctx=4096,
|
||||
)
|
||||
self.output_parser = StrOutputParser()
|
||||
194
llm_be/chat_backend/services/data_analysis_service.py
Normal file
194
llm_be/chat_backend/services/data_analysis_service.py
Normal file
@@ -0,0 +1,194 @@
|
||||
import pandas as pd
|
||||
import io
|
||||
import re
|
||||
import json
|
||||
import base64
|
||||
import matplotlib.pyplot as plt
|
||||
from typing import AsyncGenerator
|
||||
from langchain_core.prompts import ChatPromptTemplate
|
||||
from langchain_ollama import OllamaLLM
|
||||
from langchain_core.output_parsers import StrOutputParser
|
||||
import docx
|
||||
import pypdf
|
||||
from django.conf import settings
|
||||
|
||||
|
||||
class AsyncDataAnalysisService:
|
||||
"""Asynchronous service for performing data analysis with an LLM."""
|
||||
|
||||
def __init__(self):
|
||||
# A model with a large context window and strong analytical skills is best
|
||||
self.llm = OllamaLLM(
|
||||
model="llama3.2" if not settings.DEBUG else "gpt-oss:20b",
|
||||
temperature=0.3,
|
||||
num_ctx=8192,
|
||||
)
|
||||
self.output_parser = StrOutputParser()
|
||||
self._setup_chain()
|
||||
|
||||
def _setup_chain(self):
|
||||
"""Set up the LLM chain with a prompt tailored for data analysis."""
|
||||
template = """You are an expert data analyst. Your role is to directly answer a user's question about a dataset or document they have provided.
|
||||
You will be given a summary and a sample of the dataset, or the content of the document.
|
||||
Based on this information, provide a clear and concise answer to the user's question.
|
||||
Do not provide Python code or any other code. The user is not a developer and wants a direct answer.
|
||||
Even if you don't think the data provides enough evidence for the query, still provide a response
|
||||
|
||||
---
|
||||
Data/Document Content:
|
||||
{data_summary}
|
||||
---
|
||||
|
||||
User's Question: {query}
|
||||
Answer:"""
|
||||
|
||||
self.prompt = ChatPromptTemplate.from_template(template)
|
||||
|
||||
self.analysis_chain = (
|
||||
{
|
||||
"data_summary": lambda x: x["data_summary"],
|
||||
"query": lambda x: x["query"],
|
||||
}
|
||||
| self.prompt
|
||||
| self.llm
|
||||
| self.output_parser
|
||||
)
|
||||
|
||||
def _get_dataframe_summary(self, df: pd.DataFrame) -> str:
|
||||
"""Generates a structured summary of the DataFrame for the LLM."""
|
||||
|
||||
num_rows, num_cols = df.shape
|
||||
summary_lines = [
|
||||
f"DataFrame has {num_rows} rows and {num_cols} columns.",
|
||||
"Column Information (Name, Dtype, Non-Null Count):",
|
||||
"--------------------------------------------------",
|
||||
]
|
||||
|
||||
# Add a concise summary using df.info()
|
||||
info_buffer = io.StringIO()
|
||||
df.info(buf=info_buffer, verbose=True, show_counts=True)
|
||||
summary_lines.append(info_buffer.getvalue())
|
||||
|
||||
summary_lines.append("\nDescriptive Statistics (for numerical columns):")
|
||||
summary_lines.append("--------------------------------------------")
|
||||
summary_lines.append(df.describe().to_string())
|
||||
|
||||
summary_lines.append("\nSample of Data:")
|
||||
summary_lines.append("-----------------")
|
||||
# Show the first 5 rows and a few random rows to give a feel for the data
|
||||
summary_lines.append(df.head(5).to_string())
|
||||
|
||||
return "\n".join(summary_lines)
|
||||
|
||||
def _read_docx(self, file_bytes: bytes) -> str:
|
||||
"""Reads text from a DOCX file."""
|
||||
doc = docx.Document(io.BytesIO(file_bytes))
|
||||
full_text = []
|
||||
for para in doc.paragraphs:
|
||||
full_text.append(para.text)
|
||||
return "\n".join(full_text)
|
||||
|
||||
def _read_pdf(self, file_bytes: bytes) -> str:
|
||||
"""Reads text from a PDF file."""
|
||||
pdf_reader = pypdf.PdfReader(io.BytesIO(file_bytes))
|
||||
full_text = []
|
||||
for page in pdf_reader.pages:
|
||||
full_text.append(page.extract_text())
|
||||
return "\n".join(full_text)
|
||||
|
||||
def _generate_plot(self, query: str, df: pd.DataFrame) -> str:
|
||||
"""
|
||||
Generates a plot from a DataFrame based on a natural language query,
|
||||
encodes it in Base64, and returns it.
|
||||
|
||||
If columns are specified (e.g., "plot X vs Y"), it uses them.
|
||||
If not, it automatically picks the first two numerical columns.
|
||||
"""
|
||||
col1, col2 = None, None
|
||||
title = "Scatter Plot"
|
||||
|
||||
# Attempt to find explicitly mentioned columns, e.g., "plot Column1 vs Column2"
|
||||
match = re.search(r"(?:plot|scatter|visualize)\s+(.*?)\s+(?:vs|versus|and)\s+(.*)", query, re.IGNORECASE)
|
||||
if match:
|
||||
potential_col1 = match.group(1).strip()
|
||||
potential_col2 = match.group(2).strip()
|
||||
if potential_col1 in df.columns and potential_col2 in df.columns:
|
||||
col1, col2 = potential_col1, potential_col2
|
||||
title = f"Scatterplot of {col1} vs {col2}"
|
||||
|
||||
# If no valid columns were explicitly found, auto-detect
|
||||
if not col1 or not col2:
|
||||
numeric_cols = df.select_dtypes(include=['number']).columns.tolist()
|
||||
if len(numeric_cols) >= 2:
|
||||
col1, col2 = numeric_cols[0], numeric_cols[1]
|
||||
title = f"Scatterplot of {col1} vs {col2} (Auto-selected)"
|
||||
else:
|
||||
raise ValueError("I couldn't find two numerical columns to plot automatically. Please specify columns, like 'plot column_A vs column_B'.")
|
||||
|
||||
fig, ax = plt.subplots()
|
||||
ax.scatter(df[col1], df[col2])
|
||||
ax.set_xlabel(col1)
|
||||
ax.set_ylabel(col2)
|
||||
ax.set_title(title)
|
||||
ax.grid(True)
|
||||
|
||||
buf = io.BytesIO()
|
||||
plt.savefig(buf, format='png', bbox_inches='tight')
|
||||
plt.close(fig)
|
||||
buf.seek(0)
|
||||
|
||||
image_base64 = base64.b64encode(buf.read()).decode('utf-8')
|
||||
return image_base64
|
||||
|
||||
async def generate_response(
|
||||
self,
|
||||
query: str,
|
||||
decoded_file: bytes,
|
||||
file_type: str,
|
||||
) -> AsyncGenerator[str, None]:
|
||||
"""
|
||||
Generate a response based on the uploaded data and user query.
|
||||
This can be a text analysis or a plot visualization.
|
||||
"""
|
||||
try:
|
||||
df = None
|
||||
data_summary = ""
|
||||
file_type = file_type.lower()
|
||||
print(file_type)
|
||||
|
||||
if "csv" in file_type:
|
||||
df = pd.read_csv(io.BytesIO(decoded_file))
|
||||
data_summary = self._get_dataframe_summary(df)
|
||||
elif "xlsx" in file_type or "spreadsheet" in file_type:
|
||||
df = pd.read_excel(io.BytesIO(decoded_file))
|
||||
data_summary = self._get_dataframe_summary(df)
|
||||
elif "word" in file_type or "docx" in file_type:
|
||||
data_summary = self._read_docx(decoded_file)
|
||||
elif "pdf" in file_type:
|
||||
data_summary = self._read_pdf(decoded_file)
|
||||
else:
|
||||
yield json.dumps({"type": "error", "content": f"Unsupported file type: {file_type}. I can analyze CSV, XLSX, DOCX, and PDF files."})
|
||||
return
|
||||
|
||||
# Only attempt plotting if we have a DataFrame
|
||||
if df is not None:
|
||||
plot_keywords = ["plot", "graph", "scatter", "visualize"]
|
||||
if any(keyword in query.lower() for keyword in plot_keywords):
|
||||
try:
|
||||
image_base64 = self._generate_plot(query, df)
|
||||
yield json.dumps({
|
||||
"type": "plot",
|
||||
"format": "png",
|
||||
"image": image_base64
|
||||
})
|
||||
except ValueError as e:
|
||||
yield json.dumps({"type": "error", "content": str(e)})
|
||||
return
|
||||
|
||||
chain_input = {"data_summary": data_summary, "query": query}
|
||||
|
||||
async for chunk in self.analysis_chain.astream(chain_input):
|
||||
yield chunk #json.dumps({"type": "text", "content": chunk})
|
||||
|
||||
except Exception as e:
|
||||
yield json.dumps({"type": "error", "content": f"An error occurred: {e}"})
|
||||
144
llm_be/chat_backend/services/image_generation.py
Normal file
144
llm_be/chat_backend/services/image_generation.py
Normal file
@@ -0,0 +1,144 @@
|
||||
import os
|
||||
import logging
|
||||
from typing import Optional, Tuple
|
||||
from PIL import Image
|
||||
import torch
|
||||
from diffusers import StableDiffusionPipeline, DPMSolverSinglestepScheduler
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ImageGenerationService:
|
||||
"""
|
||||
Service for text-to-image generation using Stable Diffusion.
|
||||
Uses singleton pattern to maintain loaded model in memory.
|
||||
"""
|
||||
|
||||
_instance = None
|
||||
_model_loaded = False
|
||||
|
||||
def __new__(cls):
|
||||
if cls._instance is None:
|
||||
cls._instance = super().__new__(cls)
|
||||
cls._instance._initialize()
|
||||
return cls._instance
|
||||
|
||||
def _initialize(self):
|
||||
"""Initialize the service with default settings"""
|
||||
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
self.model_id = "stabilityai/stable-diffusion-2-1"
|
||||
self.pipeline = None
|
||||
self.default_params = {
|
||||
"num_inference_steps": 25,
|
||||
"guidance_scale": 7.5,
|
||||
"width": 512,
|
||||
"height": 512,
|
||||
}
|
||||
|
||||
def load_model(self):
|
||||
"""Load the Stable Diffusion model"""
|
||||
if self._model_loaded:
|
||||
return
|
||||
|
||||
try:
|
||||
logger.info(f"Loading Stable Diffusion model on {self.device}...")
|
||||
|
||||
# Use DPMSolver for faster inference
|
||||
self.pipeline = StableDiffusionPipeline.from_pretrained(
|
||||
self.model_id,
|
||||
torch_dtype=torch.float16 if self.device == "cuda" else torch.float32,
|
||||
)
|
||||
self.pipeline.scheduler = DPMSolverSinglestepScheduler.from_config(
|
||||
self.pipeline.scheduler.config
|
||||
)
|
||||
self.pipeline = self.pipeline.to(self.device)
|
||||
|
||||
# Optimizations
|
||||
if self.device == "cuda":
|
||||
self.pipeline.enable_attention_slicing()
|
||||
self.pipeline.enable_xformers_memory_efficient_attention()
|
||||
|
||||
self._model_loaded = True
|
||||
logger.info("Model loaded successfully")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to load model: {str(e)}")
|
||||
raise RuntimeError(f"Model loading failed: {str(e)}")
|
||||
|
||||
def generate_image(
|
||||
self,
|
||||
prompt: str,
|
||||
negative_prompt: Optional[str] = None,
|
||||
output_path: Optional[str] = None,
|
||||
**kwargs,
|
||||
) -> Tuple[Image.Image, dict]:
|
||||
"""
|
||||
Generate image from text prompt.
|
||||
|
||||
Args:
|
||||
prompt: Text prompt for image generation
|
||||
negative_prompt: Text for things to avoid in generation
|
||||
output_path: Optional path to save the image
|
||||
**kwargs: Generation parameters (overrides defaults)
|
||||
|
||||
Returns:
|
||||
Tuple of (PIL.Image, generation_parameters)
|
||||
"""
|
||||
if not self._model_loaded:
|
||||
self.load_model()
|
||||
|
||||
# Merge default params with overrides
|
||||
params = {**self.default_params, **kwargs}
|
||||
|
||||
try:
|
||||
logger.info(f"Generating image with prompt: {prompt[:50]}...")
|
||||
|
||||
with torch.inference_mode():
|
||||
result = self.pipeline(
|
||||
prompt=prompt, negative_prompt=negative_prompt, **params
|
||||
)
|
||||
|
||||
image = result.images[0]
|
||||
|
||||
if output_path:
|
||||
os.makedirs(os.path.dirname(output_path), exist_ok=True)
|
||||
image.save(output_path)
|
||||
logger.info(f"Image saved to {output_path}")
|
||||
|
||||
return image, params
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Image generation failed: {str(e)}")
|
||||
raise RuntimeError(f"Image generation failed: {str(e)}")
|
||||
|
||||
|
||||
class AsyncImageGenerationService:
|
||||
"""
|
||||
Asynchronous wrapper for image generation service.
|
||||
Runs the synchronous service in a thread pool.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.sync_service = ImageGenerationService()
|
||||
|
||||
async def generate_image(
|
||||
self,
|
||||
prompt: str,
|
||||
negative_prompt: Optional[str] = None,
|
||||
output_path: Optional[str] = None,
|
||||
**kwargs,
|
||||
) -> Tuple[Image.Image, dict]:
|
||||
"""Async version of generate_image"""
|
||||
import asyncio
|
||||
from functools import partial
|
||||
|
||||
loop = asyncio.get_event_loop()
|
||||
func = partial(
|
||||
self.sync_service.generate_image,
|
||||
prompt=prompt,
|
||||
negative_prompt=negative_prompt,
|
||||
output_path=output_path,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
return await loop.run_in_executor(None, func)
|
||||
152
llm_be/chat_backend/services/llm_service.py
Normal file
152
llm_be/chat_backend/services/llm_service.py
Normal file
@@ -0,0 +1,152 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import AsyncGenerator, Generator, Optional
|
||||
|
||||
# from langchain_community.llms import Ollama
|
||||
from langchain_ollama import OllamaLLM
|
||||
from langchain_core.output_parsers import StrOutputParser
|
||||
from langchain_core.prompts import ChatPromptTemplate
|
||||
from django.conf import settings
|
||||
|
||||
from chat_backend.models import Conversation, Prompt
|
||||
|
||||
|
||||
class LLMService(ABC):
|
||||
"""Abstract base class for LLM conversation services."""
|
||||
|
||||
def __init__(self):
|
||||
self.llm = OllamaLLM(
|
||||
model="llama3.2" if not settings.DEBUG else "gpt-oss:20b",
|
||||
temperature=0.7,
|
||||
top_k=50,
|
||||
top_p=0.9,
|
||||
repeat_penalty=1.1,
|
||||
num_ctx=4096,
|
||||
)
|
||||
self.output_parser = StrOutputParser()
|
||||
|
||||
@abstractmethod
|
||||
def generate_response(self, conversation: Conversation, query: str, **kwargs):
|
||||
"""Generate a response to a query within a conversation context."""
|
||||
pass
|
||||
|
||||
def _format_history(self, conversation: Conversation) -> str:
|
||||
"""Format conversation history for the prompt."""
|
||||
prompts = Prompt.objects.filter(conversation=conversation).order_by(
|
||||
"created_at"
|
||||
)
|
||||
return "\n".join(
|
||||
f"{'User' if prompt.is_user else 'AI'}: {prompt.text}" for prompt in prompts
|
||||
)
|
||||
|
||||
|
||||
class SyncLLMService(LLMService):
|
||||
"""Synchronous LLM conversation service."""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self._setup_chain()
|
||||
|
||||
def _setup_chain(self):
|
||||
"""Setup the conversation chain."""
|
||||
template = """Continue the conversation based on the following history:
|
||||
|
||||
{history}
|
||||
|
||||
Latest message: {query}
|
||||
|
||||
Response:"""
|
||||
self.prompt = ChatPromptTemplate.from_template(template)
|
||||
|
||||
self.conversation_chain = (
|
||||
{
|
||||
"history": lambda x: self._format_history(x["conversation"]),
|
||||
"query": lambda x: x["query"],
|
||||
}
|
||||
| self.prompt
|
||||
| self.llm
|
||||
| self.output_parser
|
||||
)
|
||||
|
||||
def generate_response(
|
||||
self, conversation: Conversation, query: str, **kwargs
|
||||
) -> Generator[str, None, None]:
|
||||
"""Generate response with streaming support."""
|
||||
chain_input = {"query": query, "conversation": conversation}
|
||||
|
||||
for chunk in self.conversation_chain.stream(chain_input):
|
||||
yield chunk
|
||||
|
||||
|
||||
class AsyncLLMService(LLMService):
|
||||
"""Asynchronous LLM conversation service."""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self._setup_chain()
|
||||
|
||||
def _setup_chain(self):
|
||||
"""Setup the conversation chain."""
|
||||
template = """Continue this conversation while maintaining context by providing a single helpful response.
|
||||
Current context: {context}
|
||||
|
||||
Last 3 messages:
|
||||
{recent_history}
|
||||
|
||||
Latest message: {query}
|
||||
|
||||
Instructions:
|
||||
- Carefully maintain all established context
|
||||
- If referencing previous elements (like stories), preserve all details
|
||||
- When asked to modify something, identify what's being modified
|
||||
|
||||
Response:"""
|
||||
|
||||
self.prompt = ChatPromptTemplate.from_template(template)
|
||||
|
||||
self.conversation_chain = (
|
||||
{
|
||||
"context":lambda x: x["conversation"],
|
||||
"recent_history":lambda x: x['recent_conversation'],
|
||||
"query": lambda x: x["query"],
|
||||
}
|
||||
| self.prompt
|
||||
| self.llm
|
||||
| self.output_parser
|
||||
)
|
||||
|
||||
async def _format_history(self, conversation: list) -> str:
|
||||
"""Async version of format conversation history."""
|
||||
# prompts = list(
|
||||
# await Prompt.objects.filter(conversation_id=conversation_id)
|
||||
# .order_by("created")
|
||||
|
||||
# )
|
||||
# return "\n".join(
|
||||
# f"{'User' if prompt.is_user else 'AI'}: {prompt.text}" for prompt in prompts
|
||||
# )
|
||||
return "\n".join([f"{"User" if prompt.type=="human" else "AI"}: {prompt.text()}" for prompt in conversation])
|
||||
|
||||
async def _get_recent_messages(self, conversation: list) -> str:
|
||||
"""Async version of format conversation history."""
|
||||
|
||||
# prompts = list(
|
||||
# await Prompt.objects.filter(conversation_id=conversation_id)
|
||||
# .order_by("created")
|
||||
# [-6:]
|
||||
# )
|
||||
# return "\n".join(
|
||||
# f"{'User' if prompt.is_user else 'AI'}: {prompt.text}" for prompt in prompts
|
||||
# )
|
||||
return "\n".join([f"{"User" if prompt.type=="human" else "AI"}: {prompt.text()}" for prompt in conversation])
|
||||
|
||||
async def generate_response(
|
||||
self, conversation: Conversation, query: str, conversation_id: int, **kwargs
|
||||
) -> AsyncGenerator[str, None]:
|
||||
"""Generate response with async streaming support."""
|
||||
chain_input = {
|
||||
"query": query,
|
||||
"conversation": await self._format_history(conversation),
|
||||
"recent_conversation": await self._get_recent_messages(conversation[-6:])}
|
||||
|
||||
async for chunk in self.conversation_chain.astream(chain_input):
|
||||
yield chunk
|
||||
93
llm_be/chat_backend/services/moderation_classifier.py
Normal file
93
llm_be/chat_backend/services/moderation_classifier.py
Normal file
@@ -0,0 +1,93 @@
|
||||
from enum import Enum, auto
|
||||
from typing import Dict, Any
|
||||
from langchain_core.prompts import ChatPromptTemplate
|
||||
from langchain_ollama import OllamaLLM
|
||||
from chat_backend.services.base_service import BaseService
|
||||
|
||||
|
||||
class ModerationLabel(Enum):
|
||||
NSFW = auto()
|
||||
FINE = auto()
|
||||
|
||||
|
||||
|
||||
class ModerationClassifier(BaseService):
|
||||
"""
|
||||
Classifies prompts as NSFW or FINE (safe) content.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(temperature=0.1)
|
||||
self.llm = OllamaLLM(
|
||||
model="llama3.2",
|
||||
temperature=0.1, # Very low for strict moderation
|
||||
top_k=10,
|
||||
num_ctx=2048,
|
||||
)
|
||||
|
||||
self.moderation_prompt = ChatPromptTemplate.from_messages(
|
||||
[
|
||||
(
|
||||
"system",
|
||||
"""You are a strict content moderator. Classify the following prompt as either NSFW or FINE.
|
||||
|
||||
NSFW includes:
|
||||
- Sexual content
|
||||
- Violence/gore
|
||||
- Hate speech
|
||||
- Illegal activities
|
||||
- Harassment
|
||||
- Graphic/disturbing content
|
||||
|
||||
FINE includes:
|
||||
- Safe for work topics
|
||||
- General conversation
|
||||
- Professional inquiries
|
||||
- Creative requests (non-explicit)
|
||||
- Technical questions
|
||||
- Data Analysis
|
||||
|
||||
Examples:
|
||||
- "How to make a bomb" → NSFW
|
||||
- "Write a love poem" → FINE
|
||||
- "Explicit sex scene" → NSFW
|
||||
- "Python tutorial" → FINE
|
||||
- "Who won the 2024 presidental race?" → FINE
|
||||
- "Please analyze this file and project the next 12 months for me. Add a graph visual of the data as well" → FINE
|
||||
- "Okie, instead of 6 month projection, can you tell me what the values would be in the next 5 days" → FINE
|
||||
|
||||
Return ONLY "NSFW" or "FINE", nothing else.""",
|
||||
),
|
||||
("human", "{prompt}"),
|
||||
]
|
||||
)
|
||||
|
||||
self.chain = self.moderation_prompt | self.llm
|
||||
|
||||
async def classify_async(self, prompt: str) -> ModerationLabel:
|
||||
"""Asynchronous classification"""
|
||||
try:
|
||||
response = (await self.chain.ainvoke({"prompt": prompt})).strip().upper()
|
||||
return self._parse_response(response)
|
||||
except Exception as e:
|
||||
print(f"Moderation error: {e}")
|
||||
return ModerationLabel.NSFW # Fail-safe to NSFW
|
||||
|
||||
def classify(self, prompt: str) -> ModerationLabel:
|
||||
"""Synchronous classification"""
|
||||
try:
|
||||
response = self.chain.invoke({"prompt": prompt}).strip().upper()
|
||||
return self._parse_response(response)
|
||||
except Exception as e:
|
||||
print(f"Moderation error: {e}")
|
||||
return ModerationLabel.NSFW # Fail-safe to NSFW
|
||||
|
||||
def _parse_response(self, response: str) -> ModerationLabel:
|
||||
"""Convert string response to ModerationLabel enum"""
|
||||
if "NSFW" in response:
|
||||
return ModerationLabel.NSFW
|
||||
return ModerationLabel.FINE # Default to FINE if unclear
|
||||
|
||||
|
||||
# Singleton instance
|
||||
moderation_classifier = ModerationClassifier()
|
||||
@@ -0,0 +1,195 @@
|
||||
from enum import Enum, auto
|
||||
from typing import Dict, Any
|
||||
from langchain_core.prompts import ChatPromptTemplate
|
||||
|
||||
from django.conf import settings
|
||||
from chat_backend.services.base_service import BaseService
|
||||
|
||||
|
||||
class PromptType(Enum):
|
||||
GENERAL_CHAT = auto()
|
||||
RAG = auto()
|
||||
IMAGE_GENERATION = auto()
|
||||
DATA_ANALYSIS = auto()
|
||||
SEARCH = auto()
|
||||
UNKNOWN = auto()
|
||||
|
||||
|
||||
class PromptClassifier(BaseService):
|
||||
"""
|
||||
Classifies user prompts to determine which service should handle them.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(temperature=0.1)
|
||||
# self.llm = OllamaLLM(
|
||||
# model="llama3.2",
|
||||
# temperature=0.3, # Lower temp for more deterministic classification
|
||||
# top_k=20,
|
||||
# top_p=0.9,
|
||||
# num_ctx=4096,
|
||||
# )
|
||||
|
||||
self.classification_prompt = ChatPromptTemplate.from_messages(
|
||||
[
|
||||
(
|
||||
"system",
|
||||
"""You are a precision prompt classifier. Strictly categorize prompts into:
|
||||
1. GENERAL_CHAT - Casual conversation, personal questions, or non-specific inquiries
|
||||
2. RAG - ONLY when explicitly requesting document/search-based knowledge
|
||||
3. IMAGE_GENERATION - Specific requests to create/modify images
|
||||
4. DATA_ANALYSIS - When a user wants to read an uploaded document (PDF, Word, etc.) and generate an index, summary, or extract structured information. Includes prompts like "Please read this document and make me an index for it".
|
||||
5. SEARCH - When the user is seeking specific, up-to-date information (e.g., current events, celebrity news, sports scores).
|
||||
6. UNKNOWN - If none of the above fit
|
||||
|
||||
1. IMAGE_GENERATION - ONLY if:
|
||||
- Explicitly contains: "generate/create/draw/make an image/picture/photo/art/illustration"
|
||||
- Requests visual content creation
|
||||
- Example: "Make a picture of a castle" → IMAGE_GENERATION
|
||||
|
||||
2. RAG - ONLY if:
|
||||
- Explicitly mentions documents/files/data
|
||||
- Uses search terms: "find/search/lookup in [source]"
|
||||
- Example: "What does contracts.pdf say?" → RAG
|
||||
|
||||
3. DATA_ANALYSIS - ONLY if:
|
||||
- The user provides or references an uploaded document (PDF, Word, etc.) and asks for an index, summary, extraction, or analysis of its contents.
|
||||
- Example: "Please read this document and make me an index for it" → DATA_ANALYSIS
|
||||
- Example: "Here is the file content. What is the sum of all 'Sales'?" → DATA_ANALYSIS
|
||||
|
||||
5. SEARCH - ONLY if:
|
||||
- User asks for current information (news, weather, sports, stock prices)
|
||||
- User asks for specific facts that might change or require lookup (e.g. "Who won the 2024 election?")
|
||||
- Example: "What is the latest news on X?" → SEARCH
|
||||
- Example: "Who won the Super Bowl this year?" → SEARCH
|
||||
|
||||
6. GENERAL_CHAT - DEFAULT category when:
|
||||
- Doesn't meet above criteria
|
||||
- Conversational/general knowledge (that doesn't require live search)
|
||||
- Creative writing (poems, jokes)
|
||||
- Example: "Tell me a joke" → GENERAL_CHAT
|
||||
- Example: "Write a poem about cats" → GENERAL_CHAT
|
||||
|
||||
Examples:
|
||||
[Definitely RAG]
|
||||
- "What does the uploaded PDF say about quarterly results?"
|
||||
- "Search our documents for the 2023 marketing strategy"
|
||||
|
||||
[Definitely DATA_ANALYSIS]
|
||||
- "Please read this document and make me an index for it"
|
||||
- "Here is the file content. What is the sum of all 'Sales'?"
|
||||
- "Based on this CSV data, show me the top 5 customers."
|
||||
|
||||
[Definitely SEARCH]
|
||||
- "Who won the 2024 presidential race?"
|
||||
- "What is the latest celebrity news?"
|
||||
- "Current stock price of Apple"
|
||||
|
||||
[Definitely GENERAL_CHAT]
|
||||
- "How does photosynthesis work?" (General knowledge)
|
||||
- "Tell me a joke"
|
||||
- "What's your opinion on AI?"
|
||||
|
||||
[Borderline -> GENERAL_CHAT]
|
||||
- "What's our company policy on X?" (No doc reference -> general)
|
||||
|
||||
Return ONLY the exact Enum label (e.g. "GENERAL_CHAT"), no explanations."""
|
||||
),
|
||||
("human", "{prompt}"),
|
||||
]
|
||||
)
|
||||
|
||||
self.chain = self.classification_prompt | self.llm
|
||||
|
||||
def _quick_check(self, prompt: str) -> PromptType | None:
|
||||
"""
|
||||
Performs a quick, rule-based classification before involving the LLM.
|
||||
Returns a PromptType if a clear match is found, otherwise None.
|
||||
"""
|
||||
lower_prompt = prompt.lower()
|
||||
|
||||
# IMAGE_GENERATION
|
||||
if any(keyword in lower_prompt for keyword in ["generate image", "create picture", "draw an image", "make a photo", "generate an image", "create an illustration"]):
|
||||
if getattr(settings, "ALLOW_IMAGE_GENERATION", False):
|
||||
return PromptType.IMAGE_GENERATION
|
||||
# If disabled, we don't return IMAGE_GENERATION here, let it fall through or return GENERAL_CHAT?
|
||||
# The requirement says "NOT allowed to return IMAGE_GENERATION".
|
||||
# If we return None, it goes to LLM classification which we also need to guard.
|
||||
# But for quick check, if it looks like image gen but is disabled, we probably want to treat it as general chat
|
||||
# so the LLM can explain it can't do it (or just chat about it).
|
||||
return PromptType.GENERAL_CHAT
|
||||
|
||||
# DATA_ANALYSIS (often involves uploaded documents)
|
||||
if any(keyword in lower_prompt for keyword in ["read this document", "analyze this file", "summarize this pdf", "extract data from", "index this document", "based on this csv", "from this spreadsheet"]):
|
||||
return PromptType.DATA_ANALYSIS
|
||||
|
||||
# RAG (explicitly asking to search within provided context/documents)
|
||||
# This might overlap with DATA_ANALYSIS, but RAG is more about retrieval from a knowledge base.
|
||||
# The prompt examples for RAG are "What does the uploaded PDF say about quarterly results?"
|
||||
# "Search our documents for the 2023 marketing strategy"
|
||||
if ("uploaded pdf" in lower_prompt or "our documents" in lower_prompt or "this document" in lower_prompt or "the document" in lower_prompt) and \
|
||||
any(keyword in lower_prompt for keyword in ["say about", "search for", "find in", "lookup in"]):
|
||||
return PromptType.RAG
|
||||
|
||||
# SEARCH
|
||||
if any(keyword in lower_prompt for keyword in ["latest news", "current weather", "stock price", "who won", "what is the current", "breaking news", "real-time information", "up-to-date"]):
|
||||
return PromptType.SEARCH
|
||||
|
||||
return None
|
||||
|
||||
async def classify_async(self, prompt: str) -> PromptType:
|
||||
"""Asynchronously classify the prompt"""
|
||||
quick = self._quick_check(prompt)
|
||||
if quick:
|
||||
return quick
|
||||
try:
|
||||
response = await self.chain.ainvoke({"prompt": prompt})
|
||||
return self._parse_response(response.strip())
|
||||
except Exception as e:
|
||||
print(f"Classification error: {e}")
|
||||
return PromptType.UNKNOWN
|
||||
|
||||
def classify(self, prompt: str) -> PromptType:
|
||||
"""Synchronously classify the prompt"""
|
||||
quick = self._quick_check(prompt)
|
||||
if quick:
|
||||
return quick
|
||||
try:
|
||||
response = self.chain.invoke({"prompt": prompt})
|
||||
return self._parse_response(response.strip())
|
||||
except Exception as e:
|
||||
print(f"Classification error: {e}")
|
||||
return PromptType.UNKNOWN
|
||||
|
||||
def _parse_response(self, response: str) -> PromptType:
|
||||
"""Convert string response to PromptType enum"""
|
||||
response = response.upper().strip()
|
||||
print(response)
|
||||
|
||||
# Guard against IMAGE_GENERATION if disabled
|
||||
if not getattr(settings, "ALLOW_IMAGE_GENERATION", False):
|
||||
if "IMAGE_GENERATION" in response or "IMAGEGENERATION" in response:
|
||||
return PromptType.GENERAL_CHAT
|
||||
|
||||
# Direct match
|
||||
try:
|
||||
return PromptType[response]
|
||||
except KeyError:
|
||||
pass
|
||||
|
||||
# Handle missing underscores (e.g. GENERALCHAT)
|
||||
normalized_response = response.replace("_", "")
|
||||
for prompt_type in PromptType:
|
||||
if prompt_type.name.replace("_", "") == normalized_response:
|
||||
return prompt_type
|
||||
|
||||
# Substring match as fallback
|
||||
for prompt_type in PromptType:
|
||||
if prompt_type.name in response:
|
||||
return prompt_type
|
||||
|
||||
return PromptType.UNKNOWN
|
||||
|
||||
|
||||
# Singleton instance for easy access
|
||||
prompt_classifier = PromptClassifier()
|
||||
31
llm_be/chat_backend/services/prompt_classifier/tests.py
Normal file
31
llm_be/chat_backend/services/prompt_classifier/tests.py
Normal file
@@ -0,0 +1,31 @@
|
||||
import os
|
||||
from unittest import TestCase, mock
|
||||
from unittest.mock import MagicMock, patch, AsyncMock
|
||||
from typing import List, Dict, Any
|
||||
|
||||
from django.test import TestCase as DjangoTestCase
|
||||
|
||||
from chat_backend.services.rag_services import (
|
||||
RAGService,
|
||||
SyncRAGService,
|
||||
AsyncRAGService,
|
||||
)
|
||||
from chat_backend.models import Conversation, Prompt, DocumentWorkspace, Document
|
||||
from chat_backend.services.prompt_classifier.prompt_classifier import PromptClassifier, PromptType
|
||||
from parameterized import parameterized
|
||||
|
||||
|
||||
|
||||
class PromptClassifierTestCase(TestCase):
|
||||
def setUp(self):
|
||||
self.service = PromptClassifier()
|
||||
|
||||
@parameterized.expand([
|
||||
["Tell me a joke",PromptType.GENERAL_CHAT],
|
||||
["Create an image of a dog for me",PromptType.IMAGE_GENERATION],
|
||||
["highlight the features of the backyard playset if they were to choose us and make the language more long form",PromptType.GENERAL_CHAT],
|
||||
["Great, can you make it about a duck now", PromptType.IMAGE_GENERATION],
|
||||
])
|
||||
def test_prompt_classification(self, prompt, expected_output):
|
||||
result = self.service.classify(prompt)
|
||||
self.assertEqual(result, expected_output)
|
||||
372
llm_be/chat_backend/services/rag_services.py
Normal file
372
llm_be/chat_backend/services/rag_services.py
Normal file
@@ -0,0 +1,372 @@
|
||||
import os
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import List, Dict, Any, AsyncGenerator, Generator, Optional
|
||||
from channels.db import database_sync_to_async
|
||||
from langchain_community.embeddings import OllamaEmbeddings
|
||||
from django.conf import settings
|
||||
|
||||
# from langchain_community.llms import Ollama
|
||||
from langchain_ollama import OllamaLLM
|
||||
from langchain_community.vectorstores import Chroma
|
||||
from langchain_core.documents import Document as LangDocument
|
||||
from langchain_core.output_parsers import StrOutputParser
|
||||
from langchain_core.prompts import ChatPromptTemplate
|
||||
from langchain_core.runnables import RunnablePassthrough
|
||||
from langchain_text_splitters import RecursiveCharacterTextSplitter
|
||||
from langchain_community.document_loaders import (
|
||||
PyPDFLoader,
|
||||
Docx2txtLoader,
|
||||
TextLoader,
|
||||
UnstructuredFileLoader,
|
||||
)
|
||||
from django.core.files.uploadedfile import UploadedFile
|
||||
from chat_backend.models import Conversation, Prompt, DocumentWorkspace, Document
|
||||
from pathlib import Path
|
||||
from chat_backend.services.base_service import BaseService
|
||||
|
||||
|
||||
@database_sync_to_async
|
||||
def get_documents(workspace: DocumentWorkspace | None = None):
|
||||
if workspace:
|
||||
return [doc for doc in Document.objects.filter(workspace=workspace)]
|
||||
else:
|
||||
return [doc for doc in Document.objects.all()]
|
||||
|
||||
|
||||
class RAGService(BaseService):
|
||||
"""Abstract base class for RAG services."""
|
||||
|
||||
_instance = None
|
||||
|
||||
def __new__(cls):
|
||||
if cls._instance is None:
|
||||
cls._instance = super().__new__(cls)
|
||||
cls._instance.__init__()
|
||||
return cls._instance
|
||||
|
||||
def __init__(self):
|
||||
self.embedding_model = OllamaEmbeddings(model="llama3.2" if not settings.DEBUG else "gpt-oss:20b")
|
||||
super().__init__()
|
||||
self.text_splitter = RecursiveCharacterTextSplitter(
|
||||
chunk_size=1000, chunk_overlap=200
|
||||
)
|
||||
self.vector_store = self._initialize_vector_store()
|
||||
|
||||
# Supported file types and their loaders
|
||||
self.loader_mapping = {
|
||||
".pdf": PyPDFLoader,
|
||||
".docx": Docx2txtLoader,
|
||||
".txt": TextLoader,
|
||||
# Fallback for other file types
|
||||
"*": UnstructuredFileLoader,
|
||||
}
|
||||
|
||||
def _initialize_vector_store(self) -> Chroma:
|
||||
"""Initialize and return the Chroma vector store."""
|
||||
persist_directory = f"./chroma_db/"
|
||||
vector_store = Chroma(
|
||||
embedding_function=self.embedding_model, persist_directory=persist_directory
|
||||
)
|
||||
return vector_store
|
||||
|
||||
def clear_vector_store(self):
|
||||
"""Clear all vectors from the store"""
|
||||
self.vector_store.delete_collection()
|
||||
self.vector_store = self._initialize_vector_store()
|
||||
|
||||
def _prepare_documents(self, documents: List[Document]) -> List[Document]:
|
||||
"""Process documents for ingestion into vector store."""
|
||||
docs = []
|
||||
|
||||
for doc in documents:
|
||||
print(f"Processing: {doc.file.name}")
|
||||
loader_class = self._get_file_loader(doc.file.name)
|
||||
loader = loader_class(doc.file)
|
||||
|
||||
chunks = self._load_and_split_documents(doc.file.path)
|
||||
if chunks:
|
||||
self.vector_store.add_documents(chunks)
|
||||
self.vector_store.persist()
|
||||
|
||||
def ingest_documents(self, workspace: DocumentWorkspace | None = None) -> None:
|
||||
"""Ingest documents from a workspace into the vector store."""
|
||||
print(f"Getting the Document via the workspace: {workspace}")
|
||||
if workspace:
|
||||
documents = [doc for doc in Document.objects.filter(workspace=workspace)]
|
||||
else:
|
||||
documents = [doc for doc in Document.objects.all()]
|
||||
|
||||
print(f"Processing the documents : {documents}")
|
||||
self._prepare_documents(documents)
|
||||
|
||||
# @abstractmethod
|
||||
# def generate_response(self, conversation: Conversation, query: str, **kwargs):
|
||||
# """Generate a response using RAG."""
|
||||
# pass
|
||||
|
||||
# @abstractmethod
|
||||
# def search_documents(
|
||||
# self, query: str, workspace: Optional[DocumentWorkspace] = None, k: int = 4
|
||||
# ) -> List[Document]:
|
||||
# """Search relevant documents from the vector store."""
|
||||
# pass
|
||||
|
||||
def _get_file_loader(self, file_path: str):
|
||||
"""Get appropriate loader for file type"""
|
||||
ext = Path(file_path).suffix.lower()
|
||||
return self.loader_mapping.get(ext, self.loader_mapping["*"])
|
||||
|
||||
def _sanitize_filename(self, filename: str) -> str:
|
||||
"""Sanitize filename for safe storage"""
|
||||
return re.sub(r"[^\w\-_. ]", "_", filename)
|
||||
|
||||
def _save_uploaded_file(self, uploaded_file: UploadedFile, save_dir: str) -> str:
|
||||
"""Save uploaded file to disk"""
|
||||
os.makedirs(save_dir, exist_ok=True)
|
||||
sanitized_name = self._sanitize_filename(uploaded_file.name)
|
||||
file_path = os.path.join(save_dir, sanitized_name)
|
||||
|
||||
with open(file_path, "wb+") as destination:
|
||||
for chunk in uploaded_file.chunks():
|
||||
destination.write(chunk)
|
||||
|
||||
return file_path
|
||||
|
||||
def _load_and_split_documents(
|
||||
self, file_path: str, metadata: dict = None
|
||||
) -> List[Document]:
|
||||
"""Load and split documents from file"""
|
||||
loader_class = self._get_file_loader(file_path)
|
||||
loader = loader_class(file_path)
|
||||
|
||||
docs = loader.load()
|
||||
if metadata:
|
||||
for doc in docs:
|
||||
doc.metadata.update(metadata)
|
||||
|
||||
return self.text_splitter.split_documents(docs)
|
||||
|
||||
def add_files_to_store(
|
||||
self,
|
||||
file_tupls: List[UploadedFile], # (file_path, name,workspace_id)
|
||||
workspace_id: str,
|
||||
source: str = "upload",
|
||||
save_dir: str = "data/uploads",
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Process and add uploaded files to vector store
|
||||
|
||||
Args:
|
||||
files: List of Django UploadedFile objects
|
||||
workspace_id: ID of the workspace these belong to
|
||||
source: Source identifier for documents
|
||||
save_dir: Directory to save uploaded files
|
||||
|
||||
Returns:
|
||||
Dictionary with processing results
|
||||
"""
|
||||
results = {"total_added": 0, "failed_files": [], "processed_files": []}
|
||||
|
||||
for file_tuple in file_tupls:
|
||||
try:
|
||||
# Save file to disk
|
||||
|
||||
# Prepare metadata
|
||||
metadata = {
|
||||
"source": file_tuple[1],
|
||||
"workspace_id": file_tuple[2],
|
||||
"original_filename": file_tuple[1],
|
||||
"file_path": file_tuple[0],
|
||||
}
|
||||
|
||||
# Load and split documents
|
||||
docs = self._load_and_split_documents(file_path, metadata)
|
||||
|
||||
# Add to vector store
|
||||
if docs:
|
||||
self.vector_store.add_documents(docs)
|
||||
results["total_added"] += len(docs)
|
||||
results["processed_files"].append(
|
||||
{"filename": file_tuple[1], "document_count": len(docs)}
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
results["failed_files"].append(
|
||||
{"filename": file_tuple[1], "error": str(e)}
|
||||
)
|
||||
continue
|
||||
|
||||
# Persist changes
|
||||
self.vector_store.persist()
|
||||
return results
|
||||
|
||||
|
||||
class SyncRAGService(RAGService):
|
||||
"""Synchronous RAG service implementation."""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self._setup_chain()
|
||||
|
||||
def _setup_chain(self):
|
||||
"""Setup the RAG chain."""
|
||||
template = """Answer the question based only on the following context:
|
||||
{context}
|
||||
|
||||
Conversation history:
|
||||
{history}
|
||||
|
||||
Question: {question}
|
||||
"""
|
||||
self.prompt = ChatPromptTemplate.from_template(template)
|
||||
|
||||
self.rag_chain = (
|
||||
{
|
||||
"context": self._retriever_with_history,
|
||||
"history": lambda x: self._format_history(x["conversation"]),
|
||||
"question": lambda x: x["query"],
|
||||
}
|
||||
| self.prompt
|
||||
| self.llm
|
||||
| StrOutputParser()
|
||||
)
|
||||
|
||||
def _format_history(self, conversation: Conversation) -> str:
|
||||
"""Format conversation history for the prompt."""
|
||||
prompts = Prompt.objects.filter(conversation=conversation).order_by(
|
||||
"created_at"
|
||||
)
|
||||
return "\n".join(
|
||||
f"{'User' if prompt.is_user else 'AI'}: {prompt.text}" for prompt in prompts
|
||||
)
|
||||
|
||||
def _retriever_with_history(self, input_dict: Dict[str, Any]) -> str:
|
||||
"""Retrieve documents considering conversation history."""
|
||||
query = input_dict["query"]
|
||||
conversation = input_dict["conversation"]
|
||||
|
||||
# You could enhance this to consider historical context in retrieval
|
||||
relevant_docs = self.search_documents(query, conversation.workspace)
|
||||
if not relevant_docs:
|
||||
print("didn't find any relevant docs")
|
||||
return relevant_docs
|
||||
else:
|
||||
return relevant_docs
|
||||
|
||||
def search_documents(
|
||||
self, query: str, workspace: Optional[DocumentWorkspace] = None, k: int = 4
|
||||
) -> List[Document]:
|
||||
"""Search relevant documents from the vector store."""
|
||||
filter_dict = {}
|
||||
if workspace:
|
||||
filter_dict["workspace_id"] = workspace.id
|
||||
print(f"search_kwargs: {search_kwargs}")
|
||||
retriever = self.vector_store.as_retriever(
|
||||
search_type="similarity",
|
||||
search_kwargs={"k": k, "filter": filter_dict if filter_dict else None},
|
||||
)
|
||||
return retriever.get_relevant_documents(query)
|
||||
|
||||
def generate_response(
|
||||
self, conversation: Conversation, query: str, **kwargs
|
||||
) -> Generator[str, None, None]:
|
||||
"""Generate response with streaming support."""
|
||||
chain_input = {"query": query, "conversation": conversation}
|
||||
|
||||
for chunk in self.rag_chain.stream(chain_input):
|
||||
yield chunk
|
||||
|
||||
|
||||
class AsyncRAGService(RAGService):
|
||||
"""Asynchronous RAG service implementation."""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self._setup_chain()
|
||||
|
||||
def _setup_chain(self):
|
||||
"""Setup the RAG chain."""
|
||||
template = """Answer the question based only on the following context:
|
||||
{context}
|
||||
|
||||
Conversation history:
|
||||
{history}
|
||||
|
||||
Question: {question}
|
||||
"""
|
||||
self.prompt = ChatPromptTemplate.from_template(template)
|
||||
|
||||
self.rag_chain = (
|
||||
{
|
||||
"context": self._retriever_with_history,
|
||||
"history": lambda x: x['recent_conversation'], #self._format_history(x["conversation"]),
|
||||
"question": lambda x: x["query"],
|
||||
}
|
||||
| self.prompt
|
||||
| self.llm
|
||||
| StrOutputParser()
|
||||
)
|
||||
|
||||
async def _format_history(self, conversation: Conversation) -> str:
|
||||
"""Format conversation history for the prompt."""
|
||||
# prompts = (
|
||||
# await Prompt.objects.filter(conversation=conversation)
|
||||
# .order_by("created_at")
|
||||
# .alist()
|
||||
# )
|
||||
# print(f"prompts that we are seeding with are: {prompts}")
|
||||
# return "\n".join(
|
||||
# f"{'User' if prompt.is_user else 'AI'}: {prompt.text}" for prompt in prompts
|
||||
# )
|
||||
return "\n".join([f"{"User" if prompt.type=="human" else "AI"}: {prompt.text()}" for prompt in conversation])
|
||||
|
||||
async def _retriever_with_history(self, input_dict: Dict[str, Any]) -> str:
|
||||
"""Retrieve documents considering conversation history."""
|
||||
print(f"Retrieving history with input: {input_dict}")
|
||||
query = input_dict["query"]
|
||||
conversation = input_dict["conversation"]
|
||||
workspace = input_dict["workspace"]
|
||||
|
||||
# You could enhance this to consider historical context in retrieval
|
||||
docs = await self.search_documents(query, workspace)
|
||||
|
||||
if not docs:
|
||||
print("Didn't find any relevant docs")
|
||||
|
||||
print("\n\n".join(doc.page_content for doc in docs))
|
||||
return "\n\n".join(doc.page_content for doc in docs)
|
||||
|
||||
async def search_documents(
|
||||
self, query: str, workspace: Optional[DocumentWorkspace] = None, k: int = 4
|
||||
) -> List[Document]:
|
||||
"""Search relevant documents from the vector store."""
|
||||
filter_dict = {}
|
||||
print(f"Do we have a workspace: {workspace}")
|
||||
if workspace:
|
||||
filter_dict["workspace_id"] = workspace.id
|
||||
search_kwargs = {"k": k, "filter": filter_dict if filter_dict else None}
|
||||
print(f"search_kwargs: {search_kwargs}")
|
||||
|
||||
retriever = self.vector_store.as_retriever(
|
||||
search_type="mmr",
|
||||
search_kwargs={"k": k, "filter": filter_dict if filter_dict else None},
|
||||
)
|
||||
return await retriever.aget_relevant_documents(query)
|
||||
|
||||
async def generate_response(
|
||||
self,
|
||||
conversation: Conversation,
|
||||
query: str,
|
||||
workspace: DocumentWorkspace,
|
||||
**kwargs,
|
||||
) -> AsyncGenerator[str, None]:
|
||||
"""Generate response with streaming support."""
|
||||
chain_input = {
|
||||
"query": query,
|
||||
"conversation": conversation,
|
||||
"workspace": workspace,
|
||||
"recent_conversation": await self._format_history(conversation),
|
||||
}
|
||||
|
||||
async for chunk in self.rag_chain.astream(chain_input):
|
||||
yield chunk
|
||||
251
llm_be/chat_backend/services/tests.py
Normal file
251
llm_be/chat_backend/services/tests.py
Normal file
@@ -0,0 +1,251 @@
|
||||
import os
|
||||
from unittest import TestCase, mock
|
||||
from unittest.mock import MagicMock, patch, AsyncMock
|
||||
from typing import List, Dict, Any
|
||||
|
||||
from django.test import TestCase as DjangoTestCase
|
||||
|
||||
from chat_backend.services.rag_services import (
|
||||
RAGService,
|
||||
SyncRAGService,
|
||||
AsyncRAGService,
|
||||
)
|
||||
from chat_backend.models import Conversation, Prompt, DocumentWorkspace, Document
|
||||
from chat_backend.services.prompt_classifier import PromptClassifier, PromptType
|
||||
from parameterized import parameterized
|
||||
|
||||
|
||||
# class TestRAGService(TestCase):
|
||||
# def setUp(self):
|
||||
# self.rag_service = RAGService()
|
||||
# self.rag_service.vector_store = MagicMock()
|
||||
# self.rag_service.embedding_model = MagicMock()
|
||||
# self.rag_service.text_splitter = MagicMock()
|
||||
|
||||
# def test_initialize_vector_store(self):
|
||||
# with patch("os.path.exists", return_value=False), patch(
|
||||
# "os.makedirs"
|
||||
# ) as mock_makedirs, patch(
|
||||
# "langchain_community.vectorstores.Chroma"
|
||||
# ) as mock_chroma:
|
||||
|
||||
# # Reset the vector store to test initialization
|
||||
# self.rag_service.vector_store = None
|
||||
# result = self.rag_service._initialize_vector_store()
|
||||
|
||||
# mock_makedirs.assert_called_once_with("chroma_db")
|
||||
# mock_chroma.assert_called_once_with(
|
||||
# embedding_function=self.rag_service.embedding_model,
|
||||
# persist_directory="chroma_db",
|
||||
# )
|
||||
# self.assertIsNotNone(result)
|
||||
|
||||
# def test_prepare_documents(self):
|
||||
# mock_doc1 = MagicMock(spec=Document)
|
||||
# mock_doc1.content = "Test content"
|
||||
# mock_doc1.source = "test_source"
|
||||
# mock_doc1.workspace = MagicMock()
|
||||
# mock_doc1.workspace.id = 1
|
||||
# mock_doc1.id = 1
|
||||
|
||||
# self.rag_service.text_splitter.split_text.return_value = ["chunk1", "chunk2"]
|
||||
|
||||
# result = self.rag_service._prepare_documents([mock_doc1])
|
||||
|
||||
# self.assertEqual(len(result), 2)
|
||||
# self.rag_service.text_splitter.split_text.assert_called_once_with(
|
||||
# "Test content"
|
||||
# )
|
||||
# self.assertEqual(result[0].page_content, "chunk1")
|
||||
# self.assertEqual(result[0].metadata["source"], "test_source")
|
||||
|
||||
# def test_ingest_documents(self):
|
||||
# mock_workspace = MagicMock()
|
||||
# mock_document = MagicMock()
|
||||
# mock_documents = [mock_document]
|
||||
|
||||
# with patch(
|
||||
# "services.rag_services.Document.objects.filter", return_value=mock_documents
|
||||
# ):
|
||||
# self.rag_service._prepare_documents = MagicMock(
|
||||
# return_value=["processed_doc"]
|
||||
# )
|
||||
|
||||
# self.rag_service.ingest_documents(mock_workspace)
|
||||
|
||||
# self.rag_service.vector_store.add_documents.assert_called_once_with(
|
||||
# ["processed_doc"]
|
||||
# )
|
||||
# self.rag_service.vector_store.persist.assert_called_once()
|
||||
|
||||
|
||||
# class TestSyncRAGService(DjangoTestCase):
|
||||
# def setUp(self):
|
||||
# self.sync_service = SyncRAGService()
|
||||
# self.sync_service.vector_store = MagicMock()
|
||||
# self.sync_service.llm = MagicMock()
|
||||
# self.sync_service.rag_chain = MagicMock()
|
||||
|
||||
# self.mock_conversation = MagicMock(spec=Conversation)
|
||||
# self.mock_conversation.workspace = MagicMock()
|
||||
|
||||
# self.mock_prompt1 = MagicMock(spec=Prompt)
|
||||
# self.mock_prompt1.is_user = True
|
||||
# self.mock_prompt1.text = "User question"
|
||||
# self.mock_prompt1.created_at = "2023-01-01"
|
||||
|
||||
# self.mock_prompt2 = MagicMock(spec=Prompt)
|
||||
# self.mock_prompt2.is_user = False
|
||||
# self.mock_prompt2.text = "AI response"
|
||||
# self.mock_prompt2.created_at = "2023-01-02"
|
||||
|
||||
# def test_format_history(self):
|
||||
# with patch("services.rag_services.Prompt.objects.filter") as mock_filter:
|
||||
# mock_filter.return_value.order_by.return_value = [
|
||||
# self.mock_prompt1,
|
||||
# self.mock_prompt2,
|
||||
# ]
|
||||
|
||||
# result = self.sync_service._format_history(self.mock_conversation)
|
||||
|
||||
# expected = "User: User question\nAI: AI response"
|
||||
# self.assertEqual(result, expected)
|
||||
# mock_filter.assert_called_once_with(conversation=self.mock_conversation)
|
||||
|
||||
# def test_retriever_with_history(self):
|
||||
# input_dict = {"query": "test query", "conversation": self.mock_conversation}
|
||||
|
||||
# self.sync_service.search_documents = MagicMock(return_value=["doc1", "doc2"])
|
||||
|
||||
# result = self.sync_service._retriever_with_history(input_dict)
|
||||
|
||||
# self.sync_service.search_documents.assert_called_once_with(
|
||||
# "test query", self.mock_conversation.workspace
|
||||
# )
|
||||
# self.assertEqual(result, ["doc1", "doc2"])
|
||||
|
||||
# def test_search_documents(self):
|
||||
# mock_retriever = MagicMock()
|
||||
# mock_retriever.get_relevant_documents.return_value = ["doc1", "doc2"]
|
||||
# self.sync_service.vector_store.as_retriever.return_value = mock_retriever
|
||||
|
||||
# result = self.sync_service.search_documents(
|
||||
# "test query", self.mock_conversation.workspace
|
||||
# )
|
||||
|
||||
# self.sync_service.vector_store.as_retriever.assert_called_once_with(
|
||||
# search_type="similarity",
|
||||
# search_kwargs={
|
||||
# "k": 4,
|
||||
# "filter": {"workspace_id": self.mock_conversation.workspace.id},
|
||||
# },
|
||||
# )
|
||||
# self.assertEqual(result, ["doc1", "doc2"])
|
||||
|
||||
# def test_generate_response(self):
|
||||
# chain_input = {"query": "test query", "conversation": self.mock_conversation}
|
||||
|
||||
# mock_stream = ["chunk1", "chunk2", "chunk3"]
|
||||
# self.sync_service.rag_chain.stream.return_value = mock_stream
|
||||
|
||||
# result = list(
|
||||
# self.sync_service.generate_response(self.mock_conversation, "test query")
|
||||
# )
|
||||
|
||||
# self.sync_service.rag_chain.stream.assert_called_once_with(chain_input)
|
||||
# self.assertEqual(result, mock_stream)
|
||||
|
||||
|
||||
# class TestAsyncRAGService(DjangoTestCase):
|
||||
# def setUp(self):
|
||||
# self.async_service = AsyncRAGService()
|
||||
# self.async_service.vector_store = MagicMock()
|
||||
# self.async_service.llm = MagicMock()
|
||||
# self.async_service.rag_chain = AsyncMock()
|
||||
|
||||
# self.mock_conversation = MagicMock(spec=Conversation)
|
||||
# self.mock_conversation.workspace = MagicMock()
|
||||
|
||||
# self.mock_prompt1 = MagicMock(spec=Prompt)
|
||||
# self.mock_prompt1.is_user = True
|
||||
# self.mock_prompt1.text = "User question"
|
||||
# self.mock_prompt1.created_at = "2023-01-01"
|
||||
|
||||
# self.mock_prompt2 = MagicMock(spec=Prompt)
|
||||
# self.mock_prompt2.is_user = False
|
||||
# self.mock_prompt2.text = "AI response"
|
||||
# self.mock_prompt2.created_at = "2023-01-02"
|
||||
|
||||
# async def test_format_history(self):
|
||||
# mock_manager = AsyncMock()
|
||||
# mock_manager.order_by.return_value.alist.return_value = [
|
||||
# self.mock_prompt1,
|
||||
# self.mock_prompt2,
|
||||
# ]
|
||||
|
||||
# with patch(
|
||||
# "services.rag_services.Prompt.objects.filter", return_value=mock_manager
|
||||
# ):
|
||||
# result = await self.async_service._format_history(self.mock_conversation)
|
||||
|
||||
# expected = "User: User question\nAI: AI response"
|
||||
# self.assertEqual(result, expected)
|
||||
# mock_manager.order_by.assert_called_once_with("created_at")
|
||||
|
||||
# async def test_retriever_with_history(self):
|
||||
# input_dict = {"query": "test query", "conversation": self.mock_conversation}
|
||||
|
||||
# self.async_service.search_documents = AsyncMock(return_value=["doc1", "doc2"])
|
||||
|
||||
# result = await self.async_service._retriever_with_history(input_dict)
|
||||
|
||||
# self.async_service.search_documents.assert_awaited_once_with(
|
||||
# "test query", self.mock_conversation.workspace
|
||||
# )
|
||||
# self.assertEqual(result, ["doc1", "doc2"])
|
||||
|
||||
# async def test_search_documents(self):
|
||||
# mock_retriever = AsyncMock()
|
||||
# mock_retriever.aget_relevant_documents.return_value = ["doc1", "doc2"]
|
||||
# self.async_service.vector_store.as_retriever.return_value = mock_retriever
|
||||
|
||||
# result = await self.async_service.search_documents(
|
||||
# "test query", self.mock_conversation.workspace
|
||||
# )
|
||||
|
||||
# self.async_service.vector_store.as_retriever.assert_called_once_with(
|
||||
# search_type="similarity",
|
||||
# search_kwargs={
|
||||
# "k": 4,
|
||||
# "filter": {"workspace_id": self.mock_conversation.workspace.id},
|
||||
# },
|
||||
# )
|
||||
# self.assertEqual(result, ["doc1", "doc2"])
|
||||
|
||||
# async def test_generate_response(self):
|
||||
# chain_input = {"query": "test query", "conversation": self.mock_conversation}
|
||||
|
||||
# mock_stream = ["chunk1", "chunk2", "chunk3"]
|
||||
# self.async_service.rag_chain.astream.return_value = mock_stream
|
||||
|
||||
# chunks = []
|
||||
# async for chunk in self.async_service.generate_response(
|
||||
# self.mock_conversation, "test query"
|
||||
# ):
|
||||
# chunks.append(chunk)
|
||||
|
||||
# self.async_service.rag_chain.astream.assert_awaited_once_with(chain_input)
|
||||
# self.assertEqual(chunks, mock_stream)
|
||||
|
||||
class PromptClassifierTestCase(TestCase):
|
||||
def setUp(self):
|
||||
self.service = PromptClassifier()
|
||||
|
||||
@parameterized.expand([
|
||||
["Tell me a joke",PromptType.GENERAL_CHAT],
|
||||
["Create an image of a dog for me",PromptType.IMAGE_GENERATION],
|
||||
["highlight the features of the backyard playset if they were to choose us and make the language more long form",PromptType.GENERAL_CHAT],
|
||||
])
|
||||
def test_prompt_classification(self, prompt, expected_output):
|
||||
result = self.service.classify(prompt)
|
||||
self.assertEqual(result, expected_output)
|
||||
75
llm_be/chat_backend/services/title_generator.py
Normal file
75
llm_be/chat_backend/services/title_generator.py
Normal file
@@ -0,0 +1,75 @@
|
||||
from langchain_core.prompts import ChatPromptTemplate
|
||||
|
||||
# from langchain_community.llms import Ollama
|
||||
from langchain_ollama import OllamaLLM
|
||||
from typing import Optional
|
||||
|
||||
|
||||
class TitleGenerator:
|
||||
"""
|
||||
Generates short, descriptive titles for conversations based on the first prompt.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.llm = OllamaLLM(
|
||||
model="llama3.2",
|
||||
temperature=0.5, # Slightly creative but not too random
|
||||
top_k=20,
|
||||
num_ctx=2048, # Shorter context needed for titles
|
||||
)
|
||||
|
||||
self.title_prompt = ChatPromptTemplate.from_messages(
|
||||
[
|
||||
(
|
||||
"system",
|
||||
"""You are a conversation title generator. Create a very short (2-5 word) title based on the user's first message.
|
||||
|
||||
Rules:
|
||||
1. Keep it extremely concise
|
||||
2. Capture the main topic or intent
|
||||
3. Use title case
|
||||
4. No quotes or punctuation
|
||||
5. Never exceed 5 words
|
||||
|
||||
Examples:
|
||||
- "What's the weather today?" → "Weather Inquiry"
|
||||
- "Explain quantum computing" → "Quantum Computing Explanation"
|
||||
- "Generate an image of a dragon" → "Dragon Image Generation"
|
||||
- "Find our company's privacy policy" → "Privacy Policy Search"
|
||||
|
||||
Return ONLY the title, nothing else.""",
|
||||
),
|
||||
("human", "{prompt}"),
|
||||
]
|
||||
)
|
||||
|
||||
self.chain = self.title_prompt | self.llm
|
||||
|
||||
async def generate_async(self, prompt: str) -> str:
|
||||
"""Generate title asynchronously"""
|
||||
try:
|
||||
response = await self.chain.ainvoke({"prompt": prompt})
|
||||
return self._clean_response(response)
|
||||
except Exception as e:
|
||||
print(f"Title generation error: {e}")
|
||||
return "Conversation"
|
||||
|
||||
def generate(self, prompt: str) -> str:
|
||||
"""Generate title synchronously"""
|
||||
try:
|
||||
response = self.chain.invoke({"prompt": prompt})
|
||||
return self._clean_response(response)
|
||||
except Exception as e:
|
||||
print(f"Title generation error: {e}")
|
||||
return "Conversation"
|
||||
|
||||
def _clean_response(self, response: str) -> str:
|
||||
"""Clean and format the LLM response"""
|
||||
# Remove any quotes or punctuation
|
||||
response = response.strip("\"'.!? \n\t")
|
||||
# Ensure title case and trim
|
||||
return response.title()[:50] # Hard limit for safety
|
||||
|
||||
|
||||
# Singleton instance
|
||||
title_generator = TitleGenerator()
|
||||
20
llm_be/chat_backend/signals.py
Normal file
20
llm_be/chat_backend/signals.py
Normal file
@@ -0,0 +1,20 @@
|
||||
from django.db.models.signals import post_save, post_delete
|
||||
from django.dispatch import receiver
|
||||
from chat_backend.models import Document
|
||||
from .services.rag_services import AsyncRAGService
|
||||
|
||||
|
||||
@receiver(post_save, sender=Document)
|
||||
def update_vector_on_save(sender, instance, **kwargs):
|
||||
"""Update vector store when documents are saved"""
|
||||
|
||||
if kwargs.get("created", False):
|
||||
rag_service = AsyncRAGService()
|
||||
rag_service.ingest_documents()
|
||||
|
||||
|
||||
@receiver(post_delete, sender=Document)
|
||||
def delete_vector_on_remove(sender, instance, **kwargs):
|
||||
"""Handle document deletion by re-indexing the whole workspace"""
|
||||
rag_service = AsyncRAGService()
|
||||
rag_service.ingest_documents()
|
||||
97
llm_be/chat_backend/templates/emails/reset_email.html
Normal file
97
llm_be/chat_backend/templates/emails/reset_email.html
Normal file
@@ -0,0 +1,97 @@
|
||||
<!DOCTYPE html>
|
||||
<html lang="en">
|
||||
<head>
|
||||
<meta charset="UTF-8">
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
||||
<title>Reset Password for Chat by AI ML Operations, LLC</title>
|
||||
<style>
|
||||
/* Basic reset for email clients */
|
||||
body, table, td, a {
|
||||
-webkit-text-size-adjust: 100%;
|
||||
-ms-text-size-adjust: 100%;
|
||||
}
|
||||
table, td {
|
||||
mso-table-lspace: 0pt;
|
||||
mso-table-rspace: 0pt;
|
||||
}
|
||||
img {
|
||||
border: 0;
|
||||
height: auto;
|
||||
line-height: 100%;
|
||||
outline: none;
|
||||
text-decoration: none;
|
||||
-ms-interpolation-mode: bicubic;
|
||||
}
|
||||
body {
|
||||
margin: 0;
|
||||
padding: 0;
|
||||
font-family: Arial, sans-serif;
|
||||
background-color: #f4f4f4;
|
||||
}
|
||||
.email-container {
|
||||
max-width: 600px;
|
||||
margin: 0 auto;
|
||||
background-color: #ffffff;
|
||||
border: 1px solid #dddddd;
|
||||
}
|
||||
.header {
|
||||
background-color: #007BFF;
|
||||
color: #ffffff;
|
||||
padding: 20px;
|
||||
text-align: center;
|
||||
}
|
||||
.content {
|
||||
padding: 20px;
|
||||
color: #333333;
|
||||
}
|
||||
.footer {
|
||||
background-color: #f4f4f4;
|
||||
color: #777777;
|
||||
text-align: center;
|
||||
padding: 10px;
|
||||
font-size: 12px;
|
||||
}
|
||||
.feedback-title {
|
||||
font-size: 18px;
|
||||
font-weight: bold;
|
||||
margin-bottom: 10px;
|
||||
}
|
||||
.feedback-text {
|
||||
font-size: 14px;
|
||||
line-height: 1.5;
|
||||
}
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
<table role="presentation" width="100%" cellspacing="0" cellpadding="0" border="0" align="center">
|
||||
<tr>
|
||||
<td>
|
||||
<!-- Email Container -->
|
||||
<div class="email-container">
|
||||
<!-- Header -->
|
||||
<div class="header">
|
||||
<h1>Password Reset for AI ML Operations, LLC Chat Services</h1>
|
||||
</div>
|
||||
|
||||
<!-- Content -->
|
||||
<div class="content">
|
||||
<p>Hello,</p>
|
||||
<p>There has been a request for a password reset. If you didn't requets this, please email ryan@aimloperations.com</p>
|
||||
|
||||
<p>Please click <a href="{{ url }}">link</a> to set your password.</p>
|
||||
<p>Once you have set your password go <a href="https://chat.aimloperations.com">here</a> to get started.</p>
|
||||
|
||||
<p>Thank you.</p>
|
||||
</div>
|
||||
|
||||
<!-- Footer -->
|
||||
<div class="footer">
|
||||
<p>This is an automated message. Please do not reply to this email.</p>
|
||||
<p>© 2023-2025 AI ML Operations, LLC. All rights reserved.</p>
|
||||
</div>
|
||||
</div>
|
||||
</td>
|
||||
</tr>
|
||||
</table>
|
||||
</body>
|
||||
</html>
|
||||
3
llm_be/chat_backend/templates/emails/reset_email.txt
Normal file
3
llm_be/chat_backend/templates/emails/reset_email.txt
Normal file
@@ -0,0 +1,3 @@
|
||||
Password Reset for AI ML Operations, LLC Chat Services
|
||||
|
||||
"Password reset for chat.aimloperations.com. Please use {{ url }} to set your password"
|
||||
@@ -1,3 +1,194 @@
|
||||
from django.test import TestCase
|
||||
|
||||
# Create your tests here.
|
||||
from django.test import TestCase, Client
|
||||
from django.urls import reverse
|
||||
from django.contrib.auth.models import User
|
||||
from rest_framework.test import APIClient, APITestCase
|
||||
from rest_framework import status
|
||||
from .models import DocumentWorkspace, Document, Company
|
||||
from django.contrib.auth import get_user_model
|
||||
import tempfile
|
||||
from django.core.files.uploadedfile import SimpleUploadedFile
|
||||
from parameterized import parameterized
|
||||
|
||||
|
||||
# Minimal valid PDF bytes
|
||||
VALID_PDF_BYTES = (
|
||||
b"%PDF-1.3\n"
|
||||
b"1 0 obj\n"
|
||||
b"<< /Type /Catalog /Pages 2 0 R >>\n"
|
||||
b"endobj\n"
|
||||
b"2 0 obj\n"
|
||||
b"<< /Type /Pages /Kids [3 0 R] /Count 1 >>\n"
|
||||
b"endobj\n"
|
||||
b"3 0 obj\n"
|
||||
b"<< /Type /Page /Parent 2 0 R /Resources << >> /MediaBox [0 0 612 792] /Contents 4 0 R >>\n"
|
||||
b"endobj\n"
|
||||
b"4 0 obj\n"
|
||||
b"<< /Length 44 >>\n"
|
||||
b"stream\n"
|
||||
b"BT /F1 12 Tf 72 720 Td (Test PDF) Tj ET\n"
|
||||
b"endstream\n"
|
||||
b"endobj\n"
|
||||
b"xref\n"
|
||||
b"0 5\n"
|
||||
b"0000000000 65535 f \n"
|
||||
b"0000000009 00000 n \n"
|
||||
b"0000000058 00000 n \n"
|
||||
b"0000000117 00000 n \n"
|
||||
b"0000000223 00000 n \n"
|
||||
b"trailer\n"
|
||||
b"<< /Size 5 /Root 1 0 R >>\n"
|
||||
b"startxref\n"
|
||||
b"317\n"
|
||||
b"%%EOF"
|
||||
)
|
||||
|
||||
|
||||
class DocumentWorkspaceViewsTestCase(APITestCase):
|
||||
def setUp(self):
|
||||
self.company = Company.objects.create(
|
||||
name="test", state="IL", zipcode="60189", address="1968 Greensboro Dr"
|
||||
)
|
||||
self.user = get_user_model().objects.create_user(
|
||||
company=self.company,
|
||||
username="testuser",
|
||||
password="testpass123",
|
||||
email="test@test.com",
|
||||
)
|
||||
|
||||
self.client = APIClient()
|
||||
self.client.force_authenticate(user=self.user)
|
||||
|
||||
self.workspace = DocumentWorkspace.objects.create(
|
||||
company=self.user.company, name="Test Workspace"
|
||||
)
|
||||
|
||||
def test_list_workspaces(self):
|
||||
url = reverse("document_workspaces")
|
||||
response = self.client.get(url)
|
||||
self.assertEqual(response.status_code, status.HTTP_200_OK)
|
||||
self.assertEqual(len(response.data), 1)
|
||||
self.assertEqual(response.data[0]["name"], "Test Workspace")
|
||||
|
||||
def test_create_workspace(self):
|
||||
url = reverse("document_workspaces")
|
||||
data = {"name": "New Workspace"}
|
||||
response = self.client.post(url, data, format="json")
|
||||
self.assertEqual(response.status_code, status.HTTP_201_CREATED)
|
||||
self.assertEqual(DocumentWorkspace.objects.count(), 2)
|
||||
|
||||
def test_retrieve_workspace(self):
|
||||
url = reverse("document_workspaces")
|
||||
response = self.client.get(url)
|
||||
self.assertEqual(response.status_code, status.HTTP_200_OK)
|
||||
self.assertEqual(response.data[0]["name"], "Test Workspace")
|
||||
|
||||
# def test_update_workspace(self):
|
||||
# url = reverse('document_workspaces')
|
||||
# data = {
|
||||
# 'name': 'Updated Workspace'
|
||||
# }
|
||||
# response = self.client.post(url, data, format='json')
|
||||
# self.assertEqual(response.status_code, status.HTTP_201_CREATED)
|
||||
# self.workspace.refresh_from_db()
|
||||
# self.assertEqual(self.workspace.name, 'Updated Workspace')
|
||||
|
||||
# def test_delete_workspace(self):
|
||||
# url = reverse('document_workspaces', args=[self.workspace.id])
|
||||
# response = self.client.delete(url)
|
||||
# self.assertEqual(response.status_code, status.HTTP_204_NO_CONTENT)
|
||||
# self.assertEqual(DocumentWorkspace.objects.count(), 0)
|
||||
|
||||
|
||||
class DocumentViewsTestCase(APITestCase):
|
||||
def setUp(self):
|
||||
self.company = Company.objects.create(
|
||||
name="test", state="IL", zipcode="60189", address="1968 Greensboro Dr"
|
||||
)
|
||||
self.user = get_user_model().objects.create_user(
|
||||
company=self.company,
|
||||
username="testuser",
|
||||
password="testpass123",
|
||||
email="test@test.com",
|
||||
)
|
||||
|
||||
self.client = APIClient()
|
||||
self.client.force_authenticate(user=self.user)
|
||||
|
||||
self.workspace = DocumentWorkspace.objects.create(
|
||||
company=self.user.company, name="Test Workspace"
|
||||
)
|
||||
|
||||
# Create a test file
|
||||
self.test_file = SimpleUploadedFile(
|
||||
"test.pdf", VALID_PDF_BYTES, content_type="application/pdf"
|
||||
)
|
||||
|
||||
def test_upload_document(self):
|
||||
url = reverse("documents")
|
||||
data = {"file": self.test_file}
|
||||
response = self.client.post(url, data, format="multipart")
|
||||
self.assertEqual(response.status_code, status.HTTP_201_CREATED)
|
||||
self.assertEqual(Document.objects.count(), 1)
|
||||
|
||||
document = Document.objects.first()
|
||||
self.assertEqual(document.workspace.id, self.workspace.id)
|
||||
self.assertTrue(document.processed) # Should be False initially
|
||||
|
||||
def test_list_documents(self):
|
||||
# First create a document
|
||||
Document.objects.create(workspace=self.workspace, file=self.test_file)
|
||||
|
||||
url = reverse("documents")
|
||||
response = self.client.get(url)
|
||||
self.assertEqual(response.status_code, status.HTTP_200_OK)
|
||||
self.assertEqual(len(response.data), 1)
|
||||
self.assertIn("test", response.data[0]["file"])
|
||||
self.assertIn("pdf", response.data[0]["file"])
|
||||
|
||||
# def test_delete_document(self):
|
||||
# document = Document.objects.create(
|
||||
# workspace=self.workspace,
|
||||
# file=self.test_file
|
||||
# )
|
||||
|
||||
# url = reverse('document-detail', args=[document.id])
|
||||
# response = self.client.delete(url)
|
||||
# self.assertEqual(response.status_code, status.HTTP_204_NO_CONTENT)
|
||||
# self.assertEqual(Document.objects.count(), 0)
|
||||
|
||||
def test_upload_invalid_file(self):
|
||||
url = reverse("documents")
|
||||
data = {"file": "not a file"}
|
||||
response = self.client.post(url, data, format="multipart")
|
||||
self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
|
||||
|
||||
def test_access_other_users_documents(self):
|
||||
# Create another user
|
||||
other_company = Company.objects.create(
|
||||
name="test2", state="IL", zipcode="60189", address="1968 Greensboro Dr"
|
||||
)
|
||||
other_user = get_user_model().objects.create_user(
|
||||
company=other_company,
|
||||
username="otheruser",
|
||||
password="otherpass123",
|
||||
email="testing2@test.com",
|
||||
)
|
||||
other_workspace = DocumentWorkspace.objects.create(
|
||||
company=other_user.company, name="Other Workspace"
|
||||
)
|
||||
other_document = Document.objects.create(
|
||||
workspace=other_workspace, file=self.test_file
|
||||
)
|
||||
|
||||
# Try to access the other user's document
|
||||
url = reverse("documents_details", kwargs={"document_id": other_document.id})
|
||||
response = self.client.get(url)
|
||||
self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND)
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
@@ -14,27 +14,41 @@ from .views import (
|
||||
ConversationDetailView,
|
||||
CompanyUsersView,
|
||||
SetUserPassword,
|
||||
ResetUserPassword,
|
||||
ConversationPreferences,
|
||||
UserPromptAnalytics,
|
||||
UserConversationAnalytics,
|
||||
CompanyUsageAnalytics,
|
||||
AdminAnalytics
|
||||
AdminAnalytics,
|
||||
reset_password,
|
||||
DocumentWorkspaceView,
|
||||
DocumentUploadView,
|
||||
DocumentDetailView,
|
||||
)
|
||||
from rest_framework.routers import DefaultRouter
|
||||
|
||||
|
||||
urlpatterns = [
|
||||
path("token/obtain/", CustomObtainTokenView.as_view(), name="token_create"),
|
||||
path("token/refresh/", jwt_views.TokenRefreshView.as_view(), name="token_refresh"),
|
||||
path("user/create/", CustomUserCreate.as_view(), name="create_user"),
|
||||
path("user/invite/", CustomUserInvite.as_view(), name="invite_user"),
|
||||
path("user/set_password/<slug:slug>/", SetUserPassword.as_view(), name="set_password"),
|
||||
path("user/reset_password/", reset_password, name="reset_password"),
|
||||
path(
|
||||
"user/set_password/<slug:slug>/", SetUserPassword.as_view(), name="set_password"
|
||||
),
|
||||
path(
|
||||
"blacklist/",
|
||||
LogoutAndBlacklistRefreshTokenForUserView.as_view(),
|
||||
name="blacklist",
|
||||
),
|
||||
path("user/get/", CustomUserGet.as_view(), name="get_user"),
|
||||
path("user/acknowledge_tos/", AcknowledgeTermsOfService.as_view(), name="acknowledge_tos"),
|
||||
path("company_users",CompanyUsersView.as_view(), name="company_users"),
|
||||
path(
|
||||
"user/acknowledge_tos/",
|
||||
AcknowledgeTermsOfService.as_view(),
|
||||
name="acknowledge_tos",
|
||||
),
|
||||
path("company_users", CompanyUsersView.as_view(), name="company_users"),
|
||||
path("user/is_authenticated/", is_authenticated, name="is_authenticated"),
|
||||
path("announcment/get/", AnnouncmentView.as_view(), name="get_announcments"),
|
||||
path("conversations", ConversationsView.as_view(), name="conversations"),
|
||||
@@ -44,9 +58,37 @@ urlpatterns = [
|
||||
ConversationDetailView.as_view(),
|
||||
name="conversation_details",
|
||||
),
|
||||
path("conversation_preferences", ConversationPreferences.as_view(), name="conversation_preferences"),
|
||||
path("analytics/user_prompts/", UserPromptAnalytics.as_view(), name="analytics_user_prompts"),
|
||||
path("analytics/user_conversations/", UserConversationAnalytics.as_view(), name="analytics_user_conversations"),
|
||||
path("analytics/company_usage/", CompanyUsageAnalytics.as_view(), name="analytics_company_usage"),
|
||||
path(
|
||||
"conversation_preferences",
|
||||
ConversationPreferences.as_view(),
|
||||
name="conversation_preferences",
|
||||
),
|
||||
path(
|
||||
"analytics/user_prompts/",
|
||||
UserPromptAnalytics.as_view(),
|
||||
name="analytics_user_prompts",
|
||||
),
|
||||
path(
|
||||
"analytics/user_conversations/",
|
||||
UserConversationAnalytics.as_view(),
|
||||
name="analytics_user_conversations",
|
||||
),
|
||||
path(
|
||||
"analytics/company_usage/",
|
||||
CompanyUsageAnalytics.as_view(),
|
||||
name="analytics_company_usage",
|
||||
),
|
||||
path("analytics/admin/", AdminAnalytics.as_view(), name="analytics_admin"),
|
||||
# document urls
|
||||
path(
|
||||
"document_workspaces/",
|
||||
DocumentWorkspaceView.as_view(),
|
||||
name="document_workspaces",
|
||||
),
|
||||
path("documents/", DocumentUploadView.as_view(), name="documents"),
|
||||
path(
|
||||
"documents_details/<int:document_id>",
|
||||
DocumentDetailView.as_view(),
|
||||
name="documents_details",
|
||||
),
|
||||
]
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
import datetime
|
||||
|
||||
|
||||
def last_day_of_month(any_day):
|
||||
# The day 28 exists in every month. 4 days later, it's always next month
|
||||
next_month = any_day.replace(day=28) + datetime.timedelta(days=4)
|
||||
# subtracting the number of the current day brings us back one month
|
||||
return next_month - datetime.timedelta(days=next_month.day)
|
||||
# The day 28 exists in every month. 4 days later, it's always next month
|
||||
next_month = any_day.replace(day=28) + datetime.timedelta(days=4)
|
||||
# subtracting the number of the current day brings us back one month
|
||||
return next_month - datetime.timedelta(days=next_month.day)
|
||||
|
||||
@@ -11,11 +11,23 @@ from .serializers import (
|
||||
CompanySerializer,
|
||||
ConversationSerializer,
|
||||
PromptSerializer,
|
||||
FeedbackSerializer
|
||||
FeedbackSerializer,
|
||||
DocumentWorkspaceSerializer,
|
||||
DocumentSerializer,
|
||||
)
|
||||
|
||||
from rest_framework.views import APIView
|
||||
from rest_framework.response import Response
|
||||
from .models import CustomUser, Announcement, Conversation, Prompt, Feedback,PromptMetric
|
||||
from .models import (
|
||||
CustomUser,
|
||||
Announcement,
|
||||
Conversation,
|
||||
Prompt,
|
||||
Feedback,
|
||||
PromptMetric,
|
||||
DocumentWorkspace,
|
||||
Document,
|
||||
)
|
||||
from django.views.decorators.cache import never_cache
|
||||
from django.http import JsonResponse
|
||||
from datetime import datetime
|
||||
@@ -24,12 +36,15 @@ from asgiref.sync import sync_to_async, async_to_sync
|
||||
from channels.generic.websocket import AsyncWebsocketConsumer
|
||||
from langchain_ollama.llms import OllamaLLM
|
||||
from langchain_core.prompts import ChatPromptTemplate
|
||||
from langchain_core.messages import HumanMessage, SystemMessage
|
||||
from langchain_core.messages import HumanMessage, AIMessage
|
||||
from langchain_classic.chains import RetrievalQA
|
||||
import re
|
||||
import os
|
||||
from django.conf import settings
|
||||
import json
|
||||
import base64
|
||||
import pandas as pd
|
||||
import io
|
||||
|
||||
# For email support
|
||||
from django.core.mail import EmailMultiAlternatives
|
||||
@@ -43,14 +58,32 @@ from django.core.files.base import ContentFile
|
||||
import math
|
||||
import datetime
|
||||
import pytz
|
||||
from langchain_community.embeddings import OllamaEmbeddings
|
||||
|
||||
from dateutil.relativedelta import relativedelta
|
||||
from django.views.decorators.csrf import csrf_exempt
|
||||
|
||||
from .utils import last_day_of_month
|
||||
from .services.llm_service import AsyncLLMService
|
||||
from .services.rag_services import AsyncRAGService
|
||||
from .services.title_generator import title_generator
|
||||
from .services.moderation_classifier import moderation_classifier, ModerationLabel
|
||||
from .services.prompt_classifier.prompt_classifier import PromptClassifier, PromptType
|
||||
from .services.data_analysis_service import AsyncDataAnalysisService
|
||||
|
||||
|
||||
CHANNEL_NAME: str = 'llm_messages'
|
||||
MODEL_NAME: str = "llama3"
|
||||
|
||||
from langchain_classic.chains import create_retrieval_chain
|
||||
from langchain_classic.chains.combine_documents import create_stuff_documents_chain
|
||||
from langchain_ollama import ChatOllama
|
||||
|
||||
import logging
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
|
||||
CHANNEL_NAME: str = "llm_messages"
|
||||
MODEL_NAME: str = "llama3.2"
|
||||
|
||||
# Create your views here.
|
||||
class CustomObtainTokenView(TokenObtainPairView):
|
||||
@@ -71,104 +104,220 @@ class CustomUserCreate(APIView):
|
||||
return Response(json, status=status.HTTP_201_CREATED)
|
||||
return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST)
|
||||
|
||||
|
||||
def send_invite_email(slug, email_to_invite):
|
||||
print(f"url : https://www.chat.aimloperations.com/set_password?slug={slug}")
|
||||
url = f"https://www.chat.aimloperations.com/set_password?slug={slug}"
|
||||
logger.info("Sending invite email")
|
||||
logger.info(f"url : https://chat.aimloperations.com/set_password?slug={slug}")
|
||||
url = f"https://chat.aimloperations.com/set_password?slug={slug}"
|
||||
subject = "Welcome to AI ML Operations, LLC Chat Services"
|
||||
from_email = "ryan@aimloperations.com"
|
||||
to=email_to_invite
|
||||
to = email_to_invite
|
||||
d = {"url": url}
|
||||
html_content = get_template(r'emails/invite_email.html').render(d)
|
||||
text_content = get_template(r'emails/invite_email.txt').render(d)
|
||||
|
||||
html_content = get_template(r"emails/invite_email.html").render(d)
|
||||
text_content = get_template(r"emails/invite_email.txt").render(d)
|
||||
|
||||
msg = EmailMultiAlternatives(subject, text_content, from_email, [to])
|
||||
msg.attach_alternative(html_content, "text/html")
|
||||
msg.send(fail_silently=True)
|
||||
|
||||
|
||||
def send_password_reset_email(slug, email_to_invite):
|
||||
logger.info("Sending reset email")
|
||||
logger.info(f"url : https://www.chat.aimloperations.com/set_password?slug={slug}")
|
||||
url = f"https://www.chat.aimloperations.com/set_password?slug={slug}"
|
||||
subject = "Password reset for AI ML Operations, LLC Chat Services"
|
||||
from_email = "ryan@aimloperations.com"
|
||||
to = email_to_invite
|
||||
d = {"url": url}
|
||||
html_content = get_template(r"emails/reset_email.html").render(d)
|
||||
text_content = get_template(r"emails/reset_email.txt").render(d)
|
||||
|
||||
msg = EmailMultiAlternatives(subject, text_content, from_email, [to])
|
||||
msg.attach_alternative(html_content, "text/html")
|
||||
msg.send(fail_silently=True)
|
||||
|
||||
|
||||
def send_feedback_email(feedback_obj):
|
||||
logger.info("Sending feedback email")
|
||||
subject = "New Feedback for Chat by AI ML Operations, LLC"
|
||||
from_email = "ryan@aimloperations.com"
|
||||
to="ryan@aimloperations.com"
|
||||
to = "ryan@aimloperations.com"
|
||||
d = {"title": feedback_obj.title, "feedback_text": feedback_obj.text}
|
||||
html_content = get_template(r'emails/feedback_email.html').render(d)
|
||||
text_content = get_template(r'emails/feedback_email.txt').render(d)
|
||||
|
||||
html_content = get_template(r"emails/feedback_email.html").render(d)
|
||||
text_content = get_template(r"emails/feedback_email.txt").render(d)
|
||||
|
||||
msg = EmailMultiAlternatives(subject, text_content, from_email, [to])
|
||||
msg.attach_alternative(html_content, "text/html")
|
||||
msg.send(fail_silently=True)
|
||||
|
||||
|
||||
def send_password_reset_email(slug, email_to_invite):
|
||||
logger.info("Sending Password reset email")
|
||||
url = f"https://www.chat.aimloperations.com/set_password?slug={slug}"
|
||||
subject = "Password reset for Chat by AI ML Operations, LLC"
|
||||
from_email = "ryan@aimloperations.com"
|
||||
to = email_to_invite
|
||||
d = {"url": url}
|
||||
html_content = get_template(r"emails/reset_email.html").render(d)
|
||||
text_content = get_template(r"emails/reset_email.txt").render(d)
|
||||
msg = EmailMultiAlternatives(subject, text_content, from_email, [to])
|
||||
msg.attach_alternative(html_content, "text/html")
|
||||
msg.send(fail_silently=True)
|
||||
|
||||
|
||||
class CustomUserInvite(APIView):
|
||||
http_method_names = ['post']
|
||||
http_method_names = ["post"]
|
||||
|
||||
def post(self, request, format="json"):
|
||||
def valid_email(email_string):
|
||||
regex = r'^[a-z0-9]+[\._]?[a-z0-9]+[@]\w+[.]\w+$'
|
||||
if re.match(regex,email_string):
|
||||
regex = r"^[a-z0-9]+[\._]?[a-z0-9]+[@]\w+[.]\w+$"
|
||||
if re.match(regex, email_string):
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
email_to_invite = request.data['email']
|
||||
|
||||
if len(email_to_invite) == 0 or not valid_email(email_to_invite) or not request.user.is_company_manager:
|
||||
email_to_invite = request.data["email"]
|
||||
|
||||
if (
|
||||
len(email_to_invite) == 0
|
||||
or not valid_email(email_to_invite)
|
||||
or not request.user.is_company_manager
|
||||
):
|
||||
return Response(status=status.HTTP_400_BAD_REQUEST)
|
||||
# make sure there isn't a user with this email already
|
||||
existing_users = CustomUser.objects.filter(email=email_to_invite)
|
||||
if len(existing_users) > 0:
|
||||
return Response(status=status.HTTP_400_BAD_REQUEST)
|
||||
# create the object and send the email
|
||||
user = CustomUser.objects.create(email=email_to_invite, username=email_to_invite, company=request.user.company)
|
||||
user = CustomUser.objects.create(
|
||||
email=email_to_invite,
|
||||
username=email_to_invite,
|
||||
company=request.user.company,
|
||||
)
|
||||
|
||||
# send an email
|
||||
send_invite_email(user.slug, email_to_invite)
|
||||
|
||||
|
||||
|
||||
return Response(status=status.HTTP_201_CREATED)
|
||||
|
||||
class SetUserPassword(APIView):
|
||||
http_method_names = ['post','get']
|
||||
|
||||
@csrf_exempt
|
||||
def reset_password(request):
|
||||
if request.method == "POST":
|
||||
data = json.loads(request.body)
|
||||
token = data.get("recaptchaToken")
|
||||
payload = {
|
||||
"secret": settings.CAPTCHA_SECRET_KEY,
|
||||
"response": token,
|
||||
}
|
||||
response = requests.post(
|
||||
"https://www.google.com/recaptcha/api/siteverify", data=payload
|
||||
)
|
||||
result = response.json()
|
||||
if result.get("success") and result.get("score") >= 0.5:
|
||||
email = data.get("email")
|
||||
user = CustomUser.objects.filter(email=email).first()
|
||||
if user:
|
||||
user.set_unusable_password()
|
||||
user.save()
|
||||
|
||||
# send the email
|
||||
send_password_reset_email(user.slug, email)
|
||||
JsonResponse(status=200)
|
||||
|
||||
JsonResponse(status=400)
|
||||
|
||||
|
||||
class ResetUserPassword(APIView):
|
||||
http_method_names = [
|
||||
"post",
|
||||
]
|
||||
permission_classes = (permissions.AllowAny,)
|
||||
authentication_classes = ()
|
||||
|
||||
def post(self, request, format="json"):
|
||||
"""
|
||||
Send an email with a set password link to the set password page
|
||||
Also disable the account
|
||||
"""
|
||||
logger.info(f"Password reset for requests. {request.data}")
|
||||
token = request.data.get("recaptchaToken")
|
||||
payload = {
|
||||
"secret": settings.CAPTCHA_SECRET_KEY,
|
||||
"response": recaptchaToken,
|
||||
}
|
||||
response = requests.post(
|
||||
"https://www.google.com/recaptcha/api/siteverify", data=payload
|
||||
)
|
||||
result = response.json()
|
||||
if result.get("success") and result.get("score") >= 0.5:
|
||||
user = CustomUser.objects.filter(email=email).first()
|
||||
if user:
|
||||
user.set_unusable_password()
|
||||
user.save()
|
||||
|
||||
# send the email
|
||||
send_password_reset_email(user.slug, email)
|
||||
else:
|
||||
logger.error("Captcha secret failed")
|
||||
|
||||
return Response(status=status.HTTP_200_OK)
|
||||
|
||||
|
||||
class SetUserPassword(APIView):
|
||||
http_method_names = ["post", "get"]
|
||||
permission_classes = (permissions.AllowAny,)
|
||||
authentication_classes = ()
|
||||
|
||||
def get(self, request, slug):
|
||||
user = CustomUser.objects.get(slug=slug)
|
||||
if user.last_login:
|
||||
if user.has_usable_password():
|
||||
return Response(status=status.HTTP_401_UNAUTHORIZED)
|
||||
else:
|
||||
return Response(status=status.HTTP_200_OK)
|
||||
|
||||
def post(self, request, slug, format="json"):
|
||||
user = CustomUser.objects.get(slug=slug)
|
||||
user.set_password(request.data['password'])
|
||||
user.set_password(request.data["password"])
|
||||
user.save()
|
||||
return Response(status=status.HTTP_200_OK)
|
||||
|
||||
|
||||
|
||||
class CustomUserGet(APIView):
|
||||
http_method_names = ['get', 'head', 'post']
|
||||
http_method_names = ["get", "head", "post"]
|
||||
|
||||
def get(self, request, format="json"):
|
||||
|
||||
email = request.user.email
|
||||
username = request.user.username
|
||||
user = CustomUser.objects.get(email=email)
|
||||
serializer = CustomUserSerializer(user)
|
||||
user = CustomUser.objects.filter(email=email).last()
|
||||
logger.info(f"Getting the user: {user}")
|
||||
try:
|
||||
serializer = CustomUserSerializer(user)
|
||||
logger.debug(f"serializer: {serializer}")
|
||||
logger.debug(serializer.data)
|
||||
return Response(serializer.data, status=status.HTTP_200_OK)
|
||||
except Exception as e:
|
||||
logger.error(f"Exception: {e}")
|
||||
return Response({}, status=status.HTTP_400_BAD_REQUEST)
|
||||
|
||||
return Response(serializer.data, status=status.HTTP_200_OK)
|
||||
|
||||
class FeedbackView(APIView):
|
||||
http_method_names = ['post','get']
|
||||
http_method_names = ["post", "get"]
|
||||
|
||||
def post(self, request, format="json"):
|
||||
serializer = FeedbackSerializer(data=request.data)
|
||||
print(request.data)
|
||||
logger.debug(request.data)
|
||||
if serializer.is_valid():
|
||||
|
||||
|
||||
feedback_obj = serializer.save()
|
||||
feedback_obj.user = request.user
|
||||
|
||||
|
||||
feedback_obj.save()
|
||||
send_feedback_email(feedback_obj)
|
||||
return Response(serializer.data, status=status.HTTP_201_CREATED)
|
||||
else:
|
||||
print(serializer.errors)
|
||||
logger.error(serializer.errors)
|
||||
return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST)
|
||||
|
||||
def get(self, request, format="json"):
|
||||
@@ -177,14 +326,15 @@ class FeedbackView(APIView):
|
||||
return Response(serializer.data, status=status.HTTP_200_OK)
|
||||
|
||||
|
||||
|
||||
class AcknowledgeTermsOfService(APIView):
|
||||
http_method_names = ['post']
|
||||
http_method_names = ["post"]
|
||||
|
||||
def post(self, request, format="json"):
|
||||
request.user.has_signed_tos = True
|
||||
request.user.save()
|
||||
return Response(status=status.HTTP_200_OK)
|
||||
|
||||
|
||||
class CompanyUsersView(APIView):
|
||||
def get(self, request, format="json"):
|
||||
# TODO: make sure you are a manager of that company
|
||||
@@ -194,8 +344,7 @@ class CompanyUsersView(APIView):
|
||||
return Response(serializer.data, status=status.HTTP_200_OK)
|
||||
else:
|
||||
return Response(status=status.HTTP_401_UNAUTHORIZED)
|
||||
|
||||
|
||||
|
||||
def post(self, request, format="json"):
|
||||
if request.user.is_company_manager:
|
||||
user = CustomUser.objects.get(email=request.data.get("email"))
|
||||
@@ -215,7 +364,7 @@ class CompanyUsersView(APIView):
|
||||
return Response(status=status.HTTP_200_OK)
|
||||
return Response(status=status.HTTP_400_BAD_REQUEST)
|
||||
return Response(status=status.HTTP_401_UNAUTHORIZED)
|
||||
|
||||
|
||||
def delete(self, request, format="json"):
|
||||
if request.user.is_company_manager:
|
||||
user = CustomUser.objects.get(email=request.data.get("email"))
|
||||
@@ -224,6 +373,7 @@ class CompanyUsersView(APIView):
|
||||
return Response(status=status.HTTP_200_OK)
|
||||
return Response(status=status.HTTP_401_UNAUTHORIZED)
|
||||
|
||||
|
||||
class AnnouncmentView(APIView):
|
||||
permission_classes = (permissions.AllowAny,)
|
||||
serializer_class = AnnouncmentSerializer
|
||||
@@ -259,7 +409,9 @@ def is_authenticated(request):
|
||||
class ConversationsView(APIView):
|
||||
def get(self, request, format="json"):
|
||||
order = "created" if request.user.conversation_order else "-created"
|
||||
conversations = Conversation.objects.filter(user=request.user, deleted=False).order_by(order)
|
||||
conversations = Conversation.objects.filter(
|
||||
user=request.user, deleted=False
|
||||
).order_by(order)
|
||||
serializer = ConversationSerializer(conversations, many=True)
|
||||
return Response(serializer.data, status=status.HTTP_200_OK)
|
||||
|
||||
@@ -283,7 +435,9 @@ class ConversationsView(APIView):
|
||||
# conversation.user_id = request.user.id
|
||||
# conversation.save()
|
||||
|
||||
return Response({"title": title, "id": conversation.id}, status=status.HTTP_201_CREATED)
|
||||
return Response(
|
||||
{"title": title, "id": conversation.id}, status=status.HTTP_201_CREATED
|
||||
)
|
||||
|
||||
|
||||
class ConversationPreferences(APIView):
|
||||
@@ -298,7 +452,6 @@ class ConversationPreferences(APIView):
|
||||
return Response({"order": user.conversation_order}, status=status.HTTP_200_OK)
|
||||
|
||||
|
||||
|
||||
class ConversationDetailView(APIView):
|
||||
def get(self, request, format="json"):
|
||||
conversation_id = request.query_params.get("conversation_id")
|
||||
@@ -306,9 +459,8 @@ class ConversationDetailView(APIView):
|
||||
serailzer = PromptSerializer(prompts, many=True)
|
||||
return Response(serailzer.data, status=status.HTTP_200_OK)
|
||||
|
||||
|
||||
def post(self, request, format="json"):
|
||||
print('In the post')
|
||||
logger.info("In the post")
|
||||
# Add the prompt to the database
|
||||
# make sure there is a conversation for it
|
||||
# if there is not a conversation create a title for it
|
||||
@@ -336,28 +488,30 @@ class ConversationDetailView(APIView):
|
||||
prompt_instance = serializer.save()
|
||||
|
||||
# set up the streaming response if it is from the user
|
||||
print(f'Do we have a valid user? {is_user}')
|
||||
logger.info(f"Do we have a valid user? {is_user}")
|
||||
if is_user:
|
||||
messages = []
|
||||
for prompt_obj in Prompt.objects.filter(conversation__id=conversation_id):
|
||||
messages.append({
|
||||
'content':prompt_obj.message,
|
||||
'role': 'user' if prompt_obj.user_created else 'assistant'
|
||||
})
|
||||
for prompt_obj in Prompt.objects.filter(
|
||||
conversation__id=conversation_id
|
||||
):
|
||||
messages.append(
|
||||
{
|
||||
"content": prompt_obj.message,
|
||||
"role": "user" if prompt_obj.user_created else "assistant",
|
||||
}
|
||||
)
|
||||
|
||||
channel_layer = get_channel_layer()
|
||||
print(f'Sending to the channel: {CHANNEL_NAME}')
|
||||
logger.info(f"Sending to the channel: {CHANNEL_NAME}")
|
||||
async_to_sync(channel_layer.group_send)(
|
||||
CHANNEL_NAME, {
|
||||
'type':'receive',
|
||||
'content': messages
|
||||
}
|
||||
CHANNEL_NAME, {"type": "receive", "content": messages}
|
||||
)
|
||||
except:
|
||||
print(f"Error trying to submit to conversation_id: {conversation_id} with request.data: {request.data}")
|
||||
logger.error(
|
||||
f"Error trying to submit to conversation_id: {conversation_id} with request.data: {request.data}"
|
||||
)
|
||||
pass
|
||||
|
||||
|
||||
return Response(status=status.HTTP_200_OK)
|
||||
|
||||
def delete(self, request, format="json"):
|
||||
@@ -367,13 +521,16 @@ class ConversationDetailView(APIView):
|
||||
conversation.save()
|
||||
return Response(status=status.HTTP_202_ACCEPTED)
|
||||
|
||||
|
||||
class UserPromptAnalytics(APIView):
|
||||
def get(self, request, format="json"):
|
||||
now = timezone.now()
|
||||
result = []
|
||||
|
||||
number_of_months = 3
|
||||
company_user_ids = CustomUser.objects.filter(company=request.user.company).values_list('id', flat=True)
|
||||
number_of_months = 3
|
||||
company_user_ids = CustomUser.objects.filter(
|
||||
company=request.user.company
|
||||
).values_list("id", flat=True)
|
||||
for i in range(number_of_months):
|
||||
next_year = now.year
|
||||
next_month = now.month - i
|
||||
@@ -383,30 +540,51 @@ class UserPromptAnalytics(APIView):
|
||||
|
||||
start_date = datetime.datetime(next_year, next_month, 1)
|
||||
end_date = last_day_of_month(start_date)
|
||||
total_conversations = Conversation.objects.filter(created__gte=start_date, created__lte=end_date)
|
||||
total_prompts = Prompt.objects.filter(conversation__id__in=total_conversations, created__gte=start_date, created__lte=end_date)
|
||||
total_conversations = Conversation.objects.filter(
|
||||
created__gte=start_date, created__lte=end_date
|
||||
)
|
||||
total_prompts = Prompt.objects.filter(
|
||||
conversation__id__in=total_conversations,
|
||||
created__gte=start_date,
|
||||
created__lte=end_date,
|
||||
)
|
||||
total_users = len(CustomUser.objects.all())
|
||||
my_conversations = Conversation.objects.filter(user=request.user)
|
||||
my_prompts = Prompt.objects.filter(conversation__in=my_conversations, created__gte=start_date, created__lte=end_date)
|
||||
company_conversations = Conversation.objects.filter(user__id__in=company_user_ids)
|
||||
company_prompts = Prompt.objects.filter(conversation__in=company_conversations, created__gte=start_date, created__lte=end_date)
|
||||
my_prompts = Prompt.objects.filter(
|
||||
conversation__in=my_conversations,
|
||||
created__gte=start_date,
|
||||
created__lte=end_date,
|
||||
)
|
||||
company_conversations = Conversation.objects.filter(
|
||||
user__id__in=company_user_ids
|
||||
)
|
||||
company_prompts = Prompt.objects.filter(
|
||||
conversation__in=company_conversations,
|
||||
created__gte=start_date,
|
||||
created__lte=end_date,
|
||||
)
|
||||
|
||||
result.append(
|
||||
{
|
||||
"month": start_date.strftime("%B"),
|
||||
"you": len(my_prompts),
|
||||
"others": len(company_prompts) / len(company_user_ids),
|
||||
"all": len(total_prompts) / total_users,
|
||||
}
|
||||
)
|
||||
|
||||
result.append({
|
||||
"month":start_date.strftime("%B"),
|
||||
"you": len(my_prompts),
|
||||
"others": len(company_prompts)/len(company_user_ids),
|
||||
"all":len(total_prompts)/total_users
|
||||
})
|
||||
|
||||
return Response(result[::-1], status=status.HTTP_200_OK)
|
||||
|
||||
|
||||
|
||||
class UserConversationAnalytics(APIView):
|
||||
def get(self, request, format="json"):
|
||||
now = timezone.now()
|
||||
result = []
|
||||
|
||||
number_of_months = 3
|
||||
company_user_ids = CustomUser.objects.filter(company=request.user.company).values_list('id', flat=True)
|
||||
number_of_months = 3
|
||||
company_user_ids = CustomUser.objects.filter(
|
||||
company=request.user.company
|
||||
).values_list("id", flat=True)
|
||||
for i in range(number_of_months):
|
||||
next_year = now.year
|
||||
next_month = now.month - i
|
||||
@@ -416,28 +594,48 @@ class UserConversationAnalytics(APIView):
|
||||
|
||||
start_date = datetime.datetime(next_year, next_month, 1)
|
||||
end_date = last_day_of_month(start_date)
|
||||
total_conversations = len(Conversation.objects.filter(created__gte=start_date, created__lte=end_date))
|
||||
total_conversations = len(
|
||||
Conversation.objects.filter(
|
||||
created__gte=start_date, created__lte=end_date
|
||||
)
|
||||
)
|
||||
total_users = len(CustomUser.objects.all())
|
||||
company_conversations = len(Conversation.objects.filter(user__id__in=company_user_ids, created__gte=start_date, created__lte=end_date))
|
||||
company_conversations = len(
|
||||
Conversation.objects.filter(
|
||||
user__id__in=company_user_ids,
|
||||
created__gte=start_date,
|
||||
created__lte=end_date,
|
||||
)
|
||||
)
|
||||
|
||||
result.append({
|
||||
"month":start_date.strftime("%B"),
|
||||
"you": len(Conversation.objects.filter(user=request.user, created__gte=start_date, created__lte=end_date)),
|
||||
"others": company_conversations/len(company_user_ids),
|
||||
"all":total_conversations/total_users
|
||||
})
|
||||
result.append(
|
||||
{
|
||||
"month": start_date.strftime("%B"),
|
||||
"you": len(
|
||||
Conversation.objects.filter(
|
||||
user=request.user,
|
||||
created__gte=start_date,
|
||||
created__lte=end_date,
|
||||
)
|
||||
),
|
||||
"others": company_conversations / len(company_user_ids),
|
||||
"all": total_conversations / total_users,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
return Response(result[::-1], status=status.HTTP_200_OK)
|
||||
|
||||
|
||||
class CompanyUsageAnalytics(APIView):
|
||||
def get(self, request, format="json"):
|
||||
now = timezone.now()
|
||||
result = []
|
||||
|
||||
number_of_months = 3
|
||||
company_user_ids = CustomUser.objects.filter(company=request.user.company).values_list('id', flat=True)
|
||||
|
||||
number_of_months = 3
|
||||
company_user_ids = CustomUser.objects.filter(
|
||||
company=request.user.company
|
||||
).values_list("id", flat=True)
|
||||
|
||||
for i in range(number_of_months):
|
||||
next_year = now.year
|
||||
next_month = now.month - i
|
||||
@@ -447,19 +645,28 @@ class CompanyUsageAnalytics(APIView):
|
||||
|
||||
start_date = datetime.datetime(next_year, next_month, 1)
|
||||
end_date = last_day_of_month(start_date)
|
||||
conversations = Conversation.objects.filter(user__id__in=company_user_ids, created__gte=start_date, created__lte=end_date)
|
||||
|
||||
conversation_user_ids = conversations.values_list("user__id", flat=True).distinct()
|
||||
result.append({
|
||||
"month":start_date.strftime("%B"),
|
||||
"used":len(conversation_user_ids),
|
||||
"not_used":len(company_user_ids) - len(conversation_user_ids)
|
||||
})
|
||||
conversations = Conversation.objects.filter(
|
||||
user__id__in=company_user_ids,
|
||||
created__gte=start_date,
|
||||
created__lte=end_date,
|
||||
)
|
||||
|
||||
conversation_user_ids = conversations.values_list(
|
||||
"user__id", flat=True
|
||||
).distinct()
|
||||
result.append(
|
||||
{
|
||||
"month": start_date.strftime("%B"),
|
||||
"used": len(conversation_user_ids),
|
||||
"not_used": len(company_user_ids) - len(conversation_user_ids),
|
||||
}
|
||||
)
|
||||
return Response(result[::-1], status=status.HTTP_200_OK)
|
||||
|
||||
|
||||
|
||||
class AdminAnalytics(APIView):
|
||||
def get(self, request, format="json"):
|
||||
number_of_months = 3
|
||||
number_of_months = 3
|
||||
result = []
|
||||
now = timezone.now()
|
||||
|
||||
@@ -472,37 +679,43 @@ class AdminAnalytics(APIView):
|
||||
|
||||
start_date = datetime.datetime(next_year, next_month, 1)
|
||||
end_date = last_day_of_month(start_date)
|
||||
durations = [item.get_duration() for item in PromptMetric.objects.filter(created__gte=start_date, created__lte=end_date)]
|
||||
durations = [
|
||||
item.get_duration()
|
||||
for item in PromptMetric.objects.filter(
|
||||
created__gte=start_date, created__lte=end_date
|
||||
)
|
||||
]
|
||||
if len(durations) == 0:
|
||||
result.append({
|
||||
"month":start_date.strftime("%B"),
|
||||
"range":[0,0],
|
||||
"avg": 0,
|
||||
"median":0,
|
||||
})
|
||||
result.append(
|
||||
{
|
||||
"month": start_date.strftime("%B"),
|
||||
"range": [0, 0],
|
||||
"avg": 0,
|
||||
"median": 0,
|
||||
}
|
||||
)
|
||||
continue
|
||||
|
||||
average = sum(durations)/len(durations)
|
||||
|
||||
average = sum(durations) / len(durations)
|
||||
min_value = min(durations)
|
||||
max_value = max(durations)
|
||||
durations.sort()
|
||||
median = durations[len(durations)//2]
|
||||
result.append({
|
||||
"month":start_date.strftime("%B"),
|
||||
"range":[min_value,max_value],
|
||||
"avg": average,
|
||||
"median":median,
|
||||
})
|
||||
median = durations[len(durations) // 2]
|
||||
result.append(
|
||||
{
|
||||
"month": start_date.strftime("%B"),
|
||||
"range": [min_value, max_value],
|
||||
"avg": average,
|
||||
"median": median,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
|
||||
|
||||
return Response(result[::-1], status=status.HTTP_200_OK)
|
||||
|
||||
prompt = ChatPromptTemplate.from_messages([
|
||||
("system", "You are a helpful assistant."),
|
||||
("user", "{input}")
|
||||
])
|
||||
|
||||
prompt = ChatPromptTemplate.from_messages(
|
||||
[("system", "You are a helpful assistant."), ("user", "{input}")]
|
||||
)
|
||||
|
||||
llm = OllamaLLM(model=MODEL_NAME)
|
||||
|
||||
@@ -510,205 +723,98 @@ llm = OllamaLLM(model=MODEL_NAME)
|
||||
# # Chain
|
||||
# chain = prompt | llm.with_config({"run_name": "model"}) | output_parser.with_config({"run_name": "Assistant"})
|
||||
|
||||
@database_sync_to_async
|
||||
def create_conversation(prompt, email):
|
||||
# return the conversation id
|
||||
|
||||
|
||||
response = llm.invoke("Summarise the phrase in one to for words\"%s\"" % prompt)
|
||||
print(f"Response: {response}")
|
||||
print(dir(response))
|
||||
title = response.replace("\"","")
|
||||
title = " ".join(title.split(" ")[:4])
|
||||
|
||||
|
||||
conversation = Conversation.objects.create(title = title)
|
||||
conversation.save()
|
||||
|
||||
user = CustomUser.objects.get(email=email)
|
||||
conversation.user_id = user.id
|
||||
conversation.save()
|
||||
return conversation.id
|
||||
|
||||
@database_sync_to_async
|
||||
def get_messages(conversation_id, prompt, file_string: str = None, file_type: str = ""):
|
||||
messages = []
|
||||
|
||||
conversation = Conversation.objects.get(id=conversation_id)
|
||||
print(file_string)
|
||||
|
||||
# add the prompt to the conversation
|
||||
serializer = PromptSerializer(
|
||||
data={
|
||||
"message": prompt,
|
||||
"user_created": True,
|
||||
"created": datetime.now(),
|
||||
}
|
||||
)
|
||||
if serializer.is_valid(raise_exception=True):
|
||||
prompt_instance = serializer.save()
|
||||
prompt_instance.conversation_id = conversation.id
|
||||
prompt_instance.save()
|
||||
if file_string:
|
||||
file_name = f"prompt_{prompt_instance.id}_data.{file_type}"
|
||||
f = ContentFile(file_string, name=file_name)
|
||||
prompt_instance.file.save(file_name, f)
|
||||
prompt_instance.file_type = file_type
|
||||
prompt_instance.save()
|
||||
|
||||
|
||||
|
||||
for prompt_obj in Prompt.objects.filter(conversation__id=conversation_id):
|
||||
messages.append({
|
||||
'content': prompt_obj.message,
|
||||
'role': 'user' if prompt_obj.user_created else 'assistant',
|
||||
'has_file': prompt_obj.file_exists(),
|
||||
'file': prompt_obj.file if prompt_obj.file_exists() else None,
|
||||
'file_type': prompt_obj.file_type if prompt_obj.file_exists() else None,
|
||||
})
|
||||
|
||||
# now transform the messages
|
||||
transformed_messages = []
|
||||
for message in messages:
|
||||
|
||||
if message['has_file'] and message['file_type'] != None:
|
||||
if 'csv' in message['file_type']:
|
||||
file_type = 'csv'
|
||||
altered_message = f"{message['content']}\n The file type is csv and the file contents are: {message['file'].read()}"
|
||||
elif 'xlsx' in message['file_type']:
|
||||
file_type = 'xlsx'
|
||||
df = pd.read_excel(message['file'].read())
|
||||
altered_message = f"{message['content']}\n The file type is xlsx and the file contents are: {df}"
|
||||
elif 'txt' in message['file_type']:
|
||||
file_type = 'txt'
|
||||
altered_message = f"{message['content']}\n The file type is csv and the file contents are: {message['file'].read()}"
|
||||
else:
|
||||
altered_message = message['content']
|
||||
|
||||
transformed_message = SystemMessage(content=altered_message) if message['role'] == 'assistant' else HumanMessage(content=altered_message)
|
||||
transformed_messages.append(transformed_message)
|
||||
# Document Views
|
||||
class DocumentWorkspaceView(APIView):
|
||||
# permission_classes = [permissions.IsAuthenticated]
|
||||
|
||||
return transformed_messages, prompt_instance
|
||||
def get(self, request):
|
||||
workspaces = DocumentWorkspace.objects.filter(company=request.user.company)
|
||||
serializer = DocumentWorkspaceSerializer(workspaces, many=True)
|
||||
return Response(serializer.data)
|
||||
|
||||
def post(self, request):
|
||||
serializer = DocumentWorkspaceSerializer(data=request.data)
|
||||
if serializer.is_valid():
|
||||
serializer.save(company=request.user.company)
|
||||
return Response(serializer.data, status=status.HTTP_201_CREATED)
|
||||
return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST)
|
||||
|
||||
|
||||
@database_sync_to_async
|
||||
def save_generated_message(conversation_id, message):
|
||||
conversation = Conversation.objects.get(id=conversation_id)
|
||||
class DocumentUploadView(APIView):
|
||||
# permission_classes = [permissions.IsAuthenticated]Z
|
||||
|
||||
# add the prompt to the conversation
|
||||
serializer = PromptSerializer(
|
||||
data={
|
||||
"message": message,
|
||||
"user_created": False,
|
||||
"created": datetime.now(),
|
||||
}
|
||||
)
|
||||
if serializer.is_valid():
|
||||
prompt_instance = serializer.save()
|
||||
prompt_instance.conversation_id = conversation.id
|
||||
prompt_instance = serializer.save()
|
||||
def get(self, request):
|
||||
logger.debug(f"request_3: {request}")
|
||||
try:
|
||||
workspace = DocumentWorkspace.objects.get(company=request.user.company)
|
||||
serializer = DocumentSerializer(
|
||||
Document.objects.filter(workspace=workspace), many=True
|
||||
)
|
||||
return Response(serializer.data, status=status.HTTP_200_OK)
|
||||
|
||||
@database_sync_to_async
|
||||
def create_prompt_metric(prompt_id, prompt, has_file, file_type, model_name, conversation_id):
|
||||
prompt_metric = PromptMetric.objects.create(
|
||||
prompt_id=prompt_id,
|
||||
start_time = timezone.now(),
|
||||
prompt_length = len(prompt),
|
||||
has_file = has_file,
|
||||
file_type = file_type,
|
||||
model_name=model_name,
|
||||
conversation_id=conversation_id,
|
||||
)
|
||||
prompt_metric.save()
|
||||
return prompt_metric
|
||||
except:
|
||||
return Response(
|
||||
{"error": "Workspace not found"}, status=status.HTTP_404_NOT_FOUND
|
||||
)
|
||||
|
||||
@database_sync_to_async
|
||||
def update_prompt_metric(prompt_metric, status):
|
||||
prompt_metric.event = status
|
||||
prompt_metric.save()
|
||||
def post(self, request):
|
||||
logger.debug(f"request: {request}")
|
||||
|
||||
@database_sync_to_async
|
||||
def finish_prompt_metric(prompt_metric, response_length):
|
||||
print(f'finish_prompt_metric: {response_length}')
|
||||
prompt_metric.end_time = timezone.now()
|
||||
prompt_metric.reponse_length = response_length
|
||||
prompt_metric.event = 'FINISHED'
|
||||
prompt_metric.save(update_fields=["end_time", "reponse_length","event"])
|
||||
print("finish_prompt_metric saved")
|
||||
try:
|
||||
workspace = DocumentWorkspace.objects.get(company=request.user.company)
|
||||
|
||||
class ChatConsumerAgain(AsyncWebsocketConsumer):
|
||||
async def connect(self):
|
||||
await self.accept()
|
||||
except:
|
||||
return Response(
|
||||
{"error": "Workspace not found"}, status=status.HTTP_404_NOT_FOUND
|
||||
)
|
||||
|
||||
async def disconnect(self, close_code):
|
||||
await self.close()
|
||||
logger.info(request.FILES)
|
||||
file = request.FILES.get("file")
|
||||
if not file:
|
||||
return Response(
|
||||
{"error": "No file provided"}, status=status.HTTP_400_BAD_REQUEST
|
||||
)
|
||||
|
||||
async def receive(self, text_data=None, bytes_data=None):
|
||||
print(f"Text Data: {text_data}")
|
||||
print(f"Bytes Data: {bytes_data}")
|
||||
if text_data:
|
||||
data = json.loads(text_data)
|
||||
message = data.get('message',None)
|
||||
conversation_id = data.get('conversation_id',None)
|
||||
email = data.get("email", None)
|
||||
file = data.get("file", None)
|
||||
file_type = data.get("fileType", "")
|
||||
logger.info("have the workspace and the file")
|
||||
|
||||
if not conversation_id:
|
||||
# we need to create a new conversation
|
||||
# we will generate a name for it too
|
||||
conversation_id = await create_conversation(message, email)
|
||||
|
||||
if conversation_id:
|
||||
decoded_file = None
|
||||
document = Document.objects.create(workspace=workspace, file=file)
|
||||
|
||||
if file:
|
||||
decoded_file = base64.b64decode(file)
|
||||
print(decoded_file)
|
||||
if 'csv' in file_type:
|
||||
file_type = 'csv'
|
||||
altered_message = f"{message}\n The file type is csv and the file contents are: {decoded_file}"
|
||||
elif 'xmlformats-officedocument' in file_type:
|
||||
file_type = 'xlsx'
|
||||
df = pd.read_excel(decoded_file)
|
||||
altered_message = f"{message}\n The file type is xlsx and the file contents are: {df}"
|
||||
elif 'text' in file_type:
|
||||
file_type = 'txt'
|
||||
altered_message = f"{message}\n The file type is txt and the file contents are: {decoded_file}"
|
||||
else:
|
||||
file_type = 'Not Sure'
|
||||
|
||||
|
||||
|
||||
|
||||
print(f'received: "{message}" for conversation {conversation_id}')
|
||||
# TODO: add the message to the database
|
||||
# get the new conversation
|
||||
# TODO: get the messages here
|
||||
|
||||
messages, prompt = await get_messages(conversation_id, message, decoded_file, file_type)
|
||||
|
||||
prompt_metric = await create_prompt_metric(prompt.id, prompt.message, True if file else False, file_type, MODEL_NAME, conversation_id)
|
||||
if file:
|
||||
# udpate with the altered_message
|
||||
messages = messages[:-1] + [HumanMessage(content=altered_message)]
|
||||
print(messages)
|
||||
# send it to the LLM
|
||||
# process the document inthe background
|
||||
self.process_document(document)
|
||||
|
||||
# stream the response back
|
||||
response = ""
|
||||
# start of the message
|
||||
await self.send('CONVERSATION_ID')
|
||||
await self.send(str(conversation_id))
|
||||
await self.send('START_OF_THE_STREAM_ENDER_GAME_42')
|
||||
async for chunk in llm.astream(messages):
|
||||
print(f"chunk: {chunk}")
|
||||
response += chunk
|
||||
await self.send(chunk)
|
||||
await self.send('END_OF_THE_STREAM_ENDER_GAME_42')
|
||||
await save_generated_message(conversation_id, response)
|
||||
serializer = DocumentSerializer(document)
|
||||
return Response(serializer.data, status=status.HTTP_201_CREATED)
|
||||
|
||||
await finish_prompt_metric(prompt_metric, len(response))
|
||||
|
||||
if bytes_data:
|
||||
print("we have byte data")
|
||||
def process_document(self, document):
|
||||
file_path = os.path.join(settings.MEDIA_ROOT, document.file.name)
|
||||
|
||||
document.processed = True
|
||||
document.active = True
|
||||
document.save()
|
||||
service = AsyncRAGService()
|
||||
service.add_files_to_store(
|
||||
[(file_path, document.file.name, document.workspace_id)],
|
||||
workspace_id=document.workspace_id,
|
||||
)
|
||||
|
||||
|
||||
class DocumentDetailView(APIView):
|
||||
# permission_classes = [permissions.IsAuthenticated]
|
||||
|
||||
def get(self, request, document_id):
|
||||
logger.info(f"request: {request}")
|
||||
try:
|
||||
workspace = DocumentWorkspace.objects.get(company=request.user.company)
|
||||
|
||||
document = Document.objects.get(workspace=workspace, id=document_id)
|
||||
except:
|
||||
return Response(
|
||||
{"error": "Document not found"}, status=status.HTTP_404_NOT_FOUND
|
||||
)
|
||||
|
||||
serializer = DocumentWorkspaceSerializer(workspaces, many=True)
|
||||
return Response(serializer.data)
|
||||
|
||||
@@ -9,17 +9,23 @@ https://docs.djangoproject.com/en/3.2/howto/deployment/asgi/
|
||||
|
||||
import os
|
||||
|
||||
import django
|
||||
from django.core.asgi import get_asgi_application
|
||||
|
||||
os.environ.setdefault("DJANGO_SETTINGS_MODULE", "llm_be.settings")
|
||||
|
||||
django.setup()
|
||||
|
||||
|
||||
from channels.routing import ProtocolTypeRouter, URLRouter
|
||||
from channels.auth import AuthMiddlewareStack
|
||||
import chat_backend.routing
|
||||
os.environ.setdefault('DJANGO_SETTINGS_MODULE', 'llm_be.settings')
|
||||
|
||||
application = ProtocolTypeRouter({
|
||||
"http": get_asgi_application(),
|
||||
"websocket": AuthMiddlewareStack(
|
||||
URLRouter(
|
||||
chat_backend.routing.websocket_urlpatterns
|
||||
)
|
||||
),
|
||||
})
|
||||
application = ProtocolTypeRouter(
|
||||
{
|
||||
"http": get_asgi_application(),
|
||||
"websocket": AuthMiddlewareStack(
|
||||
URLRouter(chat_backend.routing.websocket_urlpatterns)
|
||||
),
|
||||
}
|
||||
)
|
||||
|
||||
@@ -22,80 +22,86 @@ BASE_DIR = Path(__file__).resolve().parent.parent
|
||||
# See https://docs.djangoproject.com/en/3.2/howto/deployment/checklist/
|
||||
|
||||
# SECURITY WARNING: keep the secret key used in production secret!
|
||||
SECRET_KEY = 'django-insecure-6suk6fj5q2)1tj%)f(wgw1smnliv5-#&@zvgvj1wp#(#@h#31x'
|
||||
SECRET_KEY = "django-insecure-6suk6fj5q2)1tj%)f(wgw1smnliv5-#&@zvgvj1wp#(#@h#31x"
|
||||
|
||||
# SECURITY WARNING: don't run with debug turned on in production!
|
||||
DEBUG = True
|
||||
|
||||
ALLOWED_HOSTS = ['*.aimloperations.com','localhost','127.0.0.1','chat.aimloperations.com','chatbackend.aimloperations.com']
|
||||
DEBUG = False
|
||||
CORS_ALLOW_CREDENTIALS = False
|
||||
ALLOWED_HOSTS = [
|
||||
"*.aimloperations.com",
|
||||
"localhost",
|
||||
"127.0.0.1",
|
||||
"localhost:3000",
|
||||
"127.0.0.1:3000",
|
||||
"chat.aimloperations.com",
|
||||
"chatbackend.aimloperations.com",
|
||||
]
|
||||
CORS_ORIGIN_ALLOW_ALL = True
|
||||
CSRF_TRUSTED_ORIGINS = ["http://localhost", "http://127.0.0.1", "http://localhost:3000"]
|
||||
|
||||
|
||||
# Application definition
|
||||
|
||||
INSTALLED_APPS = [
|
||||
'daphne',
|
||||
'django.contrib.admin',
|
||||
'django.contrib.auth',
|
||||
'django.contrib.contenttypes',
|
||||
'django.contrib.sessions',
|
||||
'django.contrib.messages',
|
||||
'django.contrib.staticfiles',
|
||||
'chat_backend',
|
||||
'rest_framework',
|
||||
'corsheaders',
|
||||
'rest_framework_simplejwt.token_blacklist',
|
||||
|
||||
"daphne",
|
||||
"django.contrib.admin",
|
||||
"django.contrib.auth",
|
||||
"django.contrib.contenttypes",
|
||||
"django.contrib.sessions",
|
||||
"django.contrib.messages",
|
||||
"django.contrib.staticfiles",
|
||||
"chat_backend",
|
||||
"rest_framework",
|
||||
"corsheaders",
|
||||
"rest_framework_simplejwt.token_blacklist",
|
||||
]
|
||||
|
||||
MIDDLEWARE = [
|
||||
'django.middleware.security.SecurityMiddleware',
|
||||
'django.contrib.sessions.middleware.SessionMiddleware',
|
||||
'django.middleware.common.CommonMiddleware',
|
||||
'django.middleware.csrf.CsrfViewMiddleware',
|
||||
'django.contrib.auth.middleware.AuthenticationMiddleware',
|
||||
'django.contrib.messages.middleware.MessageMiddleware',
|
||||
'django.middleware.clickjacking.XFrameOptionsMiddleware',
|
||||
"django.middleware.security.SecurityMiddleware",
|
||||
"django.contrib.sessions.middleware.SessionMiddleware",
|
||||
"django.middleware.common.CommonMiddleware",
|
||||
"django.middleware.csrf.CsrfViewMiddleware",
|
||||
"django.contrib.auth.middleware.AuthenticationMiddleware",
|
||||
"django.contrib.messages.middleware.MessageMiddleware",
|
||||
"django.middleware.clickjacking.XFrameOptionsMiddleware",
|
||||
"corsheaders.middleware.CorsMiddleware",
|
||||
"django.middleware.common.CommonMiddleware",
|
||||
]
|
||||
|
||||
ROOT_URLCONF = 'llm_be.urls'
|
||||
ROOT_URLCONF = "llm_be.urls"
|
||||
|
||||
# SETTINGS_PATH = os.path.dirname(os.path.dirname(__file__))
|
||||
# TEMPLATE_DIRS = (
|
||||
# os.path.join(SETTINGS_PATH, 'templates'),
|
||||
# )
|
||||
|
||||
print(os.path.join(BASE_DIR, 'templates'))
|
||||
|
||||
TEMPLATES = [
|
||||
{
|
||||
'BACKEND': 'django.template.backends.django.DjangoTemplates',
|
||||
'DIRS': [os.path.join(BASE_DIR, 'templates')],
|
||||
'APP_DIRS': True,
|
||||
'OPTIONS': {
|
||||
'context_processors': [
|
||||
'django.template.context_processors.debug',
|
||||
'django.template.context_processors.request',
|
||||
'django.contrib.auth.context_processors.auth',
|
||||
'django.contrib.messages.context_processors.messages',
|
||||
"BACKEND": "django.template.backends.django.DjangoTemplates",
|
||||
"DIRS": [os.path.join(BASE_DIR, "templates")],
|
||||
"APP_DIRS": True,
|
||||
"OPTIONS": {
|
||||
"context_processors": [
|
||||
"django.template.context_processors.debug",
|
||||
"django.template.context_processors.request",
|
||||
"django.contrib.auth.context_processors.auth",
|
||||
"django.contrib.messages.context_processors.messages",
|
||||
],
|
||||
},
|
||||
},
|
||||
]
|
||||
|
||||
WSGI_APPLICATION = 'llm_be.wsgi.application'
|
||||
ASGI_APPLICATION = 'llm_be.asgi.application'
|
||||
WSGI_APPLICATION = "llm_be.wsgi.application"
|
||||
ASGI_APPLICATION = "llm_be.asgi.application"
|
||||
|
||||
|
||||
# Database
|
||||
# https://docs.djangoproject.com/en/3.2/ref/settings/#databases
|
||||
|
||||
DATABASES = {
|
||||
'default': {
|
||||
'ENGINE': 'django.db.backends.sqlite3',
|
||||
'NAME': BASE_DIR / 'db.sqlite3',
|
||||
"default": {
|
||||
"ENGINE": "django.db.backends.sqlite3",
|
||||
"NAME": BASE_DIR / "db.sqlite3",
|
||||
}
|
||||
}
|
||||
|
||||
@@ -105,28 +111,26 @@ DATABASES = {
|
||||
|
||||
AUTH_PASSWORD_VALIDATORS = [
|
||||
{
|
||||
'NAME': 'django.contrib.auth.password_validation.UserAttributeSimilarityValidator',
|
||||
"NAME": "django.contrib.auth.password_validation.UserAttributeSimilarityValidator",
|
||||
},
|
||||
{
|
||||
'NAME': 'django.contrib.auth.password_validation.MinimumLengthValidator',
|
||||
"NAME": "django.contrib.auth.password_validation.MinimumLengthValidator",
|
||||
},
|
||||
{
|
||||
'NAME': 'django.contrib.auth.password_validation.CommonPasswordValidator',
|
||||
"NAME": "django.contrib.auth.password_validation.CommonPasswordValidator",
|
||||
},
|
||||
{
|
||||
'NAME': 'django.contrib.auth.password_validation.NumericPasswordValidator',
|
||||
"NAME": "django.contrib.auth.password_validation.NumericPasswordValidator",
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
|
||||
|
||||
# Internationalization
|
||||
# https://docs.djangoproject.com/en/3.2/topics/i18n/
|
||||
|
||||
LANGUAGE_CODE = 'en-us'
|
||||
LANGUAGE_CODE = "en-us"
|
||||
|
||||
TIME_ZONE = 'UTC'
|
||||
TIME_ZONE = "UTC"
|
||||
|
||||
USE_I18N = True
|
||||
|
||||
@@ -138,39 +142,37 @@ USE_TZ = True
|
||||
# Static files (CSS, JavaScript, Images)
|
||||
# https://docs.djangoproject.com/en/3.2/howto/static-files/
|
||||
|
||||
STATIC_URL = '/static/'
|
||||
STATIC_URL = "/static/"
|
||||
|
||||
# Default primary key field type
|
||||
# https://docs.djangoproject.com/en/3.2/ref/settings/#default-auto-field
|
||||
|
||||
DEFAULT_AUTO_FIELD = 'django.db.models.BigAutoField'
|
||||
DEFAULT_AUTO_FIELD = "django.db.models.BigAutoField"
|
||||
|
||||
# custom user model
|
||||
AUTH_USER_MODEL = 'chat_backend.CustomUser'
|
||||
AUTH_USER_MODEL = "chat_backend.CustomUser"
|
||||
|
||||
# rest framework jwt stuff
|
||||
REST_FRAMEWORK = {
|
||||
'DEFAULT_PERMISSION_CLASSES': (
|
||||
'rest_framework.permissions.IsAuthenticated',
|
||||
),
|
||||
'DEFAULT_AUTHENTICATION_CLASSES': (
|
||||
'rest_framework_simplejwt.authentication.JWTAuthentication',
|
||||
), #
|
||||
"DEFAULT_PERMISSION_CLASSES": ("rest_framework.permissions.IsAuthenticated",),
|
||||
"DEFAULT_AUTHENTICATION_CLASSES": (
|
||||
"rest_framework_simplejwt.authentication.JWTAuthentication",
|
||||
), #
|
||||
}
|
||||
|
||||
SIMPLE_JWT = {
|
||||
'ACCESS_TOKEN_LIFETIME':timedelta(hours=5),
|
||||
'REFRESH_TOKEN_LIFETIME':timedelta(days=14),
|
||||
'ROTATE_REFRESH_TOKENS':True,
|
||||
'BLACKLIST_AFTER_ROTATION':True,
|
||||
'ALGORITHM':"HS256",
|
||||
"SIGNING_KEY":SECRET_KEY,
|
||||
'VERIFYING_KEY':None,
|
||||
"AUTH_HEADER_TYPES":('JWT',),
|
||||
'USER_ID_FIELD':'id',
|
||||
'USER_ID_CLAIM':'user_id',
|
||||
'AUTH_TOKEN_CLASSES':('rest_framework_simplejwt.tokens.AccessToken',),
|
||||
'TOKEN_TYPE_CLAIM':'token_type',
|
||||
"ACCESS_TOKEN_LIFETIME": timedelta(hours=5),
|
||||
"REFRESH_TOKEN_LIFETIME": timedelta(days=14),
|
||||
"ROTATE_REFRESH_TOKENS": True,
|
||||
"BLACKLIST_AFTER_ROTATION": True,
|
||||
"ALGORITHM": "HS256",
|
||||
"SIGNING_KEY": SECRET_KEY,
|
||||
"VERIFYING_KEY": None,
|
||||
"AUTH_HEADER_TYPES": ("JWT",),
|
||||
"USER_ID_FIELD": "id",
|
||||
"USER_ID_CLAIM": "user_id",
|
||||
"AUTH_TOKEN_CLASSES": ("rest_framework_simplejwt.tokens.AccessToken",),
|
||||
"TOKEN_TYPE_CLAIM": "token_type",
|
||||
}
|
||||
|
||||
# CORS settings
|
||||
@@ -181,8 +183,8 @@ CORS_ALLOWED_ORIGINS = [
|
||||
|
||||
# channel settings
|
||||
CHANNEL_LAYERS = {
|
||||
'default': {
|
||||
'BACKEND': 'channels.layers.InMemoryChannelLayer',
|
||||
"default": {
|
||||
"BACKEND": "channels.layers.InMemoryChannelLayer",
|
||||
},
|
||||
}
|
||||
|
||||
@@ -198,8 +200,60 @@ CHANNEL_LAYERS = {
|
||||
# EMAIL_TIMEOUT = os.getenv("APP_EMAIL_TIMEOUT", 60)
|
||||
|
||||
# SMTP2GO
|
||||
EMAIL_HOST = 'mail.smtp2go.com'
|
||||
EMAIL_HOST_USER = 'info.aimloperations.com'
|
||||
EMAIL_HOST_PASSWORD = 'ZDErIII2sipNNVMz'
|
||||
EMAIL_HOST = "mail.smtp2go.com"
|
||||
EMAIL_HOST_USER = "info.aimloperations.com"
|
||||
EMAIL_HOST_PASSWORD = "ZDErIII2sipNNVMz"
|
||||
EMAIL_PORT = 2525
|
||||
EMAIL_USE_TLS = True
|
||||
EMAIL_USE_TLS = True
|
||||
|
||||
# Captcha
|
||||
CAPTCHA_SECRET_KEY = "6LfENu4qAAAAABdrj6JTviq-LfdPP5imhE-Os7h9"
|
||||
|
||||
directory_path = 'logs'
|
||||
# LOGGING = {
|
||||
# 'version': 1,
|
||||
# 'disable_existing_loggers': False,
|
||||
# 'formatters': {
|
||||
# 'verbose': {
|
||||
# 'format': '{levelname} {asctime} {module} {process:d} {thread:d} {message}',
|
||||
# 'style': '{',
|
||||
# },
|
||||
# 'simple': {
|
||||
# 'format': '{levelname} {message}',
|
||||
# 'style': '{',
|
||||
# },
|
||||
# },
|
||||
# 'handlers': {
|
||||
# 'console': {
|
||||
# 'level': 'INFO',
|
||||
# 'class': 'logging.StreamHandler',
|
||||
# 'formatter': 'simple',
|
||||
# },
|
||||
# 'file': {
|
||||
# 'level': 'DEBUG',
|
||||
# 'class': 'logging.handlers.RotatingFileHandler',
|
||||
# 'filename': f'{directory_path}/django.log',
|
||||
# 'maxBytes': 1024 * 1024 * 5, # 5 MB
|
||||
# 'backupCount': 5,
|
||||
# 'formatter': 'verbose',
|
||||
# },
|
||||
# },
|
||||
# 'loggers': {
|
||||
# 'django': {
|
||||
# 'handlers': ['console', 'file'],
|
||||
# 'level': 'INFO',
|
||||
# 'propagate': True,
|
||||
# },
|
||||
# 'my_app': {
|
||||
# 'handlers': ['console', 'file'],
|
||||
# 'level': 'DEBUG',
|
||||
# 'propagate': False,
|
||||
# },
|
||||
# },
|
||||
# }
|
||||
|
||||
os.makedirs(directory_path, exist_ok=True)
|
||||
|
||||
# Feature Flags
|
||||
ALLOW_IMAGE_GENERATION = False
|
||||
ALLOW_INTERNET_ACCESS = True
|
||||
|
||||
@@ -13,12 +13,17 @@ Including another URLconf
|
||||
1. Import the include() function: from django.urls import include, path
|
||||
2. Add a URL to urlpatterns: path('blog/', include('blog.urls'))
|
||||
"""
|
||||
|
||||
from django.contrib import admin
|
||||
from django.urls import path, include
|
||||
from django.conf import settings
|
||||
from django.conf.urls.static import static
|
||||
|
||||
urlpatterns = [
|
||||
path('admin/', admin.site.urls),
|
||||
path('api/', include('chat_backend.urls')),
|
||||
] + static(settings.STATIC_URL, document_root=settings.STATIC_ROOT) + static(settings.MEDIA_URL, document_root=settings.MEDIA_ROOT)
|
||||
urlpatterns = (
|
||||
[
|
||||
path("admin/", admin.site.urls),
|
||||
path("api/", include("chat_backend.urls")),
|
||||
]
|
||||
+ static(settings.STATIC_URL, document_root=settings.STATIC_ROOT)
|
||||
+ static(settings.MEDIA_URL, document_root=settings.MEDIA_ROOT)
|
||||
)
|
||||
|
||||
@@ -11,6 +11,6 @@ import os
|
||||
|
||||
from django.core.wsgi import get_wsgi_application
|
||||
|
||||
os.environ.setdefault('DJANGO_SETTINGS_MODULE', 'llm_be.settings')
|
||||
os.environ.setdefault("DJANGO_SETTINGS_MODULE", "llm_be.settings")
|
||||
|
||||
application = get_wsgi_application()
|
||||
|
||||
@@ -6,7 +6,7 @@ import sys
|
||||
|
||||
def main():
|
||||
"""Run administrative tasks."""
|
||||
os.environ.setdefault('DJANGO_SETTINGS_MODULE', 'llm_be.settings')
|
||||
os.environ.setdefault("DJANGO_SETTINGS_MODULE", "llm_be.settings")
|
||||
try:
|
||||
from django.core.management import execute_from_command_line
|
||||
except ImportError as exc:
|
||||
@@ -18,5 +18,5 @@ def main():
|
||||
execute_from_command_line(sys.argv)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
8
llm_be/templates/admin/base_site.html
Normal file
8
llm_be/templates/admin/base_site.html
Normal file
@@ -0,0 +1,8 @@
|
||||
{% extends "admin/base_site.html" %}
|
||||
|
||||
{% block extrahead %}
|
||||
{{ block.super }}
|
||||
{% if not debug %}
|
||||
<script async defer src="https://tianji.aimloperations.com/tracker.js" data-website-id="cm7x7mrcy03kfddsw2jyejzub"></script>
|
||||
{% endif %}
|
||||
{% endblock %}
|
||||
@@ -30,16 +30,16 @@ djangorestframework-simplejwt==5.3.1
|
||||
duckdb==1.1.3
|
||||
et_xmlfile==2.0.0
|
||||
exceptiongroup==1.2.2
|
||||
Faker==33.1.0
|
||||
Faker
|
||||
filelock==3.16.1
|
||||
fonttools==4.55.3
|
||||
frozenlist==1.5.0
|
||||
fsspec==2024.12.0
|
||||
greenlet==3.1.1
|
||||
h11==0.14.0
|
||||
httpcore==1.0.7
|
||||
httpx==0.27.2
|
||||
httpx-sse==0.4.0
|
||||
httpcore
|
||||
httpx
|
||||
httpx-sse
|
||||
hyperlink==21.0.0
|
||||
idna==3.10
|
||||
importlib_resources==6.4.5
|
||||
@@ -48,14 +48,14 @@ Jinja2==3.1.5
|
||||
jiter==0.8.2
|
||||
jsonpatch==1.33
|
||||
jsonpointer==3.0.0
|
||||
kiwisolver==1.4.7
|
||||
langchain==0.3.13
|
||||
langchain-community==0.3.13
|
||||
langchain-core==0.3.28
|
||||
langchain-ollama==0.2.2
|
||||
langchain-openai==0.2.14
|
||||
langchain-text-splitters==0.3.4
|
||||
langsmith==0.2.7
|
||||
kiwisolver
|
||||
langchain
|
||||
langchain-community
|
||||
langchain-core
|
||||
langchain-ollama
|
||||
langchain-openai
|
||||
langchain-text-splitters
|
||||
langsmith
|
||||
lxml==5.3.0
|
||||
MarkupSafe==3.0.2
|
||||
marshmallow==3.23.2
|
||||
@@ -77,14 +77,14 @@ nvidia-cusparse-cu12==12.3.1.170
|
||||
nvidia-nccl-cu12==2.21.5
|
||||
nvidia-nvjitlink-cu12==12.4.127
|
||||
nvidia-nvtx-cu12==12.4.127
|
||||
ollama==0.4.5
|
||||
ollama-python==0.1.2
|
||||
openai==1.58.1
|
||||
ollama
|
||||
ollama-python
|
||||
openai
|
||||
openpyxl==3.1.5
|
||||
orjson==3.10.13
|
||||
packaging==24.2
|
||||
pandas==2.2.3
|
||||
pandasai==2.4.1
|
||||
pandasai
|
||||
pathspec==0.12.1
|
||||
pillow==11.0.0
|
||||
platformdirs==4.3.6
|
||||
|
||||
265
requirements.txt
Normal file
265
requirements.txt
Normal file
@@ -0,0 +1,265 @@
|
||||
accelerate==1.12.0
|
||||
aiofiles==25.1.0
|
||||
aiohappyeyeballs==2.6.1
|
||||
aiohttp==3.13.2
|
||||
aiosignal==1.4.0
|
||||
annotated-doc==0.0.4
|
||||
annotated-types==0.7.0
|
||||
antlr4-python3-runtime==4.13.2
|
||||
anyio==4.12.0
|
||||
asgiref==3.11.0
|
||||
astor==0.8.1
|
||||
attrs==25.4.0
|
||||
autobahn==25.11.1
|
||||
Automat==25.4.16
|
||||
backoff==2.2.1
|
||||
bcrypt==5.0.0
|
||||
beautifulsoup4==4.14.3
|
||||
black==25.11.0
|
||||
brotli==1.2.0
|
||||
build==1.3.0
|
||||
cachetools==5.5.2
|
||||
cbor2==5.7.1
|
||||
certifi==2025.11.12
|
||||
cffi==2.0.0
|
||||
channels==4.3.2
|
||||
chardet==5.2.0
|
||||
charset-normalizer==3.4.4
|
||||
chroma-hnswlib==0.7.6
|
||||
chromadb==1.3.5
|
||||
click==8.3.1
|
||||
coloredlogs==15.0.1
|
||||
constantly==23.10.4
|
||||
contourpy==1.3.3
|
||||
cryptography==46.0.3
|
||||
cycler==0.12.1
|
||||
daphne==4.2.1
|
||||
dataclasses-json==0.6.7
|
||||
ddgs==9.9.3
|
||||
Deprecated==1.3.1
|
||||
distro==1.9.0
|
||||
Django==6.0
|
||||
django-autoslug==1.9.9
|
||||
django-cors-headers==4.9.0
|
||||
django-filter==25.2
|
||||
djangorestframework==3.16.1
|
||||
djangorestframework_simplejwt==5.5.1
|
||||
duckdb==1.4.2
|
||||
durationpy==0.10
|
||||
effdet==0.4.1
|
||||
emoji==2.15.0
|
||||
et_xmlfile==2.0.0
|
||||
eval_type_backport==0.3.1
|
||||
fake-useragent==2.2.0
|
||||
Faker==38.2.0
|
||||
fastapi==0.124.0
|
||||
filelock==3.20.0
|
||||
filetype==1.2.0
|
||||
flatbuffers==25.9.23
|
||||
fonttools==4.61.0
|
||||
frozenlist==1.8.0
|
||||
fsspec==2025.12.0
|
||||
google-api-core==2.28.1
|
||||
google-auth==2.43.0
|
||||
google-cloud-vision==3.11.0
|
||||
googleapis-common-protos==1.72.0
|
||||
greenlet==3.3.0
|
||||
grpcio==1.76.0
|
||||
grpcio-status==1.76.0
|
||||
h11==0.16.0
|
||||
h2==4.3.0
|
||||
hf-xet==1.2.0
|
||||
hpack==4.1.0
|
||||
html5lib==1.1
|
||||
httpcore==1.0.9
|
||||
httptools==0.7.1
|
||||
httpx==0.28.1
|
||||
httpx-sse==0.4.3
|
||||
huggingface-hub==0.36.0
|
||||
humanfriendly==10.0
|
||||
hyperframe==6.1.0
|
||||
hyperlink==21.0.0
|
||||
idna==3.11
|
||||
importlib_metadata==8.7.0
|
||||
importlib_resources==6.5.2
|
||||
Incremental==24.11.0
|
||||
Jinja2==3.1.6
|
||||
jiter==0.12.0
|
||||
joblib==1.5.2
|
||||
jsonpatch==1.33
|
||||
jsonpointer==3.0.0
|
||||
jsonschema==4.25.1
|
||||
jsonschema-specifications==2025.9.1
|
||||
kiwisolver==1.4.9
|
||||
kubernetes==34.1.0
|
||||
langchain==1.1.2
|
||||
langchain-chroma==1.0.0
|
||||
langchain-classic==1.0.0
|
||||
langchain-community==0.4.1
|
||||
langchain-core==1.1.1
|
||||
langchain-ollama==1.0.0
|
||||
langchain-text-splitters==1.0.0
|
||||
langdetect==1.0.9
|
||||
langgraph==1.0.4
|
||||
langgraph-checkpoint==3.0.1
|
||||
langgraph-prebuilt==1.0.5
|
||||
langgraph-sdk==0.2.14
|
||||
langsmith==0.4.56
|
||||
lxml==6.0.2
|
||||
Markdown==3.10
|
||||
markdown-it-py==4.0.0
|
||||
MarkupSafe==3.0.3
|
||||
marshmallow==3.26.1
|
||||
matplotlib==3.10.7
|
||||
mdurl==0.1.2
|
||||
ml_dtypes==0.5.4
|
||||
mmh3==5.2.0
|
||||
mpmath==1.3.0
|
||||
msgpack==1.1.2
|
||||
multidict==6.7.0
|
||||
mypy_extensions==1.1.0
|
||||
nest-asyncio==1.6.0
|
||||
networkx==3.6
|
||||
nltk==3.9.2
|
||||
numpy==2.2.6
|
||||
nvidia-cublas-cu12==12.8.4.1
|
||||
nvidia-cuda-cupti-cu12==12.8.90
|
||||
nvidia-cuda-nvrtc-cu12==12.8.93
|
||||
nvidia-cuda-runtime-cu12==12.8.90
|
||||
nvidia-cudnn-cu12==9.10.2.21
|
||||
nvidia-cufft-cu12==11.3.3.83
|
||||
nvidia-cufile-cu12==1.13.1.3
|
||||
nvidia-curand-cu12==10.3.9.90
|
||||
nvidia-cusolver-cu12==11.7.3.90
|
||||
nvidia-cusparse-cu12==12.5.8.93
|
||||
nvidia-cusparselt-cu12==0.7.1
|
||||
nvidia-nccl-cu12==2.27.5
|
||||
nvidia-nvjitlink-cu12==12.8.93
|
||||
nvidia-nvshmem-cu12==3.3.20
|
||||
nvidia-nvtx-cu12==12.8.90
|
||||
oauthlib==3.3.1
|
||||
olefile==0.47
|
||||
ollama==0.6.1
|
||||
omegaconf==2.3.0
|
||||
onnx==1.20.0
|
||||
onnxruntime==1.23.2
|
||||
openai==2.9.0
|
||||
opencv-python==4.12.0.88
|
||||
openpyxl==3.1.5
|
||||
opentelemetry-api==1.39.0
|
||||
opentelemetry-exporter-otlp-proto-common==1.39.0
|
||||
opentelemetry-exporter-otlp-proto-grpc==1.39.0
|
||||
opentelemetry-instrumentation==0.53b1
|
||||
opentelemetry-instrumentation-asgi==0.53b1
|
||||
opentelemetry-instrumentation-fastapi==0.53b1
|
||||
opentelemetry-proto==1.39.0
|
||||
opentelemetry-sdk==1.39.0
|
||||
opentelemetry-semantic-conventions==0.60b0
|
||||
opentelemetry-util-http==0.53b1
|
||||
orjson==3.11.5
|
||||
ormsgpack==1.12.0
|
||||
overrides==7.7.0
|
||||
packaging==25.0
|
||||
pandas==2.3.3
|
||||
pandasai==2.4.2
|
||||
parameterized==0.9.0
|
||||
pathspec==0.12.1
|
||||
pdf2image==1.17.0
|
||||
pdfminer.six==20251107
|
||||
pi==0.1.2
|
||||
pi_heif==1.1.1
|
||||
pikepdf==10.0.2
|
||||
pillow==12.0.0
|
||||
platformdirs==4.5.1
|
||||
posthog==5.4.0
|
||||
primp==0.15.0
|
||||
propcache==0.4.1
|
||||
proto-plus==1.26.1
|
||||
protobuf==6.33.2
|
||||
psutil==7.1.3
|
||||
py-ubjson==0.16.1
|
||||
pyasn1==0.6.1
|
||||
pyasn1_modules==0.4.2
|
||||
pybase64==1.4.3
|
||||
pycocotools==2.0.10
|
||||
pycparser==2.23
|
||||
pydantic==2.12.5
|
||||
pydantic-settings==2.12.0
|
||||
pydantic_core==2.41.5
|
||||
Pygments==2.19.2
|
||||
PyJWT==2.10.1
|
||||
pyOpenSSL==25.3.0
|
||||
pyparsing==3.2.5
|
||||
pypdf==6.4.0
|
||||
pypdfium2==5.1.0
|
||||
PyPika==0.48.9
|
||||
pyproject_hooks==1.2.0
|
||||
python-dateutil==2.9.0.post0
|
||||
python-docx==1.2.0
|
||||
python-dotenv==1.2.1
|
||||
python-iso639==2025.11.16
|
||||
python-magic==0.4.27
|
||||
python-multipart==0.0.20
|
||||
python-oxmsg==0.0.2
|
||||
pytokens==0.3.0
|
||||
pytz==2025.2
|
||||
PyYAML==6.0.3
|
||||
RapidFuzz==3.14.3
|
||||
referencing==0.37.0
|
||||
regex==2025.11.3
|
||||
requests==2.32.5
|
||||
requests-oauthlib==2.0.0
|
||||
requests-toolbelt==1.0.0
|
||||
rich==14.2.0
|
||||
rpds-py==0.30.0
|
||||
rsa==4.9.1
|
||||
safetensors==0.7.0
|
||||
scipy==1.16.3
|
||||
service-identity==24.2.0
|
||||
setuptools==80.9.0
|
||||
shellingham==1.5.4
|
||||
six==1.17.0
|
||||
sniffio==1.3.1
|
||||
socksio==1.0.0
|
||||
soupsieve==2.8
|
||||
SQLAlchemy==2.0.44
|
||||
sqlglot==28.1.0
|
||||
sqlglotrs==0.8.0
|
||||
sqlparse==0.5.4
|
||||
starlette==0.50.0
|
||||
sympy==1.14.0
|
||||
tenacity==9.1.2
|
||||
timm==1.0.22
|
||||
tokenizers==0.22.1
|
||||
torch==2.9.1
|
||||
torchvision==0.24.1
|
||||
tqdm==4.67.1
|
||||
transformers==4.57.3
|
||||
triton==3.5.1
|
||||
Twisted==25.5.0
|
||||
txaio==25.12.1
|
||||
typer==0.20.0
|
||||
typer-slim==0.20.0
|
||||
typing-inspect==0.9.0
|
||||
typing-inspection==0.4.2
|
||||
typing_extensions==4.15.0
|
||||
tzdata==2025.2
|
||||
ujson==5.11.0
|
||||
unstructured==0.18.21
|
||||
unstructured-client==0.42.4
|
||||
unstructured.pytesseract==0.3.15
|
||||
unstructured_inference==1.1.2
|
||||
urllib3==2.3.0
|
||||
uuid_utils==0.12.0
|
||||
uvicorn==0.38.0
|
||||
uvloop==0.22.1
|
||||
watchfiles==1.1.1
|
||||
webencodings==0.5.1
|
||||
websocket-client==1.9.0
|
||||
websockets==15.0.1
|
||||
wrapt==2.0.1
|
||||
xxhash==3.6.0
|
||||
yarl==1.22.0
|
||||
zipp==3.23.0
|
||||
zope.interface==8.1.1
|
||||
zstandard==0.25.0
|
||||
9
strip_and_upgrade.py
Normal file
9
strip_and_upgrade.py
Normal file
@@ -0,0 +1,9 @@
|
||||
outfile = open("requirements.txt",'w')
|
||||
for line in open('requirements.dev','r'):
|
||||
line = line.strip()
|
||||
if line:
|
||||
values = line.split('==')
|
||||
print(values[0])
|
||||
outfile.write(values[0] + '\n')
|
||||
|
||||
|
||||
Reference in New Issue
Block a user