Syncing with updates from prod and formatted

This commit is contained in:
2025-05-18 06:15:07 -05:00
parent d8a912e2c3
commit a85f1222eb
18 changed files with 616 additions and 518 deletions

View File

@@ -9,7 +9,7 @@ from .models import (
Feedback, Feedback,
PromptMetric, PromptMetric,
DocumentWorkspace, DocumentWorkspace,
Document Document,
) )
# Register your models here. # Register your models here.
@@ -79,14 +79,15 @@ class PromptMetricAdmin(admin.ModelAdmin):
"get_duration", "get_duration",
) )
class DocumentWorkspaceAdmin(admin.ModelAdmin): class DocumentWorkspaceAdmin(admin.ModelAdmin):
model = DocumentWorkspace model = DocumentWorkspace
list_display = ( list_display = (
"name", "name",
"company", "company",
) )
class DocumentAdmin(admin.ModelAdmin): class DocumentAdmin(admin.ModelAdmin):
model = Document model = Document
list_display = ( list_display = (

View File

@@ -8,18 +8,19 @@ class ChatBackendConfig(AppConfig):
name = "chat_backend" name = "chat_backend"
def ready(self): def ready(self):
import chat_backend.signals import chat_backend.signals
FORCE_RELOAD = False FORCE_RELOAD = False
if True: #not settings.TESTING: # Don't run during tests if True: # not settings.TESTING: # Don't run during tests
try: try:
from .services.rag_services import AsyncRAGService from .services.rag_services import AsyncRAGService
from chat_backend.models import Document from chat_backend.models import Document
# Check if Chroma needs initialization # Check if Chroma needs initialization
if Document.objects.exists(): if Document.objects.exists():
rag_service = AsyncRAGService() rag_service = AsyncRAGService()
if rag_service.vector_store._collection.count() == 0: if rag_service.vector_store._collection.count() == 0:
print("Initializing ChromaDB with existing documents...") print("Initializing ChromaDB with existing documents...")
rag_service.ingest_documents() rag_service.ingest_documents()

View File

@@ -53,6 +53,9 @@ class Company(TimeInfoBase):
help_text="A list of LLMs that company can use", help_text="A list of LLMs that company can use",
) )
def __str__(self):
return self.name
class CustomUser(AbstractUser): class CustomUser(AbstractUser):
company = models.ForeignKey( company = models.ForeignKey(
@@ -71,7 +74,7 @@ class CustomUser(AbstractUser):
) )
def get_set_password_url(self): def get_set_password_url(self):
return f"https://www.chat.aimloperations.com/set_password?slug={self.slug}" return f"https://chat.aimloperations.com/set_password?slug={self.slug}"
FEEDBACK_CHOICE = ( FEEDBACK_CHOICE = (
@@ -220,14 +223,16 @@ class PromptMetric(TimeInfoBase):
return difference.seconds return difference.seconds
return 0 return 0
# Document Models # Document Models
class DocumentWorkspace(TimeInfoBase): class DocumentWorkspace(TimeInfoBase):
name = models.CharField(max_length=255) name = models.CharField(max_length=255)
company = models.ForeignKey(Company, on_delete=models.CASCADE) company = models.ForeignKey(Company, on_delete=models.CASCADE)
class Document(TimeInfoBase): class Document(TimeInfoBase):
workspace = models.ForeignKey(DocumentWorkspace, on_delete=models.CASCADE) workspace = models.ForeignKey(DocumentWorkspace, on_delete=models.CASCADE)
file = models.FileField(upload_to='documents/') file = models.FileField(upload_to="documents/")
uploaded_at = models.DateTimeField(auto_now_add=True) uploaded_at = models.DateTimeField(auto_now_add=True)
processed = models.BooleanField(default=False) processed = models.BooleanField(default=False)
active = models.BooleanField(default=False) active = models.BooleanField(default=False)

View File

@@ -9,7 +9,7 @@ from .models import (
Feedback, Feedback,
FEEDBACK_CATEGORIES, FEEDBACK_CATEGORIES,
DocumentWorkspace, DocumentWorkspace,
Document Document,
) )
@@ -99,11 +99,20 @@ class BasicUserSerializer(serializers.ModelSerializer):
class DocumentWorkspaceSerializer(serializers.ModelSerializer): class DocumentWorkspaceSerializer(serializers.ModelSerializer):
class Meta: class Meta:
model = DocumentWorkspace model = DocumentWorkspace
fields = ['id', 'name', 'created'] fields = ["id", "name", "created"]
read_only_fields = ['id', 'created'] read_only_fields = ["id", "created"]
class DocumentSerializer(serializers.ModelSerializer): class DocumentSerializer(serializers.ModelSerializer):
class Meta: class Meta:
model = Document model = Document
fields = ['id', 'workspace', 'file', 'uploaded_at', 'processed', 'created', 'active'] fields = [
read_only_fields = ['id', 'uploaded_at', 'processed', 'created'] "id",
"workspace",
"file",
"uploaded_at",
"processed",
"created",
"active",
]
read_only_fields = ["id", "uploaded_at", "processed", "created"]

View File

@@ -7,21 +7,22 @@ from diffusers import StableDiffusionPipeline, DPMSolverSinglestepScheduler
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class ImageGenerationService: class ImageGenerationService:
""" """
Service for text-to-image generation using Stable Diffusion. Service for text-to-image generation using Stable Diffusion.
Uses singleton pattern to maintain loaded model in memory. Uses singleton pattern to maintain loaded model in memory.
""" """
_instance = None _instance = None
_model_loaded = False _model_loaded = False
def __new__(cls): def __new__(cls):
if cls._instance is None: if cls._instance is None:
cls._instance = super().__new__(cls) cls._instance = super().__new__(cls)
cls._instance._initialize() cls._instance._initialize()
return cls._instance return cls._instance
def _initialize(self): def _initialize(self):
"""Initialize the service with default settings""" """Initialize the service with default settings"""
self.device = "cuda" if torch.cuda.is_available() else "cpu" self.device = "cuda" if torch.cuda.is_available() else "cpu"
@@ -33,15 +34,15 @@ class ImageGenerationService:
"width": 512, "width": 512,
"height": 512, "height": 512,
} }
def load_model(self): def load_model(self):
"""Load the Stable Diffusion model""" """Load the Stable Diffusion model"""
if self._model_loaded: if self._model_loaded:
return return
try: try:
logger.info(f"Loading Stable Diffusion model on {self.device}...") logger.info(f"Loading Stable Diffusion model on {self.device}...")
# Use DPMSolver for faster inference # Use DPMSolver for faster inference
self.pipeline = StableDiffusionPipeline.from_pretrained( self.pipeline = StableDiffusionPipeline.from_pretrained(
self.model_id, self.model_id,
@@ -51,15 +52,15 @@ class ImageGenerationService:
self.pipeline.scheduler.config self.pipeline.scheduler.config
) )
self.pipeline = self.pipeline.to(self.device) self.pipeline = self.pipeline.to(self.device)
# Optimizations # Optimizations
if self.device == "cuda": if self.device == "cuda":
self.pipeline.enable_attention_slicing() self.pipeline.enable_attention_slicing()
self.pipeline.enable_xformers_memory_efficient_attention() self.pipeline.enable_xformers_memory_efficient_attention()
self._model_loaded = True self._model_loaded = True
logger.info("Model loaded successfully") logger.info("Model loaded successfully")
except Exception as e: except Exception as e:
logger.error(f"Failed to load model: {str(e)}") logger.error(f"Failed to load model: {str(e)}")
raise RuntimeError(f"Model loading failed: {str(e)}") raise RuntimeError(f"Model loading failed: {str(e)}")
@@ -69,45 +70,43 @@ class ImageGenerationService:
prompt: str, prompt: str,
negative_prompt: Optional[str] = None, negative_prompt: Optional[str] = None,
output_path: Optional[str] = None, output_path: Optional[str] = None,
**kwargs **kwargs,
) -> Tuple[Image.Image, dict]: ) -> Tuple[Image.Image, dict]:
""" """
Generate image from text prompt. Generate image from text prompt.
Args: Args:
prompt: Text prompt for image generation prompt: Text prompt for image generation
negative_prompt: Text for things to avoid in generation negative_prompt: Text for things to avoid in generation
output_path: Optional path to save the image output_path: Optional path to save the image
**kwargs: Generation parameters (overrides defaults) **kwargs: Generation parameters (overrides defaults)
Returns: Returns:
Tuple of (PIL.Image, generation_parameters) Tuple of (PIL.Image, generation_parameters)
""" """
if not self._model_loaded: if not self._model_loaded:
self.load_model() self.load_model()
# Merge default params with overrides # Merge default params with overrides
params = {**self.default_params, **kwargs} params = {**self.default_params, **kwargs}
try: try:
logger.info(f"Generating image with prompt: {prompt[:50]}...") logger.info(f"Generating image with prompt: {prompt[:50]}...")
with torch.inference_mode(): with torch.inference_mode():
result = self.pipeline( result = self.pipeline(
prompt=prompt, prompt=prompt, negative_prompt=negative_prompt, **params
negative_prompt=negative_prompt,
**params
) )
image = result.images[0] image = result.images[0]
if output_path: if output_path:
os.makedirs(os.path.dirname(output_path), exist_ok=True) os.makedirs(os.path.dirname(output_path), exist_ok=True)
image.save(output_path) image.save(output_path)
logger.info(f"Image saved to {output_path}") logger.info(f"Image saved to {output_path}")
return image, params return image, params
except Exception as e: except Exception as e:
logger.error(f"Image generation failed: {str(e)}") logger.error(f"Image generation failed: {str(e)}")
raise RuntimeError(f"Image generation failed: {str(e)}") raise RuntimeError(f"Image generation failed: {str(e)}")
@@ -118,28 +117,28 @@ class AsyncImageGenerationService:
Asynchronous wrapper for image generation service. Asynchronous wrapper for image generation service.
Runs the synchronous service in a thread pool. Runs the synchronous service in a thread pool.
""" """
def __init__(self): def __init__(self):
self.sync_service = ImageGenerationService() self.sync_service = ImageGenerationService()
async def generate_image( async def generate_image(
self, self,
prompt: str, prompt: str,
negative_prompt: Optional[str] = None, negative_prompt: Optional[str] = None,
output_path: Optional[str] = None, output_path: Optional[str] = None,
**kwargs **kwargs,
) -> Tuple[Image.Image, dict]: ) -> Tuple[Image.Image, dict]:
"""Async version of generate_image""" """Async version of generate_image"""
import asyncio import asyncio
from functools import partial from functools import partial
loop = asyncio.get_event_loop() loop = asyncio.get_event_loop()
func = partial( func = partial(
self.sync_service.generate_image, self.sync_service.generate_image,
prompt=prompt, prompt=prompt,
negative_prompt=negative_prompt, negative_prompt=negative_prompt,
output_path=output_path, output_path=output_path,
**kwargs **kwargs,
) )
return await loop.run_in_executor(None, func) return await loop.run_in_executor(None, func)

View File

