Allow for data analysis

This commit is contained in:
2025-09-08 12:29:20 -05:00
parent 951a58f2fa
commit 14d8211715
4 changed files with 236 additions and 83 deletions

View File

@@ -0,0 +1,103 @@
import pandas as pd
import io
from typing import AsyncGenerator
from langchain_core.prompts import ChatPromptTemplate
from langchain_ollama import OllamaLLM
from langchain_core.output_parsers import StrOutputParser
class AsyncDataAnalysisService:
"""Asynchronous service for performing data analysis with an LLM."""
def __init__(self):
# A model with a large context window and strong analytical skills is best
self.llm = OllamaLLM(
model="llama3.2",
temperature=0.3,
num_ctx=8192,
)
self.output_parser = StrOutputParser()
self._setup_chain()
def _setup_chain(self):
"""Set up the LLM chain with a prompt tailored for data analysis."""
template = """You are an expert data analyst. A user has provided a summary and sample of a dataset and is asking a question about it.
Analyze the provided information and answer the user's question. If a calculation is requested, perform it based on the summary statistics provided. If the data is not suitable for the request, explain why.
---
Data Summary:
{data_summary}
---
User's Question: {query}
Answer:"""
self.prompt = ChatPromptTemplate.from_template(template)
self.analysis_chain = (
{
"data_summary": lambda x: x["data_summary"],
"query": lambda x: x["query"],
}
| self.prompt
| self.llm
| self.output_parser
)
def _get_dataframe_summary(self, df: pd.DataFrame) -> str:
"""Generates a structured summary of the DataFrame for the LLM."""
num_rows, num_cols = df.shape
summary_lines = [
f"DataFrame has {num_rows} rows and {num_cols} columns.",
"Column Information (Name, Dtype, Non-Null Count):",
"--------------------------------------------------",
]
# Add a concise summary using df.info()
info_buffer = io.StringIO()
df.info(buf=info_buffer, verbose=True, show_counts=True)
summary_lines.append(info_buffer.getvalue())
summary_lines.append("\nDescriptive Statistics (for numerical columns):")
summary_lines.append("--------------------------------------------")
summary_lines.append(df.describe().to_string())
summary_lines.append("\nSample of Data:")
summary_lines.append("-----------------")
# Show the first 5 rows and a few random rows to give a feel for the data
summary_lines.append(df.head(5).to_string())
return "\n".join(summary_lines)
async def generate_response(
self,
query: str,
decoded_file: bytes,
file_type: str,
) -> AsyncGenerator[str, None]:
"""Generate a response based on the uploaded data and user query."""
try:
# Read the file content into a DataFrame
if file_type == "csv":
df = pd.read_csv(io.BytesIO(decoded_file))
elif file_type == "xlsx":
df = pd.read_excel(io.BytesIO(decoded_file))
else:
yield "I can only analyze CSV and XLSX files at this time."
return
# Get the structured summary instead of the full data
data_summary = self._get_dataframe_summary(df)
# Prepare the input for the LLM chain
chain_input = {
"data_summary": data_summary,
"query": query,
}
async for chunk in self.analysis_chain.astream(chain_input):
yield chunk
except Exception as e:
yield f"An error occurred while processing the file: {e}"

View File

