From 8a259158c8fa9a804444dc49af455069d78e8f09 Mon Sep 17 00:00:00 2001 From: Ryan Westfall Date: Wed, 24 Sep 2025 11:49:08 -0500 Subject: [PATCH] Updated data analysis to generate images to perform data analysis --- .../migrations/0021_alter_prompt_message.py | 20 +++++ llm_be/chat_backend/models.py | 2 +- .../services/data_analysis_service.py | 90 +++++++++++++++---- .../services/moderation_classifier.py | 3 +- llm_be/chat_backend/views.py | 33 +++++-- 5 files changed, 124 insertions(+), 24 deletions(-) create mode 100644 llm_be/chat_backend/migrations/0021_alter_prompt_message.py diff --git a/llm_be/chat_backend/migrations/0021_alter_prompt_message.py b/llm_be/chat_backend/migrations/0021_alter_prompt_message.py new file mode 100644 index 0000000..9485ea0 --- /dev/null +++ b/llm_be/chat_backend/migrations/0021_alter_prompt_message.py @@ -0,0 +1,20 @@ +# Generated by Django 5.1.7 on 2025-09-24 16:44 + +from django.db import migrations, models + + +class Migration(migrations.Migration): + + dependencies = [ + ("chat_backend", "0020_documentworkspace_document"), + ] + + operations = [ + migrations.AlterField( + model_name="prompt", + name="message", + field=models.CharField( + help_text="The text for a prompt", max_length=102400 + ), + ), + ] diff --git a/llm_be/chat_backend/models.py b/llm_be/chat_backend/models.py index cdb93a9..de870be 100644 --- a/llm_be/chat_backend/models.py +++ b/llm_be/chat_backend/models.py @@ -161,7 +161,7 @@ class Conversation(TimeInfoBase): class Prompt(TimeInfoBase): - message = models.CharField(max_length=10 * 1024, help_text="The text for a prompt") + message = models.CharField(max_length=100 * 1024, help_text="The text for a prompt") user_created = models.BooleanField( help_text="True if was created by the user. False if it was generate by the LLM" ) diff --git a/llm_be/chat_backend/services/data_analysis_service.py b/llm_be/chat_backend/services/data_analysis_service.py index 4abb568..4a62f5a 100644 --- a/llm_be/chat_backend/services/data_analysis_service.py +++ b/llm_be/chat_backend/services/data_analysis_service.py @@ -1,10 +1,15 @@ import pandas as pd import io +import re +import json +import base64 +import matplotlib.pyplot as plt from typing import AsyncGenerator from langchain_core.prompts import ChatPromptTemplate from langchain_ollama import OllamaLLM from langchain_core.output_parsers import StrOutputParser + class AsyncDataAnalysisService: """Asynchronous service for performing data analysis with an LLM.""" @@ -20,8 +25,11 @@ class AsyncDataAnalysisService: def _setup_chain(self): """Set up the LLM chain with a prompt tailored for data analysis.""" - template = """You are an expert data analyst. A user has provided a summary and sample of a dataset and is asking a question about it. -Analyze the provided information and answer the user's question. If a calculation is requested, perform it based on the summary statistics provided. If the data is not suitable for the request, explain why. + template = """You are an expert data analyst. Your role is to directly answer a user's question about a dataset they have provided. +You will be given a summary and a sample of the dataset. +Based on this information, provide a clear and concise answer to the user's question. +Do not provide Python code or any other code. The user is not a developer and wants a direct answer. +Even if you don't think the data provides enough evidence for the query, still provide a response --- Data Summary: @@ -69,35 +77,87 @@ Answer:""" return "\n".join(summary_lines) + def _generate_plot(self, query: str, df: pd.DataFrame) -> str: + """ + Generates a plot from a DataFrame based on a natural language query, + encodes it in Base64, and returns it. + + If columns are specified (e.g., "plot X vs Y"), it uses them. + If not, it automatically picks the first two numerical columns. + """ + col1, col2 = None, None + title = "Scatter Plot" + + # Attempt to find explicitly mentioned columns, e.g., "plot Column1 vs Column2" + match = re.search(r"(?:plot|scatter|visualize)\s+(.*?)\s+(?:vs|versus|and)\s+(.*)", query, re.IGNORECASE) + if match: + potential_col1 = match.group(1).strip() + potential_col2 = match.group(2).strip() + if potential_col1 in df.columns and potential_col2 in df.columns: + col1, col2 = potential_col1, potential_col2 + title = f"Scatterplot of {col1} vs {col2}" + + # If no valid columns were explicitly found, auto-detect + if not col1 or not col2: + numeric_cols = df.select_dtypes(include=['number']).columns.tolist() + if len(numeric_cols) >= 2: + col1, col2 = numeric_cols[0], numeric_cols[1] + title = f"Scatterplot of {col1} vs {col2} (Auto-selected)" + else: + raise ValueError("I couldn't find two numerical columns to plot automatically. Please specify columns, like 'plot column_A vs column_B'.") + + fig, ax = plt.subplots() + ax.scatter(df[col1], df[col2]) + ax.set_xlabel(col1) + ax.set_ylabel(col2) + ax.set_title(title) + ax.grid(True) + + buf = io.BytesIO() + plt.savefig(buf, format='png', bbox_inches='tight') + plt.close(fig) + buf.seek(0) + + image_base64 = base64.b64encode(buf.read()).decode('utf-8') + return image_base64 + async def generate_response( self, query: str, decoded_file: bytes, file_type: str, ) -> AsyncGenerator[str, None]: - """Generate a response based on the uploaded data and user query.""" - + """ + Generate a response based on the uploaded data and user query. + This can be a text analysis or a plot visualization. + """ try: - # Read the file content into a DataFrame if file_type == "csv": df = pd.read_csv(io.BytesIO(decoded_file)) elif file_type == "xlsx": df = pd.read_excel(io.BytesIO(decoded_file)) else: - yield "I can only analyze CSV and XLSX files at this time." + yield json.dumps({"type": "error", "content": "I can only analyze CSV and XLSX files."}) return - # Get the structured summary instead of the full data - data_summary = self._get_dataframe_summary(df) + plot_keywords = ["plot", "graph", "scatter", "visualize"] + if any(keyword in query.lower() for keyword in plot_keywords): + try: + image_base64 = self._generate_plot(query, df) + yield json.dumps({ + "type": "plot", + "format": "png", + "image": image_base64 + }) + except ValueError as e: + yield json.dumps({"type": "error", "content": str(e)}) + return - # Prepare the input for the LLM chain - chain_input = { - "data_summary": data_summary, - "query": query, - } + data_summary = self._get_dataframe_summary(df) + chain_input = {"data_summary": data_summary, "query": query} async for chunk in self.analysis_chain.astream(chain_input): - yield chunk + yield chunk #json.dumps({"type": "text", "content": chunk}) except Exception as e: - yield f"An error occurred while processing the file: {e}" + yield json.dumps({"type": "error", "content": f"An error occurred: {e}"}) diff --git a/llm_be/chat_backend/services/moderation_classifier.py b/llm_be/chat_backend/services/moderation_classifier.py index a8b77e7..cb3baa1 100644 --- a/llm_be/chat_backend/services/moderation_classifier.py +++ b/llm_be/chat_backend/services/moderation_classifier.py @@ -29,7 +29,7 @@ class ModerationClassifier(BaseService): ( "system", """You are a strict content moderator. Classify the following prompt as either NSFW or FINE. - + NSFW includes: - Sexual content - Violence/gore @@ -50,6 +50,7 @@ Examples: - "Write a love poem" → FINE - "Explicit sex scene" → NSFW - "Python tutorial" → FINE +- "Please analyze this file and project the next 12 months for me. Add a graph visual of the data as well" → FINE Return ONLY "NSFW" or "FINE", nothing else.""", ), diff --git a/llm_be/chat_backend/views.py b/llm_be/chat_backend/views.py index a2c6a5c..583630c 100644 --- a/llm_be/chat_backend/views.py +++ b/llm_be/chat_backend/views.py @@ -823,6 +823,8 @@ def save_generated_message(conversation_id, message): prompt_instance = serializer.save() prompt_instance.conversation_id = conversation.id prompt_instance = serializer.save() + else: + print(serializer.errors) @database_sync_to_async @@ -896,6 +898,20 @@ class ChatConsumerAgain(AsyncWebsocketConsumer): async def disconnect(self, close_code): await self.close() + async def send_json_message(self, data_str): + """ + Ensures that the message sent over the websocket is a valid JSON object. + If data_str is a plain string, it wraps it in {"type": "text", "content": ...}. + """ + try: + # Test if it's already a valid JSON object string + json.loads(data_str) + # If it is, send it as is + await self.send(data_str) + except (json.JSONDecodeError, TypeError): + # If it's a plain string or not JSON-decodable, wrap it + await self.send(data_str) + async def receive(self, text_data=None, bytes_data=None): logger.debug(f"Text Data: {text_data}") logger.debug(f"Bytes Data: {bytes_data}") @@ -946,13 +962,15 @@ class ChatConsumerAgain(AsyncWebsocketConsumer): == ModerationLabel.NSFW ): response = "Prompt has been marked as NSFW. If this is in error, submit a feedback with the prompt text." + response_to_send = json.dumps({"type": "error", "content": response}) + logger.warning("this prompt has been marked as NSFW") await self.send("CONVERSATION_ID") await self.send(str(conversation_id)) await self.send("START_OF_THE_STREAM_ENDER_GAME_42") - await self.send(response) + await self.send_json_message(response_to_send) await self.send("END_OF_THE_STREAM_ENDER_GAME_42") - await save_generated_message(conversation_id, response) + await save_generated_message(conversation_id, response_to_send) return # TODO: add the message to the database @@ -1007,18 +1025,18 @@ class ChatConsumerAgain(AsyncWebsocketConsumer): messages, prompt.message, workspace ): response += chunk - await self.send(chunk) + await self.send_json_message(chunk) elif prompt_type == PromptType.DATA_ANALYSIS: service = AsyncDataAnalysisService() if not decoded_file: - await self.send("Please upload a file to perform data analysis.") + await self.send_json_message("Please upload a file to perform data analysis.") else: async for chunk in service.generate_response(prompt.message, decoded_file, file_type): response += chunk - await self.send(chunk) + await self.send_json_message(chunk) elif prompt_type == PromptType.IMAGE_GENERATION: response = "Image Generation is not supported at this time, but it will be soon." - await self.send(response) + await self.send_json_message(response) else: logger.info(f"using the AsyncLLMService\n\n{messages}\n{prompt.message}") @@ -1027,8 +1045,9 @@ class ChatConsumerAgain(AsyncWebsocketConsumer): messages, prompt.message, conversation_id ): response += chunk - await self.send(chunk) + await self.send_json_message(chunk) await self.send("END_OF_THE_STREAM_ENDER_GAME_42") + await save_generated_message(conversation_id, response) await finish_prompt_metric(prompt_metric, len(response))