Compare commits

..

7 Commits

Author SHA1 Message Date
77d7edd0dc Closes #4
Added site tracking
Can pick the model the use
Better handle llm model based on debug or not
2025-12-08 13:52:30 -06:00
eed1abedc8 updates 2025-12-07 06:31:06 -06:00
91bdb2fd2d Merging from prod 2025-09-24 12:05:22 -05:00
8a259158c8 Updated data analysis to generate images to perform data analysis 2025-09-24 11:49:08 -05:00
14d8211715 Allow for data analysis 2025-09-08 12:29:20 -05:00
951a58f2fa fixed chat service 2025-05-28 03:25:14 -05:00
a85f1222eb Syncing with updates from prod and formatted 2025-05-18 06:15:07 -05:00
28 changed files with 2282 additions and 1187 deletions

View File

@@ -9,7 +9,7 @@ from .models import (
Feedback, Feedback,
PromptMetric, PromptMetric,
DocumentWorkspace, DocumentWorkspace,
Document Document,
) )
# Register your models here. # Register your models here.
@@ -55,21 +55,26 @@ class LLMModelsAdmin(admin.ModelAdmin):
search_fields = ("name", "port", "description") search_fields = ("name", "port", "description")
class PromptInline(admin.TabularInline):
model = Prompt
class ConversationAdmin(admin.ModelAdmin): class ConversationAdmin(admin.ModelAdmin):
model = Conversation model = Conversation
list_display = ("title", "get_user_email", "deleted") list_display = ("title", "get_user_email", "deleted")
search_fields = ("title",) search_fields = ("title",)
inlines = [PromptInline,]
class PromptAdmin(admin.ModelAdmin): class PromptAdmin(admin.ModelAdmin):
model = Prompt model = Prompt
list_display = ("message", "user_created", "get_conversation_title") list_display = ("id","message", "user_created", "get_conversation_title","created")
search_fields = ("message",) search_fields = ("message",)
class PromptMetricAdmin(admin.ModelAdmin): class PromptMetricAdmin(admin.ModelAdmin):
model = PromptMetric model = PromptMetric
list_display = ( list_display = (
"id",
"event", "event",
"model_name", "model_name",
"prompt_length", "prompt_length",
@@ -77,16 +82,18 @@ class PromptMetricAdmin(admin.ModelAdmin):
"has_file", "has_file",
"file_type", "file_type",
"get_duration", "get_duration",
"created"
) )
class DocumentWorkspaceAdmin(admin.ModelAdmin): class DocumentWorkspaceAdmin(admin.ModelAdmin):
model = DocumentWorkspace model = DocumentWorkspace
list_display = ( list_display = (
"name", "name",
"company", "company",
) )
class DocumentAdmin(admin.ModelAdmin): class DocumentAdmin(admin.ModelAdmin):
model = Document model = Document
list_display = ( list_display = (

View File

@@ -9,9 +9,10 @@ class ChatBackendConfig(AppConfig):
def ready(self): def ready(self):
import chat_backend.signals import chat_backend.signals
FORCE_RELOAD = False FORCE_RELOAD = False
if True: #not settings.TESTING: # Don't run during tests if True: # not settings.TESTING: # Don't run during tests
try: try:
from .services.rag_services import AsyncRAGService from .services.rag_services import AsyncRAGService
from chat_backend.models import Document from chat_backend.models import Document

View File

@@ -0,0 +1,422 @@
import json
import base64
import logging
import pandas as pd
from datetime import datetime
from django.utils import timezone
from django.conf import settings
from django.core.files.base import ContentFile
from channels.generic.websocket import AsyncWebsocketConsumer
from channels.db import database_sync_to_async
from channels.layers import get_channel_layer
from asgiref.sync import sync_to_async, async_to_sync
from langchain_core.messages import HumanMessage, AIMessage
from langchain_community.vectorstores import Chroma
from langchain_community.embeddings import OllamaEmbeddings
from langchain_community.tools import DuckDuckGoSearchRun
from langchain_core.runnables import RunnableLambda, RunnableBranch, RunnablePassthrough
from langchain_core.tracers.context import collect_runs
from .models import Conversation, Prompt, PromptMetric, DocumentWorkspace, Document, CustomUser
from .serializers import PromptSerializer
from .services.llm_service import AsyncLLMService
from .services.rag_services import AsyncRAGService
from .services.title_generator import title_generator
from .services.moderation_classifier import moderation_classifier, ModerationLabel
from .services.prompt_classifier.prompt_classifier import PromptClassifier, PromptType
from .services.data_analysis_service import AsyncDataAnalysisService
logger = logging.getLogger(__name__)
CHANNEL_NAME: str = "llm_messages"
MODEL_NAME: str = "llama3.2"
PROMPT_CLASSIFIER = PromptClassifier()
@database_sync_to_async
def create_conversation(prompt, email, title):
# return the conversation id
conversation = Conversation.objects.create(title=title)
conversation.save()
user = CustomUser.objects.get(email=email)
conversation.user_id = user.id
conversation.save()
return conversation.id
@database_sync_to_async
def get_workspace(conversation_id):
conversation = Conversation.objects.get(id=conversation_id)
return DocumentWorkspace.objects.get(company=conversation.user.company)
@database_sync_to_async
def get_messages(conversation_id, prompt, file_string: str = None, file_type: str = ""):
messages = []
conversation = Conversation.objects.get(id=conversation_id)
logger.debug(file_string)
# add the prompt to the conversation
serializer = PromptSerializer(
data={
"message": prompt,
"user_created": True,
"created": timezone.now(),
}
)
if serializer.is_valid(raise_exception=True):
prompt_instance = serializer.save()
prompt_instance.conversation_id = conversation.id
prompt_instance.save()
if file_string:
file_name = f"prompt_{prompt_instance.id}_data.{file_type}"
f = ContentFile(file_string, name=file_name)
prompt_instance.file.save(file_name, f)
prompt_instance.file_type = file_type
prompt_instance.save()
for prompt_obj in Prompt.objects.filter(conversation__id=conversation_id):
messages.append(
{
"content": prompt_obj.message,
"role": "user" if prompt_obj.user_created else "assistant",
"has_file": prompt_obj.file_exists(),
"file": prompt_obj.file if prompt_obj.file_exists() else None,
"file_type": prompt_obj.file_type if prompt_obj.file_exists() else None,
}
)
# now transform the messages
transformed_messages = []
for message in messages:
if message["has_file"] and message["file_type"] != None:
if "csv" in message["file_type"]:
file_type = "csv"
altered_message = f"{message['content']}\n The file type is csv and the file contents are: {message['file'].read()}"
elif "xlsx" in message["file_type"]:
file_type = "xlsx"
df = pd.read_excel(message["file"].read())
altered_message = f"{message['content']}\n The file type is xlsx and the file contents are: {df}"
elif "txt" in message["file_type"]:
file_type = "txt"
altered_message = f"{message['content']}\n The file type is csv and the file contents are: {message['file'].read()}"
else:
altered_message = message["content"]
else:
altered_message = message["content"]
transformed_message = (
AIMessage(content=altered_message)
if message["role"] == "assistant"
else HumanMessage(content=altered_message)
)
transformed_messages.append(transformed_message)
return transformed_messages, prompt_instance
@database_sync_to_async
def save_generated_message(conversation_id, message):
conversation = Conversation.objects.get(id=conversation_id)
# add the prompt to the conversation
serializer = PromptSerializer(
data={
"message": message,
"user_created": False,
"created": timezone.now(),
}
)
if serializer.is_valid():
prompt_instance = serializer.save()
prompt_instance.conversation_id = conversation.id
prompt_instance = serializer.save()
else:
print(serializer.errors)
@database_sync_to_async
def create_prompt_metric(
prompt_id, prompt, has_file, file_type, model_name, conversation_id
):
prompt_metric = PromptMetric.objects.create(
prompt_id=prompt_id,
start_time=timezone.now(),
prompt_length=len(prompt),
has_file=has_file,
file_type=file_type,
model_name=model_name,
conversation_id=conversation_id,
)
prompt_metric.save()
return prompt_metric
@database_sync_to_async
def update_prompt_metric(prompt_metric, status):
prompt_metric.event = status
prompt_metric.save()
@database_sync_to_async
def finish_prompt_metric(prompt_metric, response_length):
logger.info(f"finish_prompt_metric: {response_length}")
prompt_metric.end_time = timezone.now()
prompt_metric.reponse_length = response_length
prompt_metric.event = "FINISHED"
prompt_metric.save(update_fields=["end_time", "reponse_length", "event"])
logger.info("finish_prompt_metric saved")
@database_sync_to_async
def get_retriever(conversation_id):
logger.info(f"getting workspace from conversation: {conversation_id}")
conversation = Conversation.objects.get(id=conversation_id)
logger.info(f"Got conversation: {conversation}")
workspace = DocumentWorkspace.objects.get(company=conversation.user.company)
logger.info(f"Got workspace: {conversation}")
vectorstore = Chroma(
persist_directory=f"./chroma_db/",
embedding=OllamaEmbeddings(model="llama3.2"),
)
return vectorstore.as_retriever()
async def get_conversation_file_async(conversation_id):
try:
# Get the very first prompt in the conversation that has a file
prompt_with_file = await Prompt.objects.filter(
conversation_id=conversation_id
).exclude(file='').order_by('created').afirst()
if prompt_with_file and prompt_with_file.file:
# You must use sync_to_async to access the file's binary content
file_data = await sync_to_async(prompt_with_file.file.read)()
file_type = prompt_with_file.file_type
return file_data, file_type
except Exception as e:
logger.error(f"Error retrieving file from conversation history: {e}")
return None, None
class ChatConsumerAgain(AsyncWebsocketConsumer):
async def connect(self):
await self.accept()
async def disconnect(self, close_code):
await self.close()
async def send_json_message(self, data_str):
"""
Ensures that the message sent over the websocket is a valid JSON object.
If data_str is a plain string, it wraps it in {"type": "text", "content": ...}.
"""
try:
# Test if it's already a valid JSON object string
json.loads(data_str)
# If it is, send it as is
await self.send(data_str)
except (json.JSONDecodeError, TypeError):
# If it's a plain string or not JSON-decodable, wrap it
await self.send(data_str)
async def receive(self, text_data=None, bytes_data=None):
logger.debug(f"Text Data: {text_data}")
logger.debug(f"Bytes Data: {bytes_data}")
if text_data:
data = json.loads(text_data)
message = data.get("message", None)
conversation_id = data.get("conversation_id", None)
email = data.get("email", None)
file = data.get("file", None)
file_type = data.get("fileType", "")
model = data.get("modelName", "Turbo")
if not conversation_id:
# we need to create a new conversation
# we will generate a name for it too
title = await title_generator.generate_async(message)
conversation_id = await create_conversation(message, email, title)
if conversation_id:
decoded_file = None
if file:
decoded_file = base64.b64decode(file)
logger.debug(decoded_file)
# The `altered_message` should only be created if a file exists
# and you want to pass its content directly to the classifier.
# Here, we'll let the classifier decide based on the user's prompt
# and then handle the file content separately.
altered_message = message
if "csv" in file_type:
file_type = "csv"
#altered_message = f"{message}\n The file type is csv and the file contents are: {decoded_file}"
elif "xmlformats-officedocument" in file_type:
file_type = "xlsx"
#df = pd.read_excel(decoded_file)
#altered_message = f"{message}\n The file type is xlsx and the file contents are: {df}"
elif "word" in file_type:
file_type = "docx"
elif "pdf" in file_type:
file_type = "pdf"
elif "text" in file_type:
file_type = "txt"
#altered_message = f"{message}\n The file type is txt and the file contents are: {decoded_file}"
else:
file_type = "Not Sure"
logger.info(f'received: "{message}" for conversation {conversation_id}')
# --- LangSmith Pipeline Construction ---
async def check_moderation(input_dict):
msg = input_dict["message"]
label = await moderation_classifier.classify_async(msg)
return {**input_dict, "moderation_label": label}
async def classify_prompt_step(input_dict):
if input_dict["moderation_label"] == ModerationLabel.NSFW:
return {**input_dict, "prompt_type": None} # Skip classification
msg = input_dict["message"]
decoded_file = input_dict.get("decoded_file")
prompt_type = await PROMPT_CLASSIFIER.classify_async(msg)
# Override logic
if decoded_file and (prompt_type == PromptType.DATA_ANALYSIS or 'analyze' in msg.lower() or 'data' in msg.lower()):
prompt_type = PromptType.DATA_ANALYSIS
elif decoded_file:
prompt_type = PromptType.GENERAL_CHAT
return {**input_dict, "prompt_type": prompt_type}
async def generate_response_step(input_dict):
if input_dict["moderation_label"] == ModerationLabel.NSFW:
response = "Prompt has been marked as NSFW. If this is in error, submit a feedback with the prompt text."
return {"type": "error", "content": response}
prompt_type = input_dict["prompt_type"]
messages = input_dict["messages"]
prompt_instance = input_dict["prompt_instance"]
conversation_id = input_dict["conversation_id"]
decoded_file = input_dict.get("decoded_file")
file_type = input_dict.get("file_type")
# Feature Flag: Image Generation
if prompt_type == PromptType.IMAGE_GENERATION:
if not getattr(settings, "ALLOW_IMAGE_GENERATION", False):
return {"type": "text", "content": "Image Generation is disabled."}
# If enabled, proceed (assuming implementation exists, but user said "have it set to false for now")
return {"type": "text", "content": "Image Generation is not supported at this time, but it will be soon."}
if prompt_type == PromptType.SEARCH:
# Check modelName first - if FAST, we skip search regardless of settings
if input_dict.get("model_name") == "FAST":
pass # Skip search
elif getattr(settings, "ALLOW_INTERNET_ACCESS", False):
try:
search = DuckDuckGoSearchRun()
search_results = search.run(input_dict["message"])
messages.append(HumanMessage(content=f"Search Results: {search_results}"))
except Exception as e:
logger.error(f"Search failed: {e}")
# If search fails, we proceed without it, essentially falling back to general chat
pass
else:
# If search is disabled, we could notify the user, but for now we'll just proceed
# potentially adding a system message or just letting the LLM handle it with its training data
pass
if prompt_type == PromptType.RAG:
service = AsyncRAGService()
workspace = await get_workspace(conversation_id)
return service.generate_response(messages, prompt_instance.message, workspace)
elif prompt_type == PromptType.DATA_ANALYSIS:
service = AsyncDataAnalysisService()
print(file_type)
if not decoded_file:
return {"type": "text", "content": "Please upload a file to perform data analysis."}
return service.generate_response(prompt_instance.message, decoded_file, file_type)
else: # GENERAL_CHAT or others
service = AsyncLLMService()
return service.generate_response(messages, prompt_instance.message, conversation_id)
# --- Execution ---
# Pre-fetch messages and file
messages, prompt_instance = await get_messages(
conversation_id, message, decoded_file, file_type
)
if not decoded_file:
decoded_file, file_type = await get_conversation_file_async(conversation_id)
if file:
# udpate with the altered_message (logic from original)
# Note: altered_message was defined in original but not fully used in the messages list construction in the same way
# In original: messages = messages[:-1] + [HumanMessage(content=altered_message)]
# I need to replicate that if I want exact behavior.
# But altered_message was only set if file was present.
pass # Logic is already in get_messages for the most part, but the original code had a specific override at the end.
# Let's trust get_messages for now or add the override if needed.
# Original:
# if file:
# messages = messages[:-1] + [HumanMessage(content=altered_message)]
# I'll add it to the input_dict if needed.
prompt_metric = await create_prompt_metric(
prompt_instance.id,
prompt_instance.message,
True if file else False,
file_type,
MODEL_NAME,
conversation_id,
)
pipeline_input = {
"message": message,
"conversation_id": conversation_id,
"decoded_file": decoded_file,
"file_type": file_type,
"messages": messages,
"prompt_instance": prompt_instance,
"model_name": model
}
# Run the pipeline steps manually to handle the async generator return type of generate_response_step
# A pure RunnableSequence might struggle with the async generator return.
# So I'll chain them in python but conceptually it's one pipeline.
step1 = await check_moderation(pipeline_input)
step2 = await classify_prompt_step(step1)
# Send start markers
await self.send("CONVERSATION_ID")
await self.send(str(conversation_id))
await self.send("START_OF_THE_STREAM_ENDER_GAME_42")
response_generator_or_dict = await generate_response_step(step2)
full_response = ""
if isinstance(response_generator_or_dict, dict):
# It's an error or simple message
content = response_generator_or_dict.get("content", "")
await self.send_json_message(json.dumps(response_generator_or_dict))
full_response = content
else:
# It's an async generator
async for chunk in response_generator_or_dict:
full_response += chunk
await self.send_json_message(chunk)
await self.send("END_OF_THE_STREAM_ENDER_GAME_42")
await save_generated_message(conversation_id, full_response)
await finish_prompt_metric(prompt_metric, len(full_response))
if bytes_data:
logger.info("we have byte data")

View File

@@ -0,0 +1,363 @@
import json
import base64
import logging
import pandas as pd
from datetime import datetime
from typing import TypedDict, Annotated, List, Union, Dict, Any
from django.utils import timezone
from django.conf import settings
from django.core.files.base import ContentFile
from channels.generic.websocket import AsyncWebsocketConsumer
from channels.db import database_sync_to_async
from langchain_core.messages import HumanMessage, AIMessage, BaseMessage
from langchain_community.tools import DuckDuckGoSearchRun
from langgraph.graph import StateGraph, END
from .models import Conversation, Prompt, PromptMetric, DocumentWorkspace, CustomUser
from .serializers import PromptSerializer
from .services.llm_service import AsyncLLMService
from .services.rag_services import AsyncRAGService
from .services.title_generator import title_generator
from .services.moderation_classifier import moderation_classifier, ModerationLabel
from .services.prompt_classifier.prompt_classifier import PromptClassifier, PromptType
from .services.data_analysis_service import AsyncDataAnalysisService
logger = logging.getLogger(__name__)
CHANNEL_NAME: str = "llm_messages"
MODEL_NAME: str = "llama3.2"
PROMPT_CLASSIFIER = PromptClassifier()
# --- Database Helpers (Reused) ---
@database_sync_to_async
def create_conversation(prompt, email, title):
conversation = Conversation.objects.create(title=title)
user = CustomUser.objects.get(email=email)
conversation.user_id = user.id
conversation.save()
return conversation.id
@database_sync_to_async
def get_workspace(conversation_id):
conversation = Conversation.objects.get(id=conversation_id)
return DocumentWorkspace.objects.get(company=conversation.user.company)
@database_sync_to_async
def get_messages(conversation_id, prompt, file_string: str = None, file_type: str = ""):
messages = []
conversation = Conversation.objects.get(id=conversation_id)
serializer = PromptSerializer(
data={
"message": prompt,
"user_created": True,
"created": timezone.now(),
}
)
if serializer.is_valid(raise_exception=True):
prompt_instance = serializer.save()
prompt_instance.conversation_id = conversation.id
prompt_instance.save()
if file_string:
file_name = f"prompt_{prompt_instance.id}_data.{file_type}"
f = ContentFile(file_string, name=file_name)
prompt_instance.file.save(file_name, f)
prompt_instance.file_type = file_type
prompt_instance.save()
for prompt_obj in Prompt.objects.filter(conversation__id=conversation_id):
messages.append(
{
"content": prompt_obj.message,
"role": "user" if prompt_obj.user_created else "assistant",
"has_file": prompt_obj.file_exists(),
"file": prompt_obj.file if prompt_obj.file_exists() else None,
"file_type": prompt_obj.file_type if prompt_obj.file_exists() else None,
}
)
transformed_messages = []
for message in messages:
if message["has_file"] and message["file_type"] != None:
# Simplified handling compared to original, as we rely on services to handle files now
# But we keep the structure for context
altered_message = message["content"]
else:
altered_message = message["content"]
transformed_message = (
AIMessage(content=altered_message)
if message["role"] == "assistant"
else HumanMessage(content=altered_message)
)
transformed_messages.append(transformed_message)
return transformed_messages, prompt_instance
@database_sync_to_async
def save_generated_message(conversation_id, message):
conversation = Conversation.objects.get(id=conversation_id)
serializer = PromptSerializer(
data={
"message": message,
"user_created": False,
"created": timezone.now(),
}
)
if serializer.is_valid():
prompt_instance = serializer.save()
prompt_instance.conversation_id = conversation.id
prompt_instance.save()
else:
print(serializer.errors)
@database_sync_to_async
def create_prompt_metric(prompt_id, prompt, has_file, file_type, model_name, conversation_id):
prompt_metric = PromptMetric.objects.create(
prompt_id=prompt_id,
start_time=timezone.now(),
prompt_length=len(prompt),
has_file=has_file,
file_type=file_type,
model_name=model_name,
conversation_id=conversation_id,
)
return prompt_metric
@database_sync_to_async
def finish_prompt_metric(prompt_metric, response_length):
prompt_metric.end_time = timezone.now()
prompt_metric.reponse_length = response_length
prompt_metric.event = "FINISHED"
prompt_metric.save(update_fields=["end_time", "reponse_length", "event"])
async def get_conversation_file_async(conversation_id):
try:
prompt_with_file = await Prompt.objects.filter(
conversation_id=conversation_id
).exclude(file='').order_by('created').afirst()
if prompt_with_file and prompt_with_file.file:
file_data = await sync_to_async(prompt_with_file.file.read)()
file_type = prompt_with_file.file_type
return file_data, file_type
except Exception as e:
logger.error(f"Error retrieving file from conversation history: {e}")
return None, None
# --- LangGraph State ---
class ChatState(TypedDict):
message: str
conversation_id: int
decoded_file: Union[bytes, None]
file_type: Union[str, None]
messages: List[BaseMessage]
prompt_instance: Any # Django model instance
moderation_label: Union[ModerationLabel, None]
prompt_type: Union[PromptType, None]
response_generator: Any # AsyncGenerator or dict
error: Union[str, None]
model_name: str
# --- LangGraph Nodes ---
async def moderation_node(state: ChatState) -> ChatState:
msg = state["message"]
label = await moderation_classifier.classify_async(msg)
return {"moderation_label": label}
async def classification_node(state: ChatState) -> ChatState:
if state.get("moderation_label") == ModerationLabel.NSFW:
return {"prompt_type": None}
msg = state["message"]
decoded_file = state.get("decoded_file")
prompt_type = await PROMPT_CLASSIFIER.classify_async(msg)
# Override logic
if decoded_file and (prompt_type == PromptType.DATA_ANALYSIS or 'analyze' in msg.lower() or 'data' in msg.lower()):
prompt_type = PromptType.DATA_ANALYSIS
elif decoded_file:
prompt_type = PromptType.GENERAL_CHAT
return {"prompt_type": prompt_type}
async def generation_node(state: ChatState) -> ChatState:
if state.get("moderation_label") == ModerationLabel.NSFW:
response = "Prompt has been marked as NSFW. If this is in error, submit a feedback with the prompt text."
return {"response_generator": {"type": "error", "content": response}}
prompt_type = state["prompt_type"]
messages = state["messages"]
prompt_instance = state["prompt_instance"]
conversation_id = state["conversation_id"]
decoded_file = state.get("decoded_file")
file_type = state.get("file_type")
# Feature Flag: Image Generation
if prompt_type == PromptType.IMAGE_GENERATION:
if not getattr(settings, "ALLOW_IMAGE_GENERATION", False):
return {"response_generator": {"type": "text", "content": "Image Generation is disabled."}}
return {"response_generator": {"type": "text", "content": "Image Generation is not supported at this time, but it will be soon."}}
# Feature Flag: Internet Access
if prompt_type == PromptType.SEARCH:
# Check modelName first - if FAST, we skip search regardless of settings
if state.get("model_name") == "FAST":
pass
elif getattr(settings, "ALLOW_INTERNET_ACCESS", False):
try:
search = DuckDuckGoSearchRun()
search_results = search.run(state["message"])
messages.append(HumanMessage(content=f"Search Results: {search_results}"))
except Exception as e:
logger.error(f"Search failed: {e}")
pass
else:
pass
if prompt_type == PromptType.RAG:
service = AsyncRAGService()
workspace = await get_workspace(conversation_id)
generator = service.generate_response(messages, prompt_instance.message, workspace)
return {"response_generator": generator}
elif prompt_type == PromptType.DATA_ANALYSIS:
service = AsyncDataAnalysisService()
if not decoded_file:
return {"response_generator": {"type": "text", "content": "Please upload a file to perform data analysis."}}
generator = service.generate_response(prompt_instance.message, decoded_file, file_type)
return {"response_generator": generator}
else: # GENERAL_CHAT or others
service = AsyncLLMService()
generator = service.generate_response(messages, prompt_instance.message, conversation_id)
return {"response_generator": generator}
# --- LangGraph Definition ---
workflow = StateGraph(ChatState)
workflow.add_node("moderation", moderation_node)
workflow.add_node("classification", classification_node)
workflow.add_node("generation", generation_node)
workflow.set_entry_point("moderation")
workflow.add_edge("moderation", "classification")
workflow.add_edge("classification", "generation")
workflow.add_edge("generation", END)
app = workflow.compile()
# --- Consumer ---
class ChatConsumerGraph(AsyncWebsocketConsumer):
async def connect(self):
await self.accept()
async def disconnect(self, close_code):
await self.close()
async def send_json_message(self, data_str):
try:
json.loads(data_str)
await self.send(data_str)
except (json.JSONDecodeError, TypeError):
await self.send(data_str)
async def receive(self, text_data=None, bytes_data=None):
logger.debug(f"Text Data: {text_data}")
print("Text Data: ", text_data)
if text_data:
data = json.loads(text_data)
model = data.get("modelName", "Turbo")
message = data.get("message", None)
conversation_id = data.get("conversation_id", None)
email = data.get("email", None)
file = data.get("file", None)
file_type = data.get("fileType", "")
if not conversation_id:
title = await title_generator.generate_async(message)
conversation_id = await create_conversation(message, email, title)
if conversation_id:
print("Conversation ID: ", conversation_id)
decoded_file = None
if file:
decoded_file = base64.b64decode(file)
if "csv" in file_type: file_type = "csv"
elif "xmlformats-officedocument" in file_type: file_type = "xlsx"
elif "word" in file_type: file_type = "docx"
elif "pdf" in file_type: file_type = "pdf"
elif "text" in file_type: file_type = "txt"
else: file_type = "Not Sure"
# Pre-fetch messages and file
messages, prompt_instance = await get_messages(
conversation_id, message, decoded_file, file_type
)
print("Messages: ", messages)
if not decoded_file:
decoded_file, file_type = await get_conversation_file_async(conversation_id)
prompt_metric = await create_prompt_metric(
prompt_instance.id,
prompt_instance.message,
True if file else False,
file_type,
MODEL_NAME,
conversation_id,
)
# Initialize State
initial_state = {
"message": message,
"conversation_id": conversation_id,
"decoded_file": decoded_file,
"file_type": file_type,
"messages": messages,
"prompt_instance": prompt_instance,
"moderation_label": None,
"prompt_type": None,
"response_generator": None,
"error": None,
"model_name": model
}
print("Initial State: ", initial_state)
# Run Graph
final_state = await app.ainvoke(initial_state)
print("Final State: ", final_state)
response_generator_or_dict = final_state["response_generator"]
print("Response Generator: ", response_generator_or_dict)
# Send start markers
await self.send("CONVERSATION_ID")
await self.send(str(conversation_id))
await self.send("START_OF_THE_STREAM_ENDER_GAME_42")
full_response = ""
if isinstance(response_generator_or_dict, dict):
content = response_generator_or_dict.get("content", "")
await self.send_json_message(json.dumps(response_generator_or_dict))
full_response = content
else:
async for chunk in response_generator_or_dict:
full_response += chunk
await self.send_json_message(chunk)
await self.send("END_OF_THE_STREAM_ENDER_GAME_42")
await save_generated_message(conversation_id, full_response)
await finish_prompt_metric(prompt_metric, len(full_response))

View File

@@ -0,0 +1,20 @@
# Generated by Django 5.1.7 on 2025-09-24 16:44
from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
("chat_backend", "0020_documentworkspace_document"),
]
operations = [
migrations.AlterField(
model_name="prompt",
name="message",
field=models.CharField(
help_text="The text for a prompt", max_length=102400
),
),
]

View File

@@ -53,6 +53,9 @@ class Company(TimeInfoBase):
help_text="A list of LLMs that company can use", help_text="A list of LLMs that company can use",
) )
def __str__(self):
return self.name
class CustomUser(AbstractUser): class CustomUser(AbstractUser):
company = models.ForeignKey( company = models.ForeignKey(
@@ -71,7 +74,7 @@ class CustomUser(AbstractUser):
) )
def get_set_password_url(self): def get_set_password_url(self):
return f"https://www.chat.aimloperations.com/set_password?slug={self.slug}" return f"https://chat.aimloperations.com/set_password?slug={self.slug}"
FEEDBACK_CHOICE = ( FEEDBACK_CHOICE = (
@@ -158,7 +161,7 @@ class Conversation(TimeInfoBase):
class Prompt(TimeInfoBase): class Prompt(TimeInfoBase):
message = models.CharField(max_length=10 * 1024, help_text="The text for a prompt") message = models.CharField(max_length=100 * 1024, help_text="The text for a prompt")
user_created = models.BooleanField( user_created = models.BooleanField(
help_text="True if was created by the user. False if it was generate by the LLM" help_text="True if was created by the user. False if it was generate by the LLM"
) )
@@ -220,14 +223,16 @@ class PromptMetric(TimeInfoBase):
return difference.seconds return difference.seconds
return 0 return 0
# Document Models # Document Models
class DocumentWorkspace(TimeInfoBase): class DocumentWorkspace(TimeInfoBase):
name = models.CharField(max_length=255) name = models.CharField(max_length=255)
company = models.ForeignKey(Company, on_delete=models.CASCADE) company = models.ForeignKey(Company, on_delete=models.CASCADE)
class Document(TimeInfoBase): class Document(TimeInfoBase):
workspace = models.ForeignKey(DocumentWorkspace, on_delete=models.CASCADE) workspace = models.ForeignKey(DocumentWorkspace, on_delete=models.CASCADE)
file = models.FileField(upload_to='documents/') file = models.FileField(upload_to="documents/")
uploaded_at = models.DateTimeField(auto_now_add=True) uploaded_at = models.DateTimeField(auto_now_add=True)
processed = models.BooleanField(default=False) processed = models.BooleanField(default=False)
active = models.BooleanField(default=False) active = models.BooleanField(default=False)

View File

@@ -1,6 +1,8 @@
from django.urls import re_path from django.urls import re_path
from .views import ChatConsumerAgain from .consumers import ChatConsumerAgain
from .consumers_graph import ChatConsumerGraph
websocket_urlpatterns = [ websocket_urlpatterns = [
re_path(r"ws/chat_again/$", ChatConsumerAgain.as_asgi()), re_path(r"ws/chat_again/$", ChatConsumerAgain.as_asgi()),
re_path(r"ws/conditional_chat/$", ChatConsumerGraph.as_asgi()),
] ]

View File

@@ -9,7 +9,7 @@ from .models import (
Feedback, Feedback,
FEEDBACK_CATEGORIES, FEEDBACK_CATEGORIES,
DocumentWorkspace, DocumentWorkspace,
Document Document,
) )
@@ -99,11 +99,20 @@ class BasicUserSerializer(serializers.ModelSerializer):
class DocumentWorkspaceSerializer(serializers.ModelSerializer): class DocumentWorkspaceSerializer(serializers.ModelSerializer):
class Meta: class Meta:
model = DocumentWorkspace model = DocumentWorkspace
fields = ['id', 'name', 'created'] fields = ["id", "name", "created"]
read_only_fields = ['id', 'created'] read_only_fields = ["id", "created"]
class DocumentSerializer(serializers.ModelSerializer): class DocumentSerializer(serializers.ModelSerializer):
class Meta: class Meta:
model = Document model = Document
fields = ['id', 'workspace', 'file', 'uploaded_at', 'processed', 'created', 'active'] fields = [
read_only_fields = ['id', 'uploaded_at', 'processed', 'created'] "id",
"workspace",
"file",
"uploaded_at",
"processed",
"created",
"active",
]
read_only_fields = ["id", "uploaded_at", "processed", "created"]

View File

@@ -0,0 +1,18 @@
from abc import ABC, abstractmethod
from langchain_ollama import OllamaLLM
from langchain_core.output_parsers import StrOutputParser
from django.conf import settings
class BaseService(ABC):
"""Abstract base class for LLM conversation services."""
def __init__(self, temperature=0.7):
self.llm = OllamaLLM(
model="llama3.2" if not settings.DEBUG else "gpt-oss:20b",
temperature=0.7,
top_k=50,
top_p=0.9,
repeat_penalty=1.1,
num_ctx=4096,
)
self.output_parser = StrOutputParser()

View File

@@ -0,0 +1,194 @@
import pandas as pd
import io
import re
import json
import base64
import matplotlib.pyplot as plt
from typing import AsyncGenerator
from langchain_core.prompts import ChatPromptTemplate
from langchain_ollama import OllamaLLM
from langchain_core.output_parsers import StrOutputParser
import docx
import pypdf
from django.conf import settings
class AsyncDataAnalysisService:
"""Asynchronous service for performing data analysis with an LLM."""
def __init__(self):
# A model with a large context window and strong analytical skills is best
self.llm = OllamaLLM(
model="llama3.2" if not settings.DEBUG else "gpt-oss:20b",
temperature=0.3,
num_ctx=8192,
)
self.output_parser = StrOutputParser()
self._setup_chain()
def _setup_chain(self):
"""Set up the LLM chain with a prompt tailored for data analysis."""
template = """You are an expert data analyst. Your role is to directly answer a user's question about a dataset or document they have provided.
You will be given a summary and a sample of the dataset, or the content of the document.
Based on this information, provide a clear and concise answer to the user's question.
Do not provide Python code or any other code. The user is not a developer and wants a direct answer.
Even if you don't think the data provides enough evidence for the query, still provide a response
---
Data/Document Content:
{data_summary}
---
User's Question: {query}
Answer:"""
self.prompt = ChatPromptTemplate.from_template(template)
self.analysis_chain = (
{
"data_summary": lambda x: x["data_summary"],
"query": lambda x: x["query"],
}
| self.prompt
| self.llm
| self.output_parser
)
def _get_dataframe_summary(self, df: pd.DataFrame) -> str:
"""Generates a structured summary of the DataFrame for the LLM."""
num_rows, num_cols = df.shape
summary_lines = [
f"DataFrame has {num_rows} rows and {num_cols} columns.",
"Column Information (Name, Dtype, Non-Null Count):",
"--------------------------------------------------",
]
# Add a concise summary using df.info()
info_buffer = io.StringIO()
df.info(buf=info_buffer, verbose=True, show_counts=True)
summary_lines.append(info_buffer.getvalue())
summary_lines.append("\nDescriptive Statistics (for numerical columns):")
summary_lines.append("--------------------------------------------")
summary_lines.append(df.describe().to_string())
summary_lines.append("\nSample of Data:")
summary_lines.append("-----------------")
# Show the first 5 rows and a few random rows to give a feel for the data
summary_lines.append(df.head(5).to_string())
return "\n".join(summary_lines)
def _read_docx(self, file_bytes: bytes) -> str:
"""Reads text from a DOCX file."""
doc = docx.Document(io.BytesIO(file_bytes))
full_text = []
for para in doc.paragraphs:
full_text.append(para.text)
return "\n".join(full_text)
def _read_pdf(self, file_bytes: bytes) -> str:
"""Reads text from a PDF file."""
pdf_reader = pypdf.PdfReader(io.BytesIO(file_bytes))
full_text = []
for page in pdf_reader.pages:
full_text.append(page.extract_text())
return "\n".join(full_text)
def _generate_plot(self, query: str, df: pd.DataFrame) -> str:
"""
Generates a plot from a DataFrame based on a natural language query,
encodes it in Base64, and returns it.
If columns are specified (e.g., "plot X vs Y"), it uses them.
If not, it automatically picks the first two numerical columns.
"""
col1, col2 = None, None
title = "Scatter Plot"
# Attempt to find explicitly mentioned columns, e.g., "plot Column1 vs Column2"
match = re.search(r"(?:plot|scatter|visualize)\s+(.*?)\s+(?:vs|versus|and)\s+(.*)", query, re.IGNORECASE)
if match:
potential_col1 = match.group(1).strip()
potential_col2 = match.group(2).strip()
if potential_col1 in df.columns and potential_col2 in df.columns:
col1, col2 = potential_col1, potential_col2
title = f"Scatterplot of {col1} vs {col2}"
# If no valid columns were explicitly found, auto-detect
if not col1 or not col2:
numeric_cols = df.select_dtypes(include=['number']).columns.tolist()
if len(numeric_cols) >= 2:
col1, col2 = numeric_cols[0], numeric_cols[1]
title = f"Scatterplot of {col1} vs {col2} (Auto-selected)"
else:
raise ValueError("I couldn't find two numerical columns to plot automatically. Please specify columns, like 'plot column_A vs column_B'.")
fig, ax = plt.subplots()
ax.scatter(df[col1], df[col2])
ax.set_xlabel(col1)
ax.set_ylabel(col2)
ax.set_title(title)
ax.grid(True)
buf = io.BytesIO()
plt.savefig(buf, format='png', bbox_inches='tight')
plt.close(fig)
buf.seek(0)
image_base64 = base64.b64encode(buf.read()).decode('utf-8')
return image_base64
async def generate_response(
self,
query: str,
decoded_file: bytes,
file_type: str,
) -> AsyncGenerator[str, None]:
"""
Generate a response based on the uploaded data and user query.
This can be a text analysis or a plot visualization.
"""
try:
df = None
data_summary = ""
file_type = file_type.lower()
print(file_type)
if "csv" in file_type:
df = pd.read_csv(io.BytesIO(decoded_file))
data_summary = self._get_dataframe_summary(df)
elif "xlsx" in file_type or "spreadsheet" in file_type:
df = pd.read_excel(io.BytesIO(decoded_file))
data_summary = self._get_dataframe_summary(df)
elif "word" in file_type or "docx" in file_type:
data_summary = self._read_docx(decoded_file)
elif "pdf" in file_type:
data_summary = self._read_pdf(decoded_file)
else:
yield json.dumps({"type": "error", "content": f"Unsupported file type: {file_type}. I can analyze CSV, XLSX, DOCX, and PDF files."})
return
# Only attempt plotting if we have a DataFrame
if df is not None:
plot_keywords = ["plot", "graph", "scatter", "visualize"]
if any(keyword in query.lower() for keyword in plot_keywords):
try:
image_base64 = self._generate_plot(query, df)
yield json.dumps({
"type": "plot",
"format": "png",
"image": image_base64
})
except ValueError as e:
yield json.dumps({"type": "error", "content": str(e)})
return
chain_input = {"data_summary": data_summary, "query": query}
async for chunk in self.analysis_chain.astream(chain_input):
yield chunk #json.dumps({"type": "text", "content": chunk})
except Exception as e:
yield json.dumps({"type": "error", "content": f"An error occurred: {e}"})

View File

@@ -7,6 +7,7 @@ from diffusers import StableDiffusionPipeline, DPMSolverSinglestepScheduler
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class ImageGenerationService: class ImageGenerationService:
""" """
Service for text-to-image generation using Stable Diffusion. Service for text-to-image generation using Stable Diffusion.
@@ -69,7 +70,7 @@ class ImageGenerationService:
prompt: str, prompt: str,
negative_prompt: Optional[str] = None, negative_prompt: Optional[str] = None,
output_path: Optional[str] = None, output_path: Optional[str] = None,
**kwargs **kwargs,
) -> Tuple[Image.Image, dict]: ) -> Tuple[Image.Image, dict]:
""" """
Generate image from text prompt. Generate image from text prompt.
@@ -94,9 +95,7 @@ class ImageGenerationService:
with torch.inference_mode(): with torch.inference_mode():
result = self.pipeline( result = self.pipeline(
prompt=prompt, prompt=prompt, negative_prompt=negative_prompt, **params
negative_prompt=negative_prompt,
**params
) )
image = result.images[0] image = result.images[0]
@@ -127,7 +126,7 @@ class AsyncImageGenerationService:
prompt: str, prompt: str,
negative_prompt: Optional[str] = None, negative_prompt: Optional[str] = None,
output_path: Optional[str] = None, output_path: Optional[str] = None,
**kwargs **kwargs,
) -> Tuple[Image.Image, dict]: ) -> Tuple[Image.Image, dict]:
"""Async version of generate_image""" """Async version of generate_image"""
import asyncio import asyncio
@@ -139,7 +138,7 @@ class AsyncImageGenerationService:
prompt=prompt, prompt=prompt,
negative_prompt=negative_prompt, negative_prompt=negative_prompt,
output_path=output_path, output_path=output_path,
**kwargs **kwargs,
) )
return await loop.run_in_executor(None, func) return await loop.run_in_executor(None, func)

