Added site tracking
Can pick the model the use
Better handle llm model based on debug or not
This commit is contained in:
2025-12-08 13:52:30 -06:00
parent eed1abedc8
commit 77d7edd0dc
10 changed files with 54 additions and 17 deletions

View File

@@ -312,7 +312,10 @@ class ChatConsumerAgain(AsyncWebsocketConsumer):
return {"type": "text", "content": "Image Generation is not supported at this time, but it will be soon."} return {"type": "text", "content": "Image Generation is not supported at this time, but it will be soon."}
if prompt_type == PromptType.SEARCH: if prompt_type == PromptType.SEARCH:
if getattr(settings, "ALLOW_INTERNET_ACCESS", False): # Check modelName first - if FAST, we skip search regardless of settings
if input_dict.get("model_name") == "FAST":
pass # Skip search
elif getattr(settings, "ALLOW_INTERNET_ACCESS", False):
try: try:
search = DuckDuckGoSearchRun() search = DuckDuckGoSearchRun()
search_results = search.run(input_dict["message"]) search_results = search.run(input_dict["message"])
@@ -379,7 +382,8 @@ class ChatConsumerAgain(AsyncWebsocketConsumer):
"decoded_file": decoded_file, "decoded_file": decoded_file,
"file_type": file_type, "file_type": file_type,
"messages": messages, "messages": messages,
"prompt_instance": prompt_instance "prompt_instance": prompt_instance,
"model_name": model
} }
# Run the pipeline steps manually to handle the async generator return type of generate_response_step # Run the pipeline steps manually to handle the async generator return type of generate_response_step

View File

@@ -160,6 +160,7 @@ class ChatState(TypedDict):
prompt_type: Union[PromptType, None] prompt_type: Union[PromptType, None]
response_generator: Any # AsyncGenerator or dict response_generator: Any # AsyncGenerator or dict
error: Union[str, None] error: Union[str, None]
model_name: str
# --- LangGraph Nodes --- # --- LangGraph Nodes ---
@@ -206,7 +207,10 @@ async def generation_node(state: ChatState) -> ChatState:
# Feature Flag: Internet Access # Feature Flag: Internet Access
if prompt_type == PromptType.SEARCH: if prompt_type == PromptType.SEARCH:
if getattr(settings, "ALLOW_INTERNET_ACCESS", False): # Check modelName first - if FAST, we skip search regardless of settings
if state.get("model_name") == "FAST":
pass
elif getattr(settings, "ALLOW_INTERNET_ACCESS", False):
try: try:
search = DuckDuckGoSearchRun() search = DuckDuckGoSearchRun()
search_results = search.run(state["message"]) search_results = search.run(state["message"])
@@ -274,6 +278,7 @@ class ChatConsumerGraph(AsyncWebsocketConsumer):
print("Text Data: ", text_data) print("Text Data: ", text_data)
if text_data: if text_data:
data = json.loads(text_data) data = json.loads(text_data)
model = data.get("modelName", "Turbo")
message = data.get("message", None) message = data.get("message", None)
conversation_id = data.get("conversation_id", None) conversation_id = data.get("conversation_id", None)
email = data.get("email", None) email = data.get("email", None)
@@ -324,7 +329,8 @@ class ChatConsumerGraph(AsyncWebsocketConsumer):
"moderation_label": None, "moderation_label": None,
"prompt_type": None, "prompt_type": None,
"response_generator": None, "response_generator": None,
"error": None "error": None,
"model_name": model
} }
print("Initial State: ", initial_state) print("Initial State: ", initial_state)

View File

