diff --git a/llm_be/chat_backend/services/data_analysis_service.py b/llm_be/chat_backend/services/data_analysis_service.py new file mode 100644 index 0000000..4abb568 --- /dev/null +++ b/llm_be/chat_backend/services/data_analysis_service.py @@ -0,0 +1,103 @@ +import pandas as pd +import io +from typing import AsyncGenerator +from langchain_core.prompts import ChatPromptTemplate +from langchain_ollama import OllamaLLM +from langchain_core.output_parsers import StrOutputParser + +class AsyncDataAnalysisService: + """Asynchronous service for performing data analysis with an LLM.""" + + def __init__(self): + # A model with a large context window and strong analytical skills is best + self.llm = OllamaLLM( + model="llama3.2", + temperature=0.3, + num_ctx=8192, + ) + self.output_parser = StrOutputParser() + self._setup_chain() + + def _setup_chain(self): + """Set up the LLM chain with a prompt tailored for data analysis.""" + template = """You are an expert data analyst. A user has provided a summary and sample of a dataset and is asking a question about it. +Analyze the provided information and answer the user's question. If a calculation is requested, perform it based on the summary statistics provided. If the data is not suitable for the request, explain why. + +--- +Data Summary: +{data_summary} +--- + +User's Question: {query} +Answer:""" + + self.prompt = ChatPromptTemplate.from_template(template) + + self.analysis_chain = ( + { + "data_summary": lambda x: x["data_summary"], + "query": lambda x: x["query"], + } + | self.prompt + | self.llm + | self.output_parser + ) + + def _get_dataframe_summary(self, df: pd.DataFrame) -> str: + """Generates a structured summary of the DataFrame for the LLM.""" + + num_rows, num_cols = df.shape + summary_lines = [ + f"DataFrame has {num_rows} rows and {num_cols} columns.", + "Column Information (Name, Dtype, Non-Null Count):", + "--------------------------------------------------", + ] + + # Add a concise summary using df.info() + info_buffer = io.StringIO() + df.info(buf=info_buffer, verbose=True, show_counts=True) + summary_lines.append(info_buffer.getvalue()) + + summary_lines.append("\nDescriptive Statistics (for numerical columns):") + summary_lines.append("--------------------------------------------") + summary_lines.append(df.describe().to_string()) + + summary_lines.append("\nSample of Data:") + summary_lines.append("-----------------") + # Show the first 5 rows and a few random rows to give a feel for the data + summary_lines.append(df.head(5).to_string()) + + return "\n".join(summary_lines) + + async def generate_response( + self, + query: str, + decoded_file: bytes, + file_type: str, + ) -> AsyncGenerator[str, None]: + """Generate a response based on the uploaded data and user query.""" + + try: + # Read the file content into a DataFrame + if file_type == "csv": + df = pd.read_csv(io.BytesIO(decoded_file)) + elif file_type == "xlsx": + df = pd.read_excel(io.BytesIO(decoded_file)) + else: + yield "I can only analyze CSV and XLSX files at this time." + return + + # Get the structured summary instead of the full data + data_summary = self._get_dataframe_summary(df) + + # Prepare the input for the LLM chain + chain_input = { + "data_summary": data_summary, + "query": query, + } + + async for chunk in self.analysis_chain.astream(chain_input): + yield chunk + + except Exception as e: + yield f"An error occurred while processing the file: {e}" diff --git a/llm_be/chat_backend/services/prompt_classifier/prompt_classifier.py b/llm_be/chat_backend/services/prompt_classifier/prompt_classifier.py index 58522aa..2286892 100644 --- a/llm_be/chat_backend/services/prompt_classifier/prompt_classifier.py +++ b/llm_be/chat_backend/services/prompt_classifier/prompt_classifier.py @@ -9,6 +9,7 @@ class PromptType(Enum): GENERAL_CHAT = auto() RAG = auto() IMAGE_GENERATION = auto() + DATA_ANALYSIS = auto() UNKNOWN = auto() @@ -35,43 +36,45 @@ class PromptClassifier(BaseService): 1. GENERAL_CHAT - Casual conversation, personal questions, or non-specific inquiries 2. RAG - ONLY when explicitly requesting document/search-based knowledge 3. IMAGE_GENERATION - Specific requests to create/modify images -4. UNKNOWN - If none of the above fit +4. DATA_ANALYSIS - When a user is asking questions about an uploaded spreadsheet or CSV file. The user's message contains the data from the file. +5. UNKNOWN - If none of the above fit 1. IMAGE_GENERATION - ONLY if: - - Explicitly contains: "generate/create/draw/make an image/picture/photo/art/illustration" - - Requests visual content creation - - Example: "Make a picture of a castle" → IMAGE_GENERATION + - Explicitly contains: "generate/create/draw/make an image/picture/photo/art/illustration" + - Requests visual content creation + - Example: "Make a picture of a castle" → IMAGE_GENERATION 2. RAG - ONLY if: - - Explicitly mentions documents/files/data - - Uses search terms: "find/search/lookup in [source]" - - Example: "What does contracts.pdf say?" → RAG + - Explicitly mentions documents/files/data + - Uses search terms: "find/search/lookup in [source]" + - Example: "What does contracts.pdf say?" → RAG -3. GENERAL_CHAT - DEFAULT category when: - - Doesn't meet above criteria - - Conversational/general knowledge - - Uncertain cases - - Example: "Tell me a joke" → GENERAL_CHAT +3. DATA_ANALYSIS - ONLY if: + - The message explicitly contains structured data from a file (e.g., a DataFrame string) + - The user is asking to analyze, summarize, or plot the data + - Example: "Here is the sales data. What is the average revenue per product?" -> DATA_ANALYSIS + +4. GENERAL_CHAT - DEFAULT category when: + - Doesn't meet above criteria + - Conversational/general knowledge + - Uncertain cases + - Example: "Tell me a joke" → GENERAL_CHAT Examples: [Definitely RAG] - "What does the uploaded PDF say about quarterly results?" - "Search our documents for the 2023 marketing strategy" -- "Find the contract clause about termination" + +[Definitely DATA_ANALYSIS] +- "Here is the file content. What is the sum of all 'Sales'?" +- "Based on this CSV data, show me the top 5 customers." [Definitely GENERAL_CHAT] - "How does photosynthesis work?" (General knowledge) - "Tell me a joke" -- "What's your opinion on AI?" -[Borderline → GENERAL_CHAT] -- "What's our company policy on X?" (No doc reference → general) -- "Explain quantum computing" (General knowledge) -- "Summarize the meeting" (No doc reference) - -[Definitely NOT IMAGE_GENERATION] -- "Great, can you make it about a duck now" -- "highlight the features of the backyard playset if they were to choose us and make the language more long form" +[Borderline -> GENERAL_CHAT] +- "What's our company policy on X?" (No doc reference -> general) Return ONLY the label, no explanations.""", ), diff --git a/llm_be/chat_backend/views.py b/llm_be/chat_backend/views.py index 1937826..a2c6a5c 100644 --- a/llm_be/chat_backend/views.py +++ b/llm_be/chat_backend/views.py @@ -44,6 +44,7 @@ from django.conf import settings import json import base64 import pandas as pd +import io # For email support from django.core.mail import EmailMultiAlternatives @@ -68,16 +69,22 @@ from .services.rag_services import AsyncRAGService from .services.title_generator import title_generator from .services.moderation_classifier import moderation_classifier, ModerationLabel from .services.prompt_classifier.prompt_classifier import PromptClassifier, PromptType +from .services.data_analysis_service import AsyncDataAnalysisService + from langchain.chains import create_retrieval_chain from langchain.chains.combine_documents import create_stuff_documents_chain from langchain_ollama import ChatOllama +import logging +logger = logging.getLogger(__name__) + + + CHANNEL_NAME: str = "llm_messages" MODEL_NAME: str = "llama3.2" - # Create your views here. class CustomObtainTokenView(TokenObtainPairView): permission_classes = (permissions.AllowAny,) @@ -99,8 +106,8 @@ class CustomUserCreate(APIView): def send_invite_email(slug, email_to_invite): - print("Sending invite email") - print(f"url : https://chat.aimloperations.com/set_password?slug={slug}") + logger.info("Sending invite email") + logger.info(f"url : https://chat.aimloperations.com/set_password?slug={slug}") url = f"https://chat.aimloperations.com/set_password?slug={slug}" subject = "Welcome to AI ML Operations, LLC Chat Services" from_email = "ryan@aimloperations.com" @@ -115,8 +122,8 @@ def send_invite_email(slug, email_to_invite): def send_password_reset_email(slug, email_to_invite): - print("Sending reset email") - print(f"url : https://www.chat.aimloperations.com/set_password?slug={slug}") + logger.info("Sending reset email") + logger.info(f"url : https://www.chat.aimloperations.com/set_password?slug={slug}") url = f"https://www.chat.aimloperations.com/set_password?slug={slug}" subject = "Password reset for AI ML Operations, LLC Chat Services" from_email = "ryan@aimloperations.com" @@ -131,7 +138,7 @@ def send_password_reset_email(slug, email_to_invite): def send_feedback_email(feedback_obj): - print("Sending feedback email") + logger.info("Sending feedback email") subject = "New Feedback for Chat by AI ML Operations, LLC" from_email = "ryan@aimloperations.com" to = "ryan@aimloperations.com" @@ -145,7 +152,7 @@ def send_feedback_email(feedback_obj): def send_password_reset_email(slug, email_to_invite): - print("Sending Password reset email") + logger.info("Sending Password reset email") url = f"https://www.chat.aimloperations.com/set_password?slug={slug}" subject = "Password reset for Chat by AI ML Operations, LLC" from_email = "ryan@aimloperations.com" @@ -233,7 +240,7 @@ class ResetUserPassword(APIView): Send an email with a set password link to the set password page Also disable the account """ - print(f"Password reset for requests. {request.data}") + logger.info(f"Password reset for requests. {request.data}") token = request.data.get("recaptchaToken") payload = { "secret": settings.CAPTCHA_SECRET_KEY, @@ -252,7 +259,7 @@ class ResetUserPassword(APIView): # send the email send_password_reset_email(user.slug, email) else: - print("Captcha secret failed") + logger.error("Captcha secret failed") return Response(status=status.HTTP_200_OK) @@ -284,14 +291,14 @@ class CustomUserGet(APIView): email = request.user.email username = request.user.username user = CustomUser.objects.filter(email=email).last() - print(f"Getting the user: {user}") + logger.info(f"Getting the user: {user}") try: serializer = CustomUserSerializer(user) - print(f"serializer: {serializer}") - print(serializer.data) + logger.debug(f"serializer: {serializer}") + logger.debug(serializer.data) return Response(serializer.data, status=status.HTTP_200_OK) except Exception as e: - print(f"Exception: {e}") + logger.error(f"Exception: {e}") return Response({}, status=status.HTTP_400_BAD_REQUEST) @@ -300,7 +307,7 @@ class FeedbackView(APIView): def post(self, request, format="json"): serializer = FeedbackSerializer(data=request.data) - print(request.data) + logger.debug(request.data) if serializer.is_valid(): feedback_obj = serializer.save() @@ -310,7 +317,7 @@ class FeedbackView(APIView): send_feedback_email(feedback_obj) return Response(serializer.data, status=status.HTTP_201_CREATED) else: - print(serializer.errors) + logger.error(serializer.errors) return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST) def get(self, request, format="json"): @@ -453,7 +460,7 @@ class ConversationDetailView(APIView): return Response(serailzer.data, status=status.HTTP_200_OK) def post(self, request, format="json"): - print("In the post") + logger.info("In the post") # Add the prompt to the database # make sure there is a conversation for it # if there is not a conversation create a title for it @@ -481,7 +488,7 @@ class ConversationDetailView(APIView): prompt_instance = serializer.save() # set up the streaming response if it is from the user - print(f"Do we have a valid user? {is_user}") + logger.info(f"Do we have a valid user? {is_user}") if is_user: messages = [] for prompt_obj in Prompt.objects.filter( @@ -495,12 +502,12 @@ class ConversationDetailView(APIView): ) channel_layer = get_channel_layer() - print(f"Sending to the channel: {CHANNEL_NAME}") + logger.info(f"Sending to the channel: {CHANNEL_NAME}") async_to_sync(channel_layer.group_send)( CHANNEL_NAME, {"type": "receive", "content": messages} ) except: - print( + logger.error( f"Error trying to submit to conversation_id: {conversation_id} with request.data: {request.data}" ) pass @@ -740,7 +747,7 @@ def get_messages(conversation_id, prompt, file_string: str = None, file_type: st messages = [] conversation = Conversation.objects.get(id=conversation_id) - print(file_string) + logger.debug(file_string) # add the prompt to the conversation serializer = PromptSerializer( @@ -843,27 +850,43 @@ def update_prompt_metric(prompt_metric, status): @database_sync_to_async def finish_prompt_metric(prompt_metric, response_length): - print(f"finish_prompt_metric: {response_length}") + logger.info(f"finish_prompt_metric: {response_length}") prompt_metric.end_time = timezone.now() prompt_metric.reponse_length = response_length prompt_metric.event = "FINISHED" prompt_metric.save(update_fields=["end_time", "reponse_length", "event"]) - print("finish_prompt_metric saved") + logger.info("finish_prompt_metric saved") @database_sync_to_async def get_retriever(conversation_id): - print(f"getting workspace from conversation: {conversation_id}") + logger.info(f"getting workspace from conversation: {conversation_id}") conversation = Conversation.objects.get(id=conversation_id) - print(f"Got conversation: {conversation}") + logger.info(f"Got conversation: {conversation}") workspace = DocumentWorkspace.objects.get(company=conversation.user.company) - print(f"Got workspace: {conversation}") + logger.info(f"Got workspace: {conversation}") vectorstore = Chroma( persist_directory=f"./chroma_db/", embedding=OllamaEmbeddings(model="llama3.2"), ) return vectorstore.as_retriever() +async def get_conversation_file_async(conversation_id): + try: + # Get the very first prompt in the conversation that has a file + prompt_with_file = await Prompt.objects.filter( + conversation_id=conversation_id + ).exclude(file='').order_by('created').afirst() + + if prompt_with_file and prompt_with_file.file: + # You must use sync_to_async to access the file's binary content + file_data = await sync_to_async(prompt_with_file.file.read)() + file_type = prompt_with_file.file_type + return file_data, file_type + except Exception as e: + logger.error(f"Error retrieving file from conversation history: {e}") + return None, None + PROMPT_CLASSIFIER = PromptClassifier() class ChatConsumerAgain(AsyncWebsocketConsumer): @@ -874,8 +897,8 @@ class ChatConsumerAgain(AsyncWebsocketConsumer): await self.close() async def receive(self, text_data=None, bytes_data=None): - print(f"Text Data: {text_data}") - print(f"Bytes Data: {bytes_data}") + logger.debug(f"Text Data: {text_data}") + logger.debug(f"Bytes Data: {bytes_data}") if text_data: data = json.loads(text_data) message = data.get("message", None) @@ -896,21 +919,26 @@ class ChatConsumerAgain(AsyncWebsocketConsumer): if file: decoded_file = base64.b64decode(file) - print(decoded_file) + logger.debug(decoded_file) + # The `altered_message` should only be created if a file exists + # and you want to pass its content directly to the classifier. + # Here, we'll let the classifier decide based on the user's prompt + # and then handle the file content separately. + altered_message = message if "csv" in file_type: file_type = "csv" - altered_message = f"{message}\n The file type is csv and the file contents are: {decoded_file}" + #altered_message = f"{message}\n The file type is csv and the file contents are: {decoded_file}" elif "xmlformats-officedocument" in file_type: file_type = "xlsx" - df = pd.read_excel(decoded_file) - altered_message = f"{message}\n The file type is xlsx and the file contents are: {df}" + #df = pd.read_excel(decoded_file) + #altered_message = f"{message}\n The file type is xlsx and the file contents are: {df}" elif "text" in file_type: file_type = "txt" - altered_message = f"{message}\n The file type is txt and the file contents are: {decoded_file}" + #altered_message = f"{message}\n The file type is txt and the file contents are: {decoded_file}" else: file_type = "Not Sure" - print(f'received: "{message}" for conversation {conversation_id}') + logger.info(f'received: "{message}" for conversation {conversation_id}') # check the moderation here if ( @@ -918,7 +946,7 @@ class ChatConsumerAgain(AsyncWebsocketConsumer): == ModerationLabel.NSFW ): response = "Prompt has been marked as NSFW. If this is in error, submit a feedback with the prompt text." - print("this prompt has been marked as NSFW") + logger.warning("this prompt has been marked as NSFW") await self.send("CONVERSATION_ID") await self.send(str(conversation_id)) await self.send("START_OF_THE_STREAM_ENDER_GAME_42") @@ -934,11 +962,19 @@ class ChatConsumerAgain(AsyncWebsocketConsumer): messages, prompt = await get_messages( conversation_id, message, decoded_file, file_type ) - + if not decoded_file: + decoded_file, file_type = await get_conversation_file_async(conversation_id) prompt_type = await PROMPT_CLASSIFIER.classify_async(message) + logger.info(f"prompt_type: {prompt_type} for {message}") print(f"prompt_type: {prompt_type} for {message}") - if file: + # Check for a file AND the new DATA_ANALYSIS type + # The classifier might not correctly identify a data analysis prompt + # without the file contents. So, we'll add a check to override. + if decoded_file and (prompt_type == PromptType.DATA_ANALYSIS or 'analyze' in message.lower() or 'data' in message.lower()): + prompt_type = PromptType.DATA_ANALYSIS + elif decoded_file: + # If a decoded_file is uploaded but the query is general, default to GENERAL_CHAT prompt_type = PromptType.GENERAL_CHAT prompt_metric = await create_prompt_metric( @@ -952,7 +988,7 @@ class ChatConsumerAgain(AsyncWebsocketConsumer): if file: # udpate with the altered_message messages = messages[:-1] + [HumanMessage(content=altered_message)] - print(messages) + logger.info(messages) # send it to the LLM # stream the response back @@ -965,19 +1001,27 @@ class ChatConsumerAgain(AsyncWebsocketConsumer): service = AsyncRAGService() # await service.ingest_documents() workspace = await get_workspace(conversation_id) - print("Time to get the rag response") + logger.info("Time to get the rag response") async for chunk in service.generate_response( messages, prompt.message, workspace ): response += chunk await self.send(chunk) + elif prompt_type == PromptType.DATA_ANALYSIS: + service = AsyncDataAnalysisService() + if not decoded_file: + await self.send("Please upload a file to perform data analysis.") + else: + async for chunk in service.generate_response(prompt.message, decoded_file, file_type): + response += chunk + await self.send(chunk) elif prompt_type == PromptType.IMAGE_GENERATION: response = "Image Generation is not supported at this time, but it will be soon." await self.send(response) else: - print(f"using the AsyncLLMService\n\n{messages}\n{prompt.message}") + logger.info(f"using the AsyncLLMService\n\n{messages}\n{prompt.message}") service = AsyncLLMService() async for chunk in service.generate_response( messages, prompt.message, conversation_id @@ -990,7 +1034,7 @@ class ChatConsumerAgain(AsyncWebsocketConsumer): await finish_prompt_metric(prompt_metric, len(response)) if bytes_data: - print("we have byte data") + logger.info("we have byte data") # Document Views @@ -1014,7 +1058,7 @@ class DocumentUploadView(APIView): # permission_classes = [permissions.IsAuthenticated]Z def get(self, request): - print(f"request_3: {request}") + logger.debug(f"request_3: {request}") try: workspace = DocumentWorkspace.objects.get(company=request.user.company) serializer = DocumentSerializer( @@ -1028,7 +1072,7 @@ class DocumentUploadView(APIView): ) def post(self, request): - print(f"request: {request}") + logger.debug(f"request: {request}") try: workspace = DocumentWorkspace.objects.get(company=request.user.company) @@ -1038,14 +1082,14 @@ class DocumentUploadView(APIView): {"error": "Workspace not found"}, status=status.HTTP_404_NOT_FOUND ) - print(request.FILES) + logger.info(request.FILES) file = request.FILES.get("file") if not file: return Response( {"error": "No file provided"}, status=status.HTTP_400_BAD_REQUEST ) - print("have the workspace and the file") + logger.info("have the workspace and the file") document = Document.objects.create(workspace=workspace, file=file) @@ -1072,7 +1116,7 @@ class DocumentDetailView(APIView): # permission_classes = [permissions.IsAuthenticated] def get(self, request, document_id): - print(f"request: {request}") + logger.info(f"request: {request}") try: workspace = DocumentWorkspace.objects.get(company=request.user.company) diff --git a/requirements.txt b/requirements.txt index c4eff53..1fcc228 100644 --- a/requirements.txt +++ b/requirements.txt @@ -43,8 +43,9 @@ duckdb==1.2.1 durationpy==0.9 effdet==0.4.1 emoji==2.14.1 +et_xmlfile==2.0.0 eval_type_backport==0.2.2 -Faker +Faker==37.0.0 fastapi==0.115.9 filelock==3.17.0 filetype==1.2.0 @@ -52,13 +53,13 @@ flatbuffers==25.2.10 fonttools==4.56.0 frozenlist==1.6.0 fsspec==2025.2.0 -google-api-core -google-auth -google-cloud-vision -googleapis-common-protos +google-api-core==2.24.2 +google-auth==2.39.0 +google-cloud-vision==3.10.1 +googleapis-common-protos==1.70.0 greenlet==3.1.1 -grpcio -grpcio-status +grpcio==1.72.0rc1 +grpcio-status==1.72.0rc1 h11==0.14.0 html5lib==1.1 httpcore==1.0.7 @@ -121,13 +122,14 @@ oauthlib==3.2.2 olefile==0.47 ollama==0.4.7 omegaconf==2.3.0 -onnx -onnxruntime +onnx==1.18.0 +onnxruntime==1.21.1 openai==1.65.4 opencv-python==4.11.0.86 -opentelemetry-api -opentelemetry-exporter-otlp-proto-common -opentelemetry-exporter-otlp-proto-grpc +openpyxl==3.1.5 +opentelemetry-api==1.32.1 +opentelemetry-exporter-otlp-proto-common==1.32.1 +opentelemetry-exporter-otlp-proto-grpc==1.32.1 opentelemetry-instrumentation==0.53b1 opentelemetry-instrumentation-asgi==0.53b1 opentelemetry-instrumentation-fastapi==0.53b1 @@ -138,8 +140,9 @@ opentelemetry-util-http==0.53b1 orjson==3.10.15 overrides==7.7.0 packaging==24.2 -pandas -pandasai +pandas==2.2.3 +pandasai==2.4.2 +parameterized==0.9.0 pathspec==0.12.1 pdf2image==1.17.0 pdfminer.six==20250506 @@ -150,7 +153,7 @@ platformdirs==4.3.6 posthog==4.0.1 propcache==0.3.1 proto-plus==1.26.1 -protobuf +protobuf==6.31.0rc2 psutil==7.0.0 pyasn1==0.6.1 pyasn1_modules==0.4.1 @@ -160,7 +163,7 @@ pydantic==2.11.4 pydantic-settings==2.9.1 pydantic_core==2.33.2 Pygments==2.19.1 -PyJWT +PyJWT==2.10.1 pyOpenSSL==25.0.0 pyparsing==3.2.1 pypdf==5.4.0