View File

@@ -1,23 +1,26 @@
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import AsyncGenerator, Generator, Optional from typing import AsyncGenerator, Generator, Optional
from langchain_community.llms import Ollama # from langchain_community.llms import Ollama
from langchain_ollama import OllamaLLM
from langchain_core.output_parsers import StrOutputParser from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import ChatPromptTemplate from langchain_core.prompts import ChatPromptTemplate
from django.conf import settings
from chat_backend.models import Conversation, Prompt from chat_backend.models import Conversation, Prompt
class LLMService(ABC): class LLMService(ABC):
"""Abstract base class for LLM conversation services.""" """Abstract base class for LLM conversation services."""
def __init__(self): def __init__(self):
self.llm = Ollama( self.llm = OllamaLLM(
model="llama3.2", model="llama3.2" if not settings.DEBUG else "gpt-oss:20b",
temperature=0.7, temperature=0.7,
top_k=50, top_k=50,
top_p=0.9, top_p=0.9,
repeat_penalty=1.1, repeat_penalty=1.1,
num_ctx=4096 num_ctx=4096,
) )
self.output_parser = StrOutputParser() self.output_parser = StrOutputParser()
@@ -28,10 +31,11 @@ class LLMService(ABC):
def _format_history(self, conversation: Conversation) -> str: def _format_history(self, conversation: Conversation) -> str:
"""Format conversation history for the prompt.""" """Format conversation history for the prompt."""
prompts = Prompt.objects.filter(conversation=conversation).order_by('created_at') prompts = Prompt.objects.filter(conversation=conversation).order_by(
"created_at"
)
return "\n".join( return "\n".join(
f"{'User' if prompt.is_user else 'AI'}: {prompt.text}" f"{'User' if prompt.is_user else 'AI'}: {prompt.text}" for prompt in prompts
for prompt in prompts
) )
@@ -56,19 +60,18 @@ class SyncLLMService(LLMService):
self.conversation_chain = ( self.conversation_chain = (
{ {
"history": lambda x: self._format_history(x["conversation"]), "history": lambda x: self._format_history(x["conversation"]),
"query": lambda x: x["query"] "query": lambda x: x["query"],
} }
| self.prompt | self.prompt
| self.llm | self.llm
| self.output_parser | self.output_parser
) )
def generate_response(self, conversation: Conversation, query: str, **kwargs) -> Generator[str, None, None]: def generate_response(
self, conversation: Conversation, query: str, **kwargs
) -> Generator[str, None, None]:
"""Generate response with streaming support.""" """Generate response with streaming support."""
chain_input = { chain_input = {"query": query, "conversation": conversation}
"query": query,
"conversation": conversation
}
for chunk in self.conversation_chain.stream(chain_input): for chunk in self.conversation_chain.stream(chain_input):
yield chunk yield chunk
@@ -102,37 +105,48 @@ class AsyncLLMService(LLMService):
self.conversation_chain = ( self.conversation_chain = (
{ {
"context": lambda x: self._format_history(x["conversation"]), "context":lambda x: x["conversation"],
"recent_history": lambda x: self._get_recent_messages(x["conversation"]), "recent_history":lambda x: x['recent_conversation'],
"query": lambda x: x["query"] "query": lambda x: x["query"],
} }
| self.prompt | self.prompt
| self.llm | self.llm
| self.output_parser | self.output_parser
) )
async def _format_history(self, conversation: Conversation) -> str: async def _format_history(self, conversation: list) -> str:
"""Async version of format conversation history.""" """Async version of format conversation history."""
prompts = await Prompt.objects.filter(conversation=conversation).order_by('created_at').alist() # prompts = list(
return "\n".join( # await Prompt.objects.filter(conversation_id=conversation_id)
f"{'User' if prompt.is_user else 'AI'}: {prompt.text}" # .order_by("created")
for prompt in prompts
)
async def _get_recent_messages(self, conversation: Conversation) -> str: # )
# return "\n".join(
# f"{'User' if prompt.is_user else 'AI'}: {prompt.text}" for prompt in prompts
# )
return "\n".join([f"{"User" if prompt.type=="human" else "AI"}: {prompt.text()}" for prompt in conversation])
async def _get_recent_messages(self, conversation: list) -> str:
"""Async version of format conversation history.""" """Async version of format conversation history."""
prompts = await Prompt.objects.filter(conversation=conversation).order_by('created_at').alist()[-3:]
return "\n".join(
f"{'User' if prompt.is_user else 'AI'}: {prompt.text}"
for prompt in prompts
)
async def generate_response(self, conversation: Conversation, query: str, **kwargs) -> AsyncGenerator[str, None]: # prompts = list(
# await Prompt.objects.filter(conversation_id=conversation_id)
# .order_by("created")
# [-6:]
# )
# return "\n".join(
# f"{'User' if prompt.is_user else 'AI'}: {prompt.text}" for prompt in prompts
# )
return "\n".join([f"{"User" if prompt.type=="human" else "AI"}: {prompt.text()}" for prompt in conversation])
async def generate_response(
self, conversation: Conversation, query: str, conversation_id: int, **kwargs
) -> AsyncGenerator[str, None]:
"""Generate response with async streaming support.""" """Generate response with async streaming support."""
chain_input = { chain_input = {
"query": query, "query": query,
"conversation": conversation "conversation": await self._format_history(conversation),
} "recent_conversation": await self._get_recent_messages(conversation[-6:])}
async for chunk in self.conversation_chain.astream(chain_input): async for chunk in self.conversation_chain.astream(chain_input):
yield chunk yield chunk