@@ -1,47 +1,50 @@
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import AsyncGenerator, Generator, Optional from typing import AsyncGenerator, Generator, Optional
from langchain_community.llms import Ollama # from langchain_community.llms import Ollama
from langchain_ollama import OllamaLLM
from langchain_core.output_parsers import StrOutputParser from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import ChatPromptTemplate from langchain_core.prompts import ChatPromptTemplate
from chat_backend.models import Conversation, Prompt from chat_backend.models import Conversation, Prompt
class LLMService(ABC): class LLMService(ABC):
"""Abstract base class for LLM conversation services.""" """Abstract base class for LLM conversation services."""
def __init__(self): def __init__(self):
self.llm = Ollama( self.llm = OllamaLLM(
model="llama3.2", model="llama3.2",
temperature=0.7, temperature=0.7,
top_k=50, top_k=50,
top_p=0.9, top_p=0.9,
repeat_penalty=1.1, repeat_penalty=1.1,
num_ctx=4096 num_ctx=4096,
) )
self.output_parser = StrOutputParser() self.output_parser = StrOutputParser()
@abstractmethod @abstractmethod
def generate_response(self, conversation: Conversation, query: str, **kwargs): def generate_response(self, conversation: Conversation, query: str, **kwargs):
"""Generate a response to a query within a conversation context.""" """Generate a response to a query within a conversation context."""
pass pass
def _format_history(self, conversation: Conversation) -> str: def _format_history(self, conversation: Conversation) -> str:
"""Format conversation history for the prompt.""" """Format conversation history for the prompt."""
prompts = Prompt.objects.filter(conversation=conversation).order_by('created_at') prompts = Prompt.objects.filter(conversation=conversation).order_by(
"created_at"
)
return "\n".join( return "\n".join(
f"{'User' if prompt.is_user else 'AI'}: {prompt.text}" f"{'User' if prompt.is_user else 'AI'}: {prompt.text}" for prompt in prompts
for prompt in prompts
) )
class SyncLLMService(LLMService): class SyncLLMService(LLMService):
"""Synchronous LLM conversation service.""" """Synchronous LLM conversation service."""
def __init__(self): def __init__(self):
super().__init__() super().__init__()
self._setup_chain() self._setup_chain()
def _setup_chain(self): def _setup_chain(self):
"""Setup the conversation chain.""" """Setup the conversation chain."""
template = """Continue the conversation based on the following history: template = """Continue the conversation based on the following history:
@@ -52,35 +55,36 @@ class SyncLLMService(LLMService):
Response:""" Response:"""
self.prompt = ChatPromptTemplate.from_template(template) self.prompt = ChatPromptTemplate.from_template(template)
self.conversation_chain = ( self.conversation_chain = (
{ {
"history": lambda x: self._format_history(x["conversation"]), "history": lambda x: self._format_history(x["conversation"]),
"query": lambda x: x["query"] "query": lambda x: x["query"],
} }
| self.prompt | self.prompt
| self.llm | self.llm
| self.output_parser | self.output_parser
) )
def generate_response(self, conversation: Conversation, query: str, **kwargs) -> Generator[str, None, None]: def generate_response(
self, conversation: Conversation, query: str, **kwargs
) -> Generator[str, None, None]:
"""Generate response with streaming support.""" """Generate response with streaming support."""
chain_input = { chain_input = {"query": query, "conversation": conversation}
"query": query,
"conversation": conversation print(f"chain_input:\n{chain_input}")
}
for chunk in self.conversation_chain.stream(chain_input): for chunk in self.conversation_chain.stream(chain_input):
yield chunk yield chunk
class AsyncLLMService(LLMService): class AsyncLLMService(LLMService):
"""Asynchronous LLM conversation service.""" """Asynchronous LLM conversation service."""
def __init__(self): def __init__(self):
super().__init__() super().__init__()
self._setup_chain() self._setup_chain()
def _setup_chain(self): def _setup_chain(self):
"""Setup the conversation chain.""" """Setup the conversation chain."""
template = """Continue this conversation while maintaining context by providing a single helpful response. template = """Continue this conversation while maintaining context by providing a single helpful response.
@@ -97,42 +101,50 @@ class AsyncLLMService(LLMService):
- When asked to modify something, identify what's being modified - When asked to modify something, identify what's being modified
Response:""" Response:"""
self.prompt = ChatPromptTemplate.from_template(template) self.prompt = ChatPromptTemplate.from_template(template)
self.conversation_chain = ( self.conversation_chain = (
{ {
"context": lambda x: self._format_history(x["conversation"]), "context": lambda x: self._format_history(x["conversation"]),
"recent_history": lambda x: self._get_recent_messages(x["conversation"]), "recent_history": lambda x: self._get_recent_messages(
"query": lambda x: x["query"] x["conversation"]
),
"query": lambda x: x["query"],
} }
| self.prompt | self.prompt
| self.llm | self.llm
| self.output_parser | self.output_parser
) )
async def _format_history(self, conversation: Conversation) -> str: async def _format_history(self, conversation: Conversation) -> str:
"""Async version of format conversation history.""" """Async version of format conversation history."""
prompts = await Prompt.objects.filter(conversation=conversation).order_by('created_at').alist() prompts = (
await Prompt.objects.filter(conversation=conversation)
.order_by("created_at")
.alist()
)
return "\n".join( return "\n".join(
f"{'User' if prompt.is_user else 'AI'}: {prompt.text}" f"{'User' if prompt.is_user else 'AI'}: {prompt.text}" for prompt in prompts
for prompt in prompts
) )
async def _get_recent_messages(self, conversation: Conversation) -> str: async def _get_recent_messages(self, conversation: Conversation) -> str:
"""Async version of format conversation history.""" """Async version of format conversation history."""
prompts = await Prompt.objects.filter(conversation=conversation).order_by('created_at').alist()[-3:] prompts = (
return "\n".join( await Prompt.objects.filter(conversation=conversation)
f"{'User' if prompt.is_user else 'AI'}: {prompt.text}" .order_by("created_at")
for prompt in prompts .alist()[-6:]
) )
return "\n".join(
async def generate_response(self, conversation: Conversation, query: str, **kwargs) -> AsyncGenerator[str, None]: f"{'User' if prompt.is_user else 'AI'}: {prompt.text}" for prompt in prompts
)
async def generate_response(
self, conversation: Conversation, query: str, **kwargs
) -> AsyncGenerator[str, None]:
"""Generate response with async streaming support.""" """Generate response with async streaming support."""
chain_input = { chain_input = {"query": query, "conversation": conversation}
"query": query, print(f"LLM Chain:\n{chain_input}")
"conversation": conversation
}
async for chunk in self.conversation_chain.astream(chain_input): async for chunk in self.conversation_chain.astream(chain_input):
yield chunk yield chunk

View File

@@ -1,27 +1,34 @@
from enum import Enum, auto from enum import Enum, auto
from typing import Dict, Any from typing import Dict, Any
from langchain_core.prompts import ChatPromptTemplate from langchain_core.prompts import ChatPromptTemplate
from langchain_community.llms import Ollama
# from langchain_community.llms import Ollama
from langchain_ollama import OllamaLLM
class ModerationLabel(Enum): class ModerationLabel(Enum):
NSFW = auto() NSFW = auto()
FINE = auto() FINE = auto()
class ModerationClassifier: class ModerationClassifier:
""" """
Classifies prompts as NSFW or FINE (safe) content. Classifies prompts as NSFW or FINE (safe) content.
""" """
def __init__(self): def __init__(self):
self.llm = Ollama( self.llm = OllamaLLM(
model="llama3.2", model="llama3.2",
temperature=0.1, # Very low for strict moderation temperature=0.1, # Very low for strict moderation
top_k=10, top_k=10,
num_ctx=2048 num_ctx=2048,
) )
self.moderation_prompt = ChatPromptTemplate.from_messages([ self.moderation_prompt = ChatPromptTemplate.from_messages(
("system", """You are a strict content moderator. Classify the following prompt as either NSFW or FINE. [
(
"system",
"""You are a strict content moderator. Classify the following prompt as either NSFW or FINE.
NSFW includes: NSFW includes:
- Sexual content - Sexual content
@@ -44,12 +51,14 @@ Examples:
- "Explicit sex scene" → NSFW - "Explicit sex scene" → NSFW
- "Python tutorial" → FINE - "Python tutorial" → FINE
Return ONLY "NSFW" or "FINE", nothing else."""), Return ONLY "NSFW" or "FINE", nothing else.""",
("human", "{prompt}") ),
]) ("human", "{prompt}"),
]
)
self.chain = self.moderation_prompt | self.llm self.chain = self.moderation_prompt | self.llm
async def classify_async(self, prompt: str) -> ModerationLabel: async def classify_async(self, prompt: str) -> ModerationLabel:
"""Asynchronous classification""" """Asynchronous classification"""
try: try:
@@ -58,7 +67,7 @@ Return ONLY "NSFW" or "FINE", nothing else."""),
except Exception as e: except Exception as e:
print(f"Moderation error: {e}") print(f"Moderation error: {e}")
return ModerationLabel.NSFW # Fail-safe to NSFW return ModerationLabel.NSFW # Fail-safe to NSFW
def classify(self, prompt: str) -> ModerationLabel: def classify(self, prompt: str) -> ModerationLabel:
"""Synchronous classification""" """Synchronous classification"""
try: try:
@@ -67,7 +76,7 @@ Return ONLY "NSFW" or "FINE", nothing else."""),
except Exception as e: except Exception as e:
print(f"Moderation error: {e}") print(f"Moderation error: {e}")
return ModerationLabel.NSFW # Fail-safe to NSFW return ModerationLabel.NSFW # Fail-safe to NSFW
def _parse_response(self, response: str) -> ModerationLabel: def _parse_response(self, response: str) -> ModerationLabel:
"""Convert string response to ModerationLabel enum""" """Convert string response to ModerationLabel enum"""
if "NSFW" in response: if "NSFW" in response:
@@ -76,4 +85,4 @@ Return ONLY "NSFW" or "FINE", nothing else."""),
# Singleton instance # Singleton instance
moderation_classifier = ModerationClassifier() moderation_classifier = ModerationClassifier()

View File

@@ -1,7 +1,10 @@
from enum import Enum, auto from enum import Enum, auto
from typing import Dict, Any from typing import Dict, Any
from langchain_core.prompts import ChatPromptTemplate from langchain_core.prompts import ChatPromptTemplate
from langchain_community.llms import Ollama
# from langchain_community.llms import Ollama
from langchain_ollama import OllamaLLM
class PromptType(Enum): class PromptType(Enum):
GENERAL_CHAT = auto() GENERAL_CHAT = auto()
@@ -9,23 +12,26 @@ class PromptType(Enum):
IMAGE_GENERATION = auto() IMAGE_GENERATION = auto()
UNKNOWN = auto() UNKNOWN = auto()
class PromptClassifier: class PromptClassifier:
""" """
Classifies user prompts to determine which service should handle them. Classifies user prompts to determine which service should handle them.
""" """
def __init__(self): def __init__(self):
self.llm = Ollama( self.llm = OllamaLLM(
model="llama3", model="llama3.2",
temperature=0.3, # Lower temp for more deterministic classification temperature=0.3, # Lower temp for more deterministic classification
top_k=20, top_k=20,
top_p=0.9, top_p=0.9,
num_ctx=4096 num_ctx=4096,
) )
self.classification_prompt = ChatPromptTemplate.from_messages([ self.classification_prompt = ChatPromptTemplate.from_messages(
("system", [
"""You are a precision prompt classifier. Strictly categorize prompts into: (
"system",
"""You are a precision prompt classifier. Strictly categorize prompts into:
1. GENERAL_CHAT - Casual conversation, personal questions, or non-specific inquiries 1. GENERAL_CHAT - Casual conversation, personal questions, or non-specific inquiries
2. RAG - ONLY when explicitly requesting document/search-based knowledge 2. RAG - ONLY when explicitly requesting document/search-based knowledge
3. IMAGE_GENERATION - Specific requests to create/modify images 3. IMAGE_GENERATION - Specific requests to create/modify images
@@ -63,12 +69,14 @@ Examples:
- "Explain quantum computing" (General knowledge) - "Explain quantum computing" (General knowledge)
- "Summarize the meeting" (No doc reference) - "Summarize the meeting" (No doc reference)
Return ONLY the label, no explanations."""), Return ONLY the label, no explanations.""",
("human", "{prompt}") ),
]) ("human", "{prompt}"),
]
)
self.chain = self.classification_prompt | self.llm self.chain = self.classification_prompt | self.llm
async def classify_async(self, prompt: str) -> PromptType: async def classify_async(self, prompt: str) -> PromptType:
"""Asynchronously classify the prompt""" """Asynchronously classify the prompt"""
try: try:
@@ -77,7 +85,7 @@ Return ONLY the label, no explanations."""),
except Exception as e: except Exception as e:
print(f"Classification error: {e}") print(f"Classification error: {e}")
return PromptType.UNKNOWN return PromptType.UNKNOWN
def classify(self, prompt: str) -> PromptType: def classify(self, prompt: str) -> PromptType:
"""Synchronously classify the prompt""" """Synchronously classify the prompt"""
try: try:
@@ -86,7 +94,7 @@ Return ONLY the label, no explanations."""),
except Exception as e: except Exception as e:
print(f"Classification error: {e}") print(f"Classification error: {e}")
return PromptType.UNKNOWN return PromptType.UNKNOWN
def _parse_response(self, response: str) -> PromptType: def _parse_response(self, response: str) -> PromptType:
"""Convert string response to PromptType enum""" """Convert string response to PromptType enum"""
response = response.upper() response = response.upper()
@@ -97,4 +105,4 @@ Return ONLY the label, no explanations."""),
# Singleton instance for easy access # Singleton instance for easy access
prompt_classifier = PromptClassifier() prompt_classifier = PromptClassifier()

