Updated data analysis to generate images to perform data analysis

This commit is contained in:
2025-09-24 11:49:08 -05:00
parent 14d8211715
commit 8a259158c8
5 changed files with 124 additions and 24 deletions

View File

@@ -0,0 +1,20 @@
# Generated by Django 5.1.7 on 2025-09-24 16:44
from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
("chat_backend", "0020_documentworkspace_document"),
]
operations = [
migrations.AlterField(
model_name="prompt",
name="message",
field=models.CharField(
help_text="The text for a prompt", max_length=102400
),
),
]

View File

@@ -161,7 +161,7 @@ class Conversation(TimeInfoBase):
class Prompt(TimeInfoBase):
message = models.CharField(max_length=10 * 1024, help_text="The text for a prompt")
message = models.CharField(max_length=100 * 1024, help_text="The text for a prompt")
user_created = models.BooleanField(
help_text="True if was created by the user. False if it was generate by the LLM"
)

View File

@@ -1,10 +1,15 @@
import pandas as pd
import io
import re
import json
import base64
import matplotlib.pyplot as plt
from typing import AsyncGenerator
from langchain_core.prompts import ChatPromptTemplate
from langchain_ollama import OllamaLLM
from langchain_core.output_parsers import StrOutputParser
class AsyncDataAnalysisService:
"""Asynchronous service for performing data analysis with an LLM."""
@@ -20,8 +25,11 @@ class AsyncDataAnalysisService:
def _setup_chain(self):
"""Set up the LLM chain with a prompt tailored for data analysis."""
template = """You are an expert data analyst. A user has provided a summary and sample of a dataset and is asking a question about it.
Analyze the provided information and answer the user's question. If a calculation is requested, perform it based on the summary statistics provided. If the data is not suitable for the request, explain why.
template = """You are an expert data analyst. Your role is to directly answer a user's question about a dataset they have provided.
You will be given a summary and a sample of the dataset.
Based on this information, provide a clear and concise answer to the user's question.
Do not provide Python code or any other code. The user is not a developer and wants a direct answer.
Even if you don't think the data provides enough evidence for the query, still provide a response
---
Data Summary:
@@ -69,35 +77,87 @@ Answer:"""
return "\n".join(summary_lines)
def _generate_plot(self, query: str, df: pd.DataFrame) -> str:
"""
Generates a plot from a DataFrame based on a natural language query,
encodes it in Base64, and returns it.
If columns are specified (e.g., "plot X vs Y"), it uses them.
If not, it automatically picks the first two numerical columns.
"""
col1, col2 = None, None
title = "Scatter Plot"
# Attempt to find explicitly mentioned columns, e.g., "plot Column1 vs Column2"
match = re.search(r"(?:plot|scatter|visualize)\s+(.*?)\s+(?:vs|versus|and)\s+(.*)", query, re.IGNORECASE)
if match:
potential_col1 = match.group(1).strip()
potential_col2 = match.group(2).strip()
if potential_col1 in df.columns and potential_col2 in df.columns:
col1, col2 = potential_col1, potential_col2
title = f"Scatterplot of {col1} vs {col2}"
# If no valid columns were explicitly found, auto-detect
if not col1 or not col2:
numeric_cols = df.select_dtypes(include=['number']).columns.tolist()
if len(numeric_cols) >= 2:
col1, col2 = numeric_cols[0], numeric_cols[1]
title = f"Scatterplot of {col1} vs {col2} (Auto-selected)"
else:
raise ValueError("I couldn't find two numerical columns to plot automatically. Please specify columns, like 'plot column_A vs column_B'.")
fig, ax = plt.subplots()
ax.scatter(df[col1], df[col2])
ax.set_xlabel(col1)
ax.set_ylabel(col2)
ax.set_title(title)
ax.grid(True)
buf = io.BytesIO()
plt.savefig(buf, format='png', bbox_inches='tight')
plt.close(fig)
buf.seek(0)
image_base64 = base64.b64encode(buf.read()).decode('utf-8')
return image_base64
async def generate_response(
self,
query: str,
decoded_file: bytes,
file_type: str,
) -> AsyncGenerator[str, None]:
"""Generate a response based on the uploaded data and user query."""
"""
Generate a response based on the uploaded data and user query.
This can be a text analysis or a plot visualization.
"""
try:
# Read the file content into a DataFrame
if file_type == "csv":
df = pd.read_csv(io.BytesIO(decoded_file))
elif file_type == "xlsx":
df = pd.read_excel(io.BytesIO(decoded_file))
else:
yield "I can only analyze CSV and XLSX files at this time."
yield json.dumps({"type": "error", "content": "I can only analyze CSV and XLSX files."})
return
# Get the structured summary instead of the full data
data_summary = self._get_dataframe_summary(df)
plot_keywords = ["plot", "graph", "scatter", "visualize"]
if any(keyword in query.lower() for keyword in plot_keywords):
try:
image_base64 = self._generate_plot(query, df)
yield json.dumps({
"type": "plot",
"format": "png",
"image": image_base64
})
except ValueError as e:
yield json.dumps({"type": "error", "content": str(e)})
return
# Prepare the input for the LLM chain
chain_input = {
"data_summary": data_summary,
"query": query,
}
data_summary = self._get_dataframe_summary(df)
chain_input = {"data_summary": data_summary, "query": query}
async for chunk in self.analysis_chain.astream(chain_input):
yield chunk
yield chunk #json.dumps({"type": "text", "content": chunk})
except Exception as e:
yield f"An error occurred while processing the file: {e}"
yield json.dumps({"type": "error", "content": f"An error occurred: {e}"})

View File

@@ -29,7 +29,7 @@ class ModerationClassifier(BaseService):
(
"system",
"""You are a strict content moderator. Classify the following prompt as either NSFW or FINE.
NSFW includes:
- Sexual content
- Violence/gore
@@ -50,6 +50,7 @@ Examples:
- "Write a love poem" → FINE
- "Explicit sex scene" → NSFW
- "Python tutorial" → FINE
- "Please analyze this file and project the next 12 months for me. Add a graph visual of the data as well" → FINE
Return ONLY "NSFW" or "FINE", nothing else.""",
),

View File

@@ -823,6 +823,8 @@ def save_generated_message(conversation_id, message):
prompt_instance = serializer.save()
prompt_instance.conversation_id = conversation.id
prompt_instance = serializer.save()
else:
print(serializer.errors)
@database_sync_to_async
@@ -896,6 +898,20 @@ class ChatConsumerAgain(AsyncWebsocketConsumer):
async def disconnect(self, close_code):
await self.close()
async def send_json_message(self, data_str):
"""
Ensures that the message sent over the websocket is a valid JSON object.
If data_str is a plain string, it wraps it in {"type": "text", "content": ...}.
"""
try:
# Test if it's already a valid JSON object string
json.loads(data_str)
# If it is, send it as is
await self.send(data_str)
except (json.JSONDecodeError, TypeError):
# If it's a plain string or not JSON-decodable, wrap it
await self.send(data_str)
async def receive(self, text_data=None, bytes_data=None):
logger.debug(f"Text Data: {text_data}")
logger.debug(f"Bytes Data: {bytes_data}")
@@ -946,13 +962,15 @@ class ChatConsumerAgain(AsyncWebsocketConsumer):
== ModerationLabel.NSFW
):
response = "Prompt has been marked as NSFW. If this is in error, submit a feedback with the prompt text."
response_to_send = json.dumps({"type": "error", "content": response})
logger.warning("this prompt has been marked as NSFW")
await self.send("CONVERSATION_ID")
await self.send(str(conversation_id))
await self.send("START_OF_THE_STREAM_ENDER_GAME_42")
await self.send(response)
await self.send_json_message(response_to_send)
await self.send("END_OF_THE_STREAM_ENDER_GAME_42")
await save_generated_message(conversation_id, response)
await save_generated_message(conversation_id, response_to_send)
return
# TODO: add the message to the database
@@ -1007,18 +1025,18 @@ class ChatConsumerAgain(AsyncWebsocketConsumer):
messages, prompt.message, workspace
):
response += chunk
await self.send(chunk)
await self.send_json_message(chunk)
elif prompt_type == PromptType.DATA_ANALYSIS:
service = AsyncDataAnalysisService()
if not decoded_file:
await self.send("Please upload a file to perform data analysis.")
await self.send_json_message("Please upload a file to perform data analysis.")
else:
async for chunk in service.generate_response(prompt.message, decoded_file, file_type):
response += chunk
await self.send(chunk)
await self.send_json_message(chunk)
elif prompt_type == PromptType.IMAGE_GENERATION:
response = "Image Generation is not supported at this time, but it will be soon."
await self.send(response)
await self.send_json_message(response)
else:
logger.info(f"using the AsyncLLMService\n\n{messages}\n{prompt.message}")
@@ -1027,8 +1045,9 @@ class ChatConsumerAgain(AsyncWebsocketConsumer):
messages, prompt.message, conversation_id
):
response += chunk
await self.send(chunk)
await self.send_json_message(chunk)
await self.send("END_OF_THE_STREAM_ENDER_GAME_42")
await save_generated_message(conversation_id, response)
await finish_prompt_metric(prompt_metric, len(response))