@@ -9,6 +9,7 @@ class PromptType(Enum):
GENERAL_CHAT = auto()
RAG = auto()
IMAGE_GENERATION = auto()
DATA_ANALYSIS = auto()
UNKNOWN = auto()
@@ -35,43 +36,45 @@ class PromptClassifier(BaseService):
1. GENERAL_CHAT - Casual conversation, personal questions, or non-specific inquiries
2. RAG - ONLY when explicitly requesting document/search-based knowledge
3. IMAGE_GENERATION - Specific requests to create/modify images
4. UNKNOWN - If none of the above fit
4. DATA_ANALYSIS - When a user is asking questions about an uploaded spreadsheet or CSV file. The user's message contains the data from the file.
5. UNKNOWN - If none of the above fit
1. IMAGE_GENERATION - ONLY if:
- Explicitly contains: "generate/create/draw/make an image/picture/photo/art/illustration"
- Requests visual content creation
- Example: "Make a picture of a castle" → IMAGE_GENERATION
- Explicitly contains: "generate/create/draw/make an image/picture/photo/art/illustration"
- Requests visual content creation
- Example: "Make a picture of a castle" → IMAGE_GENERATION
2. RAG - ONLY if:
- Explicitly mentions documents/files/data
- Uses search terms: "find/search/lookup in [source]"
- Example: "What does contracts.pdf say?" → RAG
- Explicitly mentions documents/files/data
- Uses search terms: "find/search/lookup in [source]"
- Example: "What does contracts.pdf say?" → RAG
3. GENERAL_CHAT - DEFAULT category when:
- Doesn't meet above criteria
- Conversational/general knowledge
- Uncertain cases
- Example: "Tell me a joke" → GENERAL_CHAT
3. DATA_ANALYSIS - ONLY if:
- The message explicitly contains structured data from a file (e.g., a DataFrame string)
- The user is asking to analyze, summarize, or plot the data
- Example: "Here is the sales data. What is the average revenue per product?" -> DATA_ANALYSIS
4. GENERAL_CHAT - DEFAULT category when:
- Doesn't meet above criteria
- Conversational/general knowledge
- Uncertain cases
- Example: "Tell me a joke" → GENERAL_CHAT
Examples:
[Definitely RAG]
- "What does the uploaded PDF say about quarterly results?"
- "Search our documents for the 2023 marketing strategy"
- "Find the contract clause about termination"
[Definitely DATA_ANALYSIS]
- "Here is the file content. What is the sum of all 'Sales'?"
- "Based on this CSV data, show me the top 5 customers."
[Definitely GENERAL_CHAT]
- "How does photosynthesis work?" (General knowledge)
- "Tell me a joke"
- "What's your opinion on AI?"
[Borderline GENERAL_CHAT]
- "What's our company policy on X?" (No doc reference general)
- "Explain quantum computing" (General knowledge)
- "Summarize the meeting" (No doc reference)
[Definitely NOT IMAGE_GENERATION]
- "Great, can you make it about a duck now"
- "highlight the features of the backyard playset if they were to choose us and make the language more long form"
[Borderline -> GENERAL_CHAT]
- "What's our company policy on X?" (No doc reference -> general)
Return ONLY the label, no explanations.""",
),

View File

@@ -44,6 +44,7 @@ from django.conf import settings
import json
import base64
import pandas as pd
import io
# For email support
from django.core.mail import EmailMultiAlternatives
@@ -68,16 +69,22 @@ from .services.rag_services import AsyncRAGService
from .services.title_generator import title_generator
from .services.moderation_classifier import moderation_classifier, ModerationLabel
from .services.prompt_classifier.prompt_classifier import PromptClassifier, PromptType
from .services.data_analysis_service import AsyncDataAnalysisService
from langchain.chains import create_retrieval_chain
from langchain.chains.combine_documents import create_stuff_documents_chain
from langchain_ollama import ChatOllama
import logging
logger = logging.getLogger(__name__)
CHANNEL_NAME: str = "llm_messages"
MODEL_NAME: str = "llama3.2"
# Create your views here.
class CustomObtainTokenView(TokenObtainPairView):
permission_classes = (permissions.AllowAny,)
@@ -99,8 +106,8 @@ class CustomUserCreate(APIView):
def send_invite_email(slug, email_to_invite):
print("Sending invite email")
print(f"url : https://chat.aimloperations.com/set_password?slug={slug}")
logger.info("Sending invite email")
logger.info(f"url : https://chat.aimloperations.com/set_password?slug={slug}")
url = f"https://chat.aimloperations.com/set_password?slug={slug}"
subject = "Welcome to AI ML Operations, LLC Chat Services"
from_email = "ryan@aimloperations.com"
@@ -115,8 +122,8 @@ def send_invite_email(slug, email_to_invite):
def send_password_reset_email(slug, email_to_invite):
print("Sending reset email")
print(f"url : https://www.chat.aimloperations.com/set_password?slug={slug}")
logger.info("Sending reset email")
logger.info(f"url : https://www.chat.aimloperations.com/set_password?slug={slug}")
url = f"https://www.chat.aimloperations.com/set_password?slug={slug}"
subject = "Password reset for AI ML Operations, LLC Chat Services"
from_email = "ryan@aimloperations.com"
@@ -131,7 +138,7 @@ def send_password_reset_email(slug, email_to_invite):
def send_feedback_email(feedback_obj):
print("Sending feedback email")
logger.info("Sending feedback email")
subject = "New Feedback for Chat by AI ML Operations, LLC"
from_email = "ryan@aimloperations.com"
to = "ryan@aimloperations.com"
@@ -145,7 +152,7 @@ def send_feedback_email(feedback_obj):
def send_password_reset_email(slug, email_to_invite):
print("Sending Password reset email")
logger.info("Sending Password reset email")
url = f"https://www.chat.aimloperations.com/set_password?slug={slug}"
subject = "Password reset for Chat by AI ML Operations, LLC"
from_email = "ryan@aimloperations.com"
@@ -233,7 +240,7 @@ class ResetUserPassword(APIView):
Send an email with a set password link to the set password page
Also disable the account
"""
print(f"Password reset for requests. {request.data}")
logger.info(f"Password reset for requests. {request.data}")
token = request.data.get("recaptchaToken")
payload = {
"secret": settings.CAPTCHA_SECRET_KEY,
@@ -252,7 +259,7 @@ class ResetUserPassword(APIView):
# send the email
send_password_reset_email(user.slug, email)
else:
print("Captcha secret failed")
logger.error("Captcha secret failed")
return Response(status=status.HTTP_200_OK)
@@ -284,14 +291,14 @@ class CustomUserGet(APIView):
email = request.user.email
username = request.user.username
user = CustomUser.objects.filter(email=email).last()
print(f"Getting the user: {user}")
logger.info(f"Getting the user: {user}")
try:
serializer = CustomUserSerializer(user)
print(f"serializer: {serializer}")
print(serializer.data)
logger.debug(f"serializer: {serializer}")
logger.debug(serializer.data)
return Response(serializer.data, status=status.HTTP_200_OK)
except Exception as e:
print(f"Exception: {e}")
logger.error(f"Exception: {e}")
return Response({}, status=status.HTTP_400_BAD_REQUEST)
@@ -300,7 +307,7 @@ class FeedbackView(APIView):
def post(self, request, format="json"):
serializer = FeedbackSerializer(data=request.data)
print(request.data)
logger.debug(request.data)
if serializer.is_valid():
feedback_obj = serializer.save()
@@ -310,7 +317,7 @@ class FeedbackView(APIView):
send_feedback_email(feedback_obj)
return Response(serializer.data, status=status.HTTP_201_CREATED)
else:
print(serializer.errors)
logger.error(serializer.errors)
return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST)
def get(self, request, format="json"):
@@ -453,7 +460,7 @@ class ConversationDetailView(APIView):
return Response(serailzer.data, status=status.HTTP_200_OK)
def post(self, request, format="json"):
print("In the post")
logger.info("In the post")
# Add the prompt to the database
# make sure there is a conversation for it
# if there is not a conversation create a title for it
@@ -481,7 +488,7 @@ class ConversationDetailView(APIView):
prompt_instance = serializer.save()
# set up the streaming response if it is from the user
print(f"Do we have a valid user? {is_user}")
logger.info(f"Do we have a valid user? {is_user}")
if is_user:
messages = []
for prompt_obj in Prompt.objects.filter(
@@ -495,12 +502,12 @@ class ConversationDetailView(APIView):
)
channel_layer = get_channel_layer()
print(f"Sending to the channel: {CHANNEL_NAME}")
logger.info(f"Sending to the channel: {CHANNEL_NAME}")
async_to_sync(channel_layer.group_send)(
CHANNEL_NAME, {"type": "receive", "content": messages}
)
except:
print(
logger.error(
f"Error trying to submit to conversation_id: {conversation_id} with request.data: {request.data}"
)
pass
@@ -740,7 +747,7 @@ def get_messages(conversation_id, prompt, file_string: str = None, file_type: st
messages = []
conversation = Conversation.objects.get(id=conversation_id)
print(file_string)
logger.debug(file_string)
# add the prompt to the conversation
serializer = PromptSerializer(
@@ -843,27 +850,43 @@ def update_prompt_metric(prompt_metric, status):
@database_sync_to_async
def finish_prompt_metric(prompt_metric, response_length):
print(f"finish_prompt_metric: {response_length}")
logger.info(f"finish_prompt_metric: {response_length}")
prompt_metric.end_time = timezone.now()
prompt_metric.reponse_length = response_length
prompt_metric.event = "FINISHED"
prompt_metric.save(update_fields=["end_time", "reponse_length", "event"])
print("finish_prompt_metric saved")
logger.info("finish_prompt_metric saved")
@database_sync_to_async
def get_retriever(conversation_id):
print(f"getting workspace from conversation: {conversation_id}")
logger.info(f"getting workspace from conversation: {conversation_id}")
conversation = Conversation.objects.get(id=conversation_id)
print(f"Got conversation: {conversation}")
logger.info(f"Got conversation: {conversation}")
workspace = DocumentWorkspace.objects.get(company=conversation.user.company)
print(f"Got workspace: {conversation}")
logger.info(f"Got workspace: {conversation}")
vectorstore = Chroma(
persist_directory=f"./chroma_db/",
embedding=OllamaEmbeddings(model="llama3.2"),
)
return vectorstore.as_retriever()
async def get_conversation_file_async(conversation_id):
try:
# Get the very first prompt in the conversation that has a file
prompt_with_file = await Prompt.objects.filter(
conversation_id=conversation_id
).exclude(file='').order_by('created').afirst()
if prompt_with_file and prompt_with_file.file:
# You must use sync_to_async to access the file's binary content
file_data = await sync_to_async(prompt_with_file.file.read)()
file_type = prompt_with_file.file_type
return file_data, file_type
except Exception as e:
logger.error(f"Error retrieving file from conversation history: {e}")
return None, None
PROMPT_CLASSIFIER = PromptClassifier()
class ChatConsumerAgain(AsyncWebsocketConsumer):
@@ -874,8 +897,8 @@ class ChatConsumerAgain(AsyncWebsocketConsumer):
await self.close()
async def receive(self, text_data=None, bytes_data=None):
print(f"Text Data: {text_data}")
print(f"Bytes Data: {bytes_data}")
logger.debug(f"Text Data: {text_data}")
logger.debug(f"Bytes Data: {bytes_data}")
if text_data:
data = json.loads(text_data)
message = data.get("message", None)
@@ -896,21 +919,26 @@ class ChatConsumerAgain(AsyncWebsocketConsumer):
if file:
decoded_file = base64.b64decode(file)
print(decoded_file)
logger.debug(decoded_file)
# The `altered_message` should only be created if a file exists
# and you want to pass its content directly to the classifier.
# Here, we'll let the classifier decide based on the user's prompt
# and then handle the file content separately.
altered_message = message
if "csv" in file_type:
file_type = "csv"
altered_message = f"{message}\n The file type is csv and the file contents are: {decoded_file}"
#altered_message = f"{message}\n The file type is csv and the file contents are: {decoded_file}"
elif "xmlformats-officedocument" in file_type:
file_type = "xlsx"
df = pd.read_excel(decoded_file)
altered_message = f"{message}\n The file type is xlsx and the file contents are: {df}"
#df = pd.read_excel(decoded_file)
#altered_message = f"{message}\n The file type is xlsx and the file contents are: {df}"
elif "text" in file_type:
file_type = "txt"
altered_message = f"{message}\n The file type is txt and the file contents are: {decoded_file}"
#altered_message = f"{message}\n The file type is txt and the file contents are: {decoded_file}"
else:
file_type = "Not Sure"
print(f'received: "{message}" for conversation {conversation_id}')
logger.info(f'received: "{message}" for conversation {conversation_id}')
# check the moderation here
if (
@@ -918,7 +946,7 @@ class ChatConsumerAgain(AsyncWebsocketConsumer):
== ModerationLabel.NSFW
):
response = "Prompt has been marked as NSFW. If this is in error, submit a feedback with the prompt text."
print("this prompt has been marked as NSFW")
logger.warning("this prompt has been marked as NSFW")
await self.send("CONVERSATION_ID")
await self.send(str(conversation_id))
await self.send("START_OF_THE_STREAM_ENDER_GAME_42")
@@ -934,11 +962,19 @@ class ChatConsumerAgain(AsyncWebsocketConsumer):
messages, prompt = await get_messages(
conversation_id, message, decoded_file, file_type
)
if not decoded_file:
decoded_file, file_type = await get_conversation_file_async(conversation_id)
prompt_type = await PROMPT_CLASSIFIER.classify_async(message)
logger.info(f"prompt_type: {prompt_type} for {message}")
print(f"prompt_type: {prompt_type} for {message}")
if file:
# Check for a file AND the new DATA_ANALYSIS type
# The classifier might not correctly identify a data analysis prompt
# without the file contents. So, we'll add a check to override.
if decoded_file and (prompt_type == PromptType.DATA_ANALYSIS or 'analyze' in message.lower() or 'data' in message.lower()):
prompt_type = PromptType.DATA_ANALYSIS
elif decoded_file:
# If a decoded_file is uploaded but the query is general, default to GENERAL_CHAT
prompt_type = PromptType.GENERAL_CHAT
prompt_metric = await create_prompt_metric(
@@ -952,7 +988,7 @@ class ChatConsumerAgain(AsyncWebsocketConsumer):
if file:
# udpate with the altered_message
messages = messages[:-1] + [HumanMessage(content=altered_message)]
print(messages)
logger.info(messages)
# send it to the LLM
# stream the response back
@@ -965,19 +1001,27 @@ class ChatConsumerAgain(AsyncWebsocketConsumer):
service = AsyncRAGService()
# await service.ingest_documents()
workspace = await get_workspace(conversation_id)
print("Time to get the rag response")
logger.info("Time to get the rag response")
async for chunk in service.generate_response(
messages, prompt.message, workspace
):
response += chunk
await self.send(chunk)
elif prompt_type == PromptType.DATA_ANALYSIS:
service = AsyncDataAnalysisService()
if not decoded_file:
await self.send("Please upload a file to perform data analysis.")
else:
async for chunk in service.generate_response(prompt.message, decoded_file, file_type):
response += chunk
await self.send(chunk)
elif prompt_type == PromptType.IMAGE_GENERATION:
response = "Image Generation is not supported at this time, but it will be soon."
await self.send(response)
else:
print(f"using the AsyncLLMService\n\n{messages}\n{prompt.message}")
logger.info(f"using the AsyncLLMService\n\n{messages}\n{prompt.message}")
service = AsyncLLMService()
async for chunk in service.generate_response(
messages, prompt.message, conversation_id
@@ -990,7 +1034,7 @@ class ChatConsumerAgain(AsyncWebsocketConsumer):
await finish_prompt_metric(prompt_metric, len(response))
if bytes_data:
print("we have byte data")
logger.info("we have byte data")
# Document Views
@@ -1014,7 +1058,7 @@ class DocumentUploadView(APIView):
# permission_classes = [permissions.IsAuthenticated]Z
def get(self, request):
print(f"request_3: {request}")
logger.debug(f"request_3: {request}")
try:
workspace = DocumentWorkspace.objects.get(company=request.user.company)
serializer = DocumentSerializer(
@@ -1028,7 +1072,7 @@ class DocumentUploadView(APIView):
)
def post(self, request):
print(f"request: {request}")
logger.debug(f"request: {request}")
try:
workspace = DocumentWorkspace.objects.get(company=request.user.company)
@@ -1038,14 +1082,14 @@ class DocumentUploadView(APIView):
{"error": "Workspace not found"}, status=status.HTTP_404_NOT_FOUND
)
print(request.FILES)
logger.info(request.FILES)
file = request.FILES.get("file")
if not file:
return Response(
{"error": "No file provided"}, status=status.HTTP_400_BAD_REQUEST
)
print("have the workspace and the file")
logger.info("have the workspace and the file")
document = Document.objects.create(workspace=workspace, file=file)
@@ -1072,7 +1116,7 @@ class DocumentDetailView(APIView):
# permission_classes = [permissions.IsAuthenticated]
def get(self, request, document_id):
print(f"request: {request}")
logger.info(f"request: {request}")
try:
workspace = DocumentWorkspace.objects.get(company=request.user.company)