Updated data analysis to generate images to perform data analysis

This commit is contained in:
2025-09-24 11:49:08 -05:00
parent 14d8211715
commit 8a259158c8
5 changed files with 124 additions and 24 deletions

View File

@@ -0,0 +1,20 @@
# Generated by Django 5.1.7 on 2025-09-24 16:44
from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
("chat_backend", "0020_documentworkspace_document"),
]
operations = [
migrations.AlterField(
model_name="prompt",
name="message",
field=models.CharField(
help_text="The text for a prompt", max_length=102400
),
),
]

View File

@@ -161,7 +161,7 @@ class Conversation(TimeInfoBase):
class Prompt(TimeInfoBase): class Prompt(TimeInfoBase):
message = models.CharField(max_length=10 * 1024, help_text="The text for a prompt") message = models.CharField(max_length=100 * 1024, help_text="The text for a prompt")
user_created = models.BooleanField( user_created = models.BooleanField(
help_text="True if was created by the user. False if it was generate by the LLM" help_text="True if was created by the user. False if it was generate by the LLM"
) )

View File

@@ -1,10 +1,15 @@
import pandas as pd import pandas as pd
import io import io
import re
import json
import base64
import matplotlib.pyplot as plt
from typing import AsyncGenerator from typing import AsyncGenerator
from langchain_core.prompts import ChatPromptTemplate from langchain_core.prompts import ChatPromptTemplate
from langchain_ollama import OllamaLLM from langchain_ollama import OllamaLLM
from langchain_core.output_parsers import StrOutputParser from langchain_core.output_parsers import StrOutputParser
class AsyncDataAnalysisService: class AsyncDataAnalysisService:
"""Asynchronous service for performing data analysis with an LLM.""" """Asynchronous service for performing data analysis with an LLM."""
@@ -20,8 +25,11 @@ class AsyncDataAnalysisService:
def _setup_chain(self): def _setup_chain(self):
"""Set up the LLM chain with a prompt tailored for data analysis.""" """Set up the LLM chain with a prompt tailored for data analysis."""
template = """You are an expert data analyst. A user has provided a summary and sample of a dataset and is asking a question about it. template = """You are an expert data analyst. Your role is to directly answer a user's question about a dataset they have provided.
Analyze the provided information and answer the user's question. If a calculation is requested, perform it based on the summary statistics provided. If the data is not suitable for the request, explain why. You will be given a summary and a sample of the dataset.
Based on this information, provide a clear and concise answer to the user's question.
Do not provide Python code or any other code. The user is not a developer and wants a direct answer.
Even if you don't think the data provides enough evidence for the query, still provide a response
--- ---
Data Summary: Data Summary:
@@ -69,35 +77,87 @@ Answer:"""
return "\n".join(summary_lines) return "\n".join(summary_lines)
def _generate_plot(self, query: str, df: pd.DataFrame) -> str:
"""
Generates a plot from a DataFrame based on a natural language query,
encodes it in Base64, and returns it.
If columns are specified (e.g., "plot X vs Y"), it uses them.
If not, it automatically picks the first two numerical columns.
"""
col1, col2 = None, None
title = "Scatter Plot"
# Attempt to find explicitly mentioned columns, e.g., "plot Column1 vs Column2"
match = re.search(r"(?:plot|scatter|visualize)\s+(.*?)\s+(?:vs|versus|and)\s+(.*)", query, re.IGNORECASE)
if match:
potential_col1 = match.group(1).strip()
potential_col2 = match.group(2).strip()
if potential_col1 in df.columns and potential_col2 in df.columns:
col1, col2 = potential_col1, potential_col2
title = f"Scatterplot of {col1} vs {col2}"
# If no valid columns were explicitly found, auto-detect
if not col1 or not col2:
numeric_cols = df.select_dtypes(include=['number']).columns.tolist()
if len(numeric_cols) >= 2:
col1, col2 = numeric_cols[0], numeric_cols[1]
title = f"Scatterplot of {col1} vs {col2} (Auto-selected)"
else:
raise ValueError("I couldn't find two numerical columns to plot automatically. Please specify columns, like 'plot column_A vs column_B'.")
fig, ax = plt.subplots()
ax.scatter(df[col1], df[col2])
ax.set_xlabel(col1)
ax.set_ylabel(col2)
ax.set_title(title)
ax.grid(True)
buf = io.BytesIO()
plt.savefig(buf, format='png', bbox_inches='tight')
plt.close(fig)
buf.seek(0)
image_base64 = base64.b64encode(buf.read()).decode('utf-8')
return image_base64
async def generate_response( async def generate_response(
self, self,
query: str, query: str,
decoded_file: bytes, decoded_file: bytes,
file_type: str, file_type: str,
) -> AsyncGenerator[str, None]: ) -> AsyncGenerator[str, None]:
"""Generate a response based on the uploaded data and user query.""" """
Generate a response based on the uploaded data and user query.
This can be a text analysis or a plot visualization.
"""
try: try:
# Read the file content into a DataFrame
if file_type == "csv": if file_type == "csv":
df = pd.read_csv(io.BytesIO(decoded_file)) df = pd.read_csv(io.BytesIO(decoded_file))
elif file_type == "xlsx": elif file_type == "xlsx":
df = pd.read_excel(io.BytesIO(decoded_file)) df = pd.read_excel(io.BytesIO(decoded_file))
else: else:
yield "I can only analyze CSV and XLSX files at this time." yield json.dumps({"type": "error", "content": "I can only analyze CSV and XLSX files."})
return return
# Get the structured summary instead of the full data plot_keywords = ["plot", "graph", "scatter", "visualize"]
data_summary = self._get_dataframe_summary(df) if any(keyword in query.lower() for keyword in plot_keywords):
try:
image_base64 = self._generate_plot(query, df)
yield json.dumps({
"type": "plot",
"format": "png",
"image": image_base64
})
except ValueError as e:
yield json.dumps({"type": "error", "content": str(e)})
return
# Prepare the input for the LLM chain data_summary = self._get_dataframe_summary(df)
chain_input = { chain_input = {"data_summary": data_summary, "query": query}
"data_summary": data_summary,
"query": query,
}
async for chunk in self.analysis_chain.astream(chain_input): async for chunk in self.analysis_chain.astream(chain_input):
yield chunk yield chunk #json.dumps({"type": "text", "content": chunk})
except Exception as e: except Exception as e:
yield f"An error occurred while processing the file: {e}" yield json.dumps({"type": "error", "content": f"An error occurred: {e}"})

View File

@@ -29,7 +29,7 @@ class ModerationClassifier(BaseService):
( (
"system", "system",
"""You are a strict content moderator. Classify the following prompt as either NSFW or FINE. """You are a strict content moderator. Classify the following prompt as either NSFW or FINE.
NSFW includes: NSFW includes:
- Sexual content - Sexual content
- Violence/gore - Violence/gore
@@ -50,6 +50,7 @@ Examples:
- "Write a love poem" → FINE - "Write a love poem" → FINE
- "Explicit sex scene" → NSFW - "Explicit sex scene" → NSFW
- "Python tutorial" → FINE - "Python tutorial" → FINE
- "Please analyze this file and project the next 12 months for me. Add a graph visual of the data as well" → FINE
Return ONLY "NSFW" or "FINE", nothing else.""", Return ONLY "NSFW" or "FINE", nothing else.""",
), ),

View File

@@ -823,6 +823,8 @@ def save_generated_message(conversation_id, message):
prompt_instance = serializer.save() prompt_instance = serializer.save()
prompt_instance.conversation_id = conversation.id prompt_instance.conversation_id = conversation.id
prompt_instance = serializer.save() prompt_instance = serializer.save()
else:
print(serializer.errors)
@database_sync_to_async @database_sync_to_async
@@ -896,6 +898,20 @@ class ChatConsumerAgain(AsyncWebsocketConsumer):
async def disconnect(self, close_code): async def disconnect(self, close_code):
await self.close() await self.close()
async def send_json_message(self, data_str):
"""
Ensures that the message sent over the websocket is a valid JSON object.
If data_str is a plain string, it wraps it in {"type": "text", "content": ...}.
"""
try:
# Test if it's already a valid JSON object string
json.loads(data_str)
# If it is, send it as is
await self.send(data_str)
except (json.JSONDecodeError, TypeError):
# If it's a plain string or not JSON-decodable, wrap it
await self.send(data_str)
async def receive(self, text_data=None, bytes_data=None): async def receive(self, text_data=None, bytes_data=None):
logger.debug(f"Text Data: {text_data}") logger.debug(f"Text Data: {text_data}")
logger.debug(f"Bytes Data: {bytes_data}") logger.debug(f"Bytes Data: {bytes_data}")
@@ -946,13 +962,15 @@ class ChatConsumerAgain(AsyncWebsocketConsumer):
== ModerationLabel.NSFW == ModerationLabel.NSFW
): ):
response = "Prompt has been marked as NSFW. If this is in error, submit a feedback with the prompt text." response = "Prompt has been marked as NSFW. If this is in error, submit a feedback with the prompt text."
response_to_send = json.dumps({"type": "error", "content": response})
logger.warning("this prompt has been marked as NSFW") logger.warning("this prompt has been marked as NSFW")
await self.send("CONVERSATION_ID") await self.send("CONVERSATION_ID")
await self.send(str(conversation_id)) await self.send(str(conversation_id))
await self.send("START_OF_THE_STREAM_ENDER_GAME_42") await self.send("START_OF_THE_STREAM_ENDER_GAME_42")
await self.send(response) await self.send_json_message(response_to_send)
await self.send("END_OF_THE_STREAM_ENDER_GAME_42") await self.send("END_OF_THE_STREAM_ENDER_GAME_42")
await save_generated_message(conversation_id, response) await save_generated_message(conversation_id, response_to_send)
return return
# TODO: add the message to the database # TODO: add the message to the database
@@ -1007,18 +1025,18 @@ class ChatConsumerAgain(AsyncWebsocketConsumer):
messages, prompt.message, workspace messages, prompt.message, workspace
): ):
response += chunk response += chunk
await self.send(chunk) await self.send_json_message(chunk)
elif prompt_type == PromptType.DATA_ANALYSIS: elif prompt_type == PromptType.DATA_ANALYSIS:
service = AsyncDataAnalysisService() service = AsyncDataAnalysisService()
if not decoded_file: if not decoded_file:
await self.send("Please upload a file to perform data analysis.") await self.send_json_message("Please upload a file to perform data analysis.")
else: else:
async for chunk in service.generate_response(prompt.message, decoded_file, file_type): async for chunk in service.generate_response(prompt.message, decoded_file, file_type):
response += chunk response += chunk
await self.send(chunk) await self.send_json_message(chunk)
elif prompt_type == PromptType.IMAGE_GENERATION: elif prompt_type == PromptType.IMAGE_GENERATION:
response = "Image Generation is not supported at this time, but it will be soon." response = "Image Generation is not supported at this time, but it will be soon."
await self.send(response) await self.send_json_message(response)
else: else:
logger.info(f"using the AsyncLLMService\n\n{messages}\n{prompt.message}") logger.info(f"using the AsyncLLMService\n\n{messages}\n{prompt.message}")
@@ -1027,8 +1045,9 @@ class ChatConsumerAgain(AsyncWebsocketConsumer):
messages, prompt.message, conversation_id messages, prompt.message, conversation_id
): ):
response += chunk response += chunk
await self.send(chunk) await self.send_json_message(chunk)
await self.send("END_OF_THE_STREAM_ENDER_GAME_42") await self.send("END_OF_THE_STREAM_ENDER_GAME_42")
await save_generated_message(conversation_id, response) await save_generated_message(conversation_id, response)
await finish_prompt_metric(prompt_metric, len(response)) await finish_prompt_metric(prompt_metric, len(response))