Closes #4
Added site tracking Can pick the model the use Better handle llm model based on debug or not
This commit is contained in:
@@ -312,7 +312,10 @@ class ChatConsumerAgain(AsyncWebsocketConsumer):
|
|||||||
return {"type": "text", "content": "Image Generation is not supported at this time, but it will be soon."}
|
return {"type": "text", "content": "Image Generation is not supported at this time, but it will be soon."}
|
||||||
|
|
||||||
if prompt_type == PromptType.SEARCH:
|
if prompt_type == PromptType.SEARCH:
|
||||||
if getattr(settings, "ALLOW_INTERNET_ACCESS", False):
|
# Check modelName first - if FAST, we skip search regardless of settings
|
||||||
|
if input_dict.get("model_name") == "FAST":
|
||||||
|
pass # Skip search
|
||||||
|
elif getattr(settings, "ALLOW_INTERNET_ACCESS", False):
|
||||||
try:
|
try:
|
||||||
search = DuckDuckGoSearchRun()
|
search = DuckDuckGoSearchRun()
|
||||||
search_results = search.run(input_dict["message"])
|
search_results = search.run(input_dict["message"])
|
||||||
@@ -379,7 +382,8 @@ class ChatConsumerAgain(AsyncWebsocketConsumer):
|
|||||||
"decoded_file": decoded_file,
|
"decoded_file": decoded_file,
|
||||||
"file_type": file_type,
|
"file_type": file_type,
|
||||||
"messages": messages,
|
"messages": messages,
|
||||||
"prompt_instance": prompt_instance
|
"prompt_instance": prompt_instance,
|
||||||
|
"model_name": model
|
||||||
}
|
}
|
||||||
|
|
||||||
# Run the pipeline steps manually to handle the async generator return type of generate_response_step
|
# Run the pipeline steps manually to handle the async generator return type of generate_response_step
|
||||||
|
|||||||
@@ -160,6 +160,7 @@ class ChatState(TypedDict):
|
|||||||
prompt_type: Union[PromptType, None]
|
prompt_type: Union[PromptType, None]
|
||||||
response_generator: Any # AsyncGenerator or dict
|
response_generator: Any # AsyncGenerator or dict
|
||||||
error: Union[str, None]
|
error: Union[str, None]
|
||||||
|
model_name: str
|
||||||
|
|
||||||
|
|
||||||
# --- LangGraph Nodes ---
|
# --- LangGraph Nodes ---
|
||||||
@@ -206,7 +207,10 @@ async def generation_node(state: ChatState) -> ChatState:
|
|||||||
|
|
||||||
# Feature Flag: Internet Access
|
# Feature Flag: Internet Access
|
||||||
if prompt_type == PromptType.SEARCH:
|
if prompt_type == PromptType.SEARCH:
|
||||||
if getattr(settings, "ALLOW_INTERNET_ACCESS", False):
|
# Check modelName first - if FAST, we skip search regardless of settings
|
||||||
|
if state.get("model_name") == "FAST":
|
||||||
|
pass
|
||||||
|
elif getattr(settings, "ALLOW_INTERNET_ACCESS", False):
|
||||||
try:
|
try:
|
||||||
search = DuckDuckGoSearchRun()
|
search = DuckDuckGoSearchRun()
|
||||||
search_results = search.run(state["message"])
|
search_results = search.run(state["message"])
|
||||||
@@ -274,6 +278,7 @@ class ChatConsumerGraph(AsyncWebsocketConsumer):
|
|||||||
print("Text Data: ", text_data)
|
print("Text Data: ", text_data)
|
||||||
if text_data:
|
if text_data:
|
||||||
data = json.loads(text_data)
|
data = json.loads(text_data)
|
||||||
|
model = data.get("modelName", "Turbo")
|
||||||
message = data.get("message", None)
|
message = data.get("message", None)
|
||||||
conversation_id = data.get("conversation_id", None)
|
conversation_id = data.get("conversation_id", None)
|
||||||
email = data.get("email", None)
|
email = data.get("email", None)
|
||||||
@@ -324,7 +329,8 @@ class ChatConsumerGraph(AsyncWebsocketConsumer):
|
|||||||
"moderation_label": None,
|
"moderation_label": None,
|
||||||
"prompt_type": None,
|
"prompt_type": None,
|
||||||
"response_generator": None,
|
"response_generator": None,
|
||||||
"error": None
|
"error": None,
|
||||||
|
"model_name": model
|
||||||
}
|
}
|
||||||
print("Initial State: ", initial_state)
|
print("Initial State: ", initial_state)
|
||||||
|
|
||||||
|
|||||||
@@ -1,13 +1,14 @@
|
|||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from langchain_ollama import OllamaLLM
|
from langchain_ollama import OllamaLLM
|
||||||
from langchain_core.output_parsers import StrOutputParser
|
from langchain_core.output_parsers import StrOutputParser
|
||||||
|
from django.conf import settings
|
||||||
|
|
||||||
class BaseService(ABC):
|
class BaseService(ABC):
|
||||||
"""Abstract base class for LLM conversation services."""
|
"""Abstract base class for LLM conversation services."""
|
||||||
|
|
||||||
def __init__(self, temperature=0.7):
|
def __init__(self, temperature=0.7):
|
||||||
self.llm = OllamaLLM(
|
self.llm = OllamaLLM(
|
||||||
model="llama3.2",
|
model="llama3.2" if not settings.DEBUG else "gpt-oss:20b",
|
||||||
temperature=0.7,
|
temperature=0.7,
|
||||||
top_k=50,
|
top_k=50,
|
||||||
top_p=0.9,
|
top_p=0.9,
|
||||||
|
|||||||
@@ -10,6 +10,7 @@ from langchain_ollama import OllamaLLM
|
|||||||
from langchain_core.output_parsers import StrOutputParser
|
from langchain_core.output_parsers import StrOutputParser
|
||||||
import docx
|
import docx
|
||||||
import pypdf
|
import pypdf
|
||||||
|
from django.conf import settings
|
||||||
|
|
||||||
|
|
||||||
class AsyncDataAnalysisService:
|
class AsyncDataAnalysisService:
|
||||||
@@ -18,7 +19,7 @@ class AsyncDataAnalysisService:
|
|||||||
def __init__(self):
|
def __init__(self):
|
||||||
# A model with a large context window and strong analytical skills is best
|
# A model with a large context window and strong analytical skills is best
|
||||||
self.llm = OllamaLLM(
|
self.llm = OllamaLLM(
|
||||||
model="llama3.2",
|
model="llama3.2" if not settings.DEBUG else "gpt-oss:20b",
|
||||||
temperature=0.3,
|
temperature=0.3,
|
||||||
num_ctx=8192,
|
num_ctx=8192,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -5,6 +5,7 @@ from typing import AsyncGenerator, Generator, Optional
|
|||||||
from langchain_ollama import OllamaLLM
|
from langchain_ollama import OllamaLLM
|
||||||
from langchain_core.output_parsers import StrOutputParser
|
from langchain_core.output_parsers import StrOutputParser
|
||||||
from langchain_core.prompts import ChatPromptTemplate
|
from langchain_core.prompts import ChatPromptTemplate
|
||||||
|
from django.conf import settings
|
||||||
|
|
||||||
from chat_backend.models import Conversation, Prompt
|
from chat_backend.models import Conversation, Prompt
|
||||||
|
|
||||||
@@ -14,7 +15,7 @@ class LLMService(ABC):
|
|||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.llm = OllamaLLM(
|
self.llm = OllamaLLM(
|
||||||
model="llama3.2",
|
model="llama3.2" if not settings.DEBUG else "gpt-oss:20b",
|
||||||
temperature=0.7,
|
temperature=0.7,
|
||||||
top_k=50,
|
top_k=50,
|
||||||
top_p=0.9,
|
top_p=0.9,
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
from enum import Enum, auto
|
from enum import Enum, auto
|
||||||
from typing import Dict, Any
|
from typing import Dict, Any
|
||||||
from langchain_core.prompts import ChatPromptTemplate
|
from langchain_core.prompts import ChatPromptTemplate
|
||||||
|
from langchain_ollama import OllamaLLM
|
||||||
from chat_backend.services.base_service import BaseService
|
from chat_backend.services.base_service import BaseService
|
||||||
|
|
||||||
|
|
||||||
@@ -10,6 +10,7 @@ class ModerationLabel(Enum):
|
|||||||
FINE = auto()
|
FINE = auto()
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class ModerationClassifier(BaseService):
|
class ModerationClassifier(BaseService):
|
||||||
"""
|
"""
|
||||||
Classifies prompts as NSFW or FINE (safe) content.
|
Classifies prompts as NSFW or FINE (safe) content.
|
||||||
@@ -17,12 +18,12 @@ class ModerationClassifier(BaseService):
|
|||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__(temperature=0.1)
|
super().__init__(temperature=0.1)
|
||||||
# self.llm = OllamaLLM(
|
self.llm = OllamaLLM(
|
||||||
# model="llama3.2",
|
model="llama3.2",
|
||||||
# temperature=0.1, # Very low for strict moderation
|
temperature=0.1, # Very low for strict moderation
|
||||||
# top_k=10,
|
top_k=10,
|
||||||
# num_ctx=2048,
|
num_ctx=2048,
|
||||||
# )
|
)
|
||||||
|
|
||||||
self.moderation_prompt = ChatPromptTemplate.from_messages(
|
self.moderation_prompt = ChatPromptTemplate.from_messages(
|
||||||
[
|
[
|
||||||
@@ -51,6 +52,7 @@ Examples:
|
|||||||
- "Write a love poem" → FINE
|
- "Write a love poem" → FINE
|
||||||
- "Explicit sex scene" → NSFW
|
- "Explicit sex scene" → NSFW
|
||||||
- "Python tutorial" → FINE
|
- "Python tutorial" → FINE
|
||||||
|
- "Who won the 2024 presidental race?" → FINE
|
||||||
- "Please analyze this file and project the next 12 months for me. Add a graph visual of the data as well" → FINE
|
- "Please analyze this file and project the next 12 months for me. Add a graph visual of the data as well" → FINE
|
||||||
- "Okie, instead of 6 month projection, can you tell me what the values would be in the next 5 days" → FINE
|
- "Okie, instead of 6 month projection, can you tell me what the values would be in the next 5 days" → FINE
|
||||||
|
|
||||||
|
|||||||
@@ -2,6 +2,7 @@ from enum import Enum, auto
|
|||||||
from typing import Dict, Any
|
from typing import Dict, Any
|
||||||
from langchain_core.prompts import ChatPromptTemplate
|
from langchain_core.prompts import ChatPromptTemplate
|
||||||
|
|
||||||
|
from django.conf import settings
|
||||||
from chat_backend.services.base_service import BaseService
|
from chat_backend.services.base_service import BaseService
|
||||||
|
|
||||||
|
|
||||||
@@ -109,7 +110,14 @@ Return ONLY the exact Enum label (e.g. "GENERAL_CHAT"), no explanations."""
|
|||||||
|
|
||||||
# IMAGE_GENERATION
|
# IMAGE_GENERATION
|
||||||
if any(keyword in lower_prompt for keyword in ["generate image", "create picture", "draw an image", "make a photo", "generate an image", "create an illustration"]):
|
if any(keyword in lower_prompt for keyword in ["generate image", "create picture", "draw an image", "make a photo", "generate an image", "create an illustration"]):
|
||||||
return PromptType.IMAGE_GENERATION
|
if getattr(settings, "ALLOW_IMAGE_GENERATION", False):
|
||||||
|
return PromptType.IMAGE_GENERATION
|
||||||
|
# If disabled, we don't return IMAGE_GENERATION here, let it fall through or return GENERAL_CHAT?
|
||||||
|
# The requirement says "NOT allowed to return IMAGE_GENERATION".
|
||||||
|
# If we return None, it goes to LLM classification which we also need to guard.
|
||||||
|
# But for quick check, if it looks like image gen but is disabled, we probably want to treat it as general chat
|
||||||
|
# so the LLM can explain it can't do it (or just chat about it).
|
||||||
|
return PromptType.GENERAL_CHAT
|
||||||
|
|
||||||
# DATA_ANALYSIS (often involves uploaded documents)
|
# DATA_ANALYSIS (often involves uploaded documents)
|
||||||
if any(keyword in lower_prompt for keyword in ["read this document", "analyze this file", "summarize this pdf", "extract data from", "index this document", "based on this csv", "from this spreadsheet"]):
|
if any(keyword in lower_prompt for keyword in ["read this document", "analyze this file", "summarize this pdf", "extract data from", "index this document", "based on this csv", "from this spreadsheet"]):
|
||||||
@@ -158,6 +166,11 @@ Return ONLY the exact Enum label (e.g. "GENERAL_CHAT"), no explanations."""
|
|||||||
response = response.upper().strip()
|
response = response.upper().strip()
|
||||||
print(response)
|
print(response)
|
||||||
|
|
||||||
|
# Guard against IMAGE_GENERATION if disabled
|
||||||
|
if not getattr(settings, "ALLOW_IMAGE_GENERATION", False):
|
||||||
|
if "IMAGE_GENERATION" in response or "IMAGEGENERATION" in response:
|
||||||
|
return PromptType.GENERAL_CHAT
|
||||||
|
|
||||||
# Direct match
|
# Direct match
|
||||||
try:
|
try:
|
||||||
return PromptType[response]
|
return PromptType[response]
|
||||||
|
|||||||
@@ -3,6 +3,7 @@ from abc import ABC, abstractmethod
|
|||||||
from typing import List, Dict, Any, AsyncGenerator, Generator, Optional
|
from typing import List, Dict, Any, AsyncGenerator, Generator, Optional
|
||||||
from channels.db import database_sync_to_async
|
from channels.db import database_sync_to_async
|
||||||
from langchain_community.embeddings import OllamaEmbeddings
|
from langchain_community.embeddings import OllamaEmbeddings
|
||||||
|
from django.conf import settings
|
||||||
|
|
||||||
# from langchain_community.llms import Ollama
|
# from langchain_community.llms import Ollama
|
||||||
from langchain_ollama import OllamaLLM
|
from langchain_ollama import OllamaLLM
|
||||||
@@ -44,7 +45,7 @@ class RAGService(BaseService):
|
|||||||
return cls._instance
|
return cls._instance
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.embedding_model = OllamaEmbeddings(model="llama3.2")
|
self.embedding_model = OllamaEmbeddings(model="llama3.2" if not settings.DEBUG else "gpt-oss:20b")
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.text_splitter = RecursiveCharacterTextSplitter(
|
self.text_splitter = RecursiveCharacterTextSplitter(
|
||||||
chunk_size=1000, chunk_overlap=200
|
chunk_size=1000, chunk_overlap=200
|
||||||
|
|||||||
@@ -25,7 +25,7 @@ BASE_DIR = Path(__file__).resolve().parent.parent
|
|||||||
SECRET_KEY = "django-insecure-6suk6fj5q2)1tj%)f(wgw1smnliv5-#&@zvgvj1wp#(#@h#31x"
|
SECRET_KEY = "django-insecure-6suk6fj5q2)1tj%)f(wgw1smnliv5-#&@zvgvj1wp#(#@h#31x"
|
||||||
|
|
||||||
# SECURITY WARNING: don't run with debug turned on in production!
|
# SECURITY WARNING: don't run with debug turned on in production!
|
||||||
DEBUG = True
|
DEBUG = False
|
||||||
CORS_ALLOW_CREDENTIALS = False
|
CORS_ALLOW_CREDENTIALS = False
|
||||||
ALLOWED_HOSTS = [
|
ALLOWED_HOSTS = [
|
||||||
"*.aimloperations.com",
|
"*.aimloperations.com",
|
||||||
|
|||||||
8
llm_be/templates/admin/base_site.html
Normal file
8
llm_be/templates/admin/base_site.html
Normal file
@@ -0,0 +1,8 @@
|
|||||||
|
{% extends "admin/base_site.html" %}
|
||||||
|
|
||||||
|
{% block extrahead %}
|
||||||
|
{{ block.super }}
|
||||||
|
{% if not debug %}
|
||||||
|
<script async defer src="https://tianji.aimloperations.com/tracker.js" data-website-id="cm7x7mrcy03kfddsw2jyejzub"></script>
|
||||||
|
{% endif %}
|
||||||
|
{% endblock %}
|
||||||
Reference in New Issue
Block a user