View File

@@ -3,7 +3,9 @@ from abc import ABC, abstractmethod
from typing import List, Dict, Any, AsyncGenerator, Generator, Optional from typing import List, Dict, Any, AsyncGenerator, Generator, Optional
from channels.db import database_sync_to_async from channels.db import database_sync_to_async
from langchain_community.embeddings import OllamaEmbeddings from langchain_community.embeddings import OllamaEmbeddings
from langchain_community.llms import Ollama
# from langchain_community.llms import Ollama
from langchain_ollama import OllamaLLM
from langchain_community.vectorstores import Chroma from langchain_community.vectorstores import Chroma
from langchain_core.documents import Document as LangDocument from langchain_core.documents import Document as LangDocument
from langchain_core.output_parsers import StrOutputParser from langchain_core.output_parsers import StrOutputParser
@@ -14,139 +16,139 @@ from langchain_community.document_loaders import (
PyPDFLoader, PyPDFLoader,
Docx2txtLoader, Docx2txtLoader,
TextLoader, TextLoader,
UnstructuredFileLoader UnstructuredFileLoader,
) )
from django.core.files.uploadedfile import UploadedFile from django.core.files.uploadedfile import UploadedFile
from chat_backend.models import Conversation, Prompt, DocumentWorkspace, Document from chat_backend.models import Conversation, Prompt, DocumentWorkspace, Document
from pathlib import Path from pathlib import Path
@database_sync_to_async @database_sync_to_async
def get_documents(workspace: DocumentWorkspace | None = None): def get_documents(workspace: DocumentWorkspace | None = None):
if workspace: if workspace:
return [doc for doc in Document.objects.filter(workspace=workspace)] return [doc for doc in Document.objects.filter(workspace=workspace)]
else: else:
return [doc for doc in Document.objects.all()] return [doc for doc in Document.objects.all()]
class RAGService(ABC): class RAGService(ABC):
"""Abstract base class for RAG services.""" """Abstract base class for RAG services."""
_instance = None _instance = None
def __new__(cls): def __new__(cls):
if cls._instance is None: if cls._instance is None:
cls._instance = super().__new__(cls) cls._instance = super().__new__(cls)
cls._instance.__init__() cls._instance.__init__()
return cls._instance return cls._instance
def __init__(self): def __init__(self):
self.embedding_model = OllamaEmbeddings(model="llama3.2") self.embedding_model = OllamaEmbeddings(model="llama3.2")
self.llm = Ollama( self.llm = OllamaLLM(
model="llama3.2", model="llama3.2",
temperature=0.7, temperature=0.7,
top_k=50, top_k=50,
top_p=0.9, top_p=0.9,
repeat_penalty=1.1, repeat_penalty=1.1,
num_ctx=4096 num_ctx=4096,
) )
self.text_splitter = RecursiveCharacterTextSplitter( self.text_splitter = RecursiveCharacterTextSplitter(
chunk_size=1000, chunk_size=1000, chunk_overlap=200
chunk_overlap=200
) )
self.vector_store = self._initialize_vector_store() self.vector_store = self._initialize_vector_store()
# Supported file types and their loaders # Supported file types and their loaders
self.loader_mapping = { self.loader_mapping = {
'.pdf': PyPDFLoader, ".pdf": PyPDFLoader,
'.docx': Docx2txtLoader, ".docx": Docx2txtLoader,
'.txt': TextLoader, ".txt": TextLoader,
# Fallback for other file types # Fallback for other file types
'*': UnstructuredFileLoader, "*": UnstructuredFileLoader,
} }
def _initialize_vector_store(self) -> Chroma: def _initialize_vector_store(self) -> Chroma:
"""Initialize and return the Chroma vector store.""" """Initialize and return the Chroma vector store."""
persist_directory=f"./chroma_db/" persist_directory = f"./chroma_db/"
vector_store = Chroma( vector_store = Chroma(
embedding_function=self.embedding_model, embedding_function=self.embedding_model, persist_directory=persist_directory
persist_directory=persist_directory
) )
return vector_store return vector_store
def clear_vector_store(self): def clear_vector_store(self):
"""Clear all vectors from the store""" """Clear all vectors from the store"""
self.vector_store.delete_collection() self.vector_store.delete_collection()
self.vector_store = self._initialize_vector_store() self.vector_store = self._initialize_vector_store()
def _prepare_documents(self, documents: List[Document]) -> List[Document]: def _prepare_documents(self, documents: List[Document]) -> List[Document]:
"""Process documents for ingestion into vector store.""" """Process documents for ingestion into vector store."""
docs = [] docs = []
for doc in documents: for doc in documents:
print(f"Processing: {doc.file.name}") print(f"Processing: {doc.file.name}")
loader_class = self._get_file_loader( doc.file.name) loader_class = self._get_file_loader(doc.file.name)
loader = loader_class(doc.file) loader = loader_class(doc.file)
chunks = self._load_and_split_documents(doc.file.path) chunks = self._load_and_split_documents(doc.file.path)
if chunks: if chunks:
self.vector_store.add_documents(chunks) self.vector_store.add_documents(chunks)
self.vector_store.persist() self.vector_store.persist()
def ingest_documents(self, workspace: DocumentWorkspace | None = None) -> None: def ingest_documents(self, workspace: DocumentWorkspace | None = None) -> None:
"""Ingest documents from a workspace into the vector store.""" """Ingest documents from a workspace into the vector store."""
print(f"Getting the Document via the workspace: {workspace}") print(f"Getting the Document via the workspace: {workspace}")
if workspace: if workspace:
documents = [doc for doc in Document.objects.filter(workspace=workspace)] documents = [doc for doc in Document.objects.filter(workspace=workspace)]
else: else:
documents = [doc for doc in Document.objects.all()] documents = [doc for doc in Document.objects.all()]
print(f"Processing the documents : {documents}") print(f"Processing the documents : {documents}")
self._prepare_documents(documents) self._prepare_documents(documents)
@abstractmethod @abstractmethod
def generate_response(self, conversation: Conversation, query: str, **kwargs): def generate_response(self, conversation: Conversation, query: str, **kwargs):
"""Generate a response using RAG.""" """Generate a response using RAG."""
pass pass
@abstractmethod @abstractmethod
def search_documents(self, query: str, workspace: Optional[DocumentWorkspace] = None, k: int = 4) -> List[Document]: def search_documents(
self, query: str, workspace: Optional[DocumentWorkspace] = None, k: int = 4
) -> List[Document]:
"""Search relevant documents from the vector store.""" """Search relevant documents from the vector store."""
pass pass
def _get_file_loader(self, file_path: str): def _get_file_loader(self, file_path: str):
"""Get appropriate loader for file type""" """Get appropriate loader for file type"""
ext = Path(file_path).suffix.lower() ext = Path(file_path).suffix.lower()
return self.loader_mapping.get(ext, self.loader_mapping['*']) return self.loader_mapping.get(ext, self.loader_mapping["*"])
def _sanitize_filename(self, filename: str) -> str: def _sanitize_filename(self, filename: str) -> str:
"""Sanitize filename for safe storage""" """Sanitize filename for safe storage"""
return re.sub(r'[^\w\-_. ]', '_', filename) return re.sub(r"[^\w\-_. ]", "_", filename)
def _save_uploaded_file(self, uploaded_file: UploadedFile, save_dir: str) -> str: def _save_uploaded_file(self, uploaded_file: UploadedFile, save_dir: str) -> str:
"""Save uploaded file to disk""" """Save uploaded file to disk"""
os.makedirs(save_dir, exist_ok=True) os.makedirs(save_dir, exist_ok=True)
sanitized_name = self._sanitize_filename(uploaded_file.name) sanitized_name = self._sanitize_filename(uploaded_file.name)
file_path = os.path.join(save_dir, sanitized_name) file_path = os.path.join(save_dir, sanitized_name)
with open(file_path, 'wb+') as destination: with open(file_path, "wb+") as destination:
for chunk in uploaded_file.chunks(): for chunk in uploaded_file.chunks():
destination.write(chunk) destination.write(chunk)
return file_path return file_path
def _load_and_split_documents(self, file_path: str, metadata: dict = None) -> List[Document]: def _load_and_split_documents(
self, file_path: str, metadata: dict = None
) -> List[Document]:
"""Load and split documents from file""" """Load and split documents from file"""
loader_class = self._get_file_loader(file_path) loader_class = self._get_file_loader(file_path)
loader = loader_class(file_path) loader = loader_class(file_path)
docs = loader.load() docs = loader.load()
if metadata: if metadata:
for doc in docs: for doc in docs:
doc.metadata.update(metadata) doc.metadata.update(metadata)
return self.text_splitter.split_documents(docs) return self.text_splitter.split_documents(docs)
def add_files_to_store( def add_files_to_store(
@@ -154,58 +156,51 @@ class RAGService(ABC):
file_tupls: List[UploadedFile], # (file_path, name,workspace_id) file_tupls: List[UploadedFile], # (file_path, name,workspace_id)
workspace_id: str, workspace_id: str,
source: str = "upload", source: str = "upload",
save_dir: str = "data/uploads" save_dir: str = "data/uploads",
) -> Dict[str, Any]: ) -> Dict[str, Any]:
""" """
Process and add uploaded files to vector store Process and add uploaded files to vector store
Args: Args:
files: List of Django UploadedFile objects files: List of Django UploadedFile objects
workspace_id: ID of the workspace these belong to workspace_id: ID of the workspace these belong to
source: Source identifier for documents source: Source identifier for documents
save_dir: Directory to save uploaded files save_dir: Directory to save uploaded files
Returns: Returns:
Dictionary with processing results Dictionary with processing results
""" """
results = { results = {"total_added": 0, "failed_files": [], "processed_files": []}
'total_added': 0,
'failed_files': [],
'processed_files': []
}
for file_tuple in file_tupls: for file_tuple in file_tupls:
try: try:
# Save file to disk # Save file to disk
# Prepare metadata # Prepare metadata
metadata = { metadata = {
'source': file_tuple[1], "source": file_tuple[1],
'workspace_id': file_tuple[2], "workspace_id": file_tuple[2],
'original_filename': file_tuple[1], "original_filename": file_tuple[1],
'file_path': file_tuple[0], "file_path": file_tuple[0],
} }
# Load and split documents # Load and split documents
docs = self._load_and_split_documents(file_path, metadata) docs = self._load_and_split_documents(file_path, metadata)
# Add to vector store # Add to vector store
if docs: if docs:
self.vector_store.add_documents(docs) self.vector_store.add_documents(docs)
results['total_added'] += len(docs) results["total_added"] += len(docs)
results['processed_files'].append({ results["processed_files"].append(
'filename': file_tuple[1], {"filename": file_tuple[1], "document_count": len(docs)}
'document_count': len(docs) )
})
except Exception as e: except Exception as e:
results['failed_files'].append({ results["failed_files"].append(
'filename': file_tuple[1], {"filename": file_tuple[1], "error": str(e)}
'error': str(e) )
})
continue continue
# Persist changes # Persist changes
self.vector_store.persist() self.vector_store.persist()
return results return results
@@ -213,11 +208,11 @@ class RAGService(ABC):
class SyncRAGService(RAGService): class SyncRAGService(RAGService):
"""Synchronous RAG service implementation.""" """Synchronous RAG service implementation."""
def __init__(self): def __init__(self):
super().__init__() super().__init__()
self._setup_chain() self._setup_chain()
def _setup_chain(self): def _setup_chain(self):
"""Setup the RAG chain.""" """Setup the RAG chain."""
template = """Answer the question based only on the following context: template = """Answer the question based only on the following context:
@@ -229,31 +224,32 @@ class SyncRAGService(RAGService):
Question: {question} Question: {question}
""" """
self.prompt = ChatPromptTemplate.from_template(template) self.prompt = ChatPromptTemplate.from_template(template)
self.rag_chain = ( self.rag_chain = (
{ {
"context": self._retriever_with_history, "context": self._retriever_with_history,
"history": lambda x: self._format_history(x["conversation"]), "history": lambda x: self._format_history(x["conversation"]),
"question": lambda x: x["query"] "question": lambda x: x["query"],
} }
| self.prompt | self.prompt
| self.llm | self.llm
| StrOutputParser() | StrOutputParser()
) )
def _format_history(self, conversation: Conversation) -> str: def _format_history(self, conversation: Conversation) -> str:
"""Format conversation history for the prompt.""" """Format conversation history for the prompt."""
prompts = Prompt.objects.filter(conversation=conversation).order_by('created_at') prompts = Prompt.objects.filter(conversation=conversation).order_by(
return "\n".join( "created_at"
f"{'User' if prompt.is_user else 'AI'}: {prompt.text}"
for prompt in prompts
) )
return "\n".join(
f"{'User' if prompt.is_user else 'AI'}: {prompt.text}" for prompt in prompts
)
def _retriever_with_history(self, input_dict: Dict[str, Any]) -> str: def _retriever_with_history(self, input_dict: Dict[str, Any]) -> str:
"""Retrieve documents considering conversation history.""" """Retrieve documents considering conversation history."""
query = input_dict["query"] query = input_dict["query"]
conversation = input_dict["conversation"] conversation = input_dict["conversation"]
# You could enhance this to consider historical context in retrieval # You could enhance this to consider historical context in retrieval
relevant_docs = self.search_documents(query, conversation.workspace) relevant_docs = self.search_documents(query, conversation.workspace)
if not relevant_docs: if not relevant_docs:
@@ -262,8 +258,9 @@ class SyncRAGService(RAGService):
else: else:
return relevant_docs return relevant_docs
def search_documents(
def search_documents(self, query: str, workspace: Optional[DocumentWorkspace] = None, k: int = 4) -> List[Document]: self, query: str, workspace: Optional[DocumentWorkspace] = None, k: int = 4
) -> List[Document]:
"""Search relevant documents from the vector store.""" """Search relevant documents from the vector store."""
filter_dict = {} filter_dict = {}
if workspace: if workspace:
@@ -271,31 +268,27 @@ class SyncRAGService(RAGService):
print(f"search_kwargs: {search_kwargs}") print(f"search_kwargs: {search_kwargs}")
retriever = self.vector_store.as_retriever( retriever = self.vector_store.as_retriever(
search_type="similarity", search_type="similarity",
search_kwargs={ search_kwargs={"k": k, "filter": filter_dict if filter_dict else None},
"k": k,
"filter": filter_dict if filter_dict else None
}
) )
return retriever.get_relevant_documents(query) return retriever.get_relevant_documents(query)
def generate_response(self, conversation: Conversation, query: str, **kwargs) -> Generator[str, None, None]: def generate_response(
self, conversation: Conversation, query: str, **kwargs
) -> Generator[str, None, None]:
"""Generate response with streaming support.""" """Generate response with streaming support."""
chain_input = { chain_input = {"query": query, "conversation": conversation}
"query": query,
"conversation": conversation
}
for chunk in self.rag_chain.stream(chain_input): for chunk in self.rag_chain.stream(chain_input):
yield chunk yield chunk
class AsyncRAGService(RAGService): class AsyncRAGService(RAGService):
"""Asynchronous RAG service implementation.""" """Asynchronous RAG service implementation."""
def __init__(self): def __init__(self):
super().__init__() super().__init__()
self._setup_chain() self._setup_chain()
def _setup_chain(self): def _setup_chain(self):
"""Setup the RAG chain.""" """Setup the RAG chain."""
template = """Answer the question based only on the following context: template = """Answer the question based only on the following context:
@@ -307,72 +300,76 @@ class AsyncRAGService(RAGService):
Question: {question} Question: {question}
""" """
self.prompt = ChatPromptTemplate.from_template(template) self.prompt = ChatPromptTemplate.from_template(template)
self.rag_chain = ( self.rag_chain = (
{ {
"context": self._retriever_with_history, "context": self._retriever_with_history,
"history": lambda x: self._format_history(x["conversation"]), "history": lambda x: self._format_history(x["conversation"]),
"question": lambda x: x["query"] "question": lambda x: x["query"],
} }
| self.prompt | self.prompt
| self.llm | self.llm
| StrOutputParser() | StrOutputParser()
) )
async def _format_history(self, conversation: Conversation) -> str: async def _format_history(self, conversation: Conversation) -> str:
"""Format conversation history for the prompt.""" """Format conversation history for the prompt."""
prompts = await Prompt.objects.filter(conversation=conversation).order_by('created_at').alist() prompts = (
await Prompt.objects.filter(conversation=conversation)
.order_by("created_at")
.alist()
)
print(f"prompts that we are seeding with are: {prompts}") print(f"prompts that we are seeding with are: {prompts}")
return "\n".join( return "\n".join(
f"{'User' if prompt.is_user else 'AI'}: {prompt.text}" f"{'User' if prompt.is_user else 'AI'}: {prompt.text}" for prompt in prompts
for prompt in prompts
) )
async def _retriever_with_history(self, input_dict: Dict[str, Any]) -> str: async def _retriever_with_history(self, input_dict: Dict[str, Any]) -> str:
"""Retrieve documents considering conversation history.""" """Retrieve documents considering conversation history."""
print(f"Retrieving history with input: {input_dict}") print(f"Retrieving history with input: {input_dict}")
query = input_dict["query"] query = input_dict["query"]
conversation = input_dict["conversation"] conversation = input_dict["conversation"]
workspace = input_dict["workspace"] workspace = input_dict["workspace"]
# You could enhance this to consider historical context in retrieval # You could enhance this to consider historical context in retrieval
docs= await self.search_documents(query, workspace) docs = await self.search_documents(query, workspace)
if not docs: if not docs:
print("Didn't find any relevant docs") print("Didn't find any relevant docs")
print("\n\n".join(doc.page_content for doc in docs)) print("\n\n".join(doc.page_content for doc in docs))
return "\n\n".join(doc.page_content for doc in docs) return "\n\n".join(doc.page_content for doc in docs)
async def search_documents(
async def search_documents(self, query: str, workspace: Optional[DocumentWorkspace] = None, k: int = 4) -> List[Document]: self, query: str, workspace: Optional[DocumentWorkspace] = None, k: int = 4
) -> List[Document]:
"""Search relevant documents from the vector store.""" """Search relevant documents from the vector store."""
filter_dict = {} filter_dict = {}
print(f"Do we have a workspace: {workspace}") print(f"Do we have a workspace: {workspace}")
if workspace: if workspace:
filter_dict["workspace_id"] = workspace.id filter_dict["workspace_id"] = workspace.id
search_kwargs={ search_kwargs = {"k": k, "filter": filter_dict if filter_dict else None}
"k": k,
"filter": filter_dict if filter_dict else None
}
print(f"search_kwargs: {search_kwargs}") print(f"search_kwargs: {search_kwargs}")
retriever = self.vector_store.as_retriever( retriever = self.vector_store.as_retriever(
search_type="mmr", search_type="mmr",
search_kwargs={ search_kwargs={"k": k, "filter": filter_dict if filter_dict else None},
"k": k,
"filter": filter_dict if filter_dict else None
}
) )
return await retriever.aget_relevant_documents(query) return await retriever.aget_relevant_documents(query)
async def generate_response(self, conversation: Conversation, query: str, workspace: DocumentWorkspace, **kwargs) -> AsyncGenerator[str, None]: async def generate_response(
self,
conversation: Conversation,
query: str,
workspace: DocumentWorkspace,
**kwargs,
) -> AsyncGenerator[str, None]:
"""Generate response with streaming support.""" """Generate response with streaming support."""
chain_input = { chain_input = {
"query": query, "query": query,
"conversation": conversation, "conversation": conversation,
"workspace": workspace, "workspace": workspace,
} }
async for chunk in self.rag_chain.astream(chain_input): async for chunk in self.rag_chain.astream(chain_input):
yield chunk yield chunk