View File

@@ -1,27 +1,35 @@
from enum import Enum, auto from enum import Enum, auto
from typing import Dict, Any from typing import Dict, Any
from langchain_core.prompts import ChatPromptTemplate from langchain_core.prompts import ChatPromptTemplate
from langchain_community.llms import Ollama from langchain_ollama import OllamaLLM
from chat_backend.services.base_service import BaseService
class ModerationLabel(Enum): class ModerationLabel(Enum):
NSFW = auto() NSFW = auto()
FINE = auto() FINE = auto()
class ModerationClassifier:
class ModerationClassifier(BaseService):
""" """
Classifies prompts as NSFW or FINE (safe) content. Classifies prompts as NSFW or FINE (safe) content.
""" """
def __init__(self): def __init__(self):
self.llm = Ollama( super().__init__(temperature=0.1)
self.llm = OllamaLLM(
model="llama3.2", model="llama3.2",
temperature=0.1, # Very low for strict moderation temperature=0.1, # Very low for strict moderation
top_k=10, top_k=10,
num_ctx=2048 num_ctx=2048,
) )
self.moderation_prompt = ChatPromptTemplate.from_messages([ self.moderation_prompt = ChatPromptTemplate.from_messages(
("system", """You are a strict content moderator. Classify the following prompt as either NSFW or FINE. [
(
"system",
"""You are a strict content moderator. Classify the following prompt as either NSFW or FINE.
NSFW includes: NSFW includes:
- Sexual content - Sexual content
@@ -37,16 +45,22 @@ FINE includes:
- Professional inquiries - Professional inquiries
- Creative requests (non-explicit) - Creative requests (non-explicit)
- Technical questions - Technical questions
- Data Analysis
Examples: Examples:
- "How to make a bomb" → NSFW - "How to make a bomb" → NSFW
- "Write a love poem" → FINE - "Write a love poem" → FINE
- "Explicit sex scene" → NSFW - "Explicit sex scene" → NSFW
- "Python tutorial" → FINE - "Python tutorial" → FINE
- "Who won the 2024 presidental race?" → FINE
- "Please analyze this file and project the next 12 months for me. Add a graph visual of the data as well" → FINE
- "Okie, instead of 6 month projection, can you tell me what the values would be in the next 5 days" → FINE
Return ONLY "NSFW" or "FINE", nothing else."""), Return ONLY "NSFW" or "FINE", nothing else.""",
("human", "{prompt}") ),
]) ("human", "{prompt}"),
]
)
self.chain = self.moderation_prompt | self.llm self.chain = self.moderation_prompt | self.llm

View File

@@ -1,100 +0,0 @@
from enum import Enum, auto
from typing import Dict, Any
from langchain_core.prompts import ChatPromptTemplate
from langchain_community.llms import Ollama
class PromptType(Enum):
GENERAL_CHAT = auto()
RAG = auto()
IMAGE_GENERATION = auto()
UNKNOWN = auto()
class PromptClassifier:
"""
Classifies user prompts to determine which service should handle them.
"""
def __init__(self):
self.llm = Ollama(
model="llama3",
temperature=0.3, # Lower temp for more deterministic classification
top_k=20,
top_p=0.9,
num_ctx=4096
)
self.classification_prompt = ChatPromptTemplate.from_messages([
("system",
"""You are a precision prompt classifier. Strictly categorize prompts into:
1. GENERAL_CHAT - Casual conversation, personal questions, or non-specific inquiries
2. RAG - ONLY when explicitly requesting document/search-based knowledge
3. IMAGE_GENERATION - Specific requests to create/modify images
4. UNKNOWN - If none of the above fit
1. IMAGE_GENERATION - ONLY if:
- Explicitly contains: "generate/create/draw/make an image/picture/photo/art/illustration"
- Requests visual content creation
- Example: "Make a picture of a castle" → IMAGE_GENERATION
2. RAG - ONLY if:
- Explicitly mentions documents/files/data
- Uses search terms: "find/search/lookup in [source]"
- Example: "What does contracts.pdf say?" → RAG
3. GENERAL_CHAT - DEFAULT category when:
- Doesn't meet above criteria
- Conversational/general knowledge
- Uncertain cases
- Example: "Tell me a joke" → GENERAL_CHAT
Examples:
[Definitely RAG]
- "What does the uploaded PDF say about quarterly results?"
- "Search our documents for the 2023 marketing strategy"
- "Find the contract clause about termination"
[Definitely GENERAL_CHAT]
- "How does photosynthesis work?" (General knowledge)
- "Tell me a joke"
- "What's your opinion on AI?"
[Borderline → GENERAL_CHAT]
- "What's our company policy on X?" (No doc reference → general)
- "Explain quantum computing" (General knowledge)
- "Summarize the meeting" (No doc reference)
Return ONLY the label, no explanations."""),
("human", "{prompt}")
])
self.chain = self.classification_prompt | self.llm
async def classify_async(self, prompt: str) -> PromptType:
"""Asynchronously classify the prompt"""
try:
response = await self.chain.ainvoke({"prompt": prompt})
return self._parse_response(response.strip())
except Exception as e:
print(f"Classification error: {e}")
return PromptType.UNKNOWN
def classify(self, prompt: str) -> PromptType:
"""Synchronously classify the prompt"""
try:
response = self.chain.invoke({"prompt": prompt})
return self._parse_response(response.strip())
except Exception as e:
print(f"Classification error: {e}")
return PromptType.UNKNOWN
def _parse_response(self, response: str) -> PromptType:
"""Convert string response to PromptType enum"""
response = response.upper()
for prompt_type in PromptType:
if prompt_type.name in response:
return prompt_type
return PromptType.UNKNOWN
# Singleton instance for easy access
prompt_classifier = PromptClassifier()

View File

@@ -0,0 +1,195 @@
from enum import Enum, auto
from typing import Dict, Any
from langchain_core.prompts import ChatPromptTemplate
from django.conf import settings
from chat_backend.services.base_service import BaseService
class PromptType(Enum):
GENERAL_CHAT = auto()
RAG = auto()
IMAGE_GENERATION = auto()
DATA_ANALYSIS = auto()
SEARCH = auto()
UNKNOWN = auto()
class PromptClassifier(BaseService):
"""
Classifies user prompts to determine which service should handle them.
"""
def __init__(self):
super().__init__(temperature=0.1)
# self.llm = OllamaLLM(
# model="llama3.2",
# temperature=0.3, # Lower temp for more deterministic classification
# top_k=20,
# top_p=0.9,
# num_ctx=4096,
# )
self.classification_prompt = ChatPromptTemplate.from_messages(
[
(
"system",
"""You are a precision prompt classifier. Strictly categorize prompts into:
1. GENERAL_CHAT - Casual conversation, personal questions, or non-specific inquiries
2. RAG - ONLY when explicitly requesting document/search-based knowledge
3. IMAGE_GENERATION - Specific requests to create/modify images
4. DATA_ANALYSIS - When a user wants to read an uploaded document (PDF, Word, etc.) and generate an index, summary, or extract structured information. Includes prompts like "Please read this document and make me an index for it".
5. SEARCH - When the user is seeking specific, up-to-date information (e.g., current events, celebrity news, sports scores).
6. UNKNOWN - If none of the above fit
1. IMAGE_GENERATION - ONLY if:
- Explicitly contains: "generate/create/draw/make an image/picture/photo/art/illustration"
- Requests visual content creation
- Example: "Make a picture of a castle" → IMAGE_GENERATION
2. RAG - ONLY if:
- Explicitly mentions documents/files/data
- Uses search terms: "find/search/lookup in [source]"
- Example: "What does contracts.pdf say?" → RAG
3. DATA_ANALYSIS - ONLY if:
- The user provides or references an uploaded document (PDF, Word, etc.) and asks for an index, summary, extraction, or analysis of its contents.
- Example: "Please read this document and make me an index for it" → DATA_ANALYSIS
- Example: "Here is the file content. What is the sum of all 'Sales'?" → DATA_ANALYSIS
5. SEARCH - ONLY if:
- User asks for current information (news, weather, sports, stock prices)
- User asks for specific facts that might change or require lookup (e.g. "Who won the 2024 election?")
- Example: "What is the latest news on X?" → SEARCH
- Example: "Who won the Super Bowl this year?" → SEARCH
6. GENERAL_CHAT - DEFAULT category when:
- Doesn't meet above criteria
- Conversational/general knowledge (that doesn't require live search)
- Creative writing (poems, jokes)
- Example: "Tell me a joke" → GENERAL_CHAT
- Example: "Write a poem about cats" → GENERAL_CHAT
Examples:
[Definitely RAG]
- "What does the uploaded PDF say about quarterly results?"
- "Search our documents for the 2023 marketing strategy"
[Definitely DATA_ANALYSIS]
- "Please read this document and make me an index for it"
- "Here is the file content. What is the sum of all 'Sales'?"
- "Based on this CSV data, show me the top 5 customers."
[Definitely SEARCH]
- "Who won the 2024 presidential race?"
- "What is the latest celebrity news?"
- "Current stock price of Apple"
[Definitely GENERAL_CHAT]
- "How does photosynthesis work?" (General knowledge)
- "Tell me a joke"
- "What's your opinion on AI?"
[Borderline -> GENERAL_CHAT]
- "What's our company policy on X?" (No doc reference -> general)
Return ONLY the exact Enum label (e.g. "GENERAL_CHAT"), no explanations."""
),
("human", "{prompt}"),
]
)
self.chain = self.classification_prompt | self.llm
def _quick_check(self, prompt: str) -> PromptType | None:
"""
Performs a quick, rule-based classification before involving the LLM.
Returns a PromptType if a clear match is found, otherwise None.
"""
lower_prompt = prompt.lower()
# IMAGE_GENERATION
if any(keyword in lower_prompt for keyword in ["generate image", "create picture", "draw an image", "make a photo", "generate an image", "create an illustration"]):
if getattr(settings, "ALLOW_IMAGE_GENERATION", False):
return PromptType.IMAGE_GENERATION
# If disabled, we don't return IMAGE_GENERATION here, let it fall through or return GENERAL_CHAT?
# The requirement says "NOT allowed to return IMAGE_GENERATION".
# If we return None, it goes to LLM classification which we also need to guard.
# But for quick check, if it looks like image gen but is disabled, we probably want to treat it as general chat
# so the LLM can explain it can't do it (or just chat about it).
return PromptType.GENERAL_CHAT
# DATA_ANALYSIS (often involves uploaded documents)
if any(keyword in lower_prompt for keyword in ["read this document", "analyze this file", "summarize this pdf", "extract data from", "index this document", "based on this csv", "from this spreadsheet"]):
return PromptType.DATA_ANALYSIS
# RAG (explicitly asking to search within provided context/documents)
# This might overlap with DATA_ANALYSIS, but RAG is more about retrieval from a knowledge base.
# The prompt examples for RAG are "What does the uploaded PDF say about quarterly results?"
# "Search our documents for the 2023 marketing strategy"
if ("uploaded pdf" in lower_prompt or "our documents" in lower_prompt or "this document" in lower_prompt or "the document" in lower_prompt) and \
any(keyword in lower_prompt for keyword in ["say about", "search for", "find in", "lookup in"]):
return PromptType.RAG
# SEARCH
if any(keyword in lower_prompt for keyword in ["latest news", "current weather", "stock price", "who won", "what is the current", "breaking news", "real-time information", "up-to-date"]):
return PromptType.SEARCH
return None
async def classify_async(self, prompt: str) -> PromptType:
"""Asynchronously classify the prompt"""
quick = self._quick_check(prompt)
if quick:
return quick
try:
response = await self.chain.ainvoke({"prompt": prompt})
return self._parse_response(response.strip())
except Exception as e:
print(f"Classification error: {e}")
return PromptType.UNKNOWN
def classify(self, prompt: str) -> PromptType:
"""Synchronously classify the prompt"""
quick = self._quick_check(prompt)
if quick:
return quick
try:
response = self.chain.invoke({"prompt": prompt})
return self._parse_response(response.strip())
except Exception as e:
print(f"Classification error: {e}")
return PromptType.UNKNOWN
def _parse_response(self, response: str) -> PromptType:
"""Convert string response to PromptType enum"""
response = response.upper().strip()
print(response)
# Guard against IMAGE_GENERATION if disabled
if not getattr(settings, "ALLOW_IMAGE_GENERATION", False):
if "IMAGE_GENERATION" in response or "IMAGEGENERATION" in response:
return PromptType.GENERAL_CHAT
# Direct match
try:
return PromptType[response]
except KeyError:
pass
# Handle missing underscores (e.g. GENERALCHAT)
normalized_response = response.replace("_", "")
for prompt_type in PromptType:
if prompt_type.name.replace("_", "") == normalized_response:
return prompt_type
# Substring match as fallback
for prompt_type in PromptType:
if prompt_type.name in response:
return prompt_type
return PromptType.UNKNOWN
# Singleton instance for easy access
prompt_classifier = PromptClassifier()

