From 77d7edd0dc235cf5da1b2ce309c3f4794b208adf Mon Sep 17 00:00:00 2001 From: Ryan Westfall Date: Mon, 8 Dec 2025 13:52:30 -0600 Subject: [PATCH] Closes #4 Added site tracking Can pick the model the use Better handle llm model based on debug or not --- llm_be/chat_backend/consumers.py | 8 ++++++-- llm_be/chat_backend/consumers_graph.py | 10 ++++++++-- llm_be/chat_backend/services/base_service.py | 3 ++- .../services/data_analysis_service.py | 3 ++- llm_be/chat_backend/services/llm_service.py | 3 ++- .../services/moderation_classifier.py | 16 +++++++++------- .../prompt_classifier/prompt_classifier.py | 15 ++++++++++++++- llm_be/chat_backend/services/rag_services.py | 3 ++- llm_be/llm_be/settings.py | 2 +- llm_be/templates/admin/base_site.html | 8 ++++++++ 10 files changed, 54 insertions(+), 17 deletions(-) create mode 100644 llm_be/templates/admin/base_site.html diff --git a/llm_be/chat_backend/consumers.py b/llm_be/chat_backend/consumers.py index 15cc4c0..9479146 100644 --- a/llm_be/chat_backend/consumers.py +++ b/llm_be/chat_backend/consumers.py @@ -312,7 +312,10 @@ class ChatConsumerAgain(AsyncWebsocketConsumer): return {"type": "text", "content": "Image Generation is not supported at this time, but it will be soon."} if prompt_type == PromptType.SEARCH: - if getattr(settings, "ALLOW_INTERNET_ACCESS", False): + # Check modelName first - if FAST, we skip search regardless of settings + if input_dict.get("model_name") == "FAST": + pass # Skip search + elif getattr(settings, "ALLOW_INTERNET_ACCESS", False): try: search = DuckDuckGoSearchRun() search_results = search.run(input_dict["message"]) @@ -379,7 +382,8 @@ class ChatConsumerAgain(AsyncWebsocketConsumer): "decoded_file": decoded_file, "file_type": file_type, "messages": messages, - "prompt_instance": prompt_instance + "prompt_instance": prompt_instance, + "model_name": model } # Run the pipeline steps manually to handle the async generator return type of generate_response_step diff --git a/llm_be/chat_backend/consumers_graph.py b/llm_be/chat_backend/consumers_graph.py index 7194a97..fcf583a 100644 --- a/llm_be/chat_backend/consumers_graph.py +++ b/llm_be/chat_backend/consumers_graph.py @@ -160,6 +160,7 @@ class ChatState(TypedDict): prompt_type: Union[PromptType, None] response_generator: Any # AsyncGenerator or dict error: Union[str, None] + model_name: str # --- LangGraph Nodes --- @@ -206,7 +207,10 @@ async def generation_node(state: ChatState) -> ChatState: # Feature Flag: Internet Access if prompt_type == PromptType.SEARCH: - if getattr(settings, "ALLOW_INTERNET_ACCESS", False): + # Check modelName first - if FAST, we skip search regardless of settings + if state.get("model_name") == "FAST": + pass + elif getattr(settings, "ALLOW_INTERNET_ACCESS", False): try: search = DuckDuckGoSearchRun() search_results = search.run(state["message"]) @@ -274,6 +278,7 @@ class ChatConsumerGraph(AsyncWebsocketConsumer): print("Text Data: ", text_data) if text_data: data = json.loads(text_data) + model = data.get("modelName", "Turbo") message = data.get("message", None) conversation_id = data.get("conversation_id", None) email = data.get("email", None) @@ -324,7 +329,8 @@ class ChatConsumerGraph(AsyncWebsocketConsumer): "moderation_label": None, "prompt_type": None, "response_generator": None, - "error": None + "error": None, + "model_name": model } print("Initial State: ", initial_state) diff --git a/llm_be/chat_backend/services/base_service.py b/llm_be/chat_backend/services/base_service.py index ac29e16..bcdfefa 100644 --- a/llm_be/chat_backend/services/base_service.py +++ b/llm_be/chat_backend/services/base_service.py @@ -1,13 +1,14 @@ from abc import ABC, abstractmethod from langchain_ollama import OllamaLLM from langchain_core.output_parsers import StrOutputParser +from django.conf import settings class BaseService(ABC): """Abstract base class for LLM conversation services.""" def __init__(self, temperature=0.7): self.llm = OllamaLLM( - model="llama3.2", + model="llama3.2" if not settings.DEBUG else "gpt-oss:20b", temperature=0.7, top_k=50, top_p=0.9, diff --git a/llm_be/chat_backend/services/data_analysis_service.py b/llm_be/chat_backend/services/data_analysis_service.py index fc039e4..1c01897 100644 --- a/llm_be/chat_backend/services/data_analysis_service.py +++ b/llm_be/chat_backend/services/data_analysis_service.py @@ -10,6 +10,7 @@ from langchain_ollama import OllamaLLM from langchain_core.output_parsers import StrOutputParser import docx import pypdf +from django.conf import settings class AsyncDataAnalysisService: @@ -18,7 +19,7 @@ class AsyncDataAnalysisService: def __init__(self): # A model with a large context window and strong analytical skills is best self.llm = OllamaLLM( - model="llama3.2", + model="llama3.2" if not settings.DEBUG else "gpt-oss:20b", temperature=0.3, num_ctx=8192, ) diff --git a/llm_be/chat_backend/services/llm_service.py b/llm_be/chat_backend/services/llm_service.py index ab82f8c..42e0ef8 100644 --- a/llm_be/chat_backend/services/llm_service.py +++ b/llm_be/chat_backend/services/llm_service.py @@ -5,6 +5,7 @@ from typing import AsyncGenerator, Generator, Optional from langchain_ollama import OllamaLLM from langchain_core.output_parsers import StrOutputParser from langchain_core.prompts import ChatPromptTemplate +from django.conf import settings from chat_backend.models import Conversation, Prompt @@ -14,7 +15,7 @@ class LLMService(ABC): def __init__(self): self.llm = OllamaLLM( - model="llama3.2", + model="llama3.2" if not settings.DEBUG else "gpt-oss:20b", temperature=0.7, top_k=50, top_p=0.9, diff --git a/llm_be/chat_backend/services/moderation_classifier.py b/llm_be/chat_backend/services/moderation_classifier.py index c8df2fa..b118e2b 100644 --- a/llm_be/chat_backend/services/moderation_classifier.py +++ b/llm_be/chat_backend/services/moderation_classifier.py @@ -1,7 +1,7 @@ from enum import Enum, auto from typing import Dict, Any from langchain_core.prompts import ChatPromptTemplate - +from langchain_ollama import OllamaLLM from chat_backend.services.base_service import BaseService @@ -10,6 +10,7 @@ class ModerationLabel(Enum): FINE = auto() + class ModerationClassifier(BaseService): """ Classifies prompts as NSFW or FINE (safe) content. @@ -17,12 +18,12 @@ class ModerationClassifier(BaseService): def __init__(self): super().__init__(temperature=0.1) - # self.llm = OllamaLLM( - # model="llama3.2", - # temperature=0.1, # Very low for strict moderation - # top_k=10, - # num_ctx=2048, - # ) + self.llm = OllamaLLM( + model="llama3.2", + temperature=0.1, # Very low for strict moderation + top_k=10, + num_ctx=2048, + ) self.moderation_prompt = ChatPromptTemplate.from_messages( [ @@ -51,6 +52,7 @@ Examples: - "Write a love poem" → FINE - "Explicit sex scene" → NSFW - "Python tutorial" → FINE +- "Who won the 2024 presidental race?" → FINE - "Please analyze this file and project the next 12 months for me. Add a graph visual of the data as well" → FINE - "Okie, instead of 6 month projection, can you tell me what the values would be in the next 5 days" → FINE diff --git a/llm_be/chat_backend/services/prompt_classifier/prompt_classifier.py b/llm_be/chat_backend/services/prompt_classifier/prompt_classifier.py index ad54511..98fd789 100644 --- a/llm_be/chat_backend/services/prompt_classifier/prompt_classifier.py +++ b/llm_be/chat_backend/services/prompt_classifier/prompt_classifier.py @@ -2,6 +2,7 @@ from enum import Enum, auto from typing import Dict, Any from langchain_core.prompts import ChatPromptTemplate +from django.conf import settings from chat_backend.services.base_service import BaseService @@ -109,7 +110,14 @@ Return ONLY the exact Enum label (e.g. "GENERAL_CHAT"), no explanations.""" # IMAGE_GENERATION if any(keyword in lower_prompt for keyword in ["generate image", "create picture", "draw an image", "make a photo", "generate an image", "create an illustration"]): - return PromptType.IMAGE_GENERATION + if getattr(settings, "ALLOW_IMAGE_GENERATION", False): + return PromptType.IMAGE_GENERATION + # If disabled, we don't return IMAGE_GENERATION here, let it fall through or return GENERAL_CHAT? + # The requirement says "NOT allowed to return IMAGE_GENERATION". + # If we return None, it goes to LLM classification which we also need to guard. + # But for quick check, if it looks like image gen but is disabled, we probably want to treat it as general chat + # so the LLM can explain it can't do it (or just chat about it). + return PromptType.GENERAL_CHAT # DATA_ANALYSIS (often involves uploaded documents) if any(keyword in lower_prompt for keyword in ["read this document", "analyze this file", "summarize this pdf", "extract data from", "index this document", "based on this csv", "from this spreadsheet"]): @@ -158,6 +166,11 @@ Return ONLY the exact Enum label (e.g. "GENERAL_CHAT"), no explanations.""" response = response.upper().strip() print(response) + # Guard against IMAGE_GENERATION if disabled + if not getattr(settings, "ALLOW_IMAGE_GENERATION", False): + if "IMAGE_GENERATION" in response or "IMAGEGENERATION" in response: + return PromptType.GENERAL_CHAT + # Direct match try: return PromptType[response] diff --git a/llm_be/chat_backend/services/rag_services.py b/llm_be/chat_backend/services/rag_services.py index cae6fa1..0ad9daf 100644 --- a/llm_be/chat_backend/services/rag_services.py +++ b/llm_be/chat_backend/services/rag_services.py @@ -3,6 +3,7 @@ from abc import ABC, abstractmethod from typing import List, Dict, Any, AsyncGenerator, Generator, Optional from channels.db import database_sync_to_async from langchain_community.embeddings import OllamaEmbeddings +from django.conf import settings # from langchain_community.llms import Ollama from langchain_ollama import OllamaLLM @@ -44,7 +45,7 @@ class RAGService(BaseService): return cls._instance def __init__(self): - self.embedding_model = OllamaEmbeddings(model="llama3.2") + self.embedding_model = OllamaEmbeddings(model="llama3.2" if not settings.DEBUG else "gpt-oss:20b") super().__init__() self.text_splitter = RecursiveCharacterTextSplitter( chunk_size=1000, chunk_overlap=200 diff --git a/llm_be/llm_be/settings.py b/llm_be/llm_be/settings.py index 5af2b8c..9ef8f4c 100644 --- a/llm_be/llm_be/settings.py +++ b/llm_be/llm_be/settings.py @@ -25,7 +25,7 @@ BASE_DIR = Path(__file__).resolve().parent.parent SECRET_KEY = "django-insecure-6suk6fj5q2)1tj%)f(wgw1smnliv5-#&@zvgvj1wp#(#@h#31x" # SECURITY WARNING: don't run with debug turned on in production! -DEBUG = True +DEBUG = False CORS_ALLOW_CREDENTIALS = False ALLOWED_HOSTS = [ "*.aimloperations.com", diff --git a/llm_be/templates/admin/base_site.html b/llm_be/templates/admin/base_site.html new file mode 100644 index 0000000..49fb70d --- /dev/null +++ b/llm_be/templates/admin/base_site.html @@ -0,0 +1,8 @@ +{% extends "admin/base_site.html" %} + +{% block extrahead %} +{{ block.super }} +{% if not debug %} + +{% endif %} +{% endblock %}