View File

@@ -5,9 +5,14 @@ from typing import List, Dict, Any
from django.test import TestCase as DjangoTestCase from django.test import TestCase as DjangoTestCase
from chat_backend.services.rag_services import RAGService, SyncRAGService, AsyncRAGService from chat_backend.services.rag_services import (
RAGService,
SyncRAGService,
AsyncRAGService,
)
from chat_backend.models import Conversation, Prompt, DocumentWorkspace, Document from chat_backend.models import Conversation, Prompt, DocumentWorkspace, Document
class TestRAGService(TestCase): class TestRAGService(TestCase):
def setUp(self): def setUp(self):
self.rag_service = RAGService() self.rag_service = RAGService()
@@ -16,18 +21,20 @@ class TestRAGService(TestCase):
self.rag_service.text_splitter = MagicMock() self.rag_service.text_splitter = MagicMock()
def test_initialize_vector_store(self): def test_initialize_vector_store(self):
with patch('os.path.exists', return_value=False), \ with patch("os.path.exists", return_value=False), patch(
patch('os.makedirs') as mock_makedirs, \ "os.makedirs"
patch('langchain_community.vectorstores.Chroma') as mock_chroma: ) as mock_makedirs, patch(
"langchain_community.vectorstores.Chroma"
) as mock_chroma:
# Reset the vector store to test initialization # Reset the vector store to test initialization
self.rag_service.vector_store = None self.rag_service.vector_store = None
result = self.rag_service._initialize_vector_store() result = self.rag_service._initialize_vector_store()
mock_makedirs.assert_called_once_with("chroma_db") mock_makedirs.assert_called_once_with("chroma_db")
mock_chroma.assert_called_once_with( mock_chroma.assert_called_once_with(
embedding_function=self.rag_service.embedding_model, embedding_function=self.rag_service.embedding_model,
persist_directory="chroma_db" persist_directory="chroma_db",
) )
self.assertIsNotNone(result) self.assertIsNotNone(result)
@@ -40,11 +47,13 @@ class TestRAGService(TestCase):
mock_doc1.id = 1 mock_doc1.id = 1
self.rag_service.text_splitter.split_text.return_value = ["chunk1", "chunk2"] self.rag_service.text_splitter.split_text.return_value = ["chunk1", "chunk2"]
result = self.rag_service._prepare_documents([mock_doc1]) result = self.rag_service._prepare_documents([mock_doc1])
self.assertEqual(len(result), 2) self.assertEqual(len(result), 2)
self.rag_service.text_splitter.split_text.assert_called_once_with("Test content") self.rag_service.text_splitter.split_text.assert_called_once_with(
"Test content"
)
self.assertEqual(result[0].page_content, "chunk1") self.assertEqual(result[0].page_content, "chunk1")
self.assertEqual(result[0].metadata["source"], "test_source") self.assertEqual(result[0].metadata["source"], "test_source")
@@ -52,13 +61,19 @@ class TestRAGService(TestCase):
mock_workspace = MagicMock() mock_workspace = MagicMock()
mock_document = MagicMock() mock_document = MagicMock()
mock_documents = [mock_document] mock_documents = [mock_document]
with patch('services.rag_services.Document.objects.filter', return_value=mock_documents): with patch(
self.rag_service._prepare_documents = MagicMock(return_value=["processed_doc"]) "services.rag_services.Document.objects.filter", return_value=mock_documents
):
self.rag_service._prepare_documents = MagicMock(
return_value=["processed_doc"]
)
self.rag_service.ingest_documents(mock_workspace) self.rag_service.ingest_documents(mock_workspace)
self.rag_service.vector_store.add_documents.assert_called_once_with(["processed_doc"]) self.rag_service.vector_store.add_documents.assert_called_once_with(
["processed_doc"]
)
self.rag_service.vector_store.persist.assert_called_once() self.rag_service.vector_store.persist.assert_called_once()
@@ -71,40 +86,39 @@ class TestSyncRAGService(DjangoTestCase):
self.mock_conversation = MagicMock(spec=Conversation) self.mock_conversation = MagicMock(spec=Conversation)
self.mock_conversation.workspace = MagicMock() self.mock_conversation.workspace = MagicMock()
self.mock_prompt1 = MagicMock(spec=Prompt) self.mock_prompt1 = MagicMock(spec=Prompt)
self.mock_prompt1.is_user = True self.mock_prompt1.is_user = True
self.mock_prompt1.text = "User question" self.mock_prompt1.text = "User question"
self.mock_prompt1.created_at = "2023-01-01" self.mock_prompt1.created_at = "2023-01-01"
self.mock_prompt2 = MagicMock(spec=Prompt) self.mock_prompt2 = MagicMock(spec=Prompt)
self.mock_prompt2.is_user = False self.mock_prompt2.is_user = False
self.mock_prompt2.text = "AI response" self.mock_prompt2.text = "AI response"
self.mock_prompt2.created_at = "2023-01-02" self.mock_prompt2.created_at = "2023-01-02"
def test_format_history(self): def test_format_history(self):
with patch('services.rag_services.Prompt.objects.filter') as mock_filter: with patch("services.rag_services.Prompt.objects.filter") as mock_filter:
mock_filter.return_value.order_by.return_value = [self.mock_prompt1, self.mock_prompt2] mock_filter.return_value.order_by.return_value = [
self.mock_prompt1,
self.mock_prompt2,
]
result = self.sync_service._format_history(self.mock_conversation) result = self.sync_service._format_history(self.mock_conversation)
expected = "User: User question\nAI: AI response" expected = "User: User question\nAI: AI response"
self.assertEqual(result, expected) self.assertEqual(result, expected)
mock_filter.assert_called_once_with(conversation=self.mock_conversation) mock_filter.assert_called_once_with(conversation=self.mock_conversation)
def test_retriever_with_history(self): def test_retriever_with_history(self):
input_dict = { input_dict = {"query": "test query", "conversation": self.mock_conversation}
"query": "test query",
"conversation": self.mock_conversation
}
self.sync_service.search_documents = MagicMock(return_value=["doc1", "doc2"]) self.sync_service.search_documents = MagicMock(return_value=["doc1", "doc2"])
result = self.sync_service._retriever_with_history(input_dict) result = self.sync_service._retriever_with_history(input_dict)
self.sync_service.search_documents.assert_called_once_with( self.sync_service.search_documents.assert_called_once_with(
"test query", "test query", self.mock_conversation.workspace
self.mock_conversation.workspace
) )
self.assertEqual(result, ["doc1", "doc2"]) self.assertEqual(result, ["doc1", "doc2"])
@@ -112,29 +126,30 @@ class TestSyncRAGService(DjangoTestCase):
mock_retriever = MagicMock() mock_retriever = MagicMock()
mock_retriever.get_relevant_documents.return_value = ["doc1", "doc2"] mock_retriever.get_relevant_documents.return_value = ["doc1", "doc2"]
self.sync_service.vector_store.as_retriever.return_value = mock_retriever self.sync_service.vector_store.as_retriever.return_value = mock_retriever
result = self.sync_service.search_documents("test query", self.mock_conversation.workspace) result = self.sync_service.search_documents(
"test query", self.mock_conversation.workspace
)
self.sync_service.vector_store.as_retriever.assert_called_once_with( self.sync_service.vector_store.as_retriever.assert_called_once_with(
search_type="similarity", search_type="similarity",
search_kwargs={ search_kwargs={
"k": 4, "k": 4,
"filter": {"workspace_id": self.mock_conversation.workspace.id} "filter": {"workspace_id": self.mock_conversation.workspace.id},
} },
) )
self.assertEqual(result, ["doc1", "doc2"]) self.assertEqual(result, ["doc1", "doc2"])
def test_generate_response(self): def test_generate_response(self):
chain_input = { chain_input = {"query": "test query", "conversation": self.mock_conversation}
"query": "test query",
"conversation": self.mock_conversation
}
mock_stream = ["chunk1", "chunk2", "chunk3"] mock_stream = ["chunk1", "chunk2", "chunk3"]
self.sync_service.rag_chain.stream.return_value = mock_stream self.sync_service.rag_chain.stream.return_value = mock_stream
result = list(self.sync_service.generate_response(self.mock_conversation, "test query")) result = list(
self.sync_service.generate_response(self.mock_conversation, "test query")
)
self.sync_service.rag_chain.stream.assert_called_once_with(chain_input) self.sync_service.rag_chain.stream.assert_called_once_with(chain_input)
self.assertEqual(result, mock_stream) self.assertEqual(result, mock_stream)
@@ -148,12 +163,12 @@ class TestAsyncRAGService(DjangoTestCase):
self.mock_conversation = MagicMock(spec=Conversation) self.mock_conversation = MagicMock(spec=Conversation)
self.mock_conversation.workspace = MagicMock() self.mock_conversation.workspace = MagicMock()
self.mock_prompt1 = MagicMock(spec=Prompt) self.mock_prompt1 = MagicMock(spec=Prompt)
self.mock_prompt1.is_user = True self.mock_prompt1.is_user = True
self.mock_prompt1.text = "User question" self.mock_prompt1.text = "User question"
self.mock_prompt1.created_at = "2023-01-01" self.mock_prompt1.created_at = "2023-01-01"
self.mock_prompt2 = MagicMock(spec=Prompt) self.mock_prompt2 = MagicMock(spec=Prompt)
self.mock_prompt2.is_user = False self.mock_prompt2.is_user = False
self.mock_prompt2.text = "AI response" self.mock_prompt2.text = "AI response"
@@ -161,28 +176,29 @@ class TestAsyncRAGService(DjangoTestCase):
async def test_format_history(self): async def test_format_history(self):
mock_manager = AsyncMock() mock_manager = AsyncMock()
mock_manager.order_by.return_value.alist.return_value = [self.mock_prompt1, self.mock_prompt2] mock_manager.order_by.return_value.alist.return_value = [
self.mock_prompt1,
with patch('services.rag_services.Prompt.objects.filter', return_value=mock_manager): self.mock_prompt2,
]
with patch(
"services.rag_services.Prompt.objects.filter", return_value=mock_manager
):
result = await self.async_service._format_history(self.mock_conversation) result = await self.async_service._format_history(self.mock_conversation)
expected = "User: User question\nAI: AI response" expected = "User: User question\nAI: AI response"
self.assertEqual(result, expected) self.assertEqual(result, expected)
mock_manager.order_by.assert_called_once_with('created_at') mock_manager.order_by.assert_called_once_with("created_at")
async def test_retriever_with_history(self): async def test_retriever_with_history(self):
input_dict = { input_dict = {"query": "test query", "conversation": self.mock_conversation}
"query": "test query",
"conversation": self.mock_conversation
}
self.async_service.search_documents = AsyncMock(return_value=["doc1", "doc2"]) self.async_service.search_documents = AsyncMock(return_value=["doc1", "doc2"])
result = await self.async_service._retriever_with_history(input_dict) result = await self.async_service._retriever_with_history(input_dict)
self.async_service.search_documents.assert_awaited_once_with( self.async_service.search_documents.assert_awaited_once_with(
"test query", "test query", self.mock_conversation.workspace
self.mock_conversation.workspace
) )
self.assertEqual(result, ["doc1", "doc2"]) self.assertEqual(result, ["doc1", "doc2"])
@@ -190,30 +206,31 @@ class TestAsyncRAGService(DjangoTestCase):
mock_retriever = AsyncMock() mock_retriever = AsyncMock()
mock_retriever.aget_relevant_documents.return_value = ["doc1", "doc2"] mock_retriever.aget_relevant_documents.return_value = ["doc1", "doc2"]
self.async_service.vector_store.as_retriever.return_value = mock_retriever self.async_service.vector_store.as_retriever.return_value = mock_retriever
result = await self.async_service.search_documents("test query", self.mock_conversation.workspace) result = await self.async_service.search_documents(
"test query", self.mock_conversation.workspace
)
self.async_service.vector_store.as_retriever.assert_called_once_with( self.async_service.vector_store.as_retriever.assert_called_once_with(
search_type="similarity", search_type="similarity",
search_kwargs={ search_kwargs={
"k": 4, "k": 4,
"filter": {"workspace_id": self.mock_conversation.workspace.id} "filter": {"workspace_id": self.mock_conversation.workspace.id},
} },
) )
self.assertEqual(result, ["doc1", "doc2"]) self.assertEqual(result, ["doc1", "doc2"])
async def test_generate_response(self): async def test_generate_response(self):
chain_input = { chain_input = {"query": "test query", "conversation": self.mock_conversation}
"query": "test query",
"conversation": self.mock_conversation
}
mock_stream = ["chunk1", "chunk2", "chunk3"] mock_stream = ["chunk1", "chunk2", "chunk3"]
self.async_service.rag_chain.astream.return_value = mock_stream self.async_service.rag_chain.astream.return_value = mock_stream
chunks = [] chunks = []
async for chunk in self.async_service.generate_response(self.mock_conversation, "test query"): async for chunk in self.async_service.generate_response(
self.mock_conversation, "test query"
):
chunks.append(chunk) chunks.append(chunk)
self.async_service.rag_chain.astream.assert_awaited_once_with(chain_input) self.async_service.rag_chain.astream.assert_awaited_once_with(chain_input)
self.assertEqual(chunks, mock_stream) self.assertEqual(chunks, mock_stream)

