diff --git a/llm_be/chat_backend/admin.py b/llm_be/chat_backend/admin.py index e7ae303..e9b72c6 100644 --- a/llm_be/chat_backend/admin.py +++ b/llm_be/chat_backend/admin.py @@ -9,7 +9,7 @@ from .models import ( Feedback, PromptMetric, DocumentWorkspace, - Document + Document, ) # Register your models here. @@ -79,14 +79,15 @@ class PromptMetricAdmin(admin.ModelAdmin): "get_duration", ) + class DocumentWorkspaceAdmin(admin.ModelAdmin): model = DocumentWorkspace list_display = ( "name", "company", - ) + class DocumentAdmin(admin.ModelAdmin): model = Document list_display = ( diff --git a/llm_be/chat_backend/apps.py b/llm_be/chat_backend/apps.py index 524ac53..8d31a58 100644 --- a/llm_be/chat_backend/apps.py +++ b/llm_be/chat_backend/apps.py @@ -8,18 +8,19 @@ class ChatBackendConfig(AppConfig): name = "chat_backend" def ready(self): - import chat_backend.signals + import chat_backend.signals + FORCE_RELOAD = False - if True: #not settings.TESTING: # Don't run during tests + if True: # not settings.TESTING: # Don't run during tests try: from .services.rag_services import AsyncRAGService from chat_backend.models import Document - + # Check if Chroma needs initialization if Document.objects.exists(): rag_service = AsyncRAGService() - + if rag_service.vector_store._collection.count() == 0: print("Initializing ChromaDB with existing documents...") rag_service.ingest_documents() diff --git a/llm_be/chat_backend/models.py b/llm_be/chat_backend/models.py index d2b5ef2..cdb93a9 100644 --- a/llm_be/chat_backend/models.py +++ b/llm_be/chat_backend/models.py @@ -53,6 +53,9 @@ class Company(TimeInfoBase): help_text="A list of LLMs that company can use", ) + def __str__(self): + return self.name + class CustomUser(AbstractUser): company = models.ForeignKey( @@ -71,7 +74,7 @@ class CustomUser(AbstractUser): ) def get_set_password_url(self): - return f"https://www.chat.aimloperations.com/set_password?slug={self.slug}" + return f"https://chat.aimloperations.com/set_password?slug={self.slug}" FEEDBACK_CHOICE = ( @@ -220,14 +223,16 @@ class PromptMetric(TimeInfoBase): return difference.seconds return 0 + # Document Models class DocumentWorkspace(TimeInfoBase): name = models.CharField(max_length=255) company = models.ForeignKey(Company, on_delete=models.CASCADE) + class Document(TimeInfoBase): workspace = models.ForeignKey(DocumentWorkspace, on_delete=models.CASCADE) - file = models.FileField(upload_to='documents/') + file = models.FileField(upload_to="documents/") uploaded_at = models.DateTimeField(auto_now_add=True) processed = models.BooleanField(default=False) active = models.BooleanField(default=False) diff --git a/llm_be/chat_backend/serializers.py b/llm_be/chat_backend/serializers.py index c4906ec..882b190 100644 --- a/llm_be/chat_backend/serializers.py +++ b/llm_be/chat_backend/serializers.py @@ -9,7 +9,7 @@ from .models import ( Feedback, FEEDBACK_CATEGORIES, DocumentWorkspace, - Document + Document, ) @@ -99,11 +99,20 @@ class BasicUserSerializer(serializers.ModelSerializer): class DocumentWorkspaceSerializer(serializers.ModelSerializer): class Meta: model = DocumentWorkspace - fields = ['id', 'name', 'created'] - read_only_fields = ['id', 'created'] + fields = ["id", "name", "created"] + read_only_fields = ["id", "created"] + class DocumentSerializer(serializers.ModelSerializer): class Meta: model = Document - fields = ['id', 'workspace', 'file', 'uploaded_at', 'processed', 'created', 'active'] - read_only_fields = ['id', 'uploaded_at', 'processed', 'created'] \ No newline at end of file + fields = [ + "id", + "workspace", + "file", + "uploaded_at", + "processed", + "created", + "active", + ] + read_only_fields = ["id", "uploaded_at", "processed", "created"] diff --git a/llm_be/chat_backend/services/image_generation.py b/llm_be/chat_backend/services/image_generation.py index f6d95d6..8e781e4 100644 --- a/llm_be/chat_backend/services/image_generation.py +++ b/llm_be/chat_backend/services/image_generation.py @@ -7,21 +7,22 @@ from diffusers import StableDiffusionPipeline, DPMSolverSinglestepScheduler logger = logging.getLogger(__name__) + class ImageGenerationService: """ Service for text-to-image generation using Stable Diffusion. Uses singleton pattern to maintain loaded model in memory. """ - + _instance = None _model_loaded = False - + def __new__(cls): if cls._instance is None: cls._instance = super().__new__(cls) cls._instance._initialize() return cls._instance - + def _initialize(self): """Initialize the service with default settings""" self.device = "cuda" if torch.cuda.is_available() else "cpu" @@ -33,15 +34,15 @@ class ImageGenerationService: "width": 512, "height": 512, } - + def load_model(self): """Load the Stable Diffusion model""" if self._model_loaded: return - + try: logger.info(f"Loading Stable Diffusion model on {self.device}...") - + # Use DPMSolver for faster inference self.pipeline = StableDiffusionPipeline.from_pretrained( self.model_id, @@ -51,15 +52,15 @@ class ImageGenerationService: self.pipeline.scheduler.config ) self.pipeline = self.pipeline.to(self.device) - + # Optimizations if self.device == "cuda": self.pipeline.enable_attention_slicing() self.pipeline.enable_xformers_memory_efficient_attention() - + self._model_loaded = True logger.info("Model loaded successfully") - + except Exception as e: logger.error(f"Failed to load model: {str(e)}") raise RuntimeError(f"Model loading failed: {str(e)}") @@ -69,45 +70,43 @@ class ImageGenerationService: prompt: str, negative_prompt: Optional[str] = None, output_path: Optional[str] = None, - **kwargs + **kwargs, ) -> Tuple[Image.Image, dict]: """ Generate image from text prompt. - + Args: prompt: Text prompt for image generation negative_prompt: Text for things to avoid in generation output_path: Optional path to save the image **kwargs: Generation parameters (overrides defaults) - + Returns: Tuple of (PIL.Image, generation_parameters) """ if not self._model_loaded: self.load_model() - + # Merge default params with overrides params = {**self.default_params, **kwargs} - + try: logger.info(f"Generating image with prompt: {prompt[:50]}...") - + with torch.inference_mode(): result = self.pipeline( - prompt=prompt, - negative_prompt=negative_prompt, - **params + prompt=prompt, negative_prompt=negative_prompt, **params ) - + image = result.images[0] - + if output_path: os.makedirs(os.path.dirname(output_path), exist_ok=True) image.save(output_path) logger.info(f"Image saved to {output_path}") - + return image, params - + except Exception as e: logger.error(f"Image generation failed: {str(e)}") raise RuntimeError(f"Image generation failed: {str(e)}") @@ -118,28 +117,28 @@ class AsyncImageGenerationService: Asynchronous wrapper for image generation service. Runs the synchronous service in a thread pool. """ - + def __init__(self): self.sync_service = ImageGenerationService() - + async def generate_image( self, prompt: str, negative_prompt: Optional[str] = None, output_path: Optional[str] = None, - **kwargs + **kwargs, ) -> Tuple[Image.Image, dict]: """Async version of generate_image""" import asyncio from functools import partial - + loop = asyncio.get_event_loop() func = partial( self.sync_service.generate_image, prompt=prompt, negative_prompt=negative_prompt, output_path=output_path, - **kwargs + **kwargs, ) - - return await loop.run_in_executor(None, func) \ No newline at end of file + + return await loop.run_in_executor(None, func) diff --git a/llm_be/chat_backend/services/llm_service.py b/llm_be/chat_backend/services/llm_service.py index 1d24f08..241e2f6 100644 --- a/llm_be/chat_backend/services/llm_service.py +++ b/llm_be/chat_backend/services/llm_service.py @@ -1,47 +1,50 @@ from abc import ABC, abstractmethod from typing import AsyncGenerator, Generator, Optional -from langchain_community.llms import Ollama +# from langchain_community.llms import Ollama +from langchain_ollama import OllamaLLM from langchain_core.output_parsers import StrOutputParser from langchain_core.prompts import ChatPromptTemplate from chat_backend.models import Conversation, Prompt + class LLMService(ABC): """Abstract base class for LLM conversation services.""" - + def __init__(self): - self.llm = Ollama( + self.llm = OllamaLLM( model="llama3.2", temperature=0.7, top_k=50, top_p=0.9, repeat_penalty=1.1, - num_ctx=4096 + num_ctx=4096, ) self.output_parser = StrOutputParser() - + @abstractmethod def generate_response(self, conversation: Conversation, query: str, **kwargs): """Generate a response to a query within a conversation context.""" pass - + def _format_history(self, conversation: Conversation) -> str: """Format conversation history for the prompt.""" - prompts = Prompt.objects.filter(conversation=conversation).order_by('created_at') + prompts = Prompt.objects.filter(conversation=conversation).order_by( + "created_at" + ) return "\n".join( - f"{'User' if prompt.is_user else 'AI'}: {prompt.text}" - for prompt in prompts + f"{'User' if prompt.is_user else 'AI'}: {prompt.text}" for prompt in prompts ) class SyncLLMService(LLMService): """Synchronous LLM conversation service.""" - + def __init__(self): super().__init__() self._setup_chain() - + def _setup_chain(self): """Setup the conversation chain.""" template = """Continue the conversation based on the following history: @@ -52,35 +55,36 @@ class SyncLLMService(LLMService): Response:""" self.prompt = ChatPromptTemplate.from_template(template) - + self.conversation_chain = ( { "history": lambda x: self._format_history(x["conversation"]), - "query": lambda x: x["query"] + "query": lambda x: x["query"], } | self.prompt | self.llm | self.output_parser ) - - def generate_response(self, conversation: Conversation, query: str, **kwargs) -> Generator[str, None, None]: + + def generate_response( + self, conversation: Conversation, query: str, **kwargs + ) -> Generator[str, None, None]: """Generate response with streaming support.""" - chain_input = { - "query": query, - "conversation": conversation - } - + chain_input = {"query": query, "conversation": conversation} + + print(f"chain_input:\n{chain_input}") + for chunk in self.conversation_chain.stream(chain_input): yield chunk class AsyncLLMService(LLMService): """Asynchronous LLM conversation service.""" - + def __init__(self): super().__init__() self._setup_chain() - + def _setup_chain(self): """Setup the conversation chain.""" template = """Continue this conversation while maintaining context by providing a single helpful response. @@ -97,42 +101,50 @@ class AsyncLLMService(LLMService): - When asked to modify something, identify what's being modified Response:""" - + self.prompt = ChatPromptTemplate.from_template(template) - + self.conversation_chain = ( { "context": lambda x: self._format_history(x["conversation"]), - "recent_history": lambda x: self._get_recent_messages(x["conversation"]), - "query": lambda x: x["query"] + "recent_history": lambda x: self._get_recent_messages( + x["conversation"] + ), + "query": lambda x: x["query"], } | self.prompt | self.llm | self.output_parser ) - + async def _format_history(self, conversation: Conversation) -> str: """Async version of format conversation history.""" - prompts = await Prompt.objects.filter(conversation=conversation).order_by('created_at').alist() + prompts = ( + await Prompt.objects.filter(conversation=conversation) + .order_by("created_at") + .alist() + ) return "\n".join( - f"{'User' if prompt.is_user else 'AI'}: {prompt.text}" - for prompt in prompts + f"{'User' if prompt.is_user else 'AI'}: {prompt.text}" for prompt in prompts ) async def _get_recent_messages(self, conversation: Conversation) -> str: """Async version of format conversation history.""" - prompts = await Prompt.objects.filter(conversation=conversation).order_by('created_at').alist()[-3:] - return "\n".join( - f"{'User' if prompt.is_user else 'AI'}: {prompt.text}" - for prompt in prompts + prompts = ( + await Prompt.objects.filter(conversation=conversation) + .order_by("created_at") + .alist()[-6:] ) - - async def generate_response(self, conversation: Conversation, query: str, **kwargs) -> AsyncGenerator[str, None]: + return "\n".join( + f"{'User' if prompt.is_user else 'AI'}: {prompt.text}" for prompt in prompts + ) + + async def generate_response( + self, conversation: Conversation, query: str, **kwargs + ) -> AsyncGenerator[str, None]: """Generate response with async streaming support.""" - chain_input = { - "query": query, - "conversation": conversation - } - + chain_input = {"query": query, "conversation": conversation} + print(f"LLM Chain:\n{chain_input}") + async for chunk in self.conversation_chain.astream(chain_input): - yield chunk \ No newline at end of file + yield chunk diff --git a/llm_be/chat_backend/services/moderation_classifier.py b/llm_be/chat_backend/services/moderation_classifier.py index d66968b..a2af95a 100644 --- a/llm_be/chat_backend/services/moderation_classifier.py +++ b/llm_be/chat_backend/services/moderation_classifier.py @@ -1,27 +1,34 @@ from enum import Enum, auto from typing import Dict, Any from langchain_core.prompts import ChatPromptTemplate -from langchain_community.llms import Ollama + +# from langchain_community.llms import Ollama +from langchain_ollama import OllamaLLM + class ModerationLabel(Enum): NSFW = auto() FINE = auto() + class ModerationClassifier: """ Classifies prompts as NSFW or FINE (safe) content. """ - + def __init__(self): - self.llm = Ollama( + self.llm = OllamaLLM( model="llama3.2", temperature=0.1, # Very low for strict moderation top_k=10, - num_ctx=2048 + num_ctx=2048, ) - - self.moderation_prompt = ChatPromptTemplate.from_messages([ - ("system", """You are a strict content moderator. Classify the following prompt as either NSFW or FINE. + + self.moderation_prompt = ChatPromptTemplate.from_messages( + [ + ( + "system", + """You are a strict content moderator. Classify the following prompt as either NSFW or FINE. NSFW includes: - Sexual content @@ -44,12 +51,14 @@ Examples: - "Explicit sex scene" → NSFW - "Python tutorial" → FINE -Return ONLY "NSFW" or "FINE", nothing else."""), - ("human", "{prompt}") - ]) - +Return ONLY "NSFW" or "FINE", nothing else.""", + ), + ("human", "{prompt}"), + ] + ) + self.chain = self.moderation_prompt | self.llm - + async def classify_async(self, prompt: str) -> ModerationLabel: """Asynchronous classification""" try: @@ -58,7 +67,7 @@ Return ONLY "NSFW" or "FINE", nothing else."""), except Exception as e: print(f"Moderation error: {e}") return ModerationLabel.NSFW # Fail-safe to NSFW - + def classify(self, prompt: str) -> ModerationLabel: """Synchronous classification""" try: @@ -67,7 +76,7 @@ Return ONLY "NSFW" or "FINE", nothing else."""), except Exception as e: print(f"Moderation error: {e}") return ModerationLabel.NSFW # Fail-safe to NSFW - + def _parse_response(self, response: str) -> ModerationLabel: """Convert string response to ModerationLabel enum""" if "NSFW" in response: @@ -76,4 +85,4 @@ Return ONLY "NSFW" or "FINE", nothing else."""), # Singleton instance -moderation_classifier = ModerationClassifier() \ No newline at end of file +moderation_classifier = ModerationClassifier() diff --git a/llm_be/chat_backend/services/prompt_classifier.py b/llm_be/chat_backend/services/prompt_classifier.py index 4548f46..cf3534c 100644 --- a/llm_be/chat_backend/services/prompt_classifier.py +++ b/llm_be/chat_backend/services/prompt_classifier.py @@ -1,7 +1,10 @@ from enum import Enum, auto from typing import Dict, Any from langchain_core.prompts import ChatPromptTemplate -from langchain_community.llms import Ollama + +# from langchain_community.llms import Ollama +from langchain_ollama import OllamaLLM + class PromptType(Enum): GENERAL_CHAT = auto() @@ -9,23 +12,26 @@ class PromptType(Enum): IMAGE_GENERATION = auto() UNKNOWN = auto() + class PromptClassifier: """ Classifies user prompts to determine which service should handle them. """ - + def __init__(self): - self.llm = Ollama( - model="llama3", + self.llm = OllamaLLM( + model="llama3.2", temperature=0.3, # Lower temp for more deterministic classification top_k=20, top_p=0.9, - num_ctx=4096 + num_ctx=4096, ) - - self.classification_prompt = ChatPromptTemplate.from_messages([ - ("system", - """You are a precision prompt classifier. Strictly categorize prompts into: + + self.classification_prompt = ChatPromptTemplate.from_messages( + [ + ( + "system", + """You are a precision prompt classifier. Strictly categorize prompts into: 1. GENERAL_CHAT - Casual conversation, personal questions, or non-specific inquiries 2. RAG - ONLY when explicitly requesting document/search-based knowledge 3. IMAGE_GENERATION - Specific requests to create/modify images @@ -63,12 +69,14 @@ Examples: - "Explain quantum computing" (General knowledge) - "Summarize the meeting" (No doc reference) -Return ONLY the label, no explanations."""), - ("human", "{prompt}") - ]) - +Return ONLY the label, no explanations.""", + ), + ("human", "{prompt}"), + ] + ) + self.chain = self.classification_prompt | self.llm - + async def classify_async(self, prompt: str) -> PromptType: """Asynchronously classify the prompt""" try: @@ -77,7 +85,7 @@ Return ONLY the label, no explanations."""), except Exception as e: print(f"Classification error: {e}") return PromptType.UNKNOWN - + def classify(self, prompt: str) -> PromptType: """Synchronously classify the prompt""" try: @@ -86,7 +94,7 @@ Return ONLY the label, no explanations."""), except Exception as e: print(f"Classification error: {e}") return PromptType.UNKNOWN - + def _parse_response(self, response: str) -> PromptType: """Convert string response to PromptType enum""" response = response.upper() @@ -97,4 +105,4 @@ Return ONLY the label, no explanations."""), # Singleton instance for easy access -prompt_classifier = PromptClassifier() \ No newline at end of file +prompt_classifier = PromptClassifier() diff --git a/llm_be/chat_backend/services/rag_services.py b/llm_be/chat_backend/services/rag_services.py index 4608822..208570f 100644 --- a/llm_be/chat_backend/services/rag_services.py +++ b/llm_be/chat_backend/services/rag_services.py @@ -3,7 +3,9 @@ from abc import ABC, abstractmethod from typing import List, Dict, Any, AsyncGenerator, Generator, Optional from channels.db import database_sync_to_async from langchain_community.embeddings import OllamaEmbeddings -from langchain_community.llms import Ollama + +# from langchain_community.llms import Ollama +from langchain_ollama import OllamaLLM from langchain_community.vectorstores import Chroma from langchain_core.documents import Document as LangDocument from langchain_core.output_parsers import StrOutputParser @@ -14,139 +16,139 @@ from langchain_community.document_loaders import ( PyPDFLoader, Docx2txtLoader, TextLoader, - UnstructuredFileLoader + UnstructuredFileLoader, ) from django.core.files.uploadedfile import UploadedFile from chat_backend.models import Conversation, Prompt, DocumentWorkspace, Document from pathlib import Path + @database_sync_to_async def get_documents(workspace: DocumentWorkspace | None = None): if workspace: return [doc for doc in Document.objects.filter(workspace=workspace)] else: - return [doc for doc in Document.objects.all()] - + return [doc for doc in Document.objects.all()] class RAGService(ABC): """Abstract base class for RAG services.""" + _instance = None - + def __new__(cls): if cls._instance is None: cls._instance = super().__new__(cls) cls._instance.__init__() return cls._instance - + def __init__(self): self.embedding_model = OllamaEmbeddings(model="llama3.2") - self.llm = Ollama( + self.llm = OllamaLLM( model="llama3.2", temperature=0.7, top_k=50, top_p=0.9, repeat_penalty=1.1, - num_ctx=4096 + num_ctx=4096, ) self.text_splitter = RecursiveCharacterTextSplitter( - chunk_size=1000, - chunk_overlap=200 + chunk_size=1000, chunk_overlap=200 ) self.vector_store = self._initialize_vector_store() # Supported file types and their loaders self.loader_mapping = { - '.pdf': PyPDFLoader, - '.docx': Docx2txtLoader, - '.txt': TextLoader, + ".pdf": PyPDFLoader, + ".docx": Docx2txtLoader, + ".txt": TextLoader, # Fallback for other file types - '*': UnstructuredFileLoader, + "*": UnstructuredFileLoader, } - + def _initialize_vector_store(self) -> Chroma: """Initialize and return the Chroma vector store.""" - persist_directory=f"./chroma_db/" + persist_directory = f"./chroma_db/" vector_store = Chroma( - embedding_function=self.embedding_model, - persist_directory=persist_directory + embedding_function=self.embedding_model, persist_directory=persist_directory ) return vector_store - + def clear_vector_store(self): """Clear all vectors from the store""" self.vector_store.delete_collection() self.vector_store = self._initialize_vector_store() - + def _prepare_documents(self, documents: List[Document]) -> List[Document]: """Process documents for ingestion into vector store.""" docs = [] - + for doc in documents: print(f"Processing: {doc.file.name}") - loader_class = self._get_file_loader( doc.file.name) + loader_class = self._get_file_loader(doc.file.name) loader = loader_class(doc.file) - - + chunks = self._load_and_split_documents(doc.file.path) if chunks: self.vector_store.add_documents(chunks) self.vector_store.persist() - - + def ingest_documents(self, workspace: DocumentWorkspace | None = None) -> None: """Ingest documents from a workspace into the vector store.""" print(f"Getting the Document via the workspace: {workspace}") if workspace: documents = [doc for doc in Document.objects.filter(workspace=workspace)] else: - documents = [doc for doc in Document.objects.all()] - + documents = [doc for doc in Document.objects.all()] + print(f"Processing the documents : {documents}") self._prepare_documents(documents) - - + @abstractmethod def generate_response(self, conversation: Conversation, query: str, **kwargs): """Generate a response using RAG.""" pass - + @abstractmethod - def search_documents(self, query: str, workspace: Optional[DocumentWorkspace] = None, k: int = 4) -> List[Document]: + def search_documents( + self, query: str, workspace: Optional[DocumentWorkspace] = None, k: int = 4 + ) -> List[Document]: """Search relevant documents from the vector store.""" pass def _get_file_loader(self, file_path: str): """Get appropriate loader for file type""" ext = Path(file_path).suffix.lower() - return self.loader_mapping.get(ext, self.loader_mapping['*']) - + return self.loader_mapping.get(ext, self.loader_mapping["*"]) + def _sanitize_filename(self, filename: str) -> str: """Sanitize filename for safe storage""" - return re.sub(r'[^\w\-_. ]', '_', filename) - + return re.sub(r"[^\w\-_. ]", "_", filename) + def _save_uploaded_file(self, uploaded_file: UploadedFile, save_dir: str) -> str: """Save uploaded file to disk""" os.makedirs(save_dir, exist_ok=True) sanitized_name = self._sanitize_filename(uploaded_file.name) file_path = os.path.join(save_dir, sanitized_name) - - with open(file_path, 'wb+') as destination: + + with open(file_path, "wb+") as destination: for chunk in uploaded_file.chunks(): destination.write(chunk) - + return file_path - def _load_and_split_documents(self, file_path: str, metadata: dict = None) -> List[Document]: + def _load_and_split_documents( + self, file_path: str, metadata: dict = None + ) -> List[Document]: """Load and split documents from file""" loader_class = self._get_file_loader(file_path) loader = loader_class(file_path) - + docs = loader.load() if metadata: for doc in docs: doc.metadata.update(metadata) - + return self.text_splitter.split_documents(docs) def add_files_to_store( @@ -154,58 +156,51 @@ class RAGService(ABC): file_tupls: List[UploadedFile], # (file_path, name,workspace_id) workspace_id: str, source: str = "upload", - save_dir: str = "data/uploads" + save_dir: str = "data/uploads", ) -> Dict[str, Any]: """ Process and add uploaded files to vector store - + Args: files: List of Django UploadedFile objects workspace_id: ID of the workspace these belong to source: Source identifier for documents save_dir: Directory to save uploaded files - + Returns: Dictionary with processing results """ - results = { - 'total_added': 0, - 'failed_files': [], - 'processed_files': [] - } - + results = {"total_added": 0, "failed_files": [], "processed_files": []} + for file_tuple in file_tupls: try: # Save file to disk - - + # Prepare metadata metadata = { - 'source': file_tuple[1], - 'workspace_id': file_tuple[2], - 'original_filename': file_tuple[1], - 'file_path': file_tuple[0], + "source": file_tuple[1], + "workspace_id": file_tuple[2], + "original_filename": file_tuple[1], + "file_path": file_tuple[0], } - + # Load and split documents docs = self._load_and_split_documents(file_path, metadata) - + # Add to vector store if docs: self.vector_store.add_documents(docs) - results['total_added'] += len(docs) - results['processed_files'].append({ - 'filename': file_tuple[1], - 'document_count': len(docs) - }) - + results["total_added"] += len(docs) + results["processed_files"].append( + {"filename": file_tuple[1], "document_count": len(docs)} + ) + except Exception as e: - results['failed_files'].append({ - 'filename': file_tuple[1], - 'error': str(e) - }) + results["failed_files"].append( + {"filename": file_tuple[1], "error": str(e)} + ) continue - + # Persist changes self.vector_store.persist() return results @@ -213,11 +208,11 @@ class RAGService(ABC): class SyncRAGService(RAGService): """Synchronous RAG service implementation.""" - + def __init__(self): super().__init__() self._setup_chain() - + def _setup_chain(self): """Setup the RAG chain.""" template = """Answer the question based only on the following context: @@ -229,31 +224,32 @@ class SyncRAGService(RAGService): Question: {question} """ self.prompt = ChatPromptTemplate.from_template(template) - + self.rag_chain = ( { "context": self._retriever_with_history, "history": lambda x: self._format_history(x["conversation"]), - "question": lambda x: x["query"] + "question": lambda x: x["query"], } | self.prompt | self.llm | StrOutputParser() ) - + def _format_history(self, conversation: Conversation) -> str: """Format conversation history for the prompt.""" - prompts = Prompt.objects.filter(conversation=conversation).order_by('created_at') - return "\n".join( - f"{'User' if prompt.is_user else 'AI'}: {prompt.text}" - for prompt in prompts + prompts = Prompt.objects.filter(conversation=conversation).order_by( + "created_at" ) - + return "\n".join( + f"{'User' if prompt.is_user else 'AI'}: {prompt.text}" for prompt in prompts + ) + def _retriever_with_history(self, input_dict: Dict[str, Any]) -> str: """Retrieve documents considering conversation history.""" query = input_dict["query"] conversation = input_dict["conversation"] - + # You could enhance this to consider historical context in retrieval relevant_docs = self.search_documents(query, conversation.workspace) if not relevant_docs: @@ -262,8 +258,9 @@ class SyncRAGService(RAGService): else: return relevant_docs - - def search_documents(self, query: str, workspace: Optional[DocumentWorkspace] = None, k: int = 4) -> List[Document]: + def search_documents( + self, query: str, workspace: Optional[DocumentWorkspace] = None, k: int = 4 + ) -> List[Document]: """Search relevant documents from the vector store.""" filter_dict = {} if workspace: @@ -271,31 +268,27 @@ class SyncRAGService(RAGService): print(f"search_kwargs: {search_kwargs}") retriever = self.vector_store.as_retriever( search_type="similarity", - search_kwargs={ - "k": k, - "filter": filter_dict if filter_dict else None - } + search_kwargs={"k": k, "filter": filter_dict if filter_dict else None}, ) return retriever.get_relevant_documents(query) - - def generate_response(self, conversation: Conversation, query: str, **kwargs) -> Generator[str, None, None]: + + def generate_response( + self, conversation: Conversation, query: str, **kwargs + ) -> Generator[str, None, None]: """Generate response with streaming support.""" - chain_input = { - "query": query, - "conversation": conversation - } - + chain_input = {"query": query, "conversation": conversation} + for chunk in self.rag_chain.stream(chain_input): yield chunk class AsyncRAGService(RAGService): """Asynchronous RAG service implementation.""" - + def __init__(self): super().__init__() self._setup_chain() - + def _setup_chain(self): """Setup the RAG chain.""" template = """Answer the question based only on the following context: @@ -307,72 +300,76 @@ class AsyncRAGService(RAGService): Question: {question} """ self.prompt = ChatPromptTemplate.from_template(template) - + self.rag_chain = ( { "context": self._retriever_with_history, "history": lambda x: self._format_history(x["conversation"]), - "question": lambda x: x["query"] + "question": lambda x: x["query"], } | self.prompt | self.llm | StrOutputParser() ) - + async def _format_history(self, conversation: Conversation) -> str: """Format conversation history for the prompt.""" - prompts = await Prompt.objects.filter(conversation=conversation).order_by('created_at').alist() + prompts = ( + await Prompt.objects.filter(conversation=conversation) + .order_by("created_at") + .alist() + ) print(f"prompts that we are seeding with are: {prompts}") return "\n".join( - f"{'User' if prompt.is_user else 'AI'}: {prompt.text}" - for prompt in prompts + f"{'User' if prompt.is_user else 'AI'}: {prompt.text}" for prompt in prompts ) - + async def _retriever_with_history(self, input_dict: Dict[str, Any]) -> str: """Retrieve documents considering conversation history.""" print(f"Retrieving history with input: {input_dict}") query = input_dict["query"] conversation = input_dict["conversation"] workspace = input_dict["workspace"] - + # You could enhance this to consider historical context in retrieval - docs= await self.search_documents(query, workspace) + docs = await self.search_documents(query, workspace) if not docs: print("Didn't find any relevant docs") - + print("\n\n".join(doc.page_content for doc in docs)) return "\n\n".join(doc.page_content for doc in docs) - - async def search_documents(self, query: str, workspace: Optional[DocumentWorkspace] = None, k: int = 4) -> List[Document]: + async def search_documents( + self, query: str, workspace: Optional[DocumentWorkspace] = None, k: int = 4 + ) -> List[Document]: """Search relevant documents from the vector store.""" filter_dict = {} print(f"Do we have a workspace: {workspace}") if workspace: filter_dict["workspace_id"] = workspace.id - search_kwargs={ - "k": k, - "filter": filter_dict if filter_dict else None - } + search_kwargs = {"k": k, "filter": filter_dict if filter_dict else None} print(f"search_kwargs: {search_kwargs}") - + retriever = self.vector_store.as_retriever( search_type="mmr", - search_kwargs={ - "k": k, - "filter": filter_dict if filter_dict else None - } + search_kwargs={"k": k, "filter": filter_dict if filter_dict else None}, ) return await retriever.aget_relevant_documents(query) - - async def generate_response(self, conversation: Conversation, query: str, workspace: DocumentWorkspace, **kwargs) -> AsyncGenerator[str, None]: + + async def generate_response( + self, + conversation: Conversation, + query: str, + workspace: DocumentWorkspace, + **kwargs, + ) -> AsyncGenerator[str, None]: """Generate response with streaming support.""" chain_input = { "query": query, "conversation": conversation, "workspace": workspace, } - + async for chunk in self.rag_chain.astream(chain_input): yield chunk diff --git a/llm_be/chat_backend/services/tests.py b/llm_be/chat_backend/services/tests.py index 8da5303..f24da83 100644 --- a/llm_be/chat_backend/services/tests.py +++ b/llm_be/chat_backend/services/tests.py @@ -5,9 +5,14 @@ from typing import List, Dict, Any from django.test import TestCase as DjangoTestCase -from chat_backend.services.rag_services import RAGService, SyncRAGService, AsyncRAGService +from chat_backend.services.rag_services import ( + RAGService, + SyncRAGService, + AsyncRAGService, +) from chat_backend.models import Conversation, Prompt, DocumentWorkspace, Document + class TestRAGService(TestCase): def setUp(self): self.rag_service = RAGService() @@ -16,18 +21,20 @@ class TestRAGService(TestCase): self.rag_service.text_splitter = MagicMock() def test_initialize_vector_store(self): - with patch('os.path.exists', return_value=False), \ - patch('os.makedirs') as mock_makedirs, \ - patch('langchain_community.vectorstores.Chroma') as mock_chroma: - + with patch("os.path.exists", return_value=False), patch( + "os.makedirs" + ) as mock_makedirs, patch( + "langchain_community.vectorstores.Chroma" + ) as mock_chroma: + # Reset the vector store to test initialization self.rag_service.vector_store = None result = self.rag_service._initialize_vector_store() - + mock_makedirs.assert_called_once_with("chroma_db") mock_chroma.assert_called_once_with( embedding_function=self.rag_service.embedding_model, - persist_directory="chroma_db" + persist_directory="chroma_db", ) self.assertIsNotNone(result) @@ -40,11 +47,13 @@ class TestRAGService(TestCase): mock_doc1.id = 1 self.rag_service.text_splitter.split_text.return_value = ["chunk1", "chunk2"] - + result = self.rag_service._prepare_documents([mock_doc1]) - + self.assertEqual(len(result), 2) - self.rag_service.text_splitter.split_text.assert_called_once_with("Test content") + self.rag_service.text_splitter.split_text.assert_called_once_with( + "Test content" + ) self.assertEqual(result[0].page_content, "chunk1") self.assertEqual(result[0].metadata["source"], "test_source") @@ -52,13 +61,19 @@ class TestRAGService(TestCase): mock_workspace = MagicMock() mock_document = MagicMock() mock_documents = [mock_document] - - with patch('services.rag_services.Document.objects.filter', return_value=mock_documents): - self.rag_service._prepare_documents = MagicMock(return_value=["processed_doc"]) - + + with patch( + "services.rag_services.Document.objects.filter", return_value=mock_documents + ): + self.rag_service._prepare_documents = MagicMock( + return_value=["processed_doc"] + ) + self.rag_service.ingest_documents(mock_workspace) - - self.rag_service.vector_store.add_documents.assert_called_once_with(["processed_doc"]) + + self.rag_service.vector_store.add_documents.assert_called_once_with( + ["processed_doc"] + ) self.rag_service.vector_store.persist.assert_called_once() @@ -71,40 +86,39 @@ class TestSyncRAGService(DjangoTestCase): self.mock_conversation = MagicMock(spec=Conversation) self.mock_conversation.workspace = MagicMock() - + self.mock_prompt1 = MagicMock(spec=Prompt) self.mock_prompt1.is_user = True self.mock_prompt1.text = "User question" self.mock_prompt1.created_at = "2023-01-01" - + self.mock_prompt2 = MagicMock(spec=Prompt) self.mock_prompt2.is_user = False self.mock_prompt2.text = "AI response" self.mock_prompt2.created_at = "2023-01-02" def test_format_history(self): - with patch('services.rag_services.Prompt.objects.filter') as mock_filter: - mock_filter.return_value.order_by.return_value = [self.mock_prompt1, self.mock_prompt2] - + with patch("services.rag_services.Prompt.objects.filter") as mock_filter: + mock_filter.return_value.order_by.return_value = [ + self.mock_prompt1, + self.mock_prompt2, + ] + result = self.sync_service._format_history(self.mock_conversation) - + expected = "User: User question\nAI: AI response" self.assertEqual(result, expected) mock_filter.assert_called_once_with(conversation=self.mock_conversation) def test_retriever_with_history(self): - input_dict = { - "query": "test query", - "conversation": self.mock_conversation - } - + input_dict = {"query": "test query", "conversation": self.mock_conversation} + self.sync_service.search_documents = MagicMock(return_value=["doc1", "doc2"]) - + result = self.sync_service._retriever_with_history(input_dict) - + self.sync_service.search_documents.assert_called_once_with( - "test query", - self.mock_conversation.workspace + "test query", self.mock_conversation.workspace ) self.assertEqual(result, ["doc1", "doc2"]) @@ -112,29 +126,30 @@ class TestSyncRAGService(DjangoTestCase): mock_retriever = MagicMock() mock_retriever.get_relevant_documents.return_value = ["doc1", "doc2"] self.sync_service.vector_store.as_retriever.return_value = mock_retriever - - result = self.sync_service.search_documents("test query", self.mock_conversation.workspace) - + + result = self.sync_service.search_documents( + "test query", self.mock_conversation.workspace + ) + self.sync_service.vector_store.as_retriever.assert_called_once_with( search_type="similarity", search_kwargs={ "k": 4, - "filter": {"workspace_id": self.mock_conversation.workspace.id} - } + "filter": {"workspace_id": self.mock_conversation.workspace.id}, + }, ) self.assertEqual(result, ["doc1", "doc2"]) def test_generate_response(self): - chain_input = { - "query": "test query", - "conversation": self.mock_conversation - } - + chain_input = {"query": "test query", "conversation": self.mock_conversation} + mock_stream = ["chunk1", "chunk2", "chunk3"] self.sync_service.rag_chain.stream.return_value = mock_stream - - result = list(self.sync_service.generate_response(self.mock_conversation, "test query")) - + + result = list( + self.sync_service.generate_response(self.mock_conversation, "test query") + ) + self.sync_service.rag_chain.stream.assert_called_once_with(chain_input) self.assertEqual(result, mock_stream) @@ -148,12 +163,12 @@ class TestAsyncRAGService(DjangoTestCase): self.mock_conversation = MagicMock(spec=Conversation) self.mock_conversation.workspace = MagicMock() - + self.mock_prompt1 = MagicMock(spec=Prompt) self.mock_prompt1.is_user = True self.mock_prompt1.text = "User question" self.mock_prompt1.created_at = "2023-01-01" - + self.mock_prompt2 = MagicMock(spec=Prompt) self.mock_prompt2.is_user = False self.mock_prompt2.text = "AI response" @@ -161,28 +176,29 @@ class TestAsyncRAGService(DjangoTestCase): async def test_format_history(self): mock_manager = AsyncMock() - mock_manager.order_by.return_value.alist.return_value = [self.mock_prompt1, self.mock_prompt2] - - with patch('services.rag_services.Prompt.objects.filter', return_value=mock_manager): + mock_manager.order_by.return_value.alist.return_value = [ + self.mock_prompt1, + self.mock_prompt2, + ] + + with patch( + "services.rag_services.Prompt.objects.filter", return_value=mock_manager + ): result = await self.async_service._format_history(self.mock_conversation) - + expected = "User: User question\nAI: AI response" self.assertEqual(result, expected) - mock_manager.order_by.assert_called_once_with('created_at') + mock_manager.order_by.assert_called_once_with("created_at") async def test_retriever_with_history(self): - input_dict = { - "query": "test query", - "conversation": self.mock_conversation - } - + input_dict = {"query": "test query", "conversation": self.mock_conversation} + self.async_service.search_documents = AsyncMock(return_value=["doc1", "doc2"]) - + result = await self.async_service._retriever_with_history(input_dict) - + self.async_service.search_documents.assert_awaited_once_with( - "test query", - self.mock_conversation.workspace + "test query", self.mock_conversation.workspace ) self.assertEqual(result, ["doc1", "doc2"]) @@ -190,30 +206,31 @@ class TestAsyncRAGService(DjangoTestCase): mock_retriever = AsyncMock() mock_retriever.aget_relevant_documents.return_value = ["doc1", "doc2"] self.async_service.vector_store.as_retriever.return_value = mock_retriever - - result = await self.async_service.search_documents("test query", self.mock_conversation.workspace) - + + result = await self.async_service.search_documents( + "test query", self.mock_conversation.workspace + ) + self.async_service.vector_store.as_retriever.assert_called_once_with( search_type="similarity", search_kwargs={ "k": 4, - "filter": {"workspace_id": self.mock_conversation.workspace.id} - } + "filter": {"workspace_id": self.mock_conversation.workspace.id}, + }, ) self.assertEqual(result, ["doc1", "doc2"]) async def test_generate_response(self): - chain_input = { - "query": "test query", - "conversation": self.mock_conversation - } - + chain_input = {"query": "test query", "conversation": self.mock_conversation} + mock_stream = ["chunk1", "chunk2", "chunk3"] self.async_service.rag_chain.astream.return_value = mock_stream - + chunks = [] - async for chunk in self.async_service.generate_response(self.mock_conversation, "test query"): + async for chunk in self.async_service.generate_response( + self.mock_conversation, "test query" + ): chunks.append(chunk) - + self.async_service.rag_chain.astream.assert_awaited_once_with(chain_input) - self.assertEqual(chunks, mock_stream) \ No newline at end of file + self.assertEqual(chunks, mock_stream) diff --git a/llm_be/chat_backend/services/title_generator.py b/llm_be/chat_backend/services/title_generator.py index f8209d1..eb09231 100644 --- a/llm_be/chat_backend/services/title_generator.py +++ b/llm_be/chat_backend/services/title_generator.py @@ -1,22 +1,28 @@ from langchain_core.prompts import ChatPromptTemplate -from langchain_community.llms import Ollama + +# from langchain_community.llms import Ollama +from langchain_ollama import OllamaLLM from typing import Optional + class TitleGenerator: """ Generates short, descriptive titles for conversations based on the first prompt. """ - + def __init__(self): - self.llm = Ollama( - model="llama3", + self.llm = OllamaLLM( + model="llama3.2", temperature=0.5, # Slightly creative but not too random top_k=20, - num_ctx=2048 # Shorter context needed for titles + num_ctx=2048, # Shorter context needed for titles ) - - self.title_prompt = ChatPromptTemplate.from_messages([ - ("system", """You are a conversation title generator. Create a very short (2-5 word) title based on the user's first message. + + self.title_prompt = ChatPromptTemplate.from_messages( + [ + ( + "system", + """You are a conversation title generator. Create a very short (2-5 word) title based on the user's first message. Rules: 1. Keep it extremely concise @@ -31,12 +37,14 @@ Examples: - "Generate an image of a dragon" → "Dragon Image Generation" - "Find our company's privacy policy" → "Privacy Policy Search" -Return ONLY the title, nothing else."""), - ("human", "{prompt}") - ]) - +Return ONLY the title, nothing else.""", + ), + ("human", "{prompt}"), + ] + ) + self.chain = self.title_prompt | self.llm - + async def generate_async(self, prompt: str) -> str: """Generate title asynchronously""" try: @@ -45,7 +53,7 @@ Return ONLY the title, nothing else."""), except Exception as e: print(f"Title generation error: {e}") return "Conversation" - + def generate(self, prompt: str) -> str: """Generate title synchronously""" try: @@ -54,14 +62,14 @@ Return ONLY the title, nothing else."""), except Exception as e: print(f"Title generation error: {e}") return "Conversation" - + def _clean_response(self, response: str) -> str: """Clean and format the LLM response""" # Remove any quotes or punctuation - response = response.strip('"\'.!? \n\t') + response = response.strip("\"'.!? \n\t") # Ensure title case and trim return response.title()[:50] # Hard limit for safety # Singleton instance -title_generator = TitleGenerator() \ No newline at end of file +title_generator = TitleGenerator() diff --git a/llm_be/chat_backend/signals.py b/llm_be/chat_backend/signals.py index c08bab1..6e811e9 100644 --- a/llm_be/chat_backend/signals.py +++ b/llm_be/chat_backend/signals.py @@ -3,16 +3,18 @@ from django.dispatch import receiver from chat_backend.models import Document from .services.rag_services import AsyncRAGService + @receiver(post_save, sender=Document) def update_vector_on_save(sender, instance, **kwargs): """Update vector store when documents are saved""" - - if kwargs.get('created', False): + + if kwargs.get("created", False): rag_service = AsyncRAGService() rag_service.ingest_documents() + @receiver(post_delete, sender=Document) def delete_vector_on_remove(sender, instance, **kwargs): """Handle document deletion by re-indexing the whole workspace""" rag_service = AsyncRAGService() - rag_service.ingest_documents() \ No newline at end of file + rag_service.ingest_documents() diff --git a/llm_be/chat_backend/tests.py b/llm_be/chat_backend/tests.py index 392daa8..a8c3418 100644 --- a/llm_be/chat_backend/tests.py +++ b/llm_be/chat_backend/tests.py @@ -13,80 +13,75 @@ from django.core.files.uploadedfile import SimpleUploadedFile # Minimal valid PDF bytes VALID_PDF_BYTES = ( - b'%PDF-1.3\n' - b'1 0 obj\n' - b'<< /Type /Catalog /Pages 2 0 R >>\n' - b'endobj\n' - b'2 0 obj\n' - b'<< /Type /Pages /Kids [3 0 R] /Count 1 >>\n' - b'endobj\n' - b'3 0 obj\n' - b'<< /Type /Page /Parent 2 0 R /Resources << >> /MediaBox [0 0 612 792] /Contents 4 0 R >>\n' - b'endobj\n' - b'4 0 obj\n' - b'<< /Length 44 >>\n' - b'stream\n' - b'BT /F1 12 Tf 72 720 Td (Test PDF) Tj ET\n' - b'endstream\n' - b'endobj\n' - b'xref\n' - b'0 5\n' - b'0000000000 65535 f \n' - b'0000000009 00000 n \n' - b'0000000058 00000 n \n' - b'0000000117 00000 n \n' - b'0000000223 00000 n \n' - b'trailer\n' - b'<< /Size 5 /Root 1 0 R >>\n' - b'startxref\n' - b'317\n' - b'%%EOF' + b"%PDF-1.3\n" + b"1 0 obj\n" + b"<< /Type /Catalog /Pages 2 0 R >>\n" + b"endobj\n" + b"2 0 obj\n" + b"<< /Type /Pages /Kids [3 0 R] /Count 1 >>\n" + b"endobj\n" + b"3 0 obj\n" + b"<< /Type /Page /Parent 2 0 R /Resources << >> /MediaBox [0 0 612 792] /Contents 4 0 R >>\n" + b"endobj\n" + b"4 0 obj\n" + b"<< /Length 44 >>\n" + b"stream\n" + b"BT /F1 12 Tf 72 720 Td (Test PDF) Tj ET\n" + b"endstream\n" + b"endobj\n" + b"xref\n" + b"0 5\n" + b"0000000000 65535 f \n" + b"0000000009 00000 n \n" + b"0000000058 00000 n \n" + b"0000000117 00000 n \n" + b"0000000223 00000 n \n" + b"trailer\n" + b"<< /Size 5 /Root 1 0 R >>\n" + b"startxref\n" + b"317\n" + b"%%EOF" ) + class DocumentWorkspaceViewsTestCase(APITestCase): def setUp(self): self.company = Company.objects.create( - name="test", - state="IL", - zipcode="60189", - address="1968 Greensboro Dr" + name="test", state="IL", zipcode="60189", address="1968 Greensboro Dr" ) self.user = get_user_model().objects.create_user( company=self.company, - username='testuser', - password='testpass123', + username="testuser", + password="testpass123", email="test@test.com", ) self.client = APIClient() self.client.force_authenticate(user=self.user) - + self.workspace = DocumentWorkspace.objects.create( - company = self.user.company, - name='Test Workspace' + company=self.user.company, name="Test Workspace" ) def test_list_workspaces(self): - url = reverse('document_workspaces') + url = reverse("document_workspaces") response = self.client.get(url) self.assertEqual(response.status_code, status.HTTP_200_OK) self.assertEqual(len(response.data), 1) - self.assertEqual(response.data[0]['name'], 'Test Workspace') + self.assertEqual(response.data[0]["name"], "Test Workspace") def test_create_workspace(self): - url = reverse('document_workspaces') - data = { - 'name': 'New Workspace' - } - response = self.client.post(url, data, format='json') + url = reverse("document_workspaces") + data = {"name": "New Workspace"} + response = self.client.post(url, data, format="json") self.assertEqual(response.status_code, status.HTTP_201_CREATED) self.assertEqual(DocumentWorkspace.objects.count(), 2) def test_retrieve_workspace(self): - url = reverse('document_workspaces') + url = reverse("document_workspaces") response = self.client.get(url) self.assertEqual(response.status_code, status.HTTP_200_OK) - self.assertEqual(response.data[0]['name'], 'Test Workspace') + self.assertEqual(response.data[0]["name"], "Test Workspace") # def test_update_workspace(self): # url = reverse('document_workspaces') @@ -104,107 +99,89 @@ class DocumentWorkspaceViewsTestCase(APITestCase): # self.assertEqual(response.status_code, status.HTTP_204_NO_CONTENT) # self.assertEqual(DocumentWorkspace.objects.count(), 0) + class DocumentViewsTestCase(APITestCase): def setUp(self): self.company = Company.objects.create( - name="test", - state="IL", - zipcode="60189", - address="1968 Greensboro Dr" + name="test", state="IL", zipcode="60189", address="1968 Greensboro Dr" ) self.user = get_user_model().objects.create_user( company=self.company, - username='testuser', - password='testpass123', + username="testuser", + password="testpass123", email="test@test.com", ) self.client = APIClient() self.client.force_authenticate(user=self.user) - + self.workspace = DocumentWorkspace.objects.create( - company=self.user.company, - name='Test Workspace' + company=self.user.company, name="Test Workspace" ) - + # Create a test file self.test_file = SimpleUploadedFile( - "test.pdf", - VALID_PDF_BYTES, - content_type="application/pdf" + "test.pdf", VALID_PDF_BYTES, content_type="application/pdf" ) def test_upload_document(self): - url = reverse('documents') - data = { - 'file': self.test_file - } - response = self.client.post(url, data, format='multipart') + url = reverse("documents") + data = {"file": self.test_file} + response = self.client.post(url, data, format="multipart") self.assertEqual(response.status_code, status.HTTP_201_CREATED) self.assertEqual(Document.objects.count(), 1) - + document = Document.objects.first() self.assertEqual(document.workspace.id, self.workspace.id) self.assertTrue(document.processed) # Should be False initially def test_list_documents(self): # First create a document - Document.objects.create( - workspace=self.workspace, - file=self.test_file - ) - - url = reverse('documents') + Document.objects.create(workspace=self.workspace, file=self.test_file) + + url = reverse("documents") response = self.client.get(url) self.assertEqual(response.status_code, status.HTTP_200_OK) self.assertEqual(len(response.data), 1) - self.assertIn('test', response.data[0]['file']) - self.assertIn('pdf', response.data[0]['file']) + self.assertIn("test", response.data[0]["file"]) + self.assertIn("pdf", response.data[0]["file"]) # def test_delete_document(self): # document = Document.objects.create( # workspace=self.workspace, # file=self.test_file # ) - + # url = reverse('document-detail', args=[document.id]) # response = self.client.delete(url) # self.assertEqual(response.status_code, status.HTTP_204_NO_CONTENT) # self.assertEqual(Document.objects.count(), 0) def test_upload_invalid_file(self): - url = reverse('documents') - data = { - 'file': 'not a file' - } - response = self.client.post(url, data, format='multipart') + url = reverse("documents") + data = {"file": "not a file"} + response = self.client.post(url, data, format="multipart") self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) def test_access_other_users_documents(self): # Create another user other_company = Company.objects.create( - name="test2", - state="IL", - zipcode="60189", - address="1968 Greensboro Dr" + name="test2", state="IL", zipcode="60189", address="1968 Greensboro Dr" ) other_user = get_user_model().objects.create_user( company=other_company, - username='otheruser', - password='otherpass123', - email="testing2@test.com" + username="otheruser", + password="otherpass123", + email="testing2@test.com", ) other_workspace = DocumentWorkspace.objects.create( - company = other_user.company, - name='Other Workspace' + company=other_user.company, name="Other Workspace" ) other_document = Document.objects.create( - workspace=other_workspace, - file=self.test_file + workspace=other_workspace, file=self.test_file ) - + # Try to access the other user's document - url = reverse('documents_details', kwargs={"document_id":other_document.id}) + url = reverse("documents_details", kwargs={"document_id": other_document.id}) response = self.client.get(url) self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND) - \ No newline at end of file diff --git a/llm_be/chat_backend/urls.py b/llm_be/chat_backend/urls.py index 83e281a..9f32270 100644 --- a/llm_be/chat_backend/urls.py +++ b/llm_be/chat_backend/urls.py @@ -23,8 +23,7 @@ from .views import ( reset_password, DocumentWorkspaceView, DocumentUploadView, - DocumentDetailView - + DocumentDetailView, ) from rest_framework.routers import DefaultRouter @@ -80,11 +79,16 @@ urlpatterns = [ name="analytics_company_usage", ), path("analytics/admin/", AdminAnalytics.as_view(), name="analytics_admin"), - # document urls - path("document_workspaces/", DocumentWorkspaceView.as_view(), name="document_workspaces"), + path( + "document_workspaces/", + DocumentWorkspaceView.as_view(), + name="document_workspaces", + ), path("documents/", DocumentUploadView.as_view(), name="documents"), - path("documents_details/", DocumentDetailView.as_view(), name="documents_details"), - + path( + "documents_details/", + DocumentDetailView.as_view(), + name="documents_details", + ), ] - diff --git a/llm_be/chat_backend/views.py b/llm_be/chat_backend/views.py index a6a41b1..3115b96 100644 --- a/llm_be/chat_backend/views.py +++ b/llm_be/chat_backend/views.py @@ -13,7 +13,7 @@ from .serializers import ( PromptSerializer, FeedbackSerializer, DocumentWorkspaceSerializer, - DocumentSerializer + DocumentSerializer, ) from rest_framework.views import APIView from rest_framework.response import Response @@ -25,7 +25,7 @@ from .models import ( Feedback, PromptMetric, DocumentWorkspace, - Document + Document, ) from django.views.decorators.cache import never_cache from django.http import JsonResponse @@ -99,8 +99,8 @@ class CustomUserCreate(APIView): def send_invite_email(slug, email_to_invite): print("Sending invite email") - print(f"url : https://www.chat.aimloperations.com/set_password?slug={slug}") - url = f"https://www.chat.aimloperations.com/set_password?slug={slug}" + print(f"url : https://chat.aimloperations.com/set_password?slug={slug}") + url = f"https://chat.aimloperations.com/set_password?slug={slug}" subject = "Welcome to AI ML Operations, LLC Chat Services" from_email = "ryan@aimloperations.com" to = email_to_invite @@ -113,6 +113,22 @@ def send_invite_email(slug, email_to_invite): msg.send(fail_silently=True) +def send_password_reset_email(slug, email_to_invite): + print("Sending reset email") + print(f"url : https://www.chat.aimloperations.com/set_password?slug={slug}") + url = f"https://www.chat.aimloperations.com/set_password?slug={slug}" + subject = "Password reset for AI ML Operations, LLC Chat Services" + from_email = "ryan@aimloperations.com" + to = email_to_invite + d = {"url": url} + html_content = get_template(r"emails/reset_email.html").render(d) + text_content = get_template(r"emails/reset_email.txt").render(d) + + msg = EmailMultiAlternatives(subject, text_content, from_email, [to]) + msg.attach_alternative(html_content, "text/html") + msg.send(fail_silently=True) + + def send_feedback_email(feedback_obj): print("Sending feedback email") subject = "New Feedback for Chat by AI ML Operations, LLC" @@ -176,19 +192,22 @@ class CustomUserInvite(APIView): return Response(status=status.HTTP_201_CREATED) + @csrf_exempt def reset_password(request): if request.method == "POST": data = json.loads(request.body) - token = data.get('recaptchaToken') + token = data.get("recaptchaToken") payload = { - 'secret': settings.CAPTCHA_SECRET_KEY, - 'response': token, + "secret": settings.CAPTCHA_SECRET_KEY, + "response": token, } - response = requests.post('https://www.google.com/recaptcha/api/siteverify', data=payload) + response = requests.post( + "https://www.google.com/recaptcha/api/siteverify", data=payload + ) result = response.json() - if result.get('success') and result.get('score') >= 0.5: - email = data.get('email') + if result.get("success") and result.get("score") >= 0.5: + email = data.get("email") user = CustomUser.objects.filter(email=email).first() if user: user.set_unusable_password() @@ -197,10 +216,10 @@ def reset_password(request): # send the email send_password_reset_email(user.slug, email) JsonResponse(status=200) - - + JsonResponse(status=400) + class ResetUserPassword(APIView): http_method_names = [ "post", @@ -214,14 +233,16 @@ class ResetUserPassword(APIView): Also disable the account """ print(f"Password reset for requests. {request.data}") - token = request.data.get('recaptchaToken') + token = request.data.get("recaptchaToken") payload = { - 'secret': settings.CAPTCHA_SECRET_KEY, - 'response': recaptchaToken, + "secret": settings.CAPTCHA_SECRET_KEY, + "response": recaptchaToken, } - response = requests.post('https://www.google.com/recaptcha/api/siteverify', data=payload) + response = requests.post( + "https://www.google.com/recaptcha/api/siteverify", data=payload + ) result = response.json() - if result.get('success') and result.get('score') >= 0.5: + if result.get("success") and result.get("score") >= 0.5: user = CustomUser.objects.filter(email=email).first() if user: user.set_unusable_password() @@ -230,7 +251,7 @@ class ResetUserPassword(APIView): # send the email send_password_reset_email(user.slug, email) else: - print('Captcha secret failed') + print("Captcha secret failed") return Response(status=status.HTTP_200_OK) @@ -261,10 +282,16 @@ class CustomUserGet(APIView): email = request.user.email username = request.user.username - user = CustomUser.objects.get(email=email) - serializer = CustomUserSerializer(user) - - return Response(serializer.data, status=status.HTTP_200_OK) + user = CustomUser.objects.filter(email=email).last() + print(f"Getting the user: {user}") + try: + serializer = CustomUserSerializer(user) + print(f"serializer: {serializer}") + print(serializer.data) + return Response(serializer.data, status=status.HTTP_200_OK) + except Exception as e: + print(f"Exception: {e}") + return Response({}, status=status.HTTP_400_BAD_REQUEST) class FeedbackView(APIView): @@ -706,6 +733,7 @@ def get_workspace(conversation_id): conversation = Conversation.objects.get(id=conversation_id) return DocumentWorkspace.objects.get(company=conversation.user.company) + @database_sync_to_async def get_messages(conversation_id, prompt, file_string: str = None, file_type: str = ""): messages = [] @@ -821,19 +849,21 @@ def finish_prompt_metric(prompt_metric, response_length): prompt_metric.save(update_fields=["end_time", "reponse_length", "event"]) print("finish_prompt_metric saved") + @database_sync_to_async def get_retriever(conversation_id): - print(f'getting workspace from conversation: {conversation_id}') + print(f"getting workspace from conversation: {conversation_id}") conversation = Conversation.objects.get(id=conversation_id) - print(f'Got conversation: {conversation}') + print(f"Got conversation: {conversation}") workspace = DocumentWorkspace.objects.get(company=conversation.user.company) - print(f'Got workspace: {conversation}') + print(f"Got workspace: {conversation}") vectorstore = Chroma( persist_directory=f"./chroma_db/", embedding=OllamaEmbeddings(model="llama3.2"), ) return vectorstore.as_retriever() + class ChatConsumerAgain(AsyncWebsocketConsumer): async def connect(self): await self.accept() @@ -881,7 +911,10 @@ class ChatConsumerAgain(AsyncWebsocketConsumer): print(f'received: "{message}" for conversation {conversation_id}') # check the moderation here - if await moderation_classifier.classify_async(message) == ModerationLabel.NSFW: + if ( + await moderation_classifier.classify_async(message) + == ModerationLabel.NSFW + ): response = "Prompt has been marked as NSFW. If this is in error, submit a feedback with the prompt text." print("this prompt has been marked as NSFW") await self.send("CONVERSATION_ID") @@ -902,9 +935,8 @@ class ChatConsumerAgain(AsyncWebsocketConsumer): prompt_type = await prompt_classifier.classify_async(message) print(f"prompt_type: {prompt_type} for {message}") - - - + if file: + prompt_type = PromptType.GENERAL_CHAT prompt_metric = await create_prompt_metric( prompt.id, @@ -928,22 +960,25 @@ class ChatConsumerAgain(AsyncWebsocketConsumer): await self.send("START_OF_THE_STREAM_ENDER_GAME_42") if prompt_type == PromptType.RAG: service = AsyncRAGService() - #await service.ingest_documents() + # await service.ingest_documents() workspace = await get_workspace(conversation_id) - print('Time to get the rag response') - - async for chunk in service.generate_response(messages, prompt.message, workspace): - print(f"chunk: {chunk}") + print("Time to get the rag response") + + async for chunk in service.generate_response( + messages, prompt.message, workspace + ): response += chunk await self.send(chunk) elif prompt_type == PromptType.IMAGE_GENERATION: response = "Image Generation is not supported at this time, but it will be soon." await self.send(response) - + else: + print(f"using the AsyncLLMService\n\n{messages}\n{prompt.message}") service = AsyncLLMService() - async for chunk in service.generate_response(messages, prompt.message): - print(f"chunk: {chunk}") + async for chunk in service.generate_response( + messages, prompt.message + ): response += chunk await self.send(chunk) await self.send("END_OF_THE_STREAM_ENDER_GAME_42") @@ -954,9 +989,10 @@ class ChatConsumerAgain(AsyncWebsocketConsumer): if bytes_data: print("we have byte data") + # Document Views class DocumentWorkspaceView(APIView): - #permission_classes = [permissions.IsAuthenticated] + # permission_classes = [permissions.IsAuthenticated] def get(self, request): workspaces = DocumentWorkspace.objects.filter(company=request.user.company) @@ -970,70 +1006,78 @@ class DocumentWorkspaceView(APIView): return Response(serializer.data, status=status.HTTP_201_CREATED) return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST) + class DocumentUploadView(APIView): - #permission_classes = [permissions.IsAuthenticated]Z + # permission_classes = [permissions.IsAuthenticated]Z def get(self, request): - print(f'request_3: {request}') + print(f"request_3: {request}") try: - workspace = DocumentWorkspace.objects.get(company=request.user.company) - serializer = DocumentSerializer(Document.objects.filter(workspace=workspace), many=True) + workspace = DocumentWorkspace.objects.get(company=request.user.company) + serializer = DocumentSerializer( + Document.objects.filter(workspace=workspace), many=True + ) return Response(serializer.data, status=status.HTTP_200_OK) except: - return Response({'error': "Workspace not found"}, status=status.HTTP_404_NOT_FOUND) - + return Response( + {"error": "Workspace not found"}, status=status.HTTP_404_NOT_FOUND + ) + def post(self, request): - print(f'request: {request}') + print(f"request: {request}") try: workspace = DocumentWorkspace.objects.get(company=request.user.company) except: - return Response({'error': "Workspace not found"}, status=status.HTTP_404_NOT_FOUND) + return Response( + {"error": "Workspace not found"}, status=status.HTTP_404_NOT_FOUND + ) print(request.FILES) - file = request.FILES.get('file') + file = request.FILES.get("file") if not file: - return Response({"error":"No file provided"}, status=status.HTTP_400_BAD_REQUEST) + return Response( + {"error": "No file provided"}, status=status.HTTP_400_BAD_REQUEST + ) print("have the workspace and the file") - - document = Document.objects.create( - workspace=workspace, - file=file - ) + + document = Document.objects.create(workspace=workspace, file=file) # process the document inthe background self.process_document(document) serializer = DocumentSerializer(document) return Response(serializer.data, status=status.HTTP_201_CREATED) - - + def process_document(self, document): file_path = os.path.join(settings.MEDIA_ROOT, document.file.name) - + document.processed = True document.active = True - document.save() + document.save() service = AsyncRAGService() - service.add_files_to_store([(file_path, document.file.name, document.workspace_id)], workspace_id=document.workspace_id) + service.add_files_to_store( + [(file_path, document.file.name, document.workspace_id)], + workspace_id=document.workspace_id, + ) + class DocumentDetailView(APIView): - #permission_classes = [permissions.IsAuthenticated] + # permission_classes = [permissions.IsAuthenticated] def get(self, request, document_id): - print(f'request: {request}') + print(f"request: {request}") try: workspace = DocumentWorkspace.objects.get(company=request.user.company) - - document = Document.objects.get( - workspace=workspace, - id=document_id - ) + + document = Document.objects.get(workspace=workspace, id=document_id) except: - return Response({'error': "Document not found"}, status=status.HTTP_404_NOT_FOUND) - + return Response( + {"error": "Document not found"}, status=status.HTTP_404_NOT_FOUND + ) + serializer = DocumentWorkspaceSerializer(workspaces, many=True) - return Response(serializer.data) \ No newline at end of file + return Response(serializer.data) diff --git a/llm_be/llm_be/asgi.py b/llm_be/llm_be/asgi.py index cc9b0a1..3fe3936 100644 --- a/llm_be/llm_be/asgi.py +++ b/llm_be/llm_be/asgi.py @@ -9,13 +9,18 @@ https://docs.djangoproject.com/en/3.2/howto/deployment/asgi/ import os +import django from django.core.asgi import get_asgi_application + +os.environ.setdefault("DJANGO_SETTINGS_MODULE", "llm_be.settings") + +django.setup() + + from channels.routing import ProtocolTypeRouter, URLRouter from channels.auth import AuthMiddlewareStack import chat_backend.routing -os.environ.setdefault("DJANGO_SETTINGS_MODULE", "llm_be.settings") - application = ProtocolTypeRouter( { "http": get_asgi_application(), diff --git a/llm_be/llm_be/settings.py b/llm_be/llm_be/settings.py index 47b6d6d..e3d5c25 100644 --- a/llm_be/llm_be/settings.py +++ b/llm_be/llm_be/settings.py @@ -161,7 +161,7 @@ REST_FRAMEWORK = { } SIMPLE_JWT = { - "ACCESS_TOKEN_LIFETIME": timedelta(hours=24), + "ACCESS_TOKEN_LIFETIME": timedelta(hours=5), "REFRESH_TOKEN_LIFETIME": timedelta(days=14), "ROTATE_REFRESH_TOKENS": True, "BLACKLIST_AFTER_ROTATION": True, @@ -207,4 +207,4 @@ EMAIL_PORT = 2525 EMAIL_USE_TLS = True # Captcha -CAPTCHA_SECRET_KEY = "6LfENu4qAAAAABdrj6JTviq-LfdPP5imhE-Os7h9" \ No newline at end of file +CAPTCHA_SECRET_KEY = "6LfENu4qAAAAABdrj6JTviq-LfdPP5imhE-Os7h9" diff --git a/requirements.txt b/requirements.txt index be65caa..c4eff53 100644 --- a/requirements.txt +++ b/requirements.txt @@ -44,7 +44,7 @@ durationpy==0.9 effdet==0.4.1 emoji==2.14.1 eval_type_backport==0.2.2 -Faker==37.0.0 +Faker fastapi==0.115.9 filelock==3.17.0 filetype==1.2.0 @@ -52,13 +52,13 @@ flatbuffers==25.2.10 fonttools==4.56.0 frozenlist==1.6.0 fsspec==2025.2.0 -google-api-core==2.24.2 -google-auth==2.39.0 -google-cloud-vision==3.10.1 -googleapis-common-protos==1.70.0 +google-api-core +google-auth +google-cloud-vision +googleapis-common-protos greenlet==3.1.1 -grpcio==1.72.0rc1 -grpcio-status==1.72.0rc1 +grpcio +grpcio-status h11==0.14.0 html5lib==1.1 httpcore==1.0.7 @@ -121,13 +121,13 @@ oauthlib==3.2.2 olefile==0.47 ollama==0.4.7 omegaconf==2.3.0 -onnx==1.18.0 -onnxruntime==1.21.1 +onnx +onnxruntime openai==1.65.4 opencv-python==4.11.0.86 -opentelemetry-api==1.32.1 -opentelemetry-exporter-otlp-proto-common==1.32.1 -opentelemetry-exporter-otlp-proto-grpc==1.32.1 +opentelemetry-api +opentelemetry-exporter-otlp-proto-common +opentelemetry-exporter-otlp-proto-grpc opentelemetry-instrumentation==0.53b1 opentelemetry-instrumentation-asgi==0.53b1 opentelemetry-instrumentation-fastapi==0.53b1 @@ -138,8 +138,8 @@ opentelemetry-util-http==0.53b1 orjson==3.10.15 overrides==7.7.0 packaging==24.2 -pandas==2.2.3 -pandasai==2.4.2 +pandas +pandasai pathspec==0.12.1 pdf2image==1.17.0 pdfminer.six==20250506 @@ -150,7 +150,7 @@ platformdirs==4.3.6 posthog==4.0.1 propcache==0.3.1 proto-plus==1.26.1 -protobuf==6.31.0rc2 +protobuf psutil==7.0.0 pyasn1==0.6.1 pyasn1_modules==0.4.1 @@ -160,7 +160,7 @@ pydantic==2.11.4 pydantic-settings==2.9.1 pydantic_core==2.33.2 Pygments==2.19.1 -PyJWT==2.10.1 +PyJWT pyOpenSSL==25.0.0 pyparsing==3.2.1 pypdf==5.4.0