Updated data analysis to generate images to perform data analysis
This commit is contained in:
20
llm_be/chat_backend/migrations/0021_alter_prompt_message.py
Normal file
20
llm_be/chat_backend/migrations/0021_alter_prompt_message.py
Normal file
@@ -0,0 +1,20 @@
|
||||
# Generated by Django 5.1.7 on 2025-09-24 16:44
|
||||
|
||||
from django.db import migrations, models
|
||||
|
||||
|
||||
class Migration(migrations.Migration):
|
||||
|
||||
dependencies = [
|
||||
("chat_backend", "0020_documentworkspace_document"),
|
||||
]
|
||||
|
||||
operations = [
|
||||
migrations.AlterField(
|
||||
model_name="prompt",
|
||||
name="message",
|
||||
field=models.CharField(
|
||||
help_text="The text for a prompt", max_length=102400
|
||||
),
|
||||
),
|
||||
]
|
||||
@@ -161,7 +161,7 @@ class Conversation(TimeInfoBase):
|
||||
|
||||
|
||||
class Prompt(TimeInfoBase):
|
||||
message = models.CharField(max_length=10 * 1024, help_text="The text for a prompt")
|
||||
message = models.CharField(max_length=100 * 1024, help_text="The text for a prompt")
|
||||
user_created = models.BooleanField(
|
||||
help_text="True if was created by the user. False if it was generate by the LLM"
|
||||
)
|
||||
|
||||
@@ -1,10 +1,15 @@
|
||||
import pandas as pd
|
||||
import io
|
||||
import re
|
||||
import json
|
||||
import base64
|
||||
import matplotlib.pyplot as plt
|
||||
from typing import AsyncGenerator
|
||||
from langchain_core.prompts import ChatPromptTemplate
|
||||
from langchain_ollama import OllamaLLM
|
||||
from langchain_core.output_parsers import StrOutputParser
|
||||
|
||||
|
||||
class AsyncDataAnalysisService:
|
||||
"""Asynchronous service for performing data analysis with an LLM."""
|
||||
|
||||
@@ -20,8 +25,11 @@ class AsyncDataAnalysisService:
|
||||
|
||||
def _setup_chain(self):
|
||||
"""Set up the LLM chain with a prompt tailored for data analysis."""
|
||||
template = """You are an expert data analyst. A user has provided a summary and sample of a dataset and is asking a question about it.
|
||||
Analyze the provided information and answer the user's question. If a calculation is requested, perform it based on the summary statistics provided. If the data is not suitable for the request, explain why.
|
||||
template = """You are an expert data analyst. Your role is to directly answer a user's question about a dataset they have provided.
|
||||
You will be given a summary and a sample of the dataset.
|
||||
Based on this information, provide a clear and concise answer to the user's question.
|
||||
Do not provide Python code or any other code. The user is not a developer and wants a direct answer.
|
||||
Even if you don't think the data provides enough evidence for the query, still provide a response
|
||||
|
||||
---
|
||||
Data Summary:
|
||||
@@ -69,35 +77,87 @@ Answer:"""
|
||||
|
||||
return "\n".join(summary_lines)
|
||||
|
||||
def _generate_plot(self, query: str, df: pd.DataFrame) -> str:
|
||||
"""
|
||||
Generates a plot from a DataFrame based on a natural language query,
|
||||
encodes it in Base64, and returns it.
|
||||
|
||||
If columns are specified (e.g., "plot X vs Y"), it uses them.
|
||||
If not, it automatically picks the first two numerical columns.
|
||||
"""
|
||||
col1, col2 = None, None
|
||||
title = "Scatter Plot"
|
||||
|
||||
# Attempt to find explicitly mentioned columns, e.g., "plot Column1 vs Column2"
|
||||
match = re.search(r"(?:plot|scatter|visualize)\s+(.*?)\s+(?:vs|versus|and)\s+(.*)", query, re.IGNORECASE)
|
||||
if match:
|
||||
potential_col1 = match.group(1).strip()
|
||||
potential_col2 = match.group(2).strip()
|
||||
if potential_col1 in df.columns and potential_col2 in df.columns:
|
||||
col1, col2 = potential_col1, potential_col2
|
||||
title = f"Scatterplot of {col1} vs {col2}"
|
||||
|
||||
# If no valid columns were explicitly found, auto-detect
|
||||
if not col1 or not col2:
|
||||
numeric_cols = df.select_dtypes(include=['number']).columns.tolist()
|
||||
if len(numeric_cols) >= 2:
|
||||
col1, col2 = numeric_cols[0], numeric_cols[1]
|
||||
title = f"Scatterplot of {col1} vs {col2} (Auto-selected)"
|
||||
else:
|
||||
raise ValueError("I couldn't find two numerical columns to plot automatically. Please specify columns, like 'plot column_A vs column_B'.")
|
||||
|
||||
fig, ax = plt.subplots()
|
||||
ax.scatter(df[col1], df[col2])
|
||||
ax.set_xlabel(col1)
|
||||
ax.set_ylabel(col2)
|
||||
ax.set_title(title)
|
||||
ax.grid(True)
|
||||
|
||||
buf = io.BytesIO()
|
||||
plt.savefig(buf, format='png', bbox_inches='tight')
|
||||
plt.close(fig)
|
||||
buf.seek(0)
|
||||
|
||||
image_base64 = base64.b64encode(buf.read()).decode('utf-8')
|
||||
return image_base64
|
||||
|
||||
async def generate_response(
|
||||
self,
|
||||
query: str,
|
||||
decoded_file: bytes,
|
||||
file_type: str,
|
||||
) -> AsyncGenerator[str, None]:
|
||||
"""Generate a response based on the uploaded data and user query."""
|
||||
|
||||
"""
|
||||
Generate a response based on the uploaded data and user query.
|
||||
This can be a text analysis or a plot visualization.
|
||||
"""
|
||||
try:
|
||||
# Read the file content into a DataFrame
|
||||
if file_type == "csv":
|
||||
df = pd.read_csv(io.BytesIO(decoded_file))
|
||||
elif file_type == "xlsx":
|
||||
df = pd.read_excel(io.BytesIO(decoded_file))
|
||||
else:
|
||||
yield "I can only analyze CSV and XLSX files at this time."
|
||||
yield json.dumps({"type": "error", "content": "I can only analyze CSV and XLSX files."})
|
||||
return
|
||||
|
||||
# Get the structured summary instead of the full data
|
||||
data_summary = self._get_dataframe_summary(df)
|
||||
plot_keywords = ["plot", "graph", "scatter", "visualize"]
|
||||
if any(keyword in query.lower() for keyword in plot_keywords):
|
||||
try:
|
||||
image_base64 = self._generate_plot(query, df)
|
||||
yield json.dumps({
|
||||
"type": "plot",
|
||||
"format": "png",
|
||||
"image": image_base64
|
||||
})
|
||||
except ValueError as e:
|
||||
yield json.dumps({"type": "error", "content": str(e)})
|
||||
return
|
||||
|
||||
# Prepare the input for the LLM chain
|
||||
chain_input = {
|
||||
"data_summary": data_summary,
|
||||
"query": query,
|
||||
}
|
||||
data_summary = self._get_dataframe_summary(df)
|
||||
chain_input = {"data_summary": data_summary, "query": query}
|
||||
|
||||
async for chunk in self.analysis_chain.astream(chain_input):
|
||||
yield chunk
|
||||
yield chunk #json.dumps({"type": "text", "content": chunk})
|
||||
|
||||
except Exception as e:
|
||||
yield f"An error occurred while processing the file: {e}"
|
||||
yield json.dumps({"type": "error", "content": f"An error occurred: {e}"})
|
||||
|
||||
@@ -29,7 +29,7 @@ class ModerationClassifier(BaseService):
|
||||
(
|
||||
"system",
|
||||
"""You are a strict content moderator. Classify the following prompt as either NSFW or FINE.
|
||||
|
||||
|
||||
NSFW includes:
|
||||
- Sexual content
|
||||
- Violence/gore
|
||||
@@ -50,6 +50,7 @@ Examples:
|
||||
- "Write a love poem" → FINE
|
||||
- "Explicit sex scene" → NSFW
|
||||
- "Python tutorial" → FINE
|
||||
- "Please analyze this file and project the next 12 months for me. Add a graph visual of the data as well" → FINE
|
||||
|
||||
Return ONLY "NSFW" or "FINE", nothing else.""",
|
||||
),
|
||||
|
||||
@@ -823,6 +823,8 @@ def save_generated_message(conversation_id, message):
|
||||
prompt_instance = serializer.save()
|
||||
prompt_instance.conversation_id = conversation.id
|
||||
prompt_instance = serializer.save()
|
||||
else:
|
||||
print(serializer.errors)
|
||||
|
||||
|
||||
@database_sync_to_async
|
||||
@@ -896,6 +898,20 @@ class ChatConsumerAgain(AsyncWebsocketConsumer):
|
||||
async def disconnect(self, close_code):
|
||||
await self.close()
|
||||
|
||||
async def send_json_message(self, data_str):
|
||||
"""
|
||||
Ensures that the message sent over the websocket is a valid JSON object.
|
||||
If data_str is a plain string, it wraps it in {"type": "text", "content": ...}.
|
||||
"""
|
||||
try:
|
||||
# Test if it's already a valid JSON object string
|
||||
json.loads(data_str)
|
||||
# If it is, send it as is
|
||||
await self.send(data_str)
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
# If it's a plain string or not JSON-decodable, wrap it
|
||||
await self.send(data_str)
|
||||
|
||||
async def receive(self, text_data=None, bytes_data=None):
|
||||
logger.debug(f"Text Data: {text_data}")
|
||||
logger.debug(f"Bytes Data: {bytes_data}")
|
||||
@@ -946,13 +962,15 @@ class ChatConsumerAgain(AsyncWebsocketConsumer):
|
||||
== ModerationLabel.NSFW
|
||||
):
|
||||
response = "Prompt has been marked as NSFW. If this is in error, submit a feedback with the prompt text."
|
||||
response_to_send = json.dumps({"type": "error", "content": response})
|
||||
|
||||
logger.warning("this prompt has been marked as NSFW")
|
||||
await self.send("CONVERSATION_ID")
|
||||
await self.send(str(conversation_id))
|
||||
await self.send("START_OF_THE_STREAM_ENDER_GAME_42")
|
||||
await self.send(response)
|
||||
await self.send_json_message(response_to_send)
|
||||
await self.send("END_OF_THE_STREAM_ENDER_GAME_42")
|
||||
await save_generated_message(conversation_id, response)
|
||||
await save_generated_message(conversation_id, response_to_send)
|
||||
return
|
||||
|
||||
# TODO: add the message to the database
|
||||
@@ -1007,18 +1025,18 @@ class ChatConsumerAgain(AsyncWebsocketConsumer):
|
||||
messages, prompt.message, workspace
|
||||
):
|
||||
response += chunk
|
||||
await self.send(chunk)
|
||||
await self.send_json_message(chunk)
|
||||
elif prompt_type == PromptType.DATA_ANALYSIS:
|
||||
service = AsyncDataAnalysisService()
|
||||
if not decoded_file:
|
||||
await self.send("Please upload a file to perform data analysis.")
|
||||
await self.send_json_message("Please upload a file to perform data analysis.")
|
||||
else:
|
||||
async for chunk in service.generate_response(prompt.message, decoded_file, file_type):
|
||||
response += chunk
|
||||
await self.send(chunk)
|
||||
await self.send_json_message(chunk)
|
||||
elif prompt_type == PromptType.IMAGE_GENERATION:
|
||||
response = "Image Generation is not supported at this time, but it will be soon."
|
||||
await self.send(response)
|
||||
await self.send_json_message(response)
|
||||
|
||||
else:
|
||||
logger.info(f"using the AsyncLLMService\n\n{messages}\n{prompt.message}")
|
||||
@@ -1027,8 +1045,9 @@ class ChatConsumerAgain(AsyncWebsocketConsumer):
|
||||
messages, prompt.message, conversation_id
|
||||
):
|
||||
response += chunk
|
||||
await self.send(chunk)
|
||||
await self.send_json_message(chunk)
|
||||
await self.send("END_OF_THE_STREAM_ENDER_GAME_42")
|
||||
|
||||
await save_generated_message(conversation_id, response)
|
||||
|
||||
await finish_prompt_metric(prompt_metric, len(response))
|
||||
|
||||
Reference in New Issue
Block a user