Closes #4
Added site tracking Can pick the model the use Better handle llm model based on debug or not
This commit is contained in:
@@ -312,7 +312,10 @@ class ChatConsumerAgain(AsyncWebsocketConsumer):
|
||||
return {"type": "text", "content": "Image Generation is not supported at this time, but it will be soon."}
|
||||
|
||||
if prompt_type == PromptType.SEARCH:
|
||||
if getattr(settings, "ALLOW_INTERNET_ACCESS", False):
|
||||
# Check modelName first - if FAST, we skip search regardless of settings
|
||||
if input_dict.get("model_name") == "FAST":
|
||||
pass # Skip search
|
||||
elif getattr(settings, "ALLOW_INTERNET_ACCESS", False):
|
||||
try:
|
||||
search = DuckDuckGoSearchRun()
|
||||
search_results = search.run(input_dict["message"])
|
||||
@@ -379,7 +382,8 @@ class ChatConsumerAgain(AsyncWebsocketConsumer):
|
||||
"decoded_file": decoded_file,
|
||||
"file_type": file_type,
|
||||
"messages": messages,
|
||||
"prompt_instance": prompt_instance
|
||||
"prompt_instance": prompt_instance,
|
||||
"model_name": model
|
||||
}
|
||||
|
||||
# Run the pipeline steps manually to handle the async generator return type of generate_response_step
|
||||
|
||||
@@ -160,6 +160,7 @@ class ChatState(TypedDict):
|
||||
prompt_type: Union[PromptType, None]
|
||||
response_generator: Any # AsyncGenerator or dict
|
||||
error: Union[str, None]
|
||||
model_name: str
|
||||
|
||||
|
||||
# --- LangGraph Nodes ---
|
||||
@@ -206,7 +207,10 @@ async def generation_node(state: ChatState) -> ChatState:
|
||||
|
||||
# Feature Flag: Internet Access
|
||||
if prompt_type == PromptType.SEARCH:
|
||||
if getattr(settings, "ALLOW_INTERNET_ACCESS", False):
|
||||
# Check modelName first - if FAST, we skip search regardless of settings
|
||||
if state.get("model_name") == "FAST":
|
||||
pass
|
||||
elif getattr(settings, "ALLOW_INTERNET_ACCESS", False):
|
||||
try:
|
||||
search = DuckDuckGoSearchRun()
|
||||
search_results = search.run(state["message"])
|
||||
@@ -274,6 +278,7 @@ class ChatConsumerGraph(AsyncWebsocketConsumer):
|
||||
print("Text Data: ", text_data)
|
||||
if text_data:
|
||||
data = json.loads(text_data)
|
||||
model = data.get("modelName", "Turbo")
|
||||
message = data.get("message", None)
|
||||
conversation_id = data.get("conversation_id", None)
|
||||
email = data.get("email", None)
|
||||
@@ -324,7 +329,8 @@ class ChatConsumerGraph(AsyncWebsocketConsumer):
|
||||
"moderation_label": None,
|
||||
"prompt_type": None,
|
||||
"response_generator": None,
|
||||
"error": None
|
||||
"error": None,
|
||||
"model_name": model
|
||||
}
|
||||
print("Initial State: ", initial_state)
|
||||
|
||||
|
||||
@@ -1,13 +1,14 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from langchain_ollama import OllamaLLM
|
||||
from langchain_core.output_parsers import StrOutputParser
|
||||
from django.conf import settings
|
||||
|
||||
class BaseService(ABC):
|
||||
"""Abstract base class for LLM conversation services."""
|
||||
|
||||
def __init__(self, temperature=0.7):
|
||||
self.llm = OllamaLLM(
|
||||
model="llama3.2",
|
||||
model="llama3.2" if not settings.DEBUG else "gpt-oss:20b",
|
||||
temperature=0.7,
|
||||
top_k=50,
|
||||
top_p=0.9,
|
||||
|
||||
@@ -10,6 +10,7 @@ from langchain_ollama import OllamaLLM
|
||||
from langchain_core.output_parsers import StrOutputParser
|
||||
import docx
|
||||
import pypdf
|
||||
from django.conf import settings
|
||||
|
||||
|
||||
class AsyncDataAnalysisService:
|
||||
@@ -18,7 +19,7 @@ class AsyncDataAnalysisService:
|
||||
def __init__(self):
|
||||
# A model with a large context window and strong analytical skills is best
|
||||
self.llm = OllamaLLM(
|
||||
model="llama3.2",
|
||||
model="llama3.2" if not settings.DEBUG else "gpt-oss:20b",
|
||||
temperature=0.3,
|
||||
num_ctx=8192,
|
||||
)
|
||||
|
||||
@@ -5,6 +5,7 @@ from typing import AsyncGenerator, Generator, Optional
|
||||
from langchain_ollama import OllamaLLM
|
||||
from langchain_core.output_parsers import StrOutputParser
|
||||
from langchain_core.prompts import ChatPromptTemplate
|
||||
from django.conf import settings
|
||||
|
||||
from chat_backend.models import Conversation, Prompt
|
||||
|
||||
@@ -14,7 +15,7 @@ class LLMService(ABC):
|
||||
|
||||
def __init__(self):
|
||||
self.llm = OllamaLLM(
|
||||
model="llama3.2",
|
||||
model="llama3.2" if not settings.DEBUG else "gpt-oss:20b",
|
||||
temperature=0.7,
|
||||
top_k=50,
|
||||
top_p=0.9,
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from enum import Enum, auto
|
||||
from typing import Dict, Any
|
||||
from langchain_core.prompts import ChatPromptTemplate
|
||||
|
||||
from langchain_ollama import OllamaLLM
|
||||
from chat_backend.services.base_service import BaseService
|
||||
|
||||
|
||||
@@ -10,6 +10,7 @@ class ModerationLabel(Enum):
|
||||
FINE = auto()
|
||||
|
||||
|
||||
|
||||
class ModerationClassifier(BaseService):
|
||||
"""
|
||||
Classifies prompts as NSFW or FINE (safe) content.
|
||||
@@ -17,12 +18,12 @@ class ModerationClassifier(BaseService):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(temperature=0.1)
|
||||
# self.llm = OllamaLLM(
|
||||
# model="llama3.2",
|
||||
# temperature=0.1, # Very low for strict moderation
|
||||
# top_k=10,
|
||||
# num_ctx=2048,
|
||||
# )
|
||||
self.llm = OllamaLLM(
|
||||
model="llama3.2",
|
||||
temperature=0.1, # Very low for strict moderation
|
||||
top_k=10,
|
||||
num_ctx=2048,
|
||||
)
|
||||
|
||||
self.moderation_prompt = ChatPromptTemplate.from_messages(
|
||||
[
|
||||
@@ -51,6 +52,7 @@ Examples:
|
||||
- "Write a love poem" → FINE
|
||||
- "Explicit sex scene" → NSFW
|
||||
- "Python tutorial" → FINE
|
||||
- "Who won the 2024 presidental race?" → FINE
|
||||
- "Please analyze this file and project the next 12 months for me. Add a graph visual of the data as well" → FINE
|
||||
- "Okie, instead of 6 month projection, can you tell me what the values would be in the next 5 days" → FINE
|
||||
|
||||
|
||||
@@ -2,6 +2,7 @@ from enum import Enum, auto
|
||||
from typing import Dict, Any
|
||||
from langchain_core.prompts import ChatPromptTemplate
|
||||
|
||||
from django.conf import settings
|
||||
from chat_backend.services.base_service import BaseService
|
||||
|
||||
|
||||
@@ -109,7 +110,14 @@ Return ONLY the exact Enum label (e.g. "GENERAL_CHAT"), no explanations."""
|
||||
|
||||
# IMAGE_GENERATION
|
||||
if any(keyword in lower_prompt for keyword in ["generate image", "create picture", "draw an image", "make a photo", "generate an image", "create an illustration"]):
|
||||
return PromptType.IMAGE_GENERATION
|
||||
if getattr(settings, "ALLOW_IMAGE_GENERATION", False):
|
||||
return PromptType.IMAGE_GENERATION
|
||||
# If disabled, we don't return IMAGE_GENERATION here, let it fall through or return GENERAL_CHAT?
|
||||
# The requirement says "NOT allowed to return IMAGE_GENERATION".
|
||||
# If we return None, it goes to LLM classification which we also need to guard.
|
||||
# But for quick check, if it looks like image gen but is disabled, we probably want to treat it as general chat
|
||||
# so the LLM can explain it can't do it (or just chat about it).
|
||||
return PromptType.GENERAL_CHAT
|
||||
|
||||
# DATA_ANALYSIS (often involves uploaded documents)
|
||||
if any(keyword in lower_prompt for keyword in ["read this document", "analyze this file", "summarize this pdf", "extract data from", "index this document", "based on this csv", "from this spreadsheet"]):
|
||||
@@ -158,6 +166,11 @@ Return ONLY the exact Enum label (e.g. "GENERAL_CHAT"), no explanations."""
|
||||
response = response.upper().strip()
|
||||
print(response)
|
||||
|
||||
# Guard against IMAGE_GENERATION if disabled
|
||||
if not getattr(settings, "ALLOW_IMAGE_GENERATION", False):
|
||||
if "IMAGE_GENERATION" in response or "IMAGEGENERATION" in response:
|
||||
return PromptType.GENERAL_CHAT
|
||||
|
||||
# Direct match
|
||||
try:
|
||||
return PromptType[response]
|
||||
|
||||
@@ -3,6 +3,7 @@ from abc import ABC, abstractmethod
|
||||
from typing import List, Dict, Any, AsyncGenerator, Generator, Optional
|
||||
from channels.db import database_sync_to_async
|
||||
from langchain_community.embeddings import OllamaEmbeddings
|
||||
from django.conf import settings
|
||||
|
||||
# from langchain_community.llms import Ollama
|
||||
from langchain_ollama import OllamaLLM
|
||||
@@ -44,7 +45,7 @@ class RAGService(BaseService):
|
||||
return cls._instance
|
||||
|
||||
def __init__(self):
|
||||
self.embedding_model = OllamaEmbeddings(model="llama3.2")
|
||||
self.embedding_model = OllamaEmbeddings(model="llama3.2" if not settings.DEBUG else "gpt-oss:20b")
|
||||
super().__init__()
|
||||
self.text_splitter = RecursiveCharacterTextSplitter(
|
||||
chunk_size=1000, chunk_overlap=200
|
||||
|
||||
@@ -25,7 +25,7 @@ BASE_DIR = Path(__file__).resolve().parent.parent
|
||||
SECRET_KEY = "django-insecure-6suk6fj5q2)1tj%)f(wgw1smnliv5-#&@zvgvj1wp#(#@h#31x"
|
||||
|
||||
# SECURITY WARNING: don't run with debug turned on in production!
|
||||
DEBUG = True
|
||||
DEBUG = False
|
||||
CORS_ALLOW_CREDENTIALS = False
|
||||
ALLOWED_HOSTS = [
|
||||
"*.aimloperations.com",
|
||||
|
||||
8
llm_be/templates/admin/base_site.html
Normal file
8
llm_be/templates/admin/base_site.html
Normal file
@@ -0,0 +1,8 @@
|
||||
{% extends "admin/base_site.html" %}
|
||||
|
||||
{% block extrahead %}
|
||||
{{ block.super }}
|
||||
{% if not debug %}
|
||||
<script async defer src="https://tianji.aimloperations.com/tracker.js" data-website-id="cm7x7mrcy03kfddsw2jyejzub"></script>
|
||||
{% endif %}
|
||||
{% endblock %}
|
||||
Reference in New Issue
Block a user