View File

@@ -0,0 +1,31 @@
import os
from unittest import TestCase, mock
from unittest.mock import MagicMock, patch, AsyncMock
from typing import List, Dict, Any
from django.test import TestCase as DjangoTestCase
from chat_backend.services.rag_services import (
RAGService,
SyncRAGService,
AsyncRAGService,
)
from chat_backend.models import Conversation, Prompt, DocumentWorkspace, Document
from chat_backend.services.prompt_classifier.prompt_classifier import PromptClassifier, PromptType
from parameterized import parameterized
class PromptClassifierTestCase(TestCase):
def setUp(self):
self.service = PromptClassifier()
@parameterized.expand([
["Tell me a joke",PromptType.GENERAL_CHAT],
["Create an image of a dog for me",PromptType.IMAGE_GENERATION],
["highlight the features of the backyard playset if they were to choose us and make the language more long form",PromptType.GENERAL_CHAT],
["Great, can you make it about a duck now", PromptType.IMAGE_GENERATION],
])
def test_prompt_classification(self, prompt, expected_output):
result = self.service.classify(prompt)
self.assertEqual(result, expected_output)

View File

@@ -3,7 +3,10 @@ from abc import ABC, abstractmethod
from typing import List, Dict, Any, AsyncGenerator, Generator, Optional from typing import List, Dict, Any, AsyncGenerator, Generator, Optional
from channels.db import database_sync_to_async from channels.db import database_sync_to_async
from langchain_community.embeddings import OllamaEmbeddings from langchain_community.embeddings import OllamaEmbeddings
from langchain_community.llms import Ollama from django.conf import settings
# from langchain_community.llms import Ollama
from langchain_ollama import OllamaLLM
from langchain_community.vectorstores import Chroma from langchain_community.vectorstores import Chroma
from langchain_core.documents import Document as LangDocument from langchain_core.documents import Document as LangDocument
from langchain_core.output_parsers import StrOutputParser from langchain_core.output_parsers import StrOutputParser
@@ -14,11 +17,13 @@ from langchain_community.document_loaders import (
PyPDFLoader, PyPDFLoader,
Docx2txtLoader, Docx2txtLoader,
TextLoader, TextLoader,
UnstructuredFileLoader UnstructuredFileLoader,
) )
from django.core.files.uploadedfile import UploadedFile from django.core.files.uploadedfile import UploadedFile
from chat_backend.models import Conversation, Prompt, DocumentWorkspace, Document from chat_backend.models import Conversation, Prompt, DocumentWorkspace, Document
from pathlib import Path from pathlib import Path
from chat_backend.services.base_service import BaseService
@database_sync_to_async @database_sync_to_async
def get_documents(workspace: DocumentWorkspace | None = None): def get_documents(workspace: DocumentWorkspace | None = None):
@@ -28,9 +33,9 @@ def get_documents(workspace: DocumentWorkspace | None = None):
return [doc for doc in Document.objects.all()] return [doc for doc in Document.objects.all()]
class RAGService(BaseService):
class RAGService(ABC):
"""Abstract base class for RAG services.""" """Abstract base class for RAG services."""
_instance = None _instance = None
def __new__(cls): def __new__(cls):
@@ -40,36 +45,27 @@ class RAGService(ABC):
return cls._instance return cls._instance
def __init__(self): def __init__(self):
self.embedding_model = OllamaEmbeddings(model="llama3.2") self.embedding_model = OllamaEmbeddings(model="llama3.2" if not settings.DEBUG else "gpt-oss:20b")
self.llm = Ollama( super().__init__()
model="llama3.2",
temperature=0.7,
top_k=50,
top_p=0.9,
repeat_penalty=1.1,
num_ctx=4096
)
self.text_splitter = RecursiveCharacterTextSplitter( self.text_splitter = RecursiveCharacterTextSplitter(
chunk_size=1000, chunk_size=1000, chunk_overlap=200
chunk_overlap=200
) )
self.vector_store = self._initialize_vector_store() self.vector_store = self._initialize_vector_store()
# Supported file types and their loaders # Supported file types and their loaders
self.loader_mapping = { self.loader_mapping = {
'.pdf': PyPDFLoader, ".pdf": PyPDFLoader,
'.docx': Docx2txtLoader, ".docx": Docx2txtLoader,
'.txt': TextLoader, ".txt": TextLoader,
# Fallback for other file types # Fallback for other file types
'*': UnstructuredFileLoader, "*": UnstructuredFileLoader,
} }
def _initialize_vector_store(self) -> Chroma: def _initialize_vector_store(self) -> Chroma:
"""Initialize and return the Chroma vector store.""" """Initialize and return the Chroma vector store."""
persist_directory=f"./chroma_db/" persist_directory = f"./chroma_db/"
vector_store = Chroma( vector_store = Chroma(
embedding_function=self.embedding_model, embedding_function=self.embedding_model, persist_directory=persist_directory
persist_directory=persist_directory
) )
return vector_store return vector_store
@@ -84,16 +80,14 @@ class RAGService(ABC):
for doc in documents: for doc in documents:
print(f"Processing: {doc.file.name}") print(f"Processing: {doc.file.name}")
loader_class = self._get_file_loader( doc.file.name) loader_class = self._get_file_loader(doc.file.name)
loader = loader_class(doc.file) loader = loader_class(doc.file)
chunks = self._load_and_split_documents(doc.file.path) chunks = self._load_and_split_documents(doc.file.path)
if chunks: if chunks:
self.vector_store.add_documents(chunks) self.vector_store.add_documents(chunks)
self.vector_store.persist() self.vector_store.persist()
def ingest_documents(self, workspace: DocumentWorkspace | None = None) -> None: def ingest_documents(self, workspace: DocumentWorkspace | None = None) -> None:
"""Ingest documents from a workspace into the vector store.""" """Ingest documents from a workspace into the vector store."""
print(f"Getting the Document via the workspace: {workspace}") print(f"Getting the Document via the workspace: {workspace}")
@@ -105,25 +99,26 @@ class RAGService(ABC):
print(f"Processing the documents : {documents}") print(f"Processing the documents : {documents}")
self._prepare_documents(documents) self._prepare_documents(documents)
# @abstractmethod
# def generate_response(self, conversation: Conversation, query: str, **kwargs):
# """Generate a response using RAG."""
# pass
@abstractmethod # @abstractmethod
def generate_response(self, conversation: Conversation, query: str, **kwargs): # def search_documents(
"""Generate a response using RAG.""" # self, query: str, workspace: Optional[DocumentWorkspace] = None, k: int = 4
pass # ) -> List[Document]:
# """Search relevant documents from the vector store."""
@abstractmethod # pass
def search_documents(self, query: str, workspace: Optional[DocumentWorkspace] = None, k: int = 4) -> List[Document]:
"""Search relevant documents from the vector store."""
pass
def _get_file_loader(self, file_path: str): def _get_file_loader(self, file_path: str):
"""Get appropriate loader for file type""" """Get appropriate loader for file type"""
ext = Path(file_path).suffix.lower() ext = Path(file_path).suffix.lower()
return self.loader_mapping.get(ext, self.loader_mapping['*']) return self.loader_mapping.get(ext, self.loader_mapping["*"])
def _sanitize_filename(self, filename: str) -> str: def _sanitize_filename(self, filename: str) -> str:
"""Sanitize filename for safe storage""" """Sanitize filename for safe storage"""
return re.sub(r'[^\w\-_. ]', '_', filename) return re.sub(r"[^\w\-_. ]", "_", filename)
def _save_uploaded_file(self, uploaded_file: UploadedFile, save_dir: str) -> str: def _save_uploaded_file(self, uploaded_file: UploadedFile, save_dir: str) -> str:
"""Save uploaded file to disk""" """Save uploaded file to disk"""
@@ -131,13 +126,15 @@ class RAGService(ABC):
sanitized_name = self._sanitize_filename(uploaded_file.name) sanitized_name = self._sanitize_filename(uploaded_file.name)
file_path = os.path.join(save_dir, sanitized_name) file_path = os.path.join(save_dir, sanitized_name)
with open(file_path, 'wb+') as destination: with open(file_path, "wb+") as destination:
for chunk in uploaded_file.chunks(): for chunk in uploaded_file.chunks():
destination.write(chunk) destination.write(chunk)
return file_path return file_path
def _load_and_split_documents(self, file_path: str, metadata: dict = None) -> List[Document]: def _load_and_split_documents(
self, file_path: str, metadata: dict = None
) -> List[Document]:
"""Load and split documents from file""" """Load and split documents from file"""
loader_class = self._get_file_loader(file_path) loader_class = self._get_file_loader(file_path)
loader = loader_class(file_path) loader = loader_class(file_path)
@@ -154,7 +151,7 @@ class RAGService(ABC):
file_tupls: List[UploadedFile], # (file_path, name,workspace_id) file_tupls: List[UploadedFile], # (file_path, name,workspace_id)
workspace_id: str, workspace_id: str,
source: str = "upload", source: str = "upload",
save_dir: str = "data/uploads" save_dir: str = "data/uploads",
) -> Dict[str, Any]: ) -> Dict[str, Any]:
""" """
Process and add uploaded files to vector store Process and add uploaded files to vector store
@@ -168,23 +165,18 @@ class RAGService(ABC):
Returns: Returns:
Dictionary with processing results Dictionary with processing results
""" """
results = { results = {"total_added": 0, "failed_files": [], "processed_files": []}
'total_added': 0,
'failed_files': [],
'processed_files': []
}
for file_tuple in file_tupls: for file_tuple in file_tupls:
try: try:
# Save file to disk # Save file to disk
# Prepare metadata # Prepare metadata
metadata = { metadata = {
'source': file_tuple[1], "source": file_tuple[1],
'workspace_id': file_tuple[2], "workspace_id": file_tuple[2],
'original_filename': file_tuple[1], "original_filename": file_tuple[1],
'file_path': file_tuple[0], "file_path": file_tuple[0],
} }
# Load and split documents # Load and split documents
@@ -193,17 +185,15 @@ class RAGService(ABC):
# Add to vector store # Add to vector store
if docs: if docs:
self.vector_store.add_documents(docs) self.vector_store.add_documents(docs)
results['total_added'] += len(docs) results["total_added"] += len(docs)
results['processed_files'].append({ results["processed_files"].append(
'filename': file_tuple[1], {"filename": file_tuple[1], "document_count": len(docs)}
'document_count': len(docs) )
})
except Exception as e: except Exception as e:
results['failed_files'].append({ results["failed_files"].append(
'filename': file_tuple[1], {"filename": file_tuple[1], "error": str(e)}
'error': str(e) )
})
continue continue
# Persist changes # Persist changes
@@ -234,7 +224,7 @@ class SyncRAGService(RAGService):
{ {
"context": self._retriever_with_history, "context": self._retriever_with_history,
"history": lambda x: self._format_history(x["conversation"]), "history": lambda x: self._format_history(x["conversation"]),
"question": lambda x: x["query"] "question": lambda x: x["query"],
} }
| self.prompt | self.prompt
| self.llm | self.llm
@@ -243,10 +233,11 @@ class SyncRAGService(RAGService):
def _format_history(self, conversation: Conversation) -> str: def _format_history(self, conversation: Conversation) -> str:
"""Format conversation history for the prompt.""" """Format conversation history for the prompt."""
prompts = Prompt.objects.filter(conversation=conversation).order_by('created_at') prompts = Prompt.objects.filter(conversation=conversation).order_by(
"created_at"
)
return "\n".join( return "\n".join(
f"{'User' if prompt.is_user else 'AI'}: {prompt.text}" f"{'User' if prompt.is_user else 'AI'}: {prompt.text}" for prompt in prompts
for prompt in prompts
) )
def _retriever_with_history(self, input_dict: Dict[str, Any]) -> str: def _retriever_with_history(self, input_dict: Dict[str, Any]) -> str:
@@ -262,8 +253,9 @@ class SyncRAGService(RAGService):
else: else:
return relevant_docs return relevant_docs
def search_documents(
def search_documents(self, query: str, workspace: Optional[DocumentWorkspace] = None, k: int = 4) -> List[Document]: self, query: str, workspace: Optional[DocumentWorkspace] = None, k: int = 4
) -> List[Document]:
"""Search relevant documents from the vector store.""" """Search relevant documents from the vector store."""
filter_dict = {} filter_dict = {}
if workspace: if workspace:
@@ -271,19 +263,15 @@ class SyncRAGService(RAGService):
print(f"search_kwargs: {search_kwargs}") print(f"search_kwargs: {search_kwargs}")
retriever = self.vector_store.as_retriever( retriever = self.vector_store.as_retriever(
search_type="similarity", search_type="similarity",
search_kwargs={ search_kwargs={"k": k, "filter": filter_dict if filter_dict else None},
"k": k,
"filter": filter_dict if filter_dict else None
}
) )
return retriever.get_relevant_documents(query) return retriever.get_relevant_documents(query)
def generate_response(self, conversation: Conversation, query: str, **kwargs) -> Generator[str, None, None]: def generate_response(
self, conversation: Conversation, query: str, **kwargs
) -> Generator[str, None, None]:
"""Generate response with streaming support.""" """Generate response with streaming support."""
chain_input = { chain_input = {"query": query, "conversation": conversation}
"query": query,
"conversation": conversation
}
for chunk in self.rag_chain.stream(chain_input): for chunk in self.rag_chain.stream(chain_input):
yield chunk yield chunk
@@ -311,8 +299,8 @@ class AsyncRAGService(RAGService):
self.rag_chain = ( self.rag_chain = (
{ {
"context": self._retriever_with_history, "context": self._retriever_with_history,
"history": lambda x: self._format_history(x["conversation"]), "history": lambda x: x['recent_conversation'], #self._format_history(x["conversation"]),
"question": lambda x: x["query"] "question": lambda x: x["query"],
} }
| self.prompt | self.prompt
| self.llm | self.llm
@@ -321,12 +309,16 @@ class AsyncRAGService(RAGService):
async def _format_history(self, conversation: Conversation) -> str: async def _format_history(self, conversation: Conversation) -> str:
"""Format conversation history for the prompt.""" """Format conversation history for the prompt."""
prompts = await Prompt.objects.filter(conversation=conversation).order_by('created_at').alist() # prompts = (
print(f"prompts that we are seeding with are: {prompts}") # await Prompt.objects.filter(conversation=conversation)
return "\n".join( # .order_by("created_at")
f"{'User' if prompt.is_user else 'AI'}: {prompt.text}" # .alist()
for prompt in prompts # )
) # print(f"prompts that we are seeding with are: {prompts}")
# return "\n".join(
# f"{'User' if prompt.is_user else 'AI'}: {prompt.text}" for prompt in prompts
# )
return "\n".join([f"{"User" if prompt.type=="human" else "AI"}: {prompt.text()}" for prompt in conversation])
async def _retriever_with_history(self, input_dict: Dict[str, Any]) -> str: async def _retriever_with_history(self, input_dict: Dict[str, Any]) -> str:
"""Retrieve documents considering conversation history.""" """Retrieve documents considering conversation history."""
@@ -336,7 +328,7 @@ class AsyncRAGService(RAGService):
workspace = input_dict["workspace"] workspace = input_dict["workspace"]
# You could enhance this to consider historical context in retrieval # You could enhance this to consider historical context in retrieval
docs= await self.search_documents(query, workspace) docs = await self.search_documents(query, workspace)
if not docs: if not docs:
print("Didn't find any relevant docs") print("Didn't find any relevant docs")
@@ -344,34 +336,36 @@ class AsyncRAGService(RAGService):
print("\n\n".join(doc.page_content for doc in docs)) print("\n\n".join(doc.page_content for doc in docs))
return "\n\n".join(doc.page_content for doc in docs) return "\n\n".join(doc.page_content for doc in docs)
async def search_documents(
async def search_documents(self, query: str, workspace: Optional[DocumentWorkspace] = None, k: int = 4) -> List[Document]: self, query: str, workspace: Optional[DocumentWorkspace] = None, k: int = 4
) -> List[Document]:
"""Search relevant documents from the vector store.""" """Search relevant documents from the vector store."""
filter_dict = {} filter_dict = {}
print(f"Do we have a workspace: {workspace}") print(f"Do we have a workspace: {workspace}")
if workspace: if workspace:
filter_dict["workspace_id"] = workspace.id filter_dict["workspace_id"] = workspace.id
search_kwargs={ search_kwargs = {"k": k, "filter": filter_dict if filter_dict else None}
"k": k,
"filter": filter_dict if filter_dict else None
}
print(f"search_kwargs: {search_kwargs}") print(f"search_kwargs: {search_kwargs}")
retriever = self.vector_store.as_retriever( retriever = self.vector_store.as_retriever(
search_type="mmr", search_type="mmr",
search_kwargs={ search_kwargs={"k": k, "filter": filter_dict if filter_dict else None},
"k": k,
"filter": filter_dict if filter_dict else None
}
) )
return await retriever.aget_relevant_documents(query) return await retriever.aget_relevant_documents(query)
async def generate_response(self, conversation: Conversation, query: str, workspace: DocumentWorkspace, **kwargs) -> AsyncGenerator[str, None]: async def generate_response(
self,
conversation: Conversation,
query: str,
workspace: DocumentWorkspace,
**kwargs,
) -> AsyncGenerator[str, None]:
"""Generate response with streaming support.""" """Generate response with streaming support."""
chain_input = { chain_input = {
"query": query, "query": query,
"conversation": conversation, "conversation": conversation,
"workspace": workspace, "workspace": workspace,
"recent_conversation": await self._format_history(conversation),
} }
async for chunk in self.rag_chain.astream(chain_input): async for chunk in self.rag_chain.astream(chain_input):

View File

@@ -5,215 +5,247 @@ from typing import List, Dict, Any
from django.test import TestCase as DjangoTestCase from django.test import TestCase as DjangoTestCase
from chat_backend.services.rag_services import RAGService, SyncRAGService, AsyncRAGService from chat_backend.services.rag_services import (
RAGService,
SyncRAGService,
AsyncRAGService,
)
from chat_backend.models import Conversation, Prompt, DocumentWorkspace, Document from chat_backend.models import Conversation, Prompt, DocumentWorkspace, Document
from chat_backend.services.prompt_classifier import PromptClassifier, PromptType
from parameterized import parameterized
class TestRAGService(TestCase):
# class TestRAGService(TestCase):
# def setUp(self):
# self.rag_service = RAGService()
# self.rag_service.vector_store = MagicMock()
# self.rag_service.embedding_model = MagicMock()
# self.rag_service.text_splitter = MagicMock()
# def test_initialize_vector_store(self):
# with patch("os.path.exists", return_value=False), patch(
# "os.makedirs"
# ) as mock_makedirs, patch(
# "langchain_community.vectorstores.Chroma"
# ) as mock_chroma:
# # Reset the vector store to test initialization
# self.rag_service.vector_store = None
# result = self.rag_service._initialize_vector_store()
# mock_makedirs.assert_called_once_with("chroma_db")
# mock_chroma.assert_called_once_with(
# embedding_function=self.rag_service.embedding_model,
# persist_directory="chroma_db",
# )
# self.assertIsNotNone(result)
# def test_prepare_documents(self):
# mock_doc1 = MagicMock(spec=Document)
# mock_doc1.content = "Test content"
# mock_doc1.source = "test_source"
# mock_doc1.workspace = MagicMock()
# mock_doc1.workspace.id = 1
# mock_doc1.id = 1
# self.rag_service.text_splitter.split_text.return_value = ["chunk1", "chunk2"]
# result = self.rag_service._prepare_documents([mock_doc1])
# self.assertEqual(len(result), 2)
# self.rag_service.text_splitter.split_text.assert_called_once_with(
# "Test content"
# )
# self.assertEqual(result[0].page_content, "chunk1")
# self.assertEqual(result[0].metadata["source"], "test_source")
# def test_ingest_documents(self):
# mock_workspace = MagicMock()
# mock_document = MagicMock()
# mock_documents = [mock_document]
# with patch(
# "services.rag_services.Document.objects.filter", return_value=mock_documents
# ):
# self.rag_service._prepare_documents = MagicMock(
# return_value=["processed_doc"]
# )
# self.rag_service.ingest_documents(mock_workspace)
# self.rag_service.vector_store.add_documents.assert_called_once_with(
# ["processed_doc"]
# )
# self.rag_service.vector_store.persist.assert_called_once()
# class TestSyncRAGService(DjangoTestCase):
# def setUp(self):
# self.sync_service = SyncRAGService()
# self.sync_service.vector_store = MagicMock()
# self.sync_service.llm = MagicMock()
# self.sync_service.rag_chain = MagicMock()
# self.mock_conversation = MagicMock(spec=Conversation)
# self.mock_conversation.workspace = MagicMock()
# self.mock_prompt1 = MagicMock(spec=Prompt)
# self.mock_prompt1.is_user = True
# self.mock_prompt1.text = "User question"
# self.mock_prompt1.created_at = "2023-01-01"
# self.mock_prompt2 = MagicMock(spec=Prompt)
# self.mock_prompt2.is_user = False
# self.mock_prompt2.text = "AI response"
# self.mock_prompt2.created_at = "2023-01-02"
# def test_format_history(self):
# with patch("services.rag_services.Prompt.objects.filter") as mock_filter:
# mock_filter.return_value.order_by.return_value = [
# self.mock_prompt1,
# self.mock_prompt2,
# ]
# result = self.sync_service._format_history(self.mock_conversation)
# expected = "User: User question\nAI: AI response"
# self.assertEqual(result, expected)
# mock_filter.assert_called_once_with(conversation=self.mock_conversation)
# def test_retriever_with_history(self):
# input_dict = {"query": "test query", "conversation": self.mock_conversation}
# self.sync_service.search_documents = MagicMock(return_value=["doc1", "doc2"])
# result = self.sync_service._retriever_with_history(input_dict)
# self.sync_service.search_documents.assert_called_once_with(
# "test query", self.mock_conversation.workspace
# )
# self.assertEqual(result, ["doc1", "doc2"])
# def test_search_documents(self):
# mock_retriever = MagicMock()
# mock_retriever.get_relevant_documents.return_value = ["doc1", "doc2"]
# self.sync_service.vector_store.as_retriever.return_value = mock_retriever
# result = self.sync_service.search_documents(
# "test query", self.mock_conversation.workspace
# )
# self.sync_service.vector_store.as_retriever.assert_called_once_with(
# search_type="similarity",
# search_kwargs={
# "k": 4,
# "filter": {"workspace_id": self.mock_conversation.workspace.id},
# },
# )
# self.assertEqual(result, ["doc1", "doc2"])
# def test_generate_response(self):
# chain_input = {"query": "test query", "conversation": self.mock_conversation}
# mock_stream = ["chunk1", "chunk2", "chunk3"]
# self.sync_service.rag_chain.stream.return_value = mock_stream
# result = list(
# self.sync_service.generate_response(self.mock_conversation, "test query")
# )
# self.sync_service.rag_chain.stream.assert_called_once_with(chain_input)
# self.assertEqual(result, mock_stream)
# class TestAsyncRAGService(DjangoTestCase):
# def setUp(self):
# self.async_service = AsyncRAGService()
# self.async_service.vector_store = MagicMock()
# self.async_service.llm = MagicMock()
# self.async_service.rag_chain = AsyncMock()
# self.mock_conversation = MagicMock(spec=Conversation)
# self.mock_conversation.workspace = MagicMock()
# self.mock_prompt1 = MagicMock(spec=Prompt)
# self.mock_prompt1.is_user = True
# self.mock_prompt1.text = "User question"
# self.mock_prompt1.created_at = "2023-01-01"
# self.mock_prompt2 = MagicMock(spec=Prompt)
# self.mock_prompt2.is_user = False
# self.mock_prompt2.text = "AI response"
# self.mock_prompt2.created_at = "2023-01-02"
# async def test_format_history(self):
# mock_manager = AsyncMock()
# mock_manager.order_by.return_value.alist.return_value = [
# self.mock_prompt1,
# self.mock_prompt2,
# ]
# with patch(
# "services.rag_services.Prompt.objects.filter", return_value=mock_manager
# ):
# result = await self.async_service._format_history(self.mock_conversation)
# expected = "User: User question\nAI: AI response"
# self.assertEqual(result, expected)
# mock_manager.order_by.assert_called_once_with("created_at")
# async def test_retriever_with_history(self):
# input_dict = {"query": "test query", "conversation": self.mock_conversation}
# self.async_service.search_documents = AsyncMock(return_value=["doc1", "doc2"])
# result = await self.async_service._retriever_with_history(input_dict)
# self.async_service.search_documents.assert_awaited_once_with(
# "test query", self.mock_conversation.workspace
# )
# self.assertEqual(result, ["doc1", "doc2"])
# async def test_search_documents(self):
# mock_retriever = AsyncMock()
# mock_retriever.aget_relevant_documents.return_value = ["doc1", "doc2"]
# self.async_service.vector_store.as_retriever.return_value = mock_retriever
# result = await self.async_service.search_documents(
# "test query", self.mock_conversation.workspace
# )
# self.async_service.vector_store.as_retriever.assert_called_once_with(
# search_type="similarity",
# search_kwargs={
# "k": 4,
# "filter": {"workspace_id": self.mock_conversation.workspace.id},
# },
# )
# self.assertEqual(result, ["doc1", "doc2"])
# async def test_generate_response(self):
# chain_input = {"query": "test query", "conversation": self.mock_conversation}
# mock_stream = ["chunk1", "chunk2", "chunk3"]
# self.async_service.rag_chain.astream.return_value = mock_stream
# chunks = []
# async for chunk in self.async_service.generate_response(
# self.mock_conversation, "test query"
# ):
# chunks.append(chunk)
# self.async_service.rag_chain.astream.assert_awaited_once_with(chain_input)
# self.assertEqual(chunks, mock_stream)
class PromptClassifierTestCase(TestCase):
def setUp(self): def setUp(self):
self.rag_service = RAGService() self.service = PromptClassifier()
self.rag_service.vector_store = MagicMock()
self.rag_service.embedding_model = MagicMock()
self.rag_service.text_splitter = MagicMock()
def test_initialize_vector_store(self): @parameterized.expand([
with patch('os.path.exists', return_value=False), \ ["Tell me a joke",PromptType.GENERAL_CHAT],
patch('os.makedirs') as mock_makedirs, \ ["Create an image of a dog for me",PromptType.IMAGE_GENERATION],
patch('langchain_community.vectorstores.Chroma') as mock_chroma: ["highlight the features of the backyard playset if they were to choose us and make the language more long form",PromptType.GENERAL_CHAT],
])
# Reset the vector store to test initialization def test_prompt_classification(self, prompt, expected_output):
self.rag_service.vector_store = None result = self.service.classify(prompt)
result = self.rag_service._initialize_vector_store() self.assertEqual(result, expected_output)
mock_makedirs.assert_called_once_with("chroma_db")
mock_chroma.assert_called_once_with(
embedding_function=self.rag_service.embedding_model,
persist_directory="chroma_db"
)
self.assertIsNotNone(result)
def test_prepare_documents(self):
mock_doc1 = MagicMock(spec=Document)
mock_doc1.content = "Test content"
mock_doc1.source = "test_source"
mock_doc1.workspace = MagicMock()
mock_doc1.workspace.id = 1
mock_doc1.id = 1
self.rag_service.text_splitter.split_text.return_value = ["chunk1", "chunk2"]
result = self.rag_service._prepare_documents([mock_doc1])
self.assertEqual(len(result), 2)
self.rag_service.text_splitter.split_text.assert_called_once_with("Test content")
self.assertEqual(result[0].page_content, "chunk1")
self.assertEqual(result[0].metadata["source"], "test_source")
def test_ingest_documents(self):
mock_workspace = MagicMock()
mock_document = MagicMock()
mock_documents = [mock_document]
with patch('services.rag_services.Document.objects.filter', return_value=mock_documents):
self.rag_service._prepare_documents = MagicMock(return_value=["processed_doc"])
self.rag_service.ingest_documents(mock_workspace)
self.rag_service.vector_store.add_documents.assert_called_once_with(["processed_doc"])
self.rag_service.vector_store.persist.assert_called_once()
class TestSyncRAGService(DjangoTestCase):
def setUp(self):
self.sync_service = SyncRAGService()
self.sync_service.vector_store = MagicMock()
self.sync_service.llm = MagicMock()
self.sync_service.rag_chain = MagicMock()
self.mock_conversation = MagicMock(spec=Conversation)
self.mock_conversation.workspace = MagicMock()
self.mock_prompt1 = MagicMock(spec=Prompt)
self.mock_prompt1.is_user = True
self.mock_prompt1.text = "User question"
self.mock_prompt1.created_at = "2023-01-01"
self.mock_prompt2 = MagicMock(spec=Prompt)
self.mock_prompt2.is_user = False
self.mock_prompt2.text = "AI response"
self.mock_prompt2.created_at = "2023-01-02"
def test_format_history(self):
with patch('services.rag_services.Prompt.objects.filter') as mock_filter:
mock_filter.return_value.order_by.return_value = [self.mock_prompt1, self.mock_prompt2]
result = self.sync_service._format_history(self.mock_conversation)
expected = "User: User question\nAI: AI response"
self.assertEqual(result, expected)
mock_filter.assert_called_once_with(conversation=self.mock_conversation)
def test_retriever_with_history(self):
input_dict = {
"query": "test query",
"conversation": self.mock_conversation
}
self.sync_service.search_documents = MagicMock(return_value=["doc1", "doc2"])
result = self.sync_service._retriever_with_history(input_dict)
self.sync_service.search_documents.assert_called_once_with(
"test query",
self.mock_conversation.workspace
)
self.assertEqual(result, ["doc1", "doc2"])
def test_search_documents(self):
mock_retriever = MagicMock()
mock_retriever.get_relevant_documents.return_value = ["doc1", "doc2"]
self.sync_service.vector_store.as_retriever.return_value = mock_retriever
result = self.sync_service.search_documents("test query", self.mock_conversation.workspace)
self.sync_service.vector_store.as_retriever.assert_called_once_with(
search_type="similarity",
search_kwargs={
"k": 4,
"filter": {"workspace_id": self.mock_conversation.workspace.id}
}
)
self.assertEqual(result, ["doc1", "doc2"])
def test_generate_response(self):
chain_input = {
"query": "test query",
"conversation": self.mock_conversation
}
mock_stream = ["chunk1", "chunk2", "chunk3"]
self.sync_service.rag_chain.stream.return_value = mock_stream
result = list(self.sync_service.generate_response(self.mock_conversation, "test query"))
self.sync_service.rag_chain.stream.assert_called_once_with(chain_input)
self.assertEqual(result, mock_stream)
class TestAsyncRAGService(DjangoTestCase):
def setUp(self):
self.async_service = AsyncRAGService()
self.async_service.vector_store = MagicMock()
self.async_service.llm = MagicMock()
self.async_service.rag_chain = AsyncMock()
self.mock_conversation = MagicMock(spec=Conversation)
self.mock_conversation.workspace = MagicMock()
self.mock_prompt1 = MagicMock(spec=Prompt)
self.mock_prompt1.is_user = True
self.mock_prompt1.text = "User question"
self.mock_prompt1.created_at = "2023-01-01"
self.mock_prompt2 = MagicMock(spec=Prompt)
self.mock_prompt2.is_user = False
self.mock_prompt2.text = "AI response"
self.mock_prompt2.created_at = "2023-01-02"
async def test_format_history(self):
mock_manager = AsyncMock()
mock_manager.order_by.return_value.alist.return_value = [self.mock_prompt1, self.mock_prompt2]
with patch('services.rag_services.Prompt.objects.filter', return_value=mock_manager):
result = await self.async_service._format_history(self.mock_conversation)
expected = "User: User question\nAI: AI response"
self.assertEqual(result, expected)
mock_manager.order_by.assert_called_once_with('created_at')
async def test_retriever_with_history(self):
input_dict = {
"query": "test query",
"conversation": self.mock_conversation
}
self.async_service.search_documents = AsyncMock(return_value=["doc1", "doc2"])
result = await self.async_service._retriever_with_history(input_dict)
self.async_service.search_documents.assert_awaited_once_with(
"test query",
self.mock_conversation.workspace
)
self.assertEqual(result, ["doc1", "doc2"])
async def test_search_documents(self):
mock_retriever = AsyncMock()
mock_retriever.aget_relevant_documents.return_value = ["doc1", "doc2"]
self.async_service.vector_store.as_retriever.return_value = mock_retriever
result = await self.async_service.search_documents("test query", self.mock_conversation.workspace)
self.async_service.vector_store.as_retriever.assert_called_once_with(
search_type="similarity",
search_kwargs={
"k": 4,
"filter": {"workspace_id": self.mock_conversation.workspace.id}
}
)
self.assertEqual(result, ["doc1", "doc2"])
async def test_generate_response(self):
chain_input = {
"query": "test query",
"conversation": self.mock_conversation
}
mock_stream = ["chunk1", "chunk2", "chunk3"]
self.async_service.rag_chain.astream.return_value = mock_stream
chunks = []
async for chunk in self.async_service.generate_response(self.mock_conversation, "test query"):
chunks.append(chunk)
self.async_service.rag_chain.astream.assert_awaited_once_with(chain_input)
self.assertEqual(chunks, mock_stream)

View File

@@ -1,22 +1,28 @@
from langchain_core.prompts import ChatPromptTemplate from langchain_core.prompts import ChatPromptTemplate
from langchain_community.llms import Ollama
# from langchain_community.llms import Ollama
from langchain_ollama import OllamaLLM
from typing import Optional from typing import Optional
class TitleGenerator: class TitleGenerator:
""" """
Generates short, descriptive titles for conversations based on the first prompt. Generates short, descriptive titles for conversations based on the first prompt.
""" """
def __init__(self): def __init__(self):
self.llm = Ollama( self.llm = OllamaLLM(
model="llama3", model="llama3.2",
temperature=0.5, # Slightly creative but not too random temperature=0.5, # Slightly creative but not too random
top_k=20, top_k=20,
num_ctx=2048 # Shorter context needed for titles num_ctx=2048, # Shorter context needed for titles
) )
self.title_prompt = ChatPromptTemplate.from_messages([ self.title_prompt = ChatPromptTemplate.from_messages(
("system", """You are a conversation title generator. Create a very short (2-5 word) title based on the user's first message. [
(
"system",
"""You are a conversation title generator. Create a very short (2-5 word) title based on the user's first message.
Rules: Rules:
1. Keep it extremely concise 1. Keep it extremely concise
@@ -31,9 +37,11 @@ Examples:
- "Generate an image of a dragon""Dragon Image Generation" - "Generate an image of a dragon""Dragon Image Generation"
- "Find our company's privacy policy""Privacy Policy Search" - "Find our company's privacy policy""Privacy Policy Search"
Return ONLY the title, nothing else."""), Return ONLY the title, nothing else.""",
("human", "{prompt}") ),
]) ("human", "{prompt}"),
]
)
self.chain = self.title_prompt | self.llm self.chain = self.title_prompt | self.llm
@@ -58,7 +66,7 @@ Return ONLY the title, nothing else."""),
def _clean_response(self, response: str) -> str: def _clean_response(self, response: str) -> str:
"""Clean and format the LLM response""" """Clean and format the LLM response"""
# Remove any quotes or punctuation # Remove any quotes or punctuation
response = response.strip('"\'.!? \n\t') response = response.strip("\"'.!? \n\t")
# Ensure title case and trim # Ensure title case and trim
return response.title()[:50] # Hard limit for safety return response.title()[:50] # Hard limit for safety

View File

@@ -3,14 +3,16 @@ from django.dispatch import receiver
from chat_backend.models import Document from chat_backend.models import Document
from .services.rag_services import AsyncRAGService from .services.rag_services import AsyncRAGService
@receiver(post_save, sender=Document) @receiver(post_save, sender=Document)
def update_vector_on_save(sender, instance, **kwargs): def update_vector_on_save(sender, instance, **kwargs):
"""Update vector store when documents are saved""" """Update vector store when documents are saved"""
if kwargs.get('created', False): if kwargs.get("created", False):
rag_service = AsyncRAGService() rag_service = AsyncRAGService()
rag_service.ingest_documents() rag_service.ingest_documents()
@receiver(post_delete, sender=Document) @receiver(post_delete, sender=Document)
def delete_vector_on_remove(sender, instance, **kwargs): def delete_vector_on_remove(sender, instance, **kwargs):
"""Handle document deletion by re-indexing the whole workspace""" """Handle document deletion by re-indexing the whole workspace"""

View File

@@ -10,51 +10,51 @@ from .models import DocumentWorkspace, Document, Company
from django.contrib.auth import get_user_model from django.contrib.auth import get_user_model
import tempfile import tempfile
from django.core.files.uploadedfile import SimpleUploadedFile from django.core.files.uploadedfile import SimpleUploadedFile
from parameterized import parameterized
# Minimal valid PDF bytes # Minimal valid PDF bytes
VALID_PDF_BYTES = ( VALID_PDF_BYTES = (
b'%PDF-1.3\n' b"%PDF-1.3\n"
b'1 0 obj\n' b"1 0 obj\n"
b'<< /Type /Catalog /Pages 2 0 R >>\n' b"<< /Type /Catalog /Pages 2 0 R >>\n"
b'endobj\n' b"endobj\n"
b'2 0 obj\n' b"2 0 obj\n"
b'<< /Type /Pages /Kids [3 0 R] /Count 1 >>\n' b"<< /Type /Pages /Kids [3 0 R] /Count 1 >>\n"
b'endobj\n' b"endobj\n"
b'3 0 obj\n' b"3 0 obj\n"
b'<< /Type /Page /Parent 2 0 R /Resources << >> /MediaBox [0 0 612 792] /Contents 4 0 R >>\n' b"<< /Type /Page /Parent 2 0 R /Resources << >> /MediaBox [0 0 612 792] /Contents 4 0 R >>\n"
b'endobj\n' b"endobj\n"
b'4 0 obj\n' b"4 0 obj\n"
b'<< /Length 44 >>\n' b"<< /Length 44 >>\n"
b'stream\n' b"stream\n"
b'BT /F1 12 Tf 72 720 Td (Test PDF) Tj ET\n' b"BT /F1 12 Tf 72 720 Td (Test PDF) Tj ET\n"
b'endstream\n' b"endstream\n"
b'endobj\n' b"endobj\n"
b'xref\n' b"xref\n"
b'0 5\n' b"0 5\n"
b'0000000000 65535 f \n' b"0000000000 65535 f \n"
b'0000000009 00000 n \n' b"0000000009 00000 n \n"
b'0000000058 00000 n \n' b"0000000058 00000 n \n"
b'0000000117 00000 n \n' b"0000000117 00000 n \n"
b'0000000223 00000 n \n' b"0000000223 00000 n \n"
b'trailer\n' b"trailer\n"
b'<< /Size 5 /Root 1 0 R >>\n' b"<< /Size 5 /Root 1 0 R >>\n"
b'startxref\n' b"startxref\n"
b'317\n' b"317\n"
b'%%EOF' b"%%EOF"
) )
class DocumentWorkspaceViewsTestCase(APITestCase): class DocumentWorkspaceViewsTestCase(APITestCase):
def setUp(self): def setUp(self):
self.company = Company.objects.create( self.company = Company.objects.create(
name="test", name="test", state="IL", zipcode="60189", address="1968 Greensboro Dr"
state="IL",
zipcode="60189",
address="1968 Greensboro Dr"
) )
self.user = get_user_model().objects.create_user( self.user = get_user_model().objects.create_user(
company=self.company, company=self.company,
username='testuser', username="testuser",
password='testpass123', password="testpass123",
email="test@test.com", email="test@test.com",
) )
@@ -62,31 +62,28 @@ class DocumentWorkspaceViewsTestCase(APITestCase):
self.client.force_authenticate(user=self.user) self.client.force_authenticate(user=self.user)
self.workspace = DocumentWorkspace.objects.create( self.workspace = DocumentWorkspace.objects.create(
company = self.user.company, company=self.user.company, name="Test Workspace"
name='Test Workspace'
) )
def test_list_workspaces(self): def test_list_workspaces(self):
url = reverse('document_workspaces') url = reverse("document_workspaces")
response = self.client.get(url) response = self.client.get(url)
self.assertEqual(response.status_code, status.HTTP_200_OK) self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertEqual(len(response.data), 1) self.assertEqual(len(response.data), 1)
self.assertEqual(response.data[0]['name'], 'Test Workspace') self.assertEqual(response.data[0]["name"], "Test Workspace")
def test_create_workspace(self): def test_create_workspace(self):
url = reverse('document_workspaces') url = reverse("document_workspaces")
data = { data = {"name": "New Workspace"}
'name': 'New Workspace' response = self.client.post(url, data, format="json")
}
response = self.client.post(url, data, format='json')
self.assertEqual(response.status_code, status.HTTP_201_CREATED) self.assertEqual(response.status_code, status.HTTP_201_CREATED)
self.assertEqual(DocumentWorkspace.objects.count(), 2) self.assertEqual(DocumentWorkspace.objects.count(), 2)
def test_retrieve_workspace(self): def test_retrieve_workspace(self):
url = reverse('document_workspaces') url = reverse("document_workspaces")
response = self.client.get(url) response = self.client.get(url)
self.assertEqual(response.status_code, status.HTTP_200_OK) self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertEqual(response.data[0]['name'], 'Test Workspace') self.assertEqual(response.data[0]["name"], "Test Workspace")
# def test_update_workspace(self): # def test_update_workspace(self):
# url = reverse('document_workspaces') # url = reverse('document_workspaces')
@@ -104,18 +101,16 @@ class DocumentWorkspaceViewsTestCase(APITestCase):
# self.assertEqual(response.status_code, status.HTTP_204_NO_CONTENT) # self.assertEqual(response.status_code, status.HTTP_204_NO_CONTENT)
# self.assertEqual(DocumentWorkspace.objects.count(), 0) # self.assertEqual(DocumentWorkspace.objects.count(), 0)
class DocumentViewsTestCase(APITestCase): class DocumentViewsTestCase(APITestCase):
def setUp(self): def setUp(self):
self.company = Company.objects.create( self.company = Company.objects.create(
name="test", name="test", state="IL", zipcode="60189", address="1968 Greensboro Dr"
state="IL",
zipcode="60189",
address="1968 Greensboro Dr"
) )
self.user = get_user_model().objects.create_user( self.user = get_user_model().objects.create_user(
company=self.company, company=self.company,
username='testuser', username="testuser",
password='testpass123', password="testpass123",
email="test@test.com", email="test@test.com",
) )
@@ -123,23 +118,18 @@ class DocumentViewsTestCase(APITestCase):
self.client.force_authenticate(user=self.user) self.client.force_authenticate(user=self.user)
self.workspace = DocumentWorkspace.objects.create( self.workspace = DocumentWorkspace.objects.create(
company=self.user.company, company=self.user.company, name="Test Workspace"
name='Test Workspace'
) )
# Create a test file # Create a test file
self.test_file = SimpleUploadedFile( self.test_file = SimpleUploadedFile(
"test.pdf", "test.pdf", VALID_PDF_BYTES, content_type="application/pdf"
VALID_PDF_BYTES,
content_type="application/pdf"
) )
def test_upload_document(self): def test_upload_document(self):
url = reverse('documents') url = reverse("documents")
data = { data = {"file": self.test_file}
'file': self.test_file response = self.client.post(url, data, format="multipart")
}
response = self.client.post(url, data, format='multipart')
self.assertEqual(response.status_code, status.HTTP_201_CREATED) self.assertEqual(response.status_code, status.HTTP_201_CREATED)
self.assertEqual(Document.objects.count(), 1) self.assertEqual(Document.objects.count(), 1)
@@ -149,17 +139,14 @@ class DocumentViewsTestCase(APITestCase):
def test_list_documents(self): def test_list_documents(self):
# First create a document # First create a document
Document.objects.create( Document.objects.create(workspace=self.workspace, file=self.test_file)
workspace=self.workspace,
file=self.test_file
)
url = reverse('documents') url = reverse("documents")
response = self.client.get(url) response = self.client.get(url)
self.assertEqual(response.status_code, status.HTTP_200_OK) self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertEqual(len(response.data), 1) self.assertEqual(len(response.data), 1)
self.assertIn('test', response.data[0]['file']) self.assertIn("test", response.data[0]["file"])
self.assertIn('pdf', response.data[0]['file']) self.assertIn("pdf", response.data[0]["file"])
# def test_delete_document(self): # def test_delete_document(self):
# document = Document.objects.create( # document = Document.objects.create(
@@ -173,38 +160,35 @@ class DocumentViewsTestCase(APITestCase):
# self.assertEqual(Document.objects.count(), 0) # self.assertEqual(Document.objects.count(), 0)
def test_upload_invalid_file(self): def test_upload_invalid_file(self):
url = reverse('documents') url = reverse("documents")
data = { data = {"file": "not a file"}
'file': 'not a file' response = self.client.post(url, data, format="multipart")
}
response = self.client.post(url, data, format='multipart')
self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
def test_access_other_users_documents(self): def test_access_other_users_documents(self):
# Create another user # Create another user
other_company = Company.objects.create( other_company = Company.objects.create(
name="test2", name="test2", state="IL", zipcode="60189", address="1968 Greensboro Dr"
state="IL",
zipcode="60189",
address="1968 Greensboro Dr"
) )
other_user = get_user_model().objects.create_user( other_user = get_user_model().objects.create_user(
company=other_company, company=other_company,
username='otheruser', username="otheruser",
password='otherpass123', password="otherpass123",
email="testing2@test.com" email="testing2@test.com",
) )
other_workspace = DocumentWorkspace.objects.create( other_workspace = DocumentWorkspace.objects.create(
company = other_user.company, company=other_user.company, name="Other Workspace"
name='Other Workspace'
) )
other_document = Document.objects.create( other_document = Document.objects.create(
workspace=other_workspace, workspace=other_workspace, file=self.test_file
file=self.test_file
) )
# Try to access the other user's document # Try to access the other user's document
url = reverse('documents_details', kwargs={"document_id":other_document.id}) url = reverse("documents_details", kwargs={"document_id": other_document.id})
response = self.client.get(url) response = self.client.get(url)
self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND) self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND)

View File

@@ -23,8 +23,7 @@ from .views import (
reset_password, reset_password,
DocumentWorkspaceView, DocumentWorkspaceView,
DocumentUploadView, DocumentUploadView,
DocumentDetailView DocumentDetailView,
) )
from rest_framework.routers import DefaultRouter from rest_framework.routers import DefaultRouter
@@ -80,11 +79,16 @@ urlpatterns = [
name="analytics_company_usage", name="analytics_company_usage",
), ),
path("analytics/admin/", AdminAnalytics.as_view(), name="analytics_admin"), path("analytics/admin/", AdminAnalytics.as_view(), name="analytics_admin"),
# document urls # document urls
path("document_workspaces/", DocumentWorkspaceView.as_view(), name="document_workspaces"), path(
"document_workspaces/",
DocumentWorkspaceView.as_view(),
name="document_workspaces",
),
path("documents/", DocumentUploadView.as_view(), name="documents"), path("documents/", DocumentUploadView.as_view(), name="documents"),
path("documents_details/<int:document_id>", DocumentDetailView.as_view(), name="documents_details"), path(
"documents_details/<int:document_id>",
DocumentDetailView.as_view(),
name="documents_details",
),
] ]

View File

@@ -13,8 +13,9 @@ from .serializers import (
PromptSerializer, PromptSerializer,
FeedbackSerializer, FeedbackSerializer,
DocumentWorkspaceSerializer, DocumentWorkspaceSerializer,
DocumentSerializer DocumentSerializer,
) )
from rest_framework.views import APIView from rest_framework.views import APIView
from rest_framework.response import Response from rest_framework.response import Response
from .models import ( from .models import (
@@ -25,7 +26,7 @@ from .models import (
Feedback, Feedback,
PromptMetric, PromptMetric,
DocumentWorkspace, DocumentWorkspace,
Document Document,
) )
from django.views.decorators.cache import never_cache from django.views.decorators.cache import never_cache
from django.http import JsonResponse from django.http import JsonResponse
@@ -35,14 +36,15 @@ from asgiref.sync import sync_to_async, async_to_sync
from channels.generic.websocket import AsyncWebsocketConsumer from channels.generic.websocket import AsyncWebsocketConsumer
from langchain_ollama.llms import OllamaLLM from langchain_ollama.llms import OllamaLLM
from langchain_core.prompts import ChatPromptTemplate from langchain_core.prompts import ChatPromptTemplate
from langchain_core.messages import HumanMessage, SystemMessage from langchain_core.messages import HumanMessage, AIMessage
from langchain.chains import RetrievalQA from langchain_classic.chains import RetrievalQA
import re import re
import os import os
from django.conf import settings from django.conf import settings
import json import json
import base64 import base64
import pandas as pd import pandas as pd
import io
# For email support # For email support
from django.core.mail import EmailMultiAlternatives from django.core.mail import EmailMultiAlternatives
@@ -66,17 +68,23 @@ from .services.llm_service import AsyncLLMService
from .services.rag_services import AsyncRAGService from .services.rag_services import AsyncRAGService
from .services.title_generator import title_generator from .services.title_generator import title_generator
from .services.moderation_classifier import moderation_classifier, ModerationLabel from .services.moderation_classifier import moderation_classifier, ModerationLabel
from .services.prompt_classifier import prompt_classifier, PromptType from .services.prompt_classifier.prompt_classifier import PromptClassifier, PromptType
from .services.data_analysis_service import AsyncDataAnalysisService
from langchain.chains import create_retrieval_chain
from langchain.chains.combine_documents import create_stuff_documents_chain from langchain_classic.chains import create_retrieval_chain
from langchain_classic.chains.combine_documents import create_stuff_documents_chain
from langchain_ollama import ChatOllama from langchain_ollama import ChatOllama
CHANNEL_NAME: str = "llm_messages" import logging
MODEL_NAME: str = "llama3" logger = logging.getLogger(__name__)
CHANNEL_NAME: str = "llm_messages"
MODEL_NAME: str = "llama3.2"
# Create your views here. # Create your views here.
class CustomObtainTokenView(TokenObtainPairView): class CustomObtainTokenView(TokenObtainPairView):
permission_classes = (permissions.AllowAny,) permission_classes = (permissions.AllowAny,)
@@ -98,9 +106,9 @@ class CustomUserCreate(APIView):
def send_invite_email(slug, email_to_invite): def send_invite_email(slug, email_to_invite):
print("Sending invite email") logger.info("Sending invite email")
print(f"url : https://www.chat.aimloperations.com/set_password?slug={slug}") logger.info(f"url : https://chat.aimloperations.com/set_password?slug={slug}")
url = f"https://www.chat.aimloperations.com/set_password?slug={slug}" url = f"https://chat.aimloperations.com/set_password?slug={slug}"
subject = "Welcome to AI ML Operations, LLC Chat Services" subject = "Welcome to AI ML Operations, LLC Chat Services"
from_email = "ryan@aimloperations.com" from_email = "ryan@aimloperations.com"
to = email_to_invite to = email_to_invite
@@ -113,8 +121,24 @@ def send_invite_email(slug, email_to_invite):
msg.send(fail_silently=True) msg.send(fail_silently=True)
def send_password_reset_email(slug, email_to_invite):
logger.info("Sending reset email")
logger.info(f"url : https://www.chat.aimloperations.com/set_password?slug={slug}")
url = f"https://www.chat.aimloperations.com/set_password?slug={slug}"
subject = "Password reset for AI ML Operations, LLC Chat Services"
from_email = "ryan@aimloperations.com"
to = email_to_invite
d = {"url": url}
html_content = get_template(r"emails/reset_email.html").render(d)
text_content = get_template(r"emails/reset_email.txt").render(d)
msg = EmailMultiAlternatives(subject, text_content, from_email, [to])
msg.attach_alternative(html_content, "text/html")
msg.send(fail_silently=True)
def send_feedback_email(feedback_obj): def send_feedback_email(feedback_obj):
print("Sending feedback email") logger.info("Sending feedback email")
subject = "New Feedback for Chat by AI ML Operations, LLC" subject = "New Feedback for Chat by AI ML Operations, LLC"
from_email = "ryan@aimloperations.com" from_email = "ryan@aimloperations.com"
to = "ryan@aimloperations.com" to = "ryan@aimloperations.com"
@@ -128,7 +152,7 @@ def send_feedback_email(feedback_obj):
def send_password_reset_email(slug, email_to_invite): def send_password_reset_email(slug, email_to_invite):
print("Sending Password reset email") logger.info("Sending Password reset email")
url = f"https://www.chat.aimloperations.com/set_password?slug={slug}" url = f"https://www.chat.aimloperations.com/set_password?slug={slug}"
subject = "Password reset for Chat by AI ML Operations, LLC" subject = "Password reset for Chat by AI ML Operations, LLC"
from_email = "ryan@aimloperations.com" from_email = "ryan@aimloperations.com"
@@ -176,19 +200,22 @@ class CustomUserInvite(APIView):
return Response(status=status.HTTP_201_CREATED) return Response(status=status.HTTP_201_CREATED)
@csrf_exempt @csrf_exempt
def reset_password(request): def reset_password(request):
if request.method == "POST": if request.method == "POST":
data = json.loads(request.body) data = json.loads(request.body)
token = data.get('recaptchaToken') token = data.get("recaptchaToken")
payload = { payload = {
'secret': settings.CAPTCHA_SECRET_KEY, "secret": settings.CAPTCHA_SECRET_KEY,
'response': token, "response": token,
} }
response = requests.post('https://www.google.com/recaptcha/api/siteverify', data=payload) response = requests.post(
"https://www.google.com/recaptcha/api/siteverify", data=payload
)
result = response.json() result = response.json()
if result.get('success') and result.get('score') >= 0.5: if result.get("success") and result.get("score") >= 0.5:
email = data.get('email') email = data.get("email")
user = CustomUser.objects.filter(email=email).first() user = CustomUser.objects.filter(email=email).first()
if user: if user:
user.set_unusable_password() user.set_unusable_password()
@@ -198,9 +225,9 @@ def reset_password(request):
send_password_reset_email(user.slug, email) send_password_reset_email(user.slug, email)
JsonResponse(status=200) JsonResponse(status=200)
JsonResponse(status=400) JsonResponse(status=400)
class ResetUserPassword(APIView): class ResetUserPassword(APIView):
http_method_names = [ http_method_names = [
"post", "post",
@@ -213,15 +240,17 @@ class ResetUserPassword(APIView):
Send an email with a set password link to the set password page Send an email with a set password link to the set password page
Also disable the account Also disable the account
""" """
print(f"Password reset for requests. {request.data}") logger.info(f"Password reset for requests. {request.data}")
token = request.data.get('recaptchaToken') token = request.data.get("recaptchaToken")
payload = { payload = {
'secret': settings.CAPTCHA_SECRET_KEY, "secret": settings.CAPTCHA_SECRET_KEY,
'response': recaptchaToken, "response": recaptchaToken,
} }
response = requests.post('https://www.google.com/recaptcha/api/siteverify', data=payload) response = requests.post(
"https://www.google.com/recaptcha/api/siteverify", data=payload
)
result = response.json() result = response.json()
if result.get('success') and result.get('score') >= 0.5: if result.get("success") and result.get("score") >= 0.5:
user = CustomUser.objects.filter(email=email).first() user = CustomUser.objects.filter(email=email).first()
if user: if user:
user.set_unusable_password() user.set_unusable_password()
@@ -230,7 +259,7 @@ class ResetUserPassword(APIView):
# send the email # send the email
send_password_reset_email(user.slug, email) send_password_reset_email(user.slug, email)
else: else:
print('Captcha secret failed') logger.error("Captcha secret failed")
return Response(status=status.HTTP_200_OK) return Response(status=status.HTTP_200_OK)
@@ -261,10 +290,16 @@ class CustomUserGet(APIView):
email = request.user.email email = request.user.email
username = request.user.username username = request.user.username
user = CustomUser.objects.get(email=email) user = CustomUser.objects.filter(email=email).last()
logger.info(f"Getting the user: {user}")
try:
serializer = CustomUserSerializer(user) serializer = CustomUserSerializer(user)
logger.debug(f"serializer: {serializer}")
logger.debug(serializer.data)
return Response(serializer.data, status=status.HTTP_200_OK) return Response(serializer.data, status=status.HTTP_200_OK)
except Exception as e:
logger.error(f"Exception: {e}")
return Response({}, status=status.HTTP_400_BAD_REQUEST)
class FeedbackView(APIView): class FeedbackView(APIView):
@@ -272,7 +307,7 @@ class FeedbackView(APIView):
def post(self, request, format="json"): def post(self, request, format="json"):
serializer = FeedbackSerializer(data=request.data) serializer = FeedbackSerializer(data=request.data)
print(request.data) logger.debug(request.data)
if serializer.is_valid(): if serializer.is_valid():
feedback_obj = serializer.save() feedback_obj = serializer.save()
@@ -282,7 +317,7 @@ class FeedbackView(APIView):
send_feedback_email(feedback_obj) send_feedback_email(feedback_obj)
return Response(serializer.data, status=status.HTTP_201_CREATED) return Response(serializer.data, status=status.HTTP_201_CREATED)
else: else:
print(serializer.errors) logger.error(serializer.errors)
return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST) return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST)
def get(self, request, format="json"): def get(self, request, format="json"):
@@ -425,7 +460,7 @@ class ConversationDetailView(APIView):
return Response(serailzer.data, status=status.HTTP_200_OK) return Response(serailzer.data, status=status.HTTP_200_OK)
def post(self, request, format="json"): def post(self, request, format="json"):
print("In the post") logger.info("In the post")
# Add the prompt to the database # Add the prompt to the database
# make sure there is a conversation for it # make sure there is a conversation for it
# if there is not a conversation create a title for it # if there is not a conversation create a title for it
@@ -453,7 +488,7 @@ class ConversationDetailView(APIView):
prompt_instance = serializer.save() prompt_instance = serializer.save()
# set up the streaming response if it is from the user # set up the streaming response if it is from the user
print(f"Do we have a valid user? {is_user}") logger.info(f"Do we have a valid user? {is_user}")
if is_user: if is_user:
messages = [] messages = []
for prompt_obj in Prompt.objects.filter( for prompt_obj in Prompt.objects.filter(
@@ -467,12 +502,12 @@ class ConversationDetailView(APIView):
) )
channel_layer = get_channel_layer() channel_layer = get_channel_layer()
print(f"Sending to the channel: {CHANNEL_NAME}") logger.info(f"Sending to the channel: {CHANNEL_NAME}")
async_to_sync(channel_layer.group_send)( async_to_sync(channel_layer.group_send)(
CHANNEL_NAME, {"type": "receive", "content": messages} CHANNEL_NAME, {"type": "receive", "content": messages}
) )
except: except:
print( logger.error(
f"Error trying to submit to conversation_id: {conversation_id} with request.data: {request.data}" f"Error trying to submit to conversation_id: {conversation_id} with request.data: {request.data}"
) )
pass pass
@@ -689,274 +724,12 @@ llm = OllamaLLM(model=MODEL_NAME)
# chain = prompt | llm.with_config({"run_name": "model"}) | output_parser.with_config({"run_name": "Assistant"}) # chain = prompt | llm.with_config({"run_name": "model"}) | output_parser.with_config({"run_name": "Assistant"})
@database_sync_to_async
def create_conversation(prompt, email, title):
# return the conversation id
conversation = Conversation.objects.create(title=title)
conversation.save()
user = CustomUser.objects.get(email=email)
conversation.user_id = user.id
conversation.save()
return conversation.id
@database_sync_to_async
def get_workspace(conversation_id):
conversation = Conversation.objects.get(id=conversation_id)
return DocumentWorkspace.objects.get(company=conversation.user.company)
@database_sync_to_async
def get_messages(conversation_id, prompt, file_string: str = None, file_type: str = ""):
messages = []
conversation = Conversation.objects.get(id=conversation_id)
print(file_string)
# add the prompt to the conversation
serializer = PromptSerializer(
data={
"message": prompt,
"user_created": True,
"created": timezone.now(),
}
)
if serializer.is_valid(raise_exception=True):
prompt_instance = serializer.save()
prompt_instance.conversation_id = conversation.id
prompt_instance.save()
if file_string:
file_name = f"prompt_{prompt_instance.id}_data.{file_type}"
f = ContentFile(file_string, name=file_name)
prompt_instance.file.save(file_name, f)
prompt_instance.file_type = file_type
prompt_instance.save()
for prompt_obj in Prompt.objects.filter(conversation__id=conversation_id):
messages.append(
{
"content": prompt_obj.message,
"role": "user" if prompt_obj.user_created else "assistant",
"has_file": prompt_obj.file_exists(),
"file": prompt_obj.file if prompt_obj.file_exists() else None,
"file_type": prompt_obj.file_type if prompt_obj.file_exists() else None,
}
)
# now transform the messages
transformed_messages = []
for message in messages:
if message["has_file"] and message["file_type"] != None:
if "csv" in message["file_type"]:
file_type = "csv"
altered_message = f"{message['content']}\n The file type is csv and the file contents are: {message['file'].read()}"
elif "xlsx" in message["file_type"]:
file_type = "xlsx"
df = pd.read_excel(message["file"].read())
altered_message = f"{message['content']}\n The file type is xlsx and the file contents are: {df}"
elif "txt" in message["file_type"]:
file_type = "txt"
altered_message = f"{message['content']}\n The file type is csv and the file contents are: {message['file'].read()}"
else:
altered_message = message["content"]
transformed_message = (
SystemMessage(content=altered_message)
if message["role"] == "assistant"
else HumanMessage(content=altered_message)
)
transformed_messages.append(transformed_message)
return transformed_messages, prompt_instance
@database_sync_to_async
def save_generated_message(conversation_id, message):
conversation = Conversation.objects.get(id=conversation_id)
# add the prompt to the conversation
serializer = PromptSerializer(
data={
"message": message,
"user_created": False,
"created": timezone.now(),
}
)
if serializer.is_valid():
prompt_instance = serializer.save()
prompt_instance.conversation_id = conversation.id
prompt_instance = serializer.save()
@database_sync_to_async
def create_prompt_metric(
prompt_id, prompt, has_file, file_type, model_name, conversation_id
):
prompt_metric = PromptMetric.objects.create(
prompt_id=prompt_id,
start_time=timezone.now(),
prompt_length=len(prompt),
has_file=has_file,
file_type=file_type,
model_name=model_name,
conversation_id=conversation_id,
)
prompt_metric.save()
return prompt_metric
@database_sync_to_async
def update_prompt_metric(prompt_metric, status):
prompt_metric.event = status
prompt_metric.save()
@database_sync_to_async
def finish_prompt_metric(prompt_metric, response_length):
print(f"finish_prompt_metric: {response_length}")
prompt_metric.end_time = timezone.now()
prompt_metric.reponse_length = response_length
prompt_metric.event = "FINISHED"
prompt_metric.save(update_fields=["end_time", "reponse_length", "event"])
print("finish_prompt_metric saved")
@database_sync_to_async
def get_retriever(conversation_id):
print(f'getting workspace from conversation: {conversation_id}')
conversation = Conversation.objects.get(id=conversation_id)
print(f'Got conversation: {conversation}')
workspace = DocumentWorkspace.objects.get(company=conversation.user.company)
print(f'Got workspace: {conversation}')
vectorstore = Chroma(
persist_directory=f"./chroma_db/",
embedding=OllamaEmbeddings(model="llama3.2"),
)
return vectorstore.as_retriever()
class ChatConsumerAgain(AsyncWebsocketConsumer):
async def connect(self):
await self.accept()
async def disconnect(self, close_code):
await self.close()
async def receive(self, text_data=None, bytes_data=None):
print(f"Text Data: {text_data}")
print(f"Bytes Data: {bytes_data}")
if text_data:
data = json.loads(text_data)
message = data.get("message", None)
conversation_id = data.get("conversation_id", None)
email = data.get("email", None)
file = data.get("file", None)
file_type = data.get("fileType", "")
model = data.get("modelName", "Turbo")
if not conversation_id:
# we need to create a new conversation
# we will generate a name for it too
title = await title_generator.generate_async(message)
conversation_id = await create_conversation(message, email, title)
if conversation_id:
decoded_file = None
if file:
decoded_file = base64.b64decode(file)
print(decoded_file)
if "csv" in file_type:
file_type = "csv"
altered_message = f"{message}\n The file type is csv and the file contents are: {decoded_file}"
elif "xmlformats-officedocument" in file_type:
file_type = "xlsx"
df = pd.read_excel(decoded_file)
altered_message = f"{message}\n The file type is xlsx and the file contents are: {df}"
elif "text" in file_type:
file_type = "txt"
altered_message = f"{message}\n The file type is txt and the file contents are: {decoded_file}"
else:
file_type = "Not Sure"
print(f'received: "{message}" for conversation {conversation_id}')
# check the moderation here
if await moderation_classifier.classify_async(message) == ModerationLabel.NSFW:
response = "Prompt has been marked as NSFW. If this is in error, submit a feedback with the prompt text."
print("this prompt has been marked as NSFW")
await self.send("CONVERSATION_ID")
await self.send(str(conversation_id))
await self.send("START_OF_THE_STREAM_ENDER_GAME_42")
await self.send(response)
await self.send("END_OF_THE_STREAM_ENDER_GAME_42")
await save_generated_message(conversation_id, response)
return
# TODO: add the message to the database
# get the new conversation
# TODO: get the messages here
messages, prompt = await get_messages(
conversation_id, message, decoded_file, file_type
)
prompt_type = await prompt_classifier.classify_async(message)
print(f"prompt_type: {prompt_type} for {message}")
prompt_metric = await create_prompt_metric(
prompt.id,
prompt.message,
True if file else False,
file_type,
MODEL_NAME,
conversation_id,
)
if file:
# udpate with the altered_message
messages = messages[:-1] + [HumanMessage(content=altered_message)]
print(messages)
# send it to the LLM
# stream the response back
response = ""
# start of the message
await self.send("CONVERSATION_ID")
await self.send(str(conversation_id))
await self.send("START_OF_THE_STREAM_ENDER_GAME_42")
if prompt_type == PromptType.RAG:
service = AsyncRAGService()
#await service.ingest_documents()
workspace = await get_workspace(conversation_id)
print('Time to get the rag response')
async for chunk in service.generate_response(messages, prompt.message, workspace):
print(f"chunk: {chunk}")
response += chunk
await self.send(chunk)
elif prompt_type == PromptType.IMAGE_GENERATION:
response = "Image Generation is not supported at this time, but it will be soon."
await self.send(response)
else:
service = AsyncLLMService()
async for chunk in service.generate_response(messages, prompt.message):
print(f"chunk: {chunk}")
response += chunk
await self.send(chunk)
await self.send("END_OF_THE_STREAM_ENDER_GAME_42")
await save_generated_message(conversation_id, response)
await finish_prompt_metric(prompt_metric, len(response))
if bytes_data:
print("we have byte data")
# Document Views # Document Views
class DocumentWorkspaceView(APIView): class DocumentWorkspaceView(APIView):
#permission_classes = [permissions.IsAuthenticated] # permission_classes = [permissions.IsAuthenticated]
def get(self, request): def get(self, request):
workspaces = DocumentWorkspace.objects.filter(company=request.user.company) workspaces = DocumentWorkspace.objects.filter(company=request.user.company)
@@ -970,47 +743,52 @@ class DocumentWorkspaceView(APIView):
return Response(serializer.data, status=status.HTTP_201_CREATED) return Response(serializer.data, status=status.HTTP_201_CREATED)
return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST) return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST)
class DocumentUploadView(APIView): class DocumentUploadView(APIView):
#permission_classes = [permissions.IsAuthenticated]Z # permission_classes = [permissions.IsAuthenticated]Z
def get(self, request): def get(self, request):
print(f'request_3: {request}') logger.debug(f"request_3: {request}")
try: try:
workspace = DocumentWorkspace.objects.get(company=request.user.company) workspace = DocumentWorkspace.objects.get(company=request.user.company)
serializer = DocumentSerializer(Document.objects.filter(workspace=workspace), many=True) serializer = DocumentSerializer(
Document.objects.filter(workspace=workspace), many=True
)
return Response(serializer.data, status=status.HTTP_200_OK) return Response(serializer.data, status=status.HTTP_200_OK)
except: except:
return Response({'error': "Workspace not found"}, status=status.HTTP_404_NOT_FOUND) return Response(
{"error": "Workspace not found"}, status=status.HTTP_404_NOT_FOUND
)
def post(self, request): def post(self, request):
print(f'request: {request}') logger.debug(f"request: {request}")
try: try:
workspace = DocumentWorkspace.objects.get(company=request.user.company) workspace = DocumentWorkspace.objects.get(company=request.user.company)
except: except:
return Response({'error': "Workspace not found"}, status=status.HTTP_404_NOT_FOUND) return Response(
{"error": "Workspace not found"}, status=status.HTTP_404_NOT_FOUND
print(request.FILES)
file = request.FILES.get('file')
if not file:
return Response({"error":"No file provided"}, status=status.HTTP_400_BAD_REQUEST)
print("have the workspace and the file")
document = Document.objects.create(
workspace=workspace,
file=file
) )
logger.info(request.FILES)
file = request.FILES.get("file")
if not file:
return Response(
{"error": "No file provided"}, status=status.HTTP_400_BAD_REQUEST
)
logger.info("have the workspace and the file")
document = Document.objects.create(workspace=workspace, file=file)
# process the document inthe background # process the document inthe background
self.process_document(document) self.process_document(document)
serializer = DocumentSerializer(document) serializer = DocumentSerializer(document)
return Response(serializer.data, status=status.HTTP_201_CREATED) return Response(serializer.data, status=status.HTTP_201_CREATED)
def process_document(self, document): def process_document(self, document):
file_path = os.path.join(settings.MEDIA_ROOT, document.file.name) file_path = os.path.join(settings.MEDIA_ROOT, document.file.name)
@@ -1018,22 +796,25 @@ class DocumentUploadView(APIView):
document.active = True document.active = True
document.save() document.save()
service = AsyncRAGService() service = AsyncRAGService()
service.add_files_to_store([(file_path, document.file.name, document.workspace_id)], workspace_id=document.workspace_id) service.add_files_to_store(
[(file_path, document.file.name, document.workspace_id)],
workspace_id=document.workspace_id,
)
class DocumentDetailView(APIView): class DocumentDetailView(APIView):
#permission_classes = [permissions.IsAuthenticated] # permission_classes = [permissions.IsAuthenticated]
def get(self, request, document_id): def get(self, request, document_id):
print(f'request: {request}') logger.info(f"request: {request}")
try: try:
workspace = DocumentWorkspace.objects.get(company=request.user.company) workspace = DocumentWorkspace.objects.get(company=request.user.company)
document = Document.objects.get( document = Document.objects.get(workspace=workspace, id=document_id)
workspace=workspace,
id=document_id
)
except: except:
return Response({'error': "Document not found"}, status=status.HTTP_404_NOT_FOUND) return Response(
{"error": "Document not found"}, status=status.HTTP_404_NOT_FOUND
)
serializer = DocumentWorkspaceSerializer(workspaces, many=True) serializer = DocumentWorkspaceSerializer(workspaces, many=True)
return Response(serializer.data) return Response(serializer.data)

View File

@@ -9,13 +9,18 @@ https://docs.djangoproject.com/en/3.2/howto/deployment/asgi/
import os import os
import django
from django.core.asgi import get_asgi_application from django.core.asgi import get_asgi_application
os.environ.setdefault("DJANGO_SETTINGS_MODULE", "llm_be.settings")
django.setup()
from channels.routing import ProtocolTypeRouter, URLRouter from channels.routing import ProtocolTypeRouter, URLRouter
from channels.auth import AuthMiddlewareStack from channels.auth import AuthMiddlewareStack
import chat_backend.routing import chat_backend.routing
os.environ.setdefault("DJANGO_SETTINGS_MODULE", "llm_be.settings")
application = ProtocolTypeRouter( application = ProtocolTypeRouter(
{ {
"http": get_asgi_application(), "http": get_asgi_application(),

View File

@@ -25,7 +25,7 @@ BASE_DIR = Path(__file__).resolve().parent.parent
SECRET_KEY = "django-insecure-6suk6fj5q2)1tj%)f(wgw1smnliv5-#&@zvgvj1wp#(#@h#31x" SECRET_KEY = "django-insecure-6suk6fj5q2)1tj%)f(wgw1smnliv5-#&@zvgvj1wp#(#@h#31x"
# SECURITY WARNING: don't run with debug turned on in production! # SECURITY WARNING: don't run with debug turned on in production!
DEBUG = True DEBUG = False
CORS_ALLOW_CREDENTIALS = False CORS_ALLOW_CREDENTIALS = False
ALLOWED_HOSTS = [ ALLOWED_HOSTS = [
"*.aimloperations.com", "*.aimloperations.com",
@@ -161,7 +161,7 @@ REST_FRAMEWORK = {
} }
SIMPLE_JWT = { SIMPLE_JWT = {
"ACCESS_TOKEN_LIFETIME": timedelta(hours=24), "ACCESS_TOKEN_LIFETIME": timedelta(hours=5),
"REFRESH_TOKEN_LIFETIME": timedelta(days=14), "REFRESH_TOKEN_LIFETIME": timedelta(days=14),
"ROTATE_REFRESH_TOKENS": True, "ROTATE_REFRESH_TOKENS": True,
"BLACKLIST_AFTER_ROTATION": True, "BLACKLIST_AFTER_ROTATION": True,
@@ -208,3 +208,52 @@ EMAIL_USE_TLS = True
# Captcha # Captcha
CAPTCHA_SECRET_KEY = "6LfENu4qAAAAABdrj6JTviq-LfdPP5imhE-Os7h9" CAPTCHA_SECRET_KEY = "6LfENu4qAAAAABdrj6JTviq-LfdPP5imhE-Os7h9"
directory_path = 'logs'
# LOGGING = {
# 'version': 1,
# 'disable_existing_loggers': False,
# 'formatters': {
# 'verbose': {
# 'format': '{levelname} {asctime} {module} {process:d} {thread:d} {message}',
# 'style': '{',
# },
# 'simple': {
# 'format': '{levelname} {message}',
# 'style': '{',
# },
# },
# 'handlers': {
# 'console': {
# 'level': 'INFO',
# 'class': 'logging.StreamHandler',
# 'formatter': 'simple',
# },
# 'file': {
# 'level': 'DEBUG',
# 'class': 'logging.handlers.RotatingFileHandler',
# 'filename': f'{directory_path}/django.log',
# 'maxBytes': 1024 * 1024 * 5, # 5 MB
# 'backupCount': 5,
# 'formatter': 'verbose',
# },
# },
# 'loggers': {
# 'django': {
# 'handlers': ['console', 'file'],
# 'level': 'INFO',
# 'propagate': True,
# },
# 'my_app': {
# 'handlers': ['console', 'file'],
# 'level': 'DEBUG',
# 'propagate': False,
# },
# },
# }
os.makedirs(directory_path, exist_ok=True)
# Feature Flags
ALLOW_IMAGE_GENERATION = False
ALLOW_INTERNET_ACCESS = True

View File

@@ -0,0 +1,8 @@
{% extends "admin/base_site.html" %}
{% block extrahead %}
{{ block.super }}
{% if not debug %}
<script async defer src="https://tianji.aimloperations.com/tracker.js" data-website-id="cm7x7mrcy03kfddsw2jyejzub"></script>
{% endif %}
{% endblock %}

View File

@@ -1,231 +1,265 @@
aiofiles==24.1.0 accelerate==1.12.0
aiofiles==25.1.0
aiohappyeyeballs==2.6.1 aiohappyeyeballs==2.6.1
aiohttp==3.11.18 aiohttp==3.13.2
aiosignal==1.3.2 aiosignal==1.4.0
annotated-doc==0.0.4
annotated-types==0.7.0 annotated-types==0.7.0
antlr4-python3-runtime==4.9.3 antlr4-python3-runtime==4.13.2
anyio==4.8.0 anyio==4.12.0
asgiref==3.8.1 asgiref==3.11.0
astor==0.8.1 astor==0.8.1
attrs==25.1.0 attrs==25.4.0
autobahn==24.4.2 autobahn==25.11.1
Automat==24.8.1 Automat==25.4.16
backoff==2.2.1 backoff==2.2.1
bcrypt==4.3.0 bcrypt==5.0.0
beautifulsoup4==4.13.4 beautifulsoup4==4.14.3
black==25.1.0 black==25.11.0
build==1.2.2.post1 brotli==1.2.0
build==1.3.0
cachetools==5.5.2 cachetools==5.5.2
certifi==2025.1.31 cbor2==5.7.1
cffi==1.17.1 certifi==2025.11.12
channels==4.2.0 cffi==2.0.0
channels==4.3.2
chardet==5.2.0 chardet==5.2.0
charset-normalizer==3.4.1 charset-normalizer==3.4.4
chroma-hnswlib==0.7.6 chroma-hnswlib==0.7.6
chromadb==1.0.7 chromadb==1.3.5
click==8.1.8 click==8.3.1
coloredlogs==15.0.1 coloredlogs==15.0.1
constantly==23.10.4 constantly==23.10.4
contourpy==1.3.1 contourpy==1.3.3
cryptography==44.0.2 cryptography==46.0.3
cycler==0.12.1 cycler==0.12.1
daphne==4.1.2 daphne==4.2.1
dataclasses-json==0.6.7 dataclasses-json==0.6.7
Deprecated==1.2.18 ddgs==9.9.3
Deprecated==1.3.1
distro==1.9.0 distro==1.9.0
Django==5.1.7 Django==6.0
django-autoslug==1.9.9 django-autoslug==1.9.9
django-cors-headers==4.7.0 django-cors-headers==4.9.0
django-filter==25.1 django-filter==25.2
djangorestframework==3.15.2 djangorestframework==3.16.1
djangorestframework_simplejwt==5.5.0 djangorestframework_simplejwt==5.5.1
duckdb==1.2.1 duckdb==1.4.2
durationpy==0.9 durationpy==0.10
effdet==0.4.1 effdet==0.4.1
emoji==2.14.1 emoji==2.15.0
eval_type_backport==0.2.2 et_xmlfile==2.0.0
Faker==37.0.0 eval_type_backport==0.3.1
fastapi==0.115.9 fake-useragent==2.2.0
filelock==3.17.0 Faker==38.2.0
fastapi==0.124.0
filelock==3.20.0
filetype==1.2.0 filetype==1.2.0
flatbuffers==25.2.10 flatbuffers==25.9.23
fonttools==4.56.0 fonttools==4.61.0
frozenlist==1.6.0 frozenlist==1.8.0
fsspec==2025.2.0 fsspec==2025.12.0
google-api-core==2.24.2 google-api-core==2.28.1
google-auth==2.39.0 google-auth==2.43.0
google-cloud-vision==3.10.1 google-cloud-vision==3.11.0
googleapis-common-protos==1.70.0 googleapis-common-protos==1.72.0
greenlet==3.1.1 greenlet==3.3.0
grpcio==1.72.0rc1 grpcio==1.76.0
grpcio-status==1.72.0rc1 grpcio-status==1.76.0
h11==0.14.0 h11==0.16.0
h2==4.3.0
hf-xet==1.2.0
hpack==4.1.0
html5lib==1.1 html5lib==1.1
httpcore==1.0.7 httpcore==1.0.9
httptools==0.6.4 httptools==0.7.1
httpx==0.28.1 httpx==0.28.1
httpx-sse==0.4.0 httpx-sse==0.4.3
huggingface-hub==0.30.2 huggingface-hub==0.36.0
humanfriendly==10.0 humanfriendly==10.0
hyperframe==6.1.0
hyperlink==21.0.0 hyperlink==21.0.0
idna==3.10 idna==3.11
importlib_metadata==8.6.1 importlib_metadata==8.7.0
importlib_resources==6.5.2 importlib_resources==6.5.2
incremental==24.7.2 Incremental==24.11.0
Jinja2==3.1.6 Jinja2==3.1.6
jiter==0.8.2 jiter==0.12.0
joblib==1.4.2 joblib==1.5.2
jsonpatch==1.33 jsonpatch==1.33
jsonpointer==3.0.0 jsonpointer==3.0.0
jsonschema==4.23.0 jsonschema==4.25.1
jsonschema-specifications==2025.4.1 jsonschema-specifications==2025.9.1
kiwisolver==1.4.8 kiwisolver==1.4.9
kubernetes==32.0.1 kubernetes==34.1.0
langchain==0.3.24 langchain==1.1.2
langchain-community==0.3.23 langchain-chroma==1.0.0
langchain-core==0.3.56 langchain-classic==1.0.0
langchain-ollama==0.2.3 langchain-community==0.4.1
langchain-text-splitters==0.3.8 langchain-core==1.1.1
langchain-ollama==1.0.0
langchain-text-splitters==1.0.0
langdetect==1.0.9 langdetect==1.0.9
langsmith==0.3.13 langgraph==1.0.4
lxml==5.4.0 langgraph-checkpoint==3.0.1
Markdown==3.7 langgraph-prebuilt==1.0.5
markdown-it-py==3.0.0 langgraph-sdk==0.2.14
MarkupSafe==3.0.2 langsmith==0.4.56
lxml==6.0.2
Markdown==3.10
markdown-it-py==4.0.0
MarkupSafe==3.0.3
marshmallow==3.26.1 marshmallow==3.26.1
matplotlib==3.10.1 matplotlib==3.10.7
mdurl==0.1.2 mdurl==0.1.2
mmh3==5.1.0 ml_dtypes==0.5.4
mmh3==5.2.0
mpmath==1.3.0 mpmath==1.3.0
multidict==6.4.3 msgpack==1.1.2
mypy-extensions==1.0.0 multidict==6.7.0
mypy_extensions==1.1.0
nest-asyncio==1.6.0 nest-asyncio==1.6.0
networkx==3.4.2 networkx==3.6
nltk==3.9.1 nltk==3.9.2
numpy==2.2.3 numpy==2.2.6
nvidia-cublas-cu12==12.6.4.1 nvidia-cublas-cu12==12.8.4.1
nvidia-cuda-cupti-cu12==12.6.80 nvidia-cuda-cupti-cu12==12.8.90
nvidia-cuda-nvrtc-cu12==12.6.77 nvidia-cuda-nvrtc-cu12==12.8.93
nvidia-cuda-runtime-cu12==12.6.77 nvidia-cuda-runtime-cu12==12.8.90
nvidia-cudnn-cu12==9.5.1.17 nvidia-cudnn-cu12==9.10.2.21
nvidia-cufft-cu12==11.3.0.4 nvidia-cufft-cu12==11.3.3.83
nvidia-cufile-cu12==1.11.1.6 nvidia-cufile-cu12==1.13.1.3
nvidia-curand-cu12==10.3.7.77 nvidia-curand-cu12==10.3.9.90
nvidia-cusolver-cu12==11.7.1.2 nvidia-cusolver-cu12==11.7.3.90
nvidia-cusparse-cu12==12.5.4.2 nvidia-cusparse-cu12==12.5.8.93
nvidia-cusparselt-cu12==0.6.3 nvidia-cusparselt-cu12==0.7.1
nvidia-nccl-cu12==2.26.2 nvidia-nccl-cu12==2.27.5
nvidia-nvjitlink-cu12==12.6.85 nvidia-nvjitlink-cu12==12.8.93
nvidia-nvtx-cu12==12.6.77 nvidia-nvshmem-cu12==3.3.20
oauthlib==3.2.2 nvidia-nvtx-cu12==12.8.90
oauthlib==3.3.1
olefile==0.47 olefile==0.47
ollama==0.4.7 ollama==0.6.1
omegaconf==2.3.0 omegaconf==2.3.0
onnx==1.18.0 onnx==1.20.0
onnxruntime==1.21.1 onnxruntime==1.23.2
openai==1.65.4 openai==2.9.0
opencv-python==4.11.0.86 opencv-python==4.12.0.88
opentelemetry-api==1.32.1 openpyxl==3.1.5
opentelemetry-exporter-otlp-proto-common==1.32.1 opentelemetry-api==1.39.0
opentelemetry-exporter-otlp-proto-grpc==1.32.1 opentelemetry-exporter-otlp-proto-common==1.39.0
opentelemetry-exporter-otlp-proto-grpc==1.39.0
opentelemetry-instrumentation==0.53b1 opentelemetry-instrumentation==0.53b1
opentelemetry-instrumentation-asgi==0.53b1 opentelemetry-instrumentation-asgi==0.53b1
opentelemetry-instrumentation-fastapi==0.53b1 opentelemetry-instrumentation-fastapi==0.53b1
opentelemetry-proto==1.32.1 opentelemetry-proto==1.39.0
opentelemetry-sdk==1.32.1 opentelemetry-sdk==1.39.0
opentelemetry-semantic-conventions==0.53b1 opentelemetry-semantic-conventions==0.60b0
opentelemetry-util-http==0.53b1 opentelemetry-util-http==0.53b1
orjson==3.10.15 orjson==3.11.5
ormsgpack==1.12.0
overrides==7.7.0 overrides==7.7.0
packaging==24.2 packaging==25.0
pandas==2.2.3 pandas==2.3.3
pandasai==2.4.2 pandasai==2.4.2
parameterized==0.9.0
pathspec==0.12.1 pathspec==0.12.1
pdf2image==1.17.0 pdf2image==1.17.0
pdfminer.six==20250506 pdfminer.six==20251107
pi_heif==0.22.0 pi==0.1.2
pikepdf==9.7.0 pi_heif==1.1.1
pillow==11.1.0 pikepdf==10.0.2
platformdirs==4.3.6 pillow==12.0.0
posthog==4.0.1 platformdirs==4.5.1
propcache==0.3.1 posthog==5.4.0
primp==0.15.0
propcache==0.4.1
proto-plus==1.26.1 proto-plus==1.26.1
protobuf==6.31.0rc2 protobuf==6.33.2
psutil==7.0.0 psutil==7.1.3
py-ubjson==0.16.1
pyasn1==0.6.1 pyasn1==0.6.1
pyasn1_modules==0.4.1 pyasn1_modules==0.4.2
pycocotools==2.0.8 pybase64==1.4.3
pycparser==2.22 pycocotools==2.0.10
pydantic==2.11.4 pycparser==2.23
pydantic-settings==2.9.1 pydantic==2.12.5
pydantic_core==2.33.2 pydantic-settings==2.12.0
Pygments==2.19.1 pydantic_core==2.41.5
Pygments==2.19.2
PyJWT==2.10.1 PyJWT==2.10.1
pyOpenSSL==25.0.0 pyOpenSSL==25.3.0
pyparsing==3.2.1 pyparsing==3.2.5
pypdf==5.4.0 pypdf==6.4.0
pypdfium2==4.30.1 pypdfium2==5.1.0
PyPika==0.48.9 PyPika==0.48.9
pyproject_hooks==1.2.0 pyproject_hooks==1.2.0
python-dateutil==2.9.0.post0 python-dateutil==2.9.0.post0
python-dotenv==1.0.1 python-docx==1.2.0
python-iso639==2025.2.18 python-dotenv==1.2.1
python-iso639==2025.11.16
python-magic==0.4.27 python-magic==0.4.27
python-multipart==0.0.20 python-multipart==0.0.20
python-oxmsg==0.0.2 python-oxmsg==0.0.2
pytz==2025.1 pytokens==0.3.0
PyYAML==6.0.2 pytz==2025.2
RapidFuzz==3.13.0 PyYAML==6.0.3
referencing==0.36.2 RapidFuzz==3.14.3
regex==2024.11.6 referencing==0.37.0
requests==2.32.3 regex==2025.11.3
requests==2.32.5
requests-oauthlib==2.0.0 requests-oauthlib==2.0.0
requests-toolbelt==1.0.0 requests-toolbelt==1.0.0
rich==14.0.0 rich==14.2.0
rpds-py==0.24.0 rpds-py==0.30.0
rsa==4.9.1 rsa==4.9.1
safetensors==0.5.3 safetensors==0.7.0
scipy==1.15.2 scipy==1.16.3
service-identity==24.2.0 service-identity==24.2.0
setuptools==75.8.2 setuptools==80.9.0
shellingham==1.5.4 shellingham==1.5.4
six==1.17.0 six==1.17.0
sniffio==1.3.1 sniffio==1.3.1
soupsieve==2.7 socksio==1.0.0
SQLAlchemy==2.0.38 soupsieve==2.8
sqlglot==26.9.0 SQLAlchemy==2.0.44
sqlglotrs==0.4.0 sqlglot==28.1.0
sqlparse==0.5.3 sqlglotrs==0.8.0
starlette==0.45.3 sqlparse==0.5.4
starlette==0.50.0
sympy==1.14.0 sympy==1.14.0
tenacity==9.0.0 tenacity==9.1.2
timm==1.0.15 timm==1.0.22
tokenizers==0.21.1 tokenizers==0.22.1
torch==2.7.0 torch==2.9.1
torchvision==0.22.0 torchvision==0.24.1
tqdm==4.67.1 tqdm==4.67.1
transformers==4.51.3 transformers==4.57.3
triton==3.3.0 triton==3.5.1
Twisted==24.11.0 Twisted==25.5.0
txaio==23.1.1 txaio==25.12.1
typer==0.15.3 typer==0.20.0
typer-slim==0.20.0
typing-inspect==0.9.0 typing-inspect==0.9.0
typing-inspection==0.4.0 typing-inspection==0.4.2
typing_extensions==4.12.2 typing_extensions==4.15.0
tzdata==2025.1 tzdata==2025.2
unstructured==0.17.2 ujson==5.11.0
unstructured-client==0.34.0 unstructured==0.18.21
unstructured-inference==0.8.10 unstructured-client==0.42.4
unstructured.pytesseract==0.3.15 unstructured.pytesseract==0.3.15
unstructured_inference==1.1.2
urllib3==2.3.0 urllib3==2.3.0
uvicorn==0.34.2 uuid_utils==0.12.0
uvloop==0.21.0 uvicorn==0.38.0
watchfiles==1.0.5 uvloop==0.22.1
watchfiles==1.1.1
webencodings==0.5.1 webencodings==0.5.1
websocket-client==1.8.0 websocket-client==1.9.0
websockets==15.0.1 websockets==15.0.1
wrapt==1.17.2 wrapt==2.0.1
yarl==1.20.0 xxhash==3.6.0
zipp==3.21.0 yarl==1.22.0
zope.interface==7.2 zipp==3.23.0
zstandard==0.23.0 zope.interface==8.1.1
zstandard==0.25.0