View File

@@ -1,22 +1,28 @@
from langchain_core.prompts import ChatPromptTemplate from langchain_core.prompts import ChatPromptTemplate
from langchain_community.llms import Ollama
# from langchain_community.llms import Ollama
from langchain_ollama import OllamaLLM
from typing import Optional from typing import Optional
class TitleGenerator: class TitleGenerator:
""" """
Generates short, descriptive titles for conversations based on the first prompt. Generates short, descriptive titles for conversations based on the first prompt.
""" """
def __init__(self): def __init__(self):
self.llm = Ollama( self.llm = OllamaLLM(
model="llama3", model="llama3.2",
temperature=0.5, # Slightly creative but not too random temperature=0.5, # Slightly creative but not too random
top_k=20, top_k=20,
num_ctx=2048 # Shorter context needed for titles num_ctx=2048, # Shorter context needed for titles
) )
self.title_prompt = ChatPromptTemplate.from_messages([ self.title_prompt = ChatPromptTemplate.from_messages(
("system", """You are a conversation title generator. Create a very short (2-5 word) title based on the user's first message. [
(
"system",
"""You are a conversation title generator. Create a very short (2-5 word) title based on the user's first message.
Rules: Rules:
1. Keep it extremely concise 1. Keep it extremely concise
@@ -31,12 +37,14 @@ Examples:
- "Generate an image of a dragon""Dragon Image Generation" - "Generate an image of a dragon""Dragon Image Generation"
- "Find our company's privacy policy""Privacy Policy Search" - "Find our company's privacy policy""Privacy Policy Search"
Return ONLY the title, nothing else."""), Return ONLY the title, nothing else.""",
("human", "{prompt}") ),
]) ("human", "{prompt}"),
]
)
self.chain = self.title_prompt | self.llm self.chain = self.title_prompt | self.llm
async def generate_async(self, prompt: str) -> str: async def generate_async(self, prompt: str) -> str:
"""Generate title asynchronously""" """Generate title asynchronously"""
try: try:
@@ -45,7 +53,7 @@ Return ONLY the title, nothing else."""),
except Exception as e: except Exception as e:
print(f"Title generation error: {e}") print(f"Title generation error: {e}")
return "Conversation" return "Conversation"
def generate(self, prompt: str) -> str: def generate(self, prompt: str) -> str:
"""Generate title synchronously""" """Generate title synchronously"""
try: try:
@@ -54,14 +62,14 @@ Return ONLY the title, nothing else."""),
except Exception as e: except Exception as e:
print(f"Title generation error: {e}") print(f"Title generation error: {e}")
return "Conversation" return "Conversation"
def _clean_response(self, response: str) -> str: def _clean_response(self, response: str) -> str:
"""Clean and format the LLM response""" """Clean and format the LLM response"""
# Remove any quotes or punctuation # Remove any quotes or punctuation
response = response.strip('"\'.!? \n\t') response = response.strip("\"'.!? \n\t")
# Ensure title case and trim # Ensure title case and trim
return response.title()[:50] # Hard limit for safety return response.title()[:50] # Hard limit for safety
# Singleton instance # Singleton instance
title_generator = TitleGenerator() title_generator = TitleGenerator()

View File

@@ -3,16 +3,18 @@ from django.dispatch import receiver
from chat_backend.models import Document from chat_backend.models import Document
from .services.rag_services import AsyncRAGService from .services.rag_services import AsyncRAGService
@receiver(post_save, sender=Document) @receiver(post_save, sender=Document)
def update_vector_on_save(sender, instance, **kwargs): def update_vector_on_save(sender, instance, **kwargs):
"""Update vector store when documents are saved""" """Update vector store when documents are saved"""
if kwargs.get('created', False): if kwargs.get("created", False):
rag_service = AsyncRAGService() rag_service = AsyncRAGService()
rag_service.ingest_documents() rag_service.ingest_documents()
@receiver(post_delete, sender=Document) @receiver(post_delete, sender=Document)
def delete_vector_on_remove(sender, instance, **kwargs): def delete_vector_on_remove(sender, instance, **kwargs):
"""Handle document deletion by re-indexing the whole workspace""" """Handle document deletion by re-indexing the whole workspace"""
rag_service = AsyncRAGService() rag_service = AsyncRAGService()
rag_service.ingest_documents() rag_service.ingest_documents()

View File

@@ -13,80 +13,75 @@ from django.core.files.uploadedfile import SimpleUploadedFile
# Minimal valid PDF bytes # Minimal valid PDF bytes
VALID_PDF_BYTES = ( VALID_PDF_BYTES = (
b'%PDF-1.3\n' b"%PDF-1.3\n"
b'1 0 obj\n' b"1 0 obj\n"
b'<< /Type /Catalog /Pages 2 0 R >>\n' b"<< /Type /Catalog /Pages 2 0 R >>\n"
b'endobj\n' b"endobj\n"
b'2 0 obj\n' b"2 0 obj\n"
b'<< /Type /Pages /Kids [3 0 R] /Count 1 >>\n' b"<< /Type /Pages /Kids [3 0 R] /Count 1 >>\n"
b'endobj\n' b"endobj\n"
b'3 0 obj\n' b"3 0 obj\n"
b'<< /Type /Page /Parent 2 0 R /Resources << >> /MediaBox [0 0 612 792] /Contents 4 0 R >>\n' b"<< /Type /Page /Parent 2 0 R /Resources << >> /MediaBox [0 0 612 792] /Contents 4 0 R >>\n"
b'endobj\n' b"endobj\n"
b'4 0 obj\n' b"4 0 obj\n"
b'<< /Length 44 >>\n' b"<< /Length 44 >>\n"
b'stream\n' b"stream\n"
b'BT /F1 12 Tf 72 720 Td (Test PDF) Tj ET\n' b"BT /F1 12 Tf 72 720 Td (Test PDF) Tj ET\n"
b'endstream\n' b"endstream\n"
b'endobj\n' b"endobj\n"
b'xref\n' b"xref\n"
b'0 5\n' b"0 5\n"
b'0000000000 65535 f \n' b"0000000000 65535 f \n"
b'0000000009 00000 n \n' b"0000000009 00000 n \n"
b'0000000058 00000 n \n' b"0000000058 00000 n \n"
b'0000000117 00000 n \n' b"0000000117 00000 n \n"
b'0000000223 00000 n \n' b"0000000223 00000 n \n"
b'trailer\n' b"trailer\n"
b'<< /Size 5 /Root 1 0 R >>\n' b"<< /Size 5 /Root 1 0 R >>\n"
b'startxref\n' b"startxref\n"
b'317\n' b"317\n"
b'%%EOF' b"%%EOF"
) )
class DocumentWorkspaceViewsTestCase(APITestCase): class DocumentWorkspaceViewsTestCase(APITestCase):
def setUp(self): def setUp(self):
self.company = Company.objects.create( self.company = Company.objects.create(
name="test", name="test", state="IL", zipcode="60189", address="1968 Greensboro Dr"
state="IL",
zipcode="60189",
address="1968 Greensboro Dr"
) )
self.user = get_user_model().objects.create_user( self.user = get_user_model().objects.create_user(
company=self.company, company=self.company,
username='testuser', username="testuser",
password='testpass123', password="testpass123",
email="test@test.com", email="test@test.com",
) )
self.client = APIClient() self.client = APIClient()
self.client.force_authenticate(user=self.user) self.client.force_authenticate(user=self.user)
self.workspace = DocumentWorkspace.objects.create( self.workspace = DocumentWorkspace.objects.create(
company = self.user.company, company=self.user.company, name="Test Workspace"
name='Test Workspace'
) )
def test_list_workspaces(self): def test_list_workspaces(self):
url = reverse('document_workspaces') url = reverse("document_workspaces")
response = self.client.get(url) response = self.client.get(url)
self.assertEqual(response.status_code, status.HTTP_200_OK) self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertEqual(len(response.data), 1) self.assertEqual(len(response.data), 1)
self.assertEqual(response.data[0]['name'], 'Test Workspace') self.assertEqual(response.data[0]["name"], "Test Workspace")
def test_create_workspace(self): def test_create_workspace(self):
url = reverse('document_workspaces') url = reverse("document_workspaces")
data = { data = {"name": "New Workspace"}
'name': 'New Workspace' response = self.client.post(url, data, format="json")
}
response = self.client.post(url, data, format='json')
self.assertEqual(response.status_code, status.HTTP_201_CREATED) self.assertEqual(response.status_code, status.HTTP_201_CREATED)
self.assertEqual(DocumentWorkspace.objects.count(), 2) self.assertEqual(DocumentWorkspace.objects.count(), 2)
def test_retrieve_workspace(self): def test_retrieve_workspace(self):
url = reverse('document_workspaces') url = reverse("document_workspaces")
response = self.client.get(url) response = self.client.get(url)
self.assertEqual(response.status_code, status.HTTP_200_OK) self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertEqual(response.data[0]['name'], 'Test Workspace') self.assertEqual(response.data[0]["name"], "Test Workspace")
# def test_update_workspace(self): # def test_update_workspace(self):
# url = reverse('document_workspaces') # url = reverse('document_workspaces')
@@ -104,107 +99,89 @@ class DocumentWorkspaceViewsTestCase(APITestCase):
# self.assertEqual(response.status_code, status.HTTP_204_NO_CONTENT) # self.assertEqual(response.status_code, status.HTTP_204_NO_CONTENT)
# self.assertEqual(DocumentWorkspace.objects.count(), 0) # self.assertEqual(DocumentWorkspace.objects.count(), 0)
class DocumentViewsTestCase(APITestCase): class DocumentViewsTestCase(APITestCase):
def setUp(self): def setUp(self):
self.company = Company.objects.create( self.company = Company.objects.create(
name="test", name="test", state="IL", zipcode="60189", address="1968 Greensboro Dr"
state="IL",
zipcode="60189",
address="1968 Greensboro Dr"
) )
self.user = get_user_model().objects.create_user( self.user = get_user_model().objects.create_user(
company=self.company, company=self.company,
username='testuser', username="testuser",
password='testpass123', password="testpass123",
email="test@test.com", email="test@test.com",
) )
self.client = APIClient() self.client = APIClient()
self.client.force_authenticate(user=self.user) self.client.force_authenticate(user=self.user)
self.workspace = DocumentWorkspace.objects.create( self.workspace = DocumentWorkspace.objects.create(
company=self.user.company, company=self.user.company, name="Test Workspace"
name='Test Workspace'
) )
# Create a test file # Create a test file
self.test_file = SimpleUploadedFile( self.test_file = SimpleUploadedFile(
"test.pdf", "test.pdf", VALID_PDF_BYTES, content_type="application/pdf"
VALID_PDF_BYTES,
content_type="application/pdf"
) )
def test_upload_document(self): def test_upload_document(self):
url = reverse('documents') url = reverse("documents")
data = { data = {"file": self.test_file}
'file': self.test_file response = self.client.post(url, data, format="multipart")
}
response = self.client.post(url, data, format='multipart')
self.assertEqual(response.status_code, status.HTTP_201_CREATED) self.assertEqual(response.status_code, status.HTTP_201_CREATED)
self.assertEqual(Document.objects.count(), 1) self.assertEqual(Document.objects.count(), 1)
document = Document.objects.first() document = Document.objects.first()
self.assertEqual(document.workspace.id, self.workspace.id) self.assertEqual(document.workspace.id, self.workspace.id)
self.assertTrue(document.processed) # Should be False initially self.assertTrue(document.processed) # Should be False initially
def test_list_documents(self): def test_list_documents(self):
# First create a document # First create a document
Document.objects.create( Document.objects.create(workspace=self.workspace, file=self.test_file)
workspace=self.workspace,
file=self.test_file url = reverse("documents")
)
url = reverse('documents')
response = self.client.get(url) response = self.client.get(url)
self.assertEqual(response.status_code, status.HTTP_200_OK) self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertEqual(len(response.data), 1) self.assertEqual(len(response.data), 1)
self.assertIn('test', response.data[0]['file']) self.assertIn("test", response.data[0]["file"])
self.assertIn('pdf', response.data[0]['file']) self.assertIn("pdf", response.data[0]["file"])
# def test_delete_document(self): # def test_delete_document(self):
# document = Document.objects.create( # document = Document.objects.create(
# workspace=self.workspace, # workspace=self.workspace,
# file=self.test_file # file=self.test_file
# ) # )
# url = reverse('document-detail', args=[document.id]) # url = reverse('document-detail', args=[document.id])
# response = self.client.delete(url) # response = self.client.delete(url)
# self.assertEqual(response.status_code, status.HTTP_204_NO_CONTENT) # self.assertEqual(response.status_code, status.HTTP_204_NO_CONTENT)
# self.assertEqual(Document.objects.count(), 0) # self.assertEqual(Document.objects.count(), 0)
def test_upload_invalid_file(self): def test_upload_invalid_file(self):
url = reverse('documents') url = reverse("documents")
data = { data = {"file": "not a file"}
'file': 'not a file' response = self.client.post(url, data, format="multipart")
}
response = self.client.post(url, data, format='multipart')
self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
def test_access_other_users_documents(self): def test_access_other_users_documents(self):
# Create another user # Create another user
other_company = Company.objects.create( other_company = Company.objects.create(
name="test2", name="test2", state="IL", zipcode="60189", address="1968 Greensboro Dr"
state="IL",
zipcode="60189",
address="1968 Greensboro Dr"
) )
other_user = get_user_model().objects.create_user( other_user = get_user_model().objects.create_user(
company=other_company, company=other_company,
username='otheruser', username="otheruser",
password='otherpass123', password="otherpass123",
email="testing2@test.com" email="testing2@test.com",
) )
other_workspace = DocumentWorkspace.objects.create( other_workspace = DocumentWorkspace.objects.create(
company = other_user.company, company=other_user.company, name="Other Workspace"
name='Other Workspace'
) )
other_document = Document.objects.create( other_document = Document.objects.create(
workspace=other_workspace, workspace=other_workspace, file=self.test_file
file=self.test_file
) )
# Try to access the other user's document # Try to access the other user's document
url = reverse('documents_details', kwargs={"document_id":other_document.id}) url = reverse("documents_details", kwargs={"document_id": other_document.id})
response = self.client.get(url) response = self.client.get(url)
self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND) self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND)