@@ -1,13 +1,14 @@
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from langchain_ollama import OllamaLLM from langchain_ollama import OllamaLLM
from langchain_core.output_parsers import StrOutputParser from langchain_core.output_parsers import StrOutputParser
from django.conf import settings
class BaseService(ABC): class BaseService(ABC):
"""Abstract base class for LLM conversation services.""" """Abstract base class for LLM conversation services."""
def __init__(self, temperature=0.7): def __init__(self, temperature=0.7):
self.llm = OllamaLLM( self.llm = OllamaLLM(
model="llama3.2", model="llama3.2" if not settings.DEBUG else "gpt-oss:20b",
temperature=0.7, temperature=0.7,
top_k=50, top_k=50,
top_p=0.9, top_p=0.9,

View File

@@ -10,6 +10,7 @@ from langchain_ollama import OllamaLLM
from langchain_core.output_parsers import StrOutputParser from langchain_core.output_parsers import StrOutputParser
import docx import docx
import pypdf import pypdf
from django.conf import settings
class AsyncDataAnalysisService: class AsyncDataAnalysisService:
@@ -18,7 +19,7 @@ class AsyncDataAnalysisService:
def __init__(self): def __init__(self):
# A model with a large context window and strong analytical skills is best # A model with a large context window and strong analytical skills is best
self.llm = OllamaLLM( self.llm = OllamaLLM(
model="llama3.2", model="llama3.2" if not settings.DEBUG else "gpt-oss:20b",
temperature=0.3, temperature=0.3,
num_ctx=8192, num_ctx=8192,
) )

View File

@@ -5,6 +5,7 @@ from typing import AsyncGenerator, Generator, Optional
from langchain_ollama import OllamaLLM from langchain_ollama import OllamaLLM
from langchain_core.output_parsers import StrOutputParser from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import ChatPromptTemplate from langchain_core.prompts import ChatPromptTemplate
from django.conf import settings
from chat_backend.models import Conversation, Prompt from chat_backend.models import Conversation, Prompt
@@ -14,7 +15,7 @@ class LLMService(ABC):
def __init__(self): def __init__(self):
self.llm = OllamaLLM( self.llm = OllamaLLM(
model="llama3.2", model="llama3.2" if not settings.DEBUG else "gpt-oss:20b",
temperature=0.7, temperature=0.7,
top_k=50, top_k=50,
top_p=0.9, top_p=0.9,

View File

@@ -1,7 +1,7 @@
from enum import Enum, auto from enum import Enum, auto
from typing import Dict, Any from typing import Dict, Any
from langchain_core.prompts import ChatPromptTemplate from langchain_core.prompts import ChatPromptTemplate
from langchain_ollama import OllamaLLM
from chat_backend.services.base_service import BaseService from chat_backend.services.base_service import BaseService
@@ -10,6 +10,7 @@ class ModerationLabel(Enum):
FINE = auto() FINE = auto()
class ModerationClassifier(BaseService): class ModerationClassifier(BaseService):
""" """
Classifies prompts as NSFW or FINE (safe) content. Classifies prompts as NSFW or FINE (safe) content.
@@ -17,12 +18,12 @@ class ModerationClassifier(BaseService):
def __init__(self): def __init__(self):
super().__init__(temperature=0.1) super().__init__(temperature=0.1)
# self.llm = OllamaLLM( self.llm = OllamaLLM(
# model="llama3.2", model="llama3.2",
# temperature=0.1, # Very low for strict moderation temperature=0.1, # Very low for strict moderation
# top_k=10, top_k=10,
# num_ctx=2048, num_ctx=2048,
# ) )
self.moderation_prompt = ChatPromptTemplate.from_messages( self.moderation_prompt = ChatPromptTemplate.from_messages(
[ [
@@ -51,6 +52,7 @@ Examples:
- "Write a love poem" → FINE - "Write a love poem" → FINE
- "Explicit sex scene" → NSFW - "Explicit sex scene" → NSFW
- "Python tutorial" → FINE - "Python tutorial" → FINE
- "Who won the 2024 presidental race?" → FINE
- "Please analyze this file and project the next 12 months for me. Add a graph visual of the data as well" → FINE - "Please analyze this file and project the next 12 months for me. Add a graph visual of the data as well" → FINE
- "Okie, instead of 6 month projection, can you tell me what the values would be in the next 5 days" → FINE - "Okie, instead of 6 month projection, can you tell me what the values would be in the next 5 days" → FINE

View File

@@ -2,6 +2,7 @@ from enum import Enum, auto
from typing import Dict, Any from typing import Dict, Any
from langchain_core.prompts import ChatPromptTemplate from langchain_core.prompts import ChatPromptTemplate
from django.conf import settings
from chat_backend.services.base_service import BaseService from chat_backend.services.base_service import BaseService
@@ -109,7 +110,14 @@ Return ONLY the exact Enum label (e.g. "GENERAL_CHAT"), no explanations."""
# IMAGE_GENERATION # IMAGE_GENERATION
if any(keyword in lower_prompt for keyword in ["generate image", "create picture", "draw an image", "make a photo", "generate an image", "create an illustration"]): if any(keyword in lower_prompt for keyword in ["generate image", "create picture", "draw an image", "make a photo", "generate an image", "create an illustration"]):
if getattr(settings, "ALLOW_IMAGE_GENERATION", False):
return PromptType.IMAGE_GENERATION return PromptType.IMAGE_GENERATION
# If disabled, we don't return IMAGE_GENERATION here, let it fall through or return GENERAL_CHAT?
# The requirement says "NOT allowed to return IMAGE_GENERATION".
# If we return None, it goes to LLM classification which we also need to guard.
# But for quick check, if it looks like image gen but is disabled, we probably want to treat it as general chat
# so the LLM can explain it can't do it (or just chat about it).
return PromptType.GENERAL_CHAT
# DATA_ANALYSIS (often involves uploaded documents) # DATA_ANALYSIS (often involves uploaded documents)
if any(keyword in lower_prompt for keyword in ["read this document", "analyze this file", "summarize this pdf", "extract data from", "index this document", "based on this csv", "from this spreadsheet"]): if any(keyword in lower_prompt for keyword in ["read this document", "analyze this file", "summarize this pdf", "extract data from", "index this document", "based on this csv", "from this spreadsheet"]):
@@ -158,6 +166,11 @@ Return ONLY the exact Enum label (e.g. "GENERAL_CHAT"), no explanations."""
response = response.upper().strip() response = response.upper().strip()
print(response) print(response)
# Guard against IMAGE_GENERATION if disabled
if not getattr(settings, "ALLOW_IMAGE_GENERATION", False):
if "IMAGE_GENERATION" in response or "IMAGEGENERATION" in response:
return PromptType.GENERAL_CHAT
# Direct match # Direct match
try: try:
return PromptType[response] return PromptType[response]

View File

@@ -3,6 +3,7 @@ from abc import ABC, abstractmethod
from typing import List, Dict, Any, AsyncGenerator, Generator, Optional from typing import List, Dict, Any, AsyncGenerator, Generator, Optional
from channels.db import database_sync_to_async from channels.db import database_sync_to_async
from langchain_community.embeddings import OllamaEmbeddings from langchain_community.embeddings import OllamaEmbeddings
from django.conf import settings
# from langchain_community.llms import Ollama # from langchain_community.llms import Ollama
from langchain_ollama import OllamaLLM from langchain_ollama import OllamaLLM
@@ -44,7 +45,7 @@ class RAGService(BaseService):
return cls._instance return cls._instance
def __init__(self): def __init__(self):
self.embedding_model = OllamaEmbeddings(model="llama3.2") self.embedding_model = OllamaEmbeddings(model="llama3.2" if not settings.DEBUG else "gpt-oss:20b")
super().__init__() super().__init__()
self.text_splitter = RecursiveCharacterTextSplitter( self.text_splitter = RecursiveCharacterTextSplitter(
chunk_size=1000, chunk_overlap=200 chunk_size=1000, chunk_overlap=200

View File

@@ -25,7 +25,7 @@ BASE_DIR = Path(__file__).resolve().parent.parent
SECRET_KEY = "django-insecure-6suk6fj5q2)1tj%)f(wgw1smnliv5-#&@zvgvj1wp#(#@h#31x" SECRET_KEY = "django-insecure-6suk6fj5q2)1tj%)f(wgw1smnliv5-#&@zvgvj1wp#(#@h#31x"
# SECURITY WARNING: don't run with debug turned on in production! # SECURITY WARNING: don't run with debug turned on in production!
DEBUG = True DEBUG = False
CORS_ALLOW_CREDENTIALS = False CORS_ALLOW_CREDENTIALS = False
ALLOWED_HOSTS = [ ALLOWED_HOSTS = [
"*.aimloperations.com", "*.aimloperations.com",

View File

@@ -0,0 +1,8 @@
{% extends "admin/base_site.html" %}
{% block extrahead %}
{{ block.super }}
{% if not debug %}
<script async defer src="https://tianji.aimloperations.com/tracker.js" data-website-id="cm7x7mrcy03kfddsw2jyejzub"></script>
{% endif %}
{% endblock %}