View File

@@ -23,8 +23,7 @@ from .views import (
reset_password, reset_password,
DocumentWorkspaceView, DocumentWorkspaceView,
DocumentUploadView, DocumentUploadView,
DocumentDetailView DocumentDetailView,
) )
from rest_framework.routers import DefaultRouter from rest_framework.routers import DefaultRouter
@@ -80,11 +79,16 @@ urlpatterns = [
name="analytics_company_usage", name="analytics_company_usage",
), ),
path("analytics/admin/", AdminAnalytics.as_view(), name="analytics_admin"), path("analytics/admin/", AdminAnalytics.as_view(), name="analytics_admin"),
# document urls # document urls
path("document_workspaces/", DocumentWorkspaceView.as_view(), name="document_workspaces"), path(
"document_workspaces/",
DocumentWorkspaceView.as_view(),
name="document_workspaces",
),
path("documents/", DocumentUploadView.as_view(), name="documents"), path("documents/", DocumentUploadView.as_view(), name="documents"),
path("documents_details/<int:document_id>", DocumentDetailView.as_view(), name="documents_details"), path(
"documents_details/<int:document_id>",
DocumentDetailView.as_view(),
name="documents_details",
),
] ]

View File

@@ -13,7 +13,7 @@ from .serializers import (
PromptSerializer, PromptSerializer,
FeedbackSerializer, FeedbackSerializer,
DocumentWorkspaceSerializer, DocumentWorkspaceSerializer,
DocumentSerializer DocumentSerializer,
) )
from rest_framework.views import APIView from rest_framework.views import APIView
from rest_framework.response import Response from rest_framework.response import Response
@@ -25,7 +25,7 @@ from .models import (
Feedback, Feedback,
PromptMetric, PromptMetric,
DocumentWorkspace, DocumentWorkspace,
Document Document,
) )
from django.views.decorators.cache import never_cache from django.views.decorators.cache import never_cache
from django.http import JsonResponse from django.http import JsonResponse
@@ -99,8 +99,8 @@ class CustomUserCreate(APIView):
def send_invite_email(slug, email_to_invite): def send_invite_email(slug, email_to_invite):
print("Sending invite email") print("Sending invite email")
print(f"url : https://www.chat.aimloperations.com/set_password?slug={slug}") print(f"url : https://chat.aimloperations.com/set_password?slug={slug}")
url = f"https://www.chat.aimloperations.com/set_password?slug={slug}" url = f"https://chat.aimloperations.com/set_password?slug={slug}"
subject = "Welcome to AI ML Operations, LLC Chat Services" subject = "Welcome to AI ML Operations, LLC Chat Services"
from_email = "ryan@aimloperations.com" from_email = "ryan@aimloperations.com"
to = email_to_invite to = email_to_invite
@@ -113,6 +113,22 @@ def send_invite_email(slug, email_to_invite):
msg.send(fail_silently=True) msg.send(fail_silently=True)
def send_password_reset_email(slug, email_to_invite):
print("Sending reset email")
print(f"url : https://www.chat.aimloperations.com/set_password?slug={slug}")
url = f"https://www.chat.aimloperations.com/set_password?slug={slug}"
subject = "Password reset for AI ML Operations, LLC Chat Services"
from_email = "ryan@aimloperations.com"
to = email_to_invite
d = {"url": url}
html_content = get_template(r"emails/reset_email.html").render(d)
text_content = get_template(r"emails/reset_email.txt").render(d)
msg = EmailMultiAlternatives(subject, text_content, from_email, [to])
msg.attach_alternative(html_content, "text/html")
msg.send(fail_silently=True)
def send_feedback_email(feedback_obj): def send_feedback_email(feedback_obj):
print("Sending feedback email") print("Sending feedback email")
subject = "New Feedback for Chat by AI ML Operations, LLC" subject = "New Feedback for Chat by AI ML Operations, LLC"
@@ -176,19 +192,22 @@ class CustomUserInvite(APIView):
return Response(status=status.HTTP_201_CREATED) return Response(status=status.HTTP_201_CREATED)
@csrf_exempt @csrf_exempt
def reset_password(request): def reset_password(request):
if request.method == "POST": if request.method == "POST":
data = json.loads(request.body) data = json.loads(request.body)
token = data.get('recaptchaToken') token = data.get("recaptchaToken")
payload = { payload = {
'secret': settings.CAPTCHA_SECRET_KEY, "secret": settings.CAPTCHA_SECRET_KEY,
'response': token, "response": token,
} }
response = requests.post('https://www.google.com/recaptcha/api/siteverify', data=payload) response = requests.post(
"https://www.google.com/recaptcha/api/siteverify", data=payload
)
result = response.json() result = response.json()
if result.get('success') and result.get('score') >= 0.5: if result.get("success") and result.get("score") >= 0.5:
email = data.get('email') email = data.get("email")
user = CustomUser.objects.filter(email=email).first() user = CustomUser.objects.filter(email=email).first()
if user: if user:
user.set_unusable_password() user.set_unusable_password()
@@ -197,10 +216,10 @@ def reset_password(request):
# send the email # send the email
send_password_reset_email(user.slug, email) send_password_reset_email(user.slug, email)
JsonResponse(status=200) JsonResponse(status=200)
JsonResponse(status=400) JsonResponse(status=400)
class ResetUserPassword(APIView): class ResetUserPassword(APIView):
http_method_names = [ http_method_names = [
"post", "post",
@@ -214,14 +233,16 @@ class ResetUserPassword(APIView):
Also disable the account Also disable the account
""" """
print(f"Password reset for requests. {request.data}") print(f"Password reset for requests. {request.data}")
token = request.data.get('recaptchaToken') token = request.data.get("recaptchaToken")
payload = { payload = {
'secret': settings.CAPTCHA_SECRET_KEY, "secret": settings.CAPTCHA_SECRET_KEY,
'response': recaptchaToken, "response": recaptchaToken,
} }
response = requests.post('https://www.google.com/recaptcha/api/siteverify', data=payload) response = requests.post(
"https://www.google.com/recaptcha/api/siteverify", data=payload
)
result = response.json() result = response.json()
if result.get('success') and result.get('score') >= 0.5: if result.get("success") and result.get("score") >= 0.5:
user = CustomUser.objects.filter(email=email).first() user = CustomUser.objects.filter(email=email).first()
if user: if user:
user.set_unusable_password() user.set_unusable_password()
@@ -230,7 +251,7 @@ class ResetUserPassword(APIView):
# send the email # send the email
send_password_reset_email(user.slug, email) send_password_reset_email(user.slug, email)
else: else:
print('Captcha secret failed') print("Captcha secret failed")
return Response(status=status.HTTP_200_OK) return Response(status=status.HTTP_200_OK)
@@ -261,10 +282,16 @@ class CustomUserGet(APIView):
email = request.user.email email = request.user.email
username = request.user.username username = request.user.username
user = CustomUser.objects.get(email=email) user = CustomUser.objects.filter(email=email).last()
serializer = CustomUserSerializer(user) print(f"Getting the user: {user}")
try:
return Response(serializer.data, status=status.HTTP_200_OK) serializer = CustomUserSerializer(user)
print(f"serializer: {serializer}")
print(serializer.data)
return Response(serializer.data, status=status.HTTP_200_OK)
except Exception as e:
print(f"Exception: {e}")
return Response({}, status=status.HTTP_400_BAD_REQUEST)
class FeedbackView(APIView): class FeedbackView(APIView):
@@ -706,6 +733,7 @@ def get_workspace(conversation_id):
conversation = Conversation.objects.get(id=conversation_id) conversation = Conversation.objects.get(id=conversation_id)
return DocumentWorkspace.objects.get(company=conversation.user.company) return DocumentWorkspace.objects.get(company=conversation.user.company)
@database_sync_to_async @database_sync_to_async
def get_messages(conversation_id, prompt, file_string: str = None, file_type: str = ""): def get_messages(conversation_id, prompt, file_string: str = None, file_type: str = ""):
messages = [] messages = []
@@ -821,19 +849,21 @@ def finish_prompt_metric(prompt_metric, response_length):
prompt_metric.save(update_fields=["end_time", "reponse_length", "event"]) prompt_metric.save(update_fields=["end_time", "reponse_length", "event"])
print("finish_prompt_metric saved") print("finish_prompt_metric saved")
@database_sync_to_async @database_sync_to_async
def get_retriever(conversation_id): def get_retriever(conversation_id):
print(f'getting workspace from conversation: {conversation_id}') print(f"getting workspace from conversation: {conversation_id}")
conversation = Conversation.objects.get(id=conversation_id) conversation = Conversation.objects.get(id=conversation_id)
print(f'Got conversation: {conversation}') print(f"Got conversation: {conversation}")
workspace = DocumentWorkspace.objects.get(company=conversation.user.company) workspace = DocumentWorkspace.objects.get(company=conversation.user.company)
print(f'Got workspace: {conversation}') print(f"Got workspace: {conversation}")
vectorstore = Chroma( vectorstore = Chroma(
persist_directory=f"./chroma_db/", persist_directory=f"./chroma_db/",
embedding=OllamaEmbeddings(model="llama3.2"), embedding=OllamaEmbeddings(model="llama3.2"),
) )
return vectorstore.as_retriever() return vectorstore.as_retriever()
class ChatConsumerAgain(AsyncWebsocketConsumer): class ChatConsumerAgain(AsyncWebsocketConsumer):
async def connect(self): async def connect(self):
await self.accept() await self.accept()
@@ -881,7 +911,10 @@ class ChatConsumerAgain(AsyncWebsocketConsumer):
print(f'received: "{message}" for conversation {conversation_id}') print(f'received: "{message}" for conversation {conversation_id}')
# check the moderation here # check the moderation here
if await moderation_classifier.classify_async(message) == ModerationLabel.NSFW: if (
await moderation_classifier.classify_async(message)
== ModerationLabel.NSFW
):
response = "Prompt has been marked as NSFW. If this is in error, submit a feedback with the prompt text." response = "Prompt has been marked as NSFW. If this is in error, submit a feedback with the prompt text."
print("this prompt has been marked as NSFW") print("this prompt has been marked as NSFW")
await self.send("CONVERSATION_ID") await self.send("CONVERSATION_ID")
@@ -902,9 +935,8 @@ class ChatConsumerAgain(AsyncWebsocketConsumer):
prompt_type = await prompt_classifier.classify_async(message) prompt_type = await prompt_classifier.classify_async(message)
print(f"prompt_type: {prompt_type} for {message}") print(f"prompt_type: {prompt_type} for {message}")
if file:
prompt_type = PromptType.GENERAL_CHAT
prompt_metric = await create_prompt_metric( prompt_metric = await create_prompt_metric(
prompt.id, prompt.id,
@@ -928,22 +960,25 @@ class ChatConsumerAgain(AsyncWebsocketConsumer):
await self.send("START_OF_THE_STREAM_ENDER_GAME_42") await self.send("START_OF_THE_STREAM_ENDER_GAME_42")
if prompt_type == PromptType.RAG: if prompt_type == PromptType.RAG:
service = AsyncRAGService() service = AsyncRAGService()
#await service.ingest_documents() # await service.ingest_documents()
workspace = await get_workspace(conversation_id) workspace = await get_workspace(conversation_id)
print('Time to get the rag response') print("Time to get the rag response")
async for chunk in service.generate_response(messages, prompt.message, workspace): async for chunk in service.generate_response(
print(f"chunk: {chunk}") messages, prompt.message, workspace
):
response += chunk response += chunk
await self.send(chunk) await self.send(chunk)
elif prompt_type == PromptType.IMAGE_GENERATION: elif prompt_type == PromptType.IMAGE_GENERATION:
response = "Image Generation is not supported at this time, but it will be soon." response = "Image Generation is not supported at this time, but it will be soon."
await self.send(response) await self.send(response)
else: else:
print(f"using the AsyncLLMService\n\n{messages}\n{prompt.message}")
service = AsyncLLMService() service = AsyncLLMService()
async for chunk in service.generate_response(messages, prompt.message): async for chunk in service.generate_response(
print(f"chunk: {chunk}") messages, prompt.message
):
response += chunk response += chunk
await self.send(chunk) await self.send(chunk)
await self.send("END_OF_THE_STREAM_ENDER_GAME_42") await self.send("END_OF_THE_STREAM_ENDER_GAME_42")
@@ -954,9 +989,10 @@ class ChatConsumerAgain(AsyncWebsocketConsumer):
if bytes_data: if bytes_data:
print("we have byte data") print("we have byte data")
# Document Views # Document Views
class DocumentWorkspaceView(APIView): class DocumentWorkspaceView(APIView):
#permission_classes = [permissions.IsAuthenticated] # permission_classes = [permissions.IsAuthenticated]
def get(self, request): def get(self, request):
workspaces = DocumentWorkspace.objects.filter(company=request.user.company) workspaces = DocumentWorkspace.objects.filter(company=request.user.company)
@@ -970,70 +1006,78 @@ class DocumentWorkspaceView(APIView):
return Response(serializer.data, status=status.HTTP_201_CREATED) return Response(serializer.data, status=status.HTTP_201_CREATED)
return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST) return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST)
class DocumentUploadView(APIView): class DocumentUploadView(APIView):
#permission_classes = [permissions.IsAuthenticated]Z # permission_classes = [permissions.IsAuthenticated]Z
def get(self, request): def get(self, request):
print(f'request_3: {request}') print(f"request_3: {request}")
try: try:
workspace = DocumentWorkspace.objects.get(company=request.user.company) workspace = DocumentWorkspace.objects.get(company=request.user.company)
serializer = DocumentSerializer(Document.objects.filter(workspace=workspace), many=True) serializer = DocumentSerializer(
Document.objects.filter(workspace=workspace), many=True
)
return Response(serializer.data, status=status.HTTP_200_OK) return Response(serializer.data, status=status.HTTP_200_OK)
except: except:
return Response({'error': "Workspace not found"}, status=status.HTTP_404_NOT_FOUND) return Response(
{"error": "Workspace not found"}, status=status.HTTP_404_NOT_FOUND
)
def post(self, request): def post(self, request):
print(f'request: {request}') print(f"request: {request}")
try: try:
workspace = DocumentWorkspace.objects.get(company=request.user.company) workspace = DocumentWorkspace.objects.get(company=request.user.company)
except: except:
return Response({'error': "Workspace not found"}, status=status.HTTP_404_NOT_FOUND) return Response(
{"error": "Workspace not found"}, status=status.HTTP_404_NOT_FOUND
)
print(request.FILES) print(request.FILES)
file = request.FILES.get('file') file = request.FILES.get("file")
if not file: if not file:
return Response({"error":"No file provided"}, status=status.HTTP_400_BAD_REQUEST) return Response(
{"error": "No file provided"}, status=status.HTTP_400_BAD_REQUEST
)
print("have the workspace and the file") print("have the workspace and the file")
document = Document.objects.create( document = Document.objects.create(workspace=workspace, file=file)
workspace=workspace,
file=file
)
# process the document inthe background # process the document inthe background
self.process_document(document) self.process_document(document)
serializer = DocumentSerializer(document) serializer = DocumentSerializer(document)
return Response(serializer.data, status=status.HTTP_201_CREATED) return Response(serializer.data, status=status.HTTP_201_CREATED)
def process_document(self, document): def process_document(self, document):
file_path = os.path.join(settings.MEDIA_ROOT, document.file.name) file_path = os.path.join(settings.MEDIA_ROOT, document.file.name)
document.processed = True document.processed = True
document.active = True document.active = True
document.save() document.save()
service = AsyncRAGService() service = AsyncRAGService()
service.add_files_to_store([(file_path, document.file.name, document.workspace_id)], workspace_id=document.workspace_id) service.add_files_to_store(
[(file_path, document.file.name, document.workspace_id)],
workspace_id=document.workspace_id,
)
class DocumentDetailView(APIView): class DocumentDetailView(APIView):
#permission_classes = [permissions.IsAuthenticated] # permission_classes = [permissions.IsAuthenticated]
def get(self, request, document_id): def get(self, request, document_id):
print(f'request: {request}') print(f"request: {request}")
try: try:
workspace = DocumentWorkspace.objects.get(company=request.user.company) workspace = DocumentWorkspace.objects.get(company=request.user.company)
document = Document.objects.get( document = Document.objects.get(workspace=workspace, id=document_id)
workspace=workspace,
id=document_id
)
except: except:
return Response({'error': "Document not found"}, status=status.HTTP_404_NOT_FOUND) return Response(
{"error": "Document not found"}, status=status.HTTP_404_NOT_FOUND
)
serializer = DocumentWorkspaceSerializer(workspaces, many=True) serializer = DocumentWorkspaceSerializer(workspaces, many=True)
return Response(serializer.data) return Response(serializer.data)

View File

@@ -9,13 +9,18 @@ https://docs.djangoproject.com/en/3.2/howto/deployment/asgi/
import os import os
import django
from django.core.asgi import get_asgi_application from django.core.asgi import get_asgi_application
os.environ.setdefault("DJANGO_SETTINGS_MODULE", "llm_be.settings")
django.setup()
from channels.routing import ProtocolTypeRouter, URLRouter from channels.routing import ProtocolTypeRouter, URLRouter
from channels.auth import AuthMiddlewareStack from channels.auth import AuthMiddlewareStack
import chat_backend.routing import chat_backend.routing
os.environ.setdefault("DJANGO_SETTINGS_MODULE", "llm_be.settings")
application = ProtocolTypeRouter( application = ProtocolTypeRouter(
{ {
"http": get_asgi_application(), "http": get_asgi_application(),

View File

@@ -161,7 +161,7 @@ REST_FRAMEWORK = {
} }
SIMPLE_JWT = { SIMPLE_JWT = {
"ACCESS_TOKEN_LIFETIME": timedelta(hours=24), "ACCESS_TOKEN_LIFETIME": timedelta(hours=5),
"REFRESH_TOKEN_LIFETIME": timedelta(days=14), "REFRESH_TOKEN_LIFETIME": timedelta(days=14),
"ROTATE_REFRESH_TOKENS": True, "ROTATE_REFRESH_TOKENS": True,
"BLACKLIST_AFTER_ROTATION": True, "BLACKLIST_AFTER_ROTATION": True,
@@ -207,4 +207,4 @@ EMAIL_PORT = 2525
EMAIL_USE_TLS = True EMAIL_USE_TLS = True
# Captcha # Captcha
CAPTCHA_SECRET_KEY = "6LfENu4qAAAAABdrj6JTviq-LfdPP5imhE-Os7h9" CAPTCHA_SECRET_KEY = "6LfENu4qAAAAABdrj6JTviq-LfdPP5imhE-Os7h9"

View File

@@ -44,7 +44,7 @@ durationpy==0.9
effdet==0.4.1 effdet==0.4.1
emoji==2.14.1 emoji==2.14.1
eval_type_backport==0.2.2 eval_type_backport==0.2.2
Faker==37.0.0 Faker
fastapi==0.115.9 fastapi==0.115.9
filelock==3.17.0 filelock==3.17.0
filetype==1.2.0 filetype==1.2.0
@@ -52,13 +52,13 @@ flatbuffers==25.2.10
fonttools==4.56.0 fonttools==4.56.0
frozenlist==1.6.0 frozenlist==1.6.0
fsspec==2025.2.0 fsspec==2025.2.0
google-api-core==2.24.2 google-api-core
google-auth==2.39.0 google-auth
google-cloud-vision==3.10.1 google-cloud-vision
googleapis-common-protos==1.70.0 googleapis-common-protos
greenlet==3.1.1 greenlet==3.1.1
grpcio==1.72.0rc1 grpcio
grpcio-status==1.72.0rc1 grpcio-status
h11==0.14.0 h11==0.14.0
html5lib==1.1 html5lib==1.1
httpcore==1.0.7 httpcore==1.0.7
@@ -121,13 +121,13 @@ oauthlib==3.2.2
olefile==0.47 olefile==0.47
ollama==0.4.7 ollama==0.4.7
omegaconf==2.3.0 omegaconf==2.3.0
onnx==1.18.0 onnx
onnxruntime==1.21.1 onnxruntime
openai==1.65.4 openai==1.65.4
opencv-python==4.11.0.86 opencv-python==4.11.0.86
opentelemetry-api==1.32.1 opentelemetry-api
opentelemetry-exporter-otlp-proto-common==1.32.1 opentelemetry-exporter-otlp-proto-common
opentelemetry-exporter-otlp-proto-grpc==1.32.1 opentelemetry-exporter-otlp-proto-grpc
opentelemetry-instrumentation==0.53b1 opentelemetry-instrumentation==0.53b1
opentelemetry-instrumentation-asgi==0.53b1 opentelemetry-instrumentation-asgi==0.53b1
opentelemetry-instrumentation-fastapi==0.53b1 opentelemetry-instrumentation-fastapi==0.53b1
@@ -138,8 +138,8 @@ opentelemetry-util-http==0.53b1
orjson==3.10.15 orjson==3.10.15
overrides==7.7.0 overrides==7.7.0
packaging==24.2 packaging==24.2
pandas==2.2.3 pandas
pandasai==2.4.2 pandasai
pathspec==0.12.1 pathspec==0.12.1
pdf2image==1.17.0 pdf2image==1.17.0
pdfminer.six==20250506 pdfminer.six==20250506
@@ -150,7 +150,7 @@ platformdirs==4.3.6
posthog==4.0.1 posthog==4.0.1
propcache==0.3.1 propcache==0.3.1
proto-plus==1.26.1 proto-plus==1.26.1
protobuf==6.31.0rc2 protobuf
psutil==7.0.0 psutil==7.0.0
pyasn1==0.6.1 pyasn1==0.6.1
pyasn1_modules==0.4.1 pyasn1_modules==0.4.1
@@ -160,7 +160,7 @@ pydantic==2.11.4
pydantic-settings==2.9.1 pydantic-settings==2.9.1
pydantic_core==2.33.2 pydantic_core==2.33.2
Pygments==2.19.1 Pygments==2.19.1
PyJWT==2.10.1 PyJWT
pyOpenSSL==25.0.0 pyOpenSSL==25.0.0
pyparsing==3.2.1 pyparsing==3.2.1
pypdf==5.4.0 pypdf==5.4.0