From f5d29166a6f89414d0dbfe312903c63d0235d04b Mon Sep 17 00:00:00 2001 From: Ryan Westfall Date: Wed, 14 May 2025 03:27:38 -0500 Subject: [PATCH] RAG implementation, content moderation, prompt classification, new LLM chain, document storage --- .gitignore | 3 +- llm_be/chat_backend/admin.py | 54 +- llm_be/chat_backend/apps.py | 25 + llm_be/chat_backend/client.py | 22 +- .../0020_documentworkspace_document.py | 78 ++ llm_be/chat_backend/models.py | 94 ++- llm_be/chat_backend/renderers.py | 7 +- llm_be/chat_backend/routing.py | 9 +- llm_be/chat_backend/serializers.py | 48 +- llm_be/chat_backend/services/__init__.py | 0 .../chat_backend/services/image_generation.py | 145 ++++ llm_be/chat_backend/services/llm_service.py | 138 ++++ .../services/moderation_classifier.py | 79 ++ .../services/prompt_classifier.py | 100 +++ llm_be/chat_backend/services/rag_services.py | 378 ++++++++++ llm_be/chat_backend/services/tests.py | 219 ++++++ .../chat_backend/services/title_generator.py | 67 ++ llm_be/chat_backend/signals.py | 18 + .../templates/emails/reset_email.html | 97 +++ .../templates/emails/reset_email.txt | 3 + llm_be/chat_backend/tests.py | 207 ++++++ llm_be/chat_backend/urls.py | 54 +- llm_be/chat_backend/utils.py | 9 +- llm_be/chat_backend/views.py | 693 +++++++++++++----- llm_be/llm_be/asgi.py | 19 +- llm_be/llm_be/settings.py | 153 ++-- llm_be/llm_be/urls.py | 13 +- llm_be/llm_be/wsgi.py | 2 +- llm_be/manage.py | 4 +- requirements.dev | 32 +- requirements.txt | 208 ++++++ strip_and_upgrade.py | 9 + 32 files changed, 2628 insertions(+), 359 deletions(-) create mode 100644 llm_be/chat_backend/migrations/0020_documentworkspace_document.py create mode 100644 llm_be/chat_backend/services/__init__.py create mode 100644 llm_be/chat_backend/services/image_generation.py create mode 100644 llm_be/chat_backend/services/llm_service.py create mode 100644 llm_be/chat_backend/services/moderation_classifier.py create mode 100644 llm_be/chat_backend/services/prompt_classifier.py create mode 100644 llm_be/chat_backend/services/rag_services.py create mode 100644 llm_be/chat_backend/services/tests.py create mode 100644 llm_be/chat_backend/services/title_generator.py create mode 100644 llm_be/chat_backend/signals.py create mode 100644 llm_be/chat_backend/templates/emails/reset_email.html create mode 100644 llm_be/chat_backend/templates/emails/reset_email.txt create mode 100644 requirements.txt create mode 100644 strip_and_upgrade.py diff --git a/.gitignore b/.gitignore index 0dbf2f2..b741129 100644 --- a/.gitignore +++ b/.gitignore @@ -167,4 +167,5 @@ cython_debug/ # and can be added to the global gitignore or merged into this file. For a more nuclear # option (not recommended) you can uncomment the following to ignore the entire idea folder. #.idea/ - +chroma_db/ +documents/ diff --git a/llm_be/chat_backend/admin.py b/llm_be/chat_backend/admin.py index b985b0b..e7ae303 100644 --- a/llm_be/chat_backend/admin.py +++ b/llm_be/chat_backend/admin.py @@ -1,5 +1,16 @@ from django.contrib import admin -from .models import CustomUser, Announcement, Company, LLMModels, Conversation, Prompt, Feedback, PromptMetric +from .models import ( + CustomUser, + Announcement, + Company, + LLMModels, + Conversation, + Prompt, + Feedback, + PromptMetric, + DocumentWorkspace, + Document +) # Register your models here. @@ -27,16 +38,16 @@ class CustomUserAdmin(admin.ModelAdmin): "has_signed_tos", "last_login", "slug", - "get_set_password_url" + "get_set_password_url", ) search_fields = ("fields", "username", "first_name", "last_name", "slug") + class FeedbackAdmin(admin.ModelAdmin): model = Feedback search_fields = ("status", "text", "get_user_email") - list_display= ( - "status", "get_user_email", "title", "category" - ) + list_display = ("status", "get_user_email", "title", "category") + class LLMModelsAdmin(admin.ModelAdmin): model = LLMModels @@ -46,7 +57,7 @@ class LLMModelsAdmin(admin.ModelAdmin): class ConversationAdmin(admin.ModelAdmin): model = Conversation - list_display = ("title", "get_user_email","deleted") + list_display = ("title", "get_user_email", "deleted") search_fields = ("title",) @@ -55,9 +66,35 @@ class PromptAdmin(admin.ModelAdmin): list_display = ("message", "user_created", "get_conversation_title") search_fields = ("message",) + class PromptMetricAdmin(admin.ModelAdmin): model = PromptMetric - list_display = ("event", "model_name", "prompt_length","reponse_length",'has_file','file_type', "get_duration") + list_display = ( + "event", + "model_name", + "prompt_length", + "reponse_length", + "has_file", + "file_type", + "get_duration", + ) + +class DocumentWorkspaceAdmin(admin.ModelAdmin): + model = DocumentWorkspace + list_display = ( + "name", + "company", + + ) + +class DocumentAdmin(admin.ModelAdmin): + model = Document + list_display = ( + "file", + "active", + "created", + "processed", + ) admin.site.register(Announcement, AnnouncmentAdmin) @@ -69,3 +106,6 @@ admin.site.register(Conversation, ConversationAdmin) admin.site.register(Prompt, PromptAdmin) admin.site.register(PromptMetric, PromptMetricAdmin) admin.site.register(Feedback, FeedbackAdmin) + +admin.site.register(DocumentWorkspace, DocumentWorkspaceAdmin) +admin.site.register(Document, DocumentAdmin) diff --git a/llm_be/chat_backend/apps.py b/llm_be/chat_backend/apps.py index 24f2d5c..524ac53 100644 --- a/llm_be/chat_backend/apps.py +++ b/llm_be/chat_backend/apps.py @@ -1,6 +1,31 @@ from django.apps import AppConfig +from django.conf import settings +from django.db import OperationalError class ChatBackendConfig(AppConfig): default_auto_field = "django.db.models.BigAutoField" name = "chat_backend" + + def ready(self): + import chat_backend.signals + FORCE_RELOAD = False + + if True: #not settings.TESTING: # Don't run during tests + try: + from .services.rag_services import AsyncRAGService + from chat_backend.models import Document + + # Check if Chroma needs initialization + if Document.objects.exists(): + rag_service = AsyncRAGService() + + if rag_service.vector_store._collection.count() == 0: + print("Initializing ChromaDB with existing documents...") + rag_service.ingest_documents() + if FORCE_RELOAD: + print("Force Reload ChromaDB with existing documents...") + rag_service.clear_vector_store() + except OperationalError: + # Database tables might not exist yet during migration + pass diff --git a/llm_be/chat_backend/client.py b/llm_be/chat_backend/client.py index d550dbb..4425565 100644 --- a/llm_be/chat_backend/client.py +++ b/llm_be/chat_backend/client.py @@ -1,30 +1,32 @@ """ llama client - Abstract this in the future """ + import ollama from typing import List, Dict + class LlamaClient(object): - def __init__(self, model: str='llama3'): + def __init__(self, model: str = "llama3"): self.client = ollama.Client(host="http://127.0.0.1:11434") self.model = model def check_if_model_exists(self) -> bool: raise NotImplementedError - def generate_conversation_title(self, message:str): - response = self.generate_single_message("Summarise the phrase in one to for words\"%s\"" % message) - - raw_response = response['response'].replace("\"","") + def generate_conversation_title(self, message: str): + response = self.generate_single_message( + 'Summarise the phrase in one to for words"%s"' % message + ) + + raw_response = response["response"].replace('"', "") return " ".join(raw_response.split()[:4]) def generate_single_message(self, message: str): return ollama.generate(model=self.model, prompt=message) def get_chat_response(self, messages: List[str]): - return self.client.chat(model = self.model, messages=messages, stream=False) - - + return self.client.chat(model=self.model, messages=messages, stream=False) + def get_streamed_chat_response(self, messages: List[str]): - return self.client.chat(model = self.model, messages=messages, stream=True) - + return self.client.chat(model=self.model, messages=messages, stream=True) diff --git a/llm_be/chat_backend/migrations/0020_documentworkspace_document.py b/llm_be/chat_backend/migrations/0020_documentworkspace_document.py new file mode 100644 index 0000000..d255761 --- /dev/null +++ b/llm_be/chat_backend/migrations/0020_documentworkspace_document.py @@ -0,0 +1,78 @@ +# Generated by Django 5.1.7 on 2025-04-30 18:58 + +import django.db.models.deletion +import django.utils.timezone +from django.db import migrations, models + + +class Migration(migrations.Migration): + + dependencies = [ + ("chat_backend", "0019_customuser_conversation_order_and_more"), + ] + + operations = [ + migrations.CreateModel( + name="DocumentWorkspace", + fields=[ + ( + "id", + models.BigAutoField( + auto_created=True, + primary_key=True, + serialize=False, + verbose_name="ID", + ), + ), + ("created", models.DateTimeField(default=django.utils.timezone.now)), + ( + "last_modified", + models.DateTimeField(default=django.utils.timezone.now), + ), + ("name", models.CharField(max_length=255)), + ( + "company", + models.ForeignKey( + on_delete=django.db.models.deletion.CASCADE, + to="chat_backend.company", + ), + ), + ], + options={ + "abstract": False, + }, + ), + migrations.CreateModel( + name="Document", + fields=[ + ( + "id", + models.BigAutoField( + auto_created=True, + primary_key=True, + serialize=False, + verbose_name="ID", + ), + ), + ("created", models.DateTimeField(default=django.utils.timezone.now)), + ( + "last_modified", + models.DateTimeField(default=django.utils.timezone.now), + ), + ("file", models.FileField(upload_to="documents/")), + ("uploaded_at", models.DateTimeField(auto_now_add=True)), + ("processed", models.BooleanField(default=False)), + ("active", models.BooleanField(default=False)), + ( + "workspace", + models.ForeignKey( + on_delete=django.db.models.deletion.CASCADE, + to="chat_backend.documentworkspace", + ), + ), + ], + options={ + "abstract": False, + }, + ), + ] diff --git a/llm_be/chat_backend/models.py b/llm_be/chat_backend/models.py index 7505470..d2b5ef2 100644 --- a/llm_be/chat_backend/models.py +++ b/llm_be/chat_backend/models.py @@ -3,9 +3,11 @@ from django.contrib.auth.models import AbstractUser from django.utils import timezone from autoslug import AutoSlugField from django.core.files.storage import FileSystemStorage + # Create your models here. -FILE_STORAGE = FileSystemStorage(location='prompt_files') +FILE_STORAGE = FileSystemStorage(location="prompt_files") + class TimeInfoBase(models.Model): @@ -60,12 +62,18 @@ class CustomUser(AbstractUser): help_text="Allows the edit/add/remove of users for a company", default=False ) deleted = models.BooleanField(help_text="This is to hid accounts", default=False) - has_signed_tos = models.BooleanField(default=False, help_text="If the user has signed the TOS") - slug = AutoSlugField(populate_from='email') - conversation_order = models.BooleanField(default=True, help_text='How the conversations should display') + has_signed_tos = models.BooleanField( + default=False, help_text="If the user has signed the TOS" + ) + slug = AutoSlugField(populate_from="email") + conversation_order = models.BooleanField( + default=True, help_text="How the conversations should display" + ) + def get_set_password_url(self): return f"https://www.chat.aimloperations.com/set_password?slug={self.slug}" + FEEDBACK_CHOICE = ( ("SUBMITTED", "Submitted"), ("RESOLVED", "Resolved"), @@ -74,21 +82,26 @@ FEEDBACK_CHOICE = ( ) FEEDBACK_CATEGORIES = ( - ('NOT_DEFINED', 'Not defined'), - ('BUG', 'Bug'), - ('ENHANCEMENT', 'Enhancement'), - ('OTHER', 'Other'), - ('MAX_CATEGORIES', 'Max Categories'), + ("NOT_DEFINED", "Not defined"), + ("BUG", "Bug"), + ("ENHANCEMENT", "Enhancement"), + ("OTHER", "Other"), + ("MAX_CATEGORIES", "Max Categories"), ) + class Feedback(TimeInfoBase): - title = models.TextField(max_length=64, default='') + title = models.TextField(max_length=64, default="") user = models.ForeignKey( CustomUser, on_delete=models.CASCADE, blank=True, null=True ) text = models.TextField(max_length=512) - status = models.CharField(max_length=24, choices=FEEDBACK_CHOICE, default="SUBMITTED") - category = models.CharField(max_length=24, choices=FEEDBACK_CATEGORIES, default="NOT_DEFINED") + status = models.CharField( + max_length=24, choices=FEEDBACK_CHOICE, default="SUBMITTED" + ) + category = models.CharField( + max_length=24, choices=FEEDBACK_CATEGORIES, default="NOT_DEFINED" + ) def get_user_email(self): if self.user: @@ -105,9 +118,8 @@ MONTH_CHOICES = ( ("DECEMBER", "December"), ) -month = models.CharField(max_length=9, - choices=MONTH_CHOICES, - default="JANUARY") +month = models.CharField(max_length=9, choices=MONTH_CHOICES, default="JANUARY") + class Announcement(TimeInfoBase): class Status(models.TextChoices): @@ -131,7 +143,9 @@ class Conversation(TimeInfoBase): title = models.CharField( max_length=64, help_text="The title for the conversation", default="" ) - deleted = models.BooleanField(help_text="This is to hide conversations", default=False) + deleted = models.BooleanField( + help_text="This is to hide conversations", default=False + ) def get_user_email(self): if self.user: @@ -151,20 +165,26 @@ class Prompt(TimeInfoBase): conversation = models.ForeignKey( "Conversation", on_delete=models.CASCADE, blank=True, null=True ) - file =models.FileField(upload_to=FILE_STORAGE, blank=True, null=True, help_text="file for the prompt") - file_type=models.CharField(max_length=16, blank=True, null=True, help_text='file type of the file for the prompt') + file = models.FileField( + upload_to=FILE_STORAGE, blank=True, null=True, help_text="file for the prompt" + ) + file_type = models.CharField( + max_length=16, + blank=True, + null=True, + help_text="file type of the file for the prompt", + ) def get_conversation_title(self): if self.conversation: return self.conversation.title else: return "" - + def file_exists(self): return self.file != None and self.file.storage.exists(self.file.name) - class PromptMetric(TimeInfoBase): PROMPT_METRIC_CHOICES = ( ("CREATED", "Created"), @@ -174,20 +194,40 @@ class PromptMetric(TimeInfoBase): ("MAX_PROMPT_METRIC_CHOICES", "Max Prompt Metric Choices"), ) prompt_id = models.IntegerField(help_text="The id of the prompt this matches to") - conversation_id = models.IntegerField(help_text="The id of the conversation this matches to") + conversation_id = models.IntegerField( + help_text="The id of the conversation this matches to" + ) event = models.CharField( - max_length=26, choices=PROMPT_METRIC_CHOICES, default='CREATED' + max_length=26, choices=PROMPT_METRIC_CHOICES, default="CREATED" ) model_name = models.CharField(max_length=215, help_text="The name of the model") start_time = models.DateTimeField() end_time = models.DateTimeField(blank=True, null=True) - prompt_length = models.IntegerField( help_text="How many characters are in the prompt") - reponse_length = models.IntegerField(blank=True, null=True, help_text="How many characters are in the response") + prompt_length = models.IntegerField( + help_text="How many characters are in the prompt" + ) + reponse_length = models.IntegerField( + blank=True, null=True, help_text="How many characters are in the response" + ) has_file = models.BooleanField(help_text="Is there a file") - file_type = models.CharField(max_length=16, help_text='The file type, if any', blank=True, null=True) + file_type = models.CharField( + max_length=16, help_text="The file type, if any", blank=True, null=True + ) def get_duration(self): - if(self.start_time and self.end_time): - difference =self.end_time - self.start_time + if self.start_time and self.end_time: + difference = self.end_time - self.start_time return difference.seconds return 0 + +# Document Models +class DocumentWorkspace(TimeInfoBase): + name = models.CharField(max_length=255) + company = models.ForeignKey(Company, on_delete=models.CASCADE) + +class Document(TimeInfoBase): + workspace = models.ForeignKey(DocumentWorkspace, on_delete=models.CASCADE) + file = models.FileField(upload_to='documents/') + uploaded_at = models.DateTimeField(auto_now_add=True) + processed = models.BooleanField(default=False) + active = models.BooleanField(default=False) diff --git a/llm_be/chat_backend/renderers.py b/llm_be/chat_backend/renderers.py index f611bb0..e7f30ae 100644 --- a/llm_be/chat_backend/renderers.py +++ b/llm_be/chat_backend/renderers.py @@ -1,8 +1,9 @@ from rest_framework.renderers import BaseRenderer + class ServerSentEventRenderer(BaseRenderer): - media_type = 'text/event-stream' - format = 'txt' + media_type = "text/event-stream" + format = "txt" def render(self, data, accepted_media_type=None, renderer_context=None): - return data \ No newline at end of file + return data diff --git a/llm_be/chat_backend/routing.py b/llm_be/chat_backend/routing.py index 70829b5..dd6acd3 100644 --- a/llm_be/chat_backend/routing.py +++ b/llm_be/chat_backend/routing.py @@ -1,7 +1,6 @@ -from django.urls import re_path +from django.urls import re_path from .views import ChatConsumerAgain - -websocket_urlpatterns = [ - re_path(r'ws/chat_again/$', ChatConsumerAgain.as_asgi()), -] \ No newline at end of file +websocket_urlpatterns = [ + re_path(r"ws/chat_again/$", ChatConsumerAgain.as_asgi()), +] diff --git a/llm_be/chat_backend/serializers.py b/llm_be/chat_backend/serializers.py index 07c0dcc..c4906ec 100644 --- a/llm_be/chat_backend/serializers.py +++ b/llm_be/chat_backend/serializers.py @@ -1,6 +1,16 @@ from rest_framework_simplejwt.serializers import TokenObtainPairSerializer from rest_framework import serializers -from .models import CustomUser, Announcement, Company, Conversation, Prompt, Feedback, FEEDBACK_CATEGORIES +from .models import ( + CustomUser, + Announcement, + Company, + Conversation, + Prompt, + Feedback, + FEEDBACK_CATEGORIES, + DocumentWorkspace, + Document +) class MyTokenObtainPairSerializer(TokenObtainPairSerializer): @@ -25,11 +35,13 @@ class AnnouncmentSerializer(serializers.ModelSerializer): model = Announcement fields = "__all__" + class FeedbackSerializer(serializers.ModelSerializer): class Meta: model = Feedback fields = "__all__" + class CustomUserSerializer(serializers.ModelSerializer): email = serializers.EmailField(required=True) username = serializers.CharField() @@ -58,12 +70,40 @@ class ConversationSerializer(serializers.ModelSerializer): class PromptSerializer(serializers.ModelSerializer): - + class Meta: model = Prompt - fields = ("message", "user_created", "created", "id", ) + fields = ( + "message", + "user_created", + "created", + "id", + ) + class BasicUserSerializer(serializers.ModelSerializer): class Meta: model = CustomUser - fields = ("email", "first_name", "last_name", "is_active","has_usable_password","is_company_manager",'has_signed_tos') \ No newline at end of file + fields = ( + "email", + "first_name", + "last_name", + "is_active", + "has_usable_password", + "is_company_manager", + "has_signed_tos", + ) + + +# document serializers +class DocumentWorkspaceSerializer(serializers.ModelSerializer): + class Meta: + model = DocumentWorkspace + fields = ['id', 'name', 'created'] + read_only_fields = ['id', 'created'] + +class DocumentSerializer(serializers.ModelSerializer): + class Meta: + model = Document + fields = ['id', 'workspace', 'file', 'uploaded_at', 'processed', 'created', 'active'] + read_only_fields = ['id', 'uploaded_at', 'processed', 'created'] \ No newline at end of file diff --git a/llm_be/chat_backend/services/__init__.py b/llm_be/chat_backend/services/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/llm_be/chat_backend/services/image_generation.py b/llm_be/chat_backend/services/image_generation.py new file mode 100644 index 0000000..f6d95d6 --- /dev/null +++ b/llm_be/chat_backend/services/image_generation.py @@ -0,0 +1,145 @@ +import os +import logging +from typing import Optional, Tuple +from PIL import Image +import torch +from diffusers import StableDiffusionPipeline, DPMSolverSinglestepScheduler + +logger = logging.getLogger(__name__) + +class ImageGenerationService: + """ + Service for text-to-image generation using Stable Diffusion. + Uses singleton pattern to maintain loaded model in memory. + """ + + _instance = None + _model_loaded = False + + def __new__(cls): + if cls._instance is None: + cls._instance = super().__new__(cls) + cls._instance._initialize() + return cls._instance + + def _initialize(self): + """Initialize the service with default settings""" + self.device = "cuda" if torch.cuda.is_available() else "cpu" + self.model_id = "stabilityai/stable-diffusion-2-1" + self.pipeline = None + self.default_params = { + "num_inference_steps": 25, + "guidance_scale": 7.5, + "width": 512, + "height": 512, + } + + def load_model(self): + """Load the Stable Diffusion model""" + if self._model_loaded: + return + + try: + logger.info(f"Loading Stable Diffusion model on {self.device}...") + + # Use DPMSolver for faster inference + self.pipeline = StableDiffusionPipeline.from_pretrained( + self.model_id, + torch_dtype=torch.float16 if self.device == "cuda" else torch.float32, + ) + self.pipeline.scheduler = DPMSolverSinglestepScheduler.from_config( + self.pipeline.scheduler.config + ) + self.pipeline = self.pipeline.to(self.device) + + # Optimizations + if self.device == "cuda": + self.pipeline.enable_attention_slicing() + self.pipeline.enable_xformers_memory_efficient_attention() + + self._model_loaded = True + logger.info("Model loaded successfully") + + except Exception as e: + logger.error(f"Failed to load model: {str(e)}") + raise RuntimeError(f"Model loading failed: {str(e)}") + + def generate_image( + self, + prompt: str, + negative_prompt: Optional[str] = None, + output_path: Optional[str] = None, + **kwargs + ) -> Tuple[Image.Image, dict]: + """ + Generate image from text prompt. + + Args: + prompt: Text prompt for image generation + negative_prompt: Text for things to avoid in generation + output_path: Optional path to save the image + **kwargs: Generation parameters (overrides defaults) + + Returns: + Tuple of (PIL.Image, generation_parameters) + """ + if not self._model_loaded: + self.load_model() + + # Merge default params with overrides + params = {**self.default_params, **kwargs} + + try: + logger.info(f"Generating image with prompt: {prompt[:50]}...") + + with torch.inference_mode(): + result = self.pipeline( + prompt=prompt, + negative_prompt=negative_prompt, + **params + ) + + image = result.images[0] + + if output_path: + os.makedirs(os.path.dirname(output_path), exist_ok=True) + image.save(output_path) + logger.info(f"Image saved to {output_path}") + + return image, params + + except Exception as e: + logger.error(f"Image generation failed: {str(e)}") + raise RuntimeError(f"Image generation failed: {str(e)}") + + +class AsyncImageGenerationService: + """ + Asynchronous wrapper for image generation service. + Runs the synchronous service in a thread pool. + """ + + def __init__(self): + self.sync_service = ImageGenerationService() + + async def generate_image( + self, + prompt: str, + negative_prompt: Optional[str] = None, + output_path: Optional[str] = None, + **kwargs + ) -> Tuple[Image.Image, dict]: + """Async version of generate_image""" + import asyncio + from functools import partial + + loop = asyncio.get_event_loop() + func = partial( + self.sync_service.generate_image, + prompt=prompt, + negative_prompt=negative_prompt, + output_path=output_path, + **kwargs + ) + + return await loop.run_in_executor(None, func) \ No newline at end of file diff --git a/llm_be/chat_backend/services/llm_service.py b/llm_be/chat_backend/services/llm_service.py new file mode 100644 index 0000000..1d24f08 --- /dev/null +++ b/llm_be/chat_backend/services/llm_service.py @@ -0,0 +1,138 @@ +from abc import ABC, abstractmethod +from typing import AsyncGenerator, Generator, Optional + +from langchain_community.llms import Ollama +from langchain_core.output_parsers import StrOutputParser +from langchain_core.prompts import ChatPromptTemplate + +from chat_backend.models import Conversation, Prompt + +class LLMService(ABC): + """Abstract base class for LLM conversation services.""" + + def __init__(self): + self.llm = Ollama( + model="llama3.2", + temperature=0.7, + top_k=50, + top_p=0.9, + repeat_penalty=1.1, + num_ctx=4096 + ) + self.output_parser = StrOutputParser() + + @abstractmethod + def generate_response(self, conversation: Conversation, query: str, **kwargs): + """Generate a response to a query within a conversation context.""" + pass + + def _format_history(self, conversation: Conversation) -> str: + """Format conversation history for the prompt.""" + prompts = Prompt.objects.filter(conversation=conversation).order_by('created_at') + return "\n".join( + f"{'User' if prompt.is_user else 'AI'}: {prompt.text}" + for prompt in prompts + ) + + +class SyncLLMService(LLMService): + """Synchronous LLM conversation service.""" + + def __init__(self): + super().__init__() + self._setup_chain() + + def _setup_chain(self): + """Setup the conversation chain.""" + template = """Continue the conversation based on the following history: + + {history} + + Latest message: {query} + + Response:""" + self.prompt = ChatPromptTemplate.from_template(template) + + self.conversation_chain = ( + { + "history": lambda x: self._format_history(x["conversation"]), + "query": lambda x: x["query"] + } + | self.prompt + | self.llm + | self.output_parser + ) + + def generate_response(self, conversation: Conversation, query: str, **kwargs) -> Generator[str, None, None]: + """Generate response with streaming support.""" + chain_input = { + "query": query, + "conversation": conversation + } + + for chunk in self.conversation_chain.stream(chain_input): + yield chunk + + +class AsyncLLMService(LLMService): + """Asynchronous LLM conversation service.""" + + def __init__(self): + super().__init__() + self._setup_chain() + + def _setup_chain(self): + """Setup the conversation chain.""" + template = """Continue this conversation while maintaining context by providing a single helpful response. + Current context: {context} + + Last 3 messages: + {recent_history} + + Latest message: {query} + + Instructions: + - Carefully maintain all established context + - If referencing previous elements (like stories), preserve all details + - When asked to modify something, identify what's being modified + + Response:""" + + self.prompt = ChatPromptTemplate.from_template(template) + + self.conversation_chain = ( + { + "context": lambda x: self._format_history(x["conversation"]), + "recent_history": lambda x: self._get_recent_messages(x["conversation"]), + "query": lambda x: x["query"] + } + | self.prompt + | self.llm + | self.output_parser + ) + + async def _format_history(self, conversation: Conversation) -> str: + """Async version of format conversation history.""" + prompts = await Prompt.objects.filter(conversation=conversation).order_by('created_at').alist() + return "\n".join( + f"{'User' if prompt.is_user else 'AI'}: {prompt.text}" + for prompt in prompts + ) + + async def _get_recent_messages(self, conversation: Conversation) -> str: + """Async version of format conversation history.""" + prompts = await Prompt.objects.filter(conversation=conversation).order_by('created_at').alist()[-3:] + return "\n".join( + f"{'User' if prompt.is_user else 'AI'}: {prompt.text}" + for prompt in prompts + ) + + async def generate_response(self, conversation: Conversation, query: str, **kwargs) -> AsyncGenerator[str, None]: + """Generate response with async streaming support.""" + chain_input = { + "query": query, + "conversation": conversation + } + + async for chunk in self.conversation_chain.astream(chain_input): + yield chunk \ No newline at end of file diff --git a/llm_be/chat_backend/services/moderation_classifier.py b/llm_be/chat_backend/services/moderation_classifier.py new file mode 100644 index 0000000..d66968b --- /dev/null +++ b/llm_be/chat_backend/services/moderation_classifier.py @@ -0,0 +1,79 @@ +from enum import Enum, auto +from typing import Dict, Any +from langchain_core.prompts import ChatPromptTemplate +from langchain_community.llms import Ollama + +class ModerationLabel(Enum): + NSFW = auto() + FINE = auto() + +class ModerationClassifier: + """ + Classifies prompts as NSFW or FINE (safe) content. + """ + + def __init__(self): + self.llm = Ollama( + model="llama3.2", + temperature=0.1, # Very low for strict moderation + top_k=10, + num_ctx=2048 + ) + + self.moderation_prompt = ChatPromptTemplate.from_messages([ + ("system", """You are a strict content moderator. Classify the following prompt as either NSFW or FINE. + +NSFW includes: +- Sexual content +- Violence/gore +- Hate speech +- Illegal activities +- Harassment +- Graphic/disturbing content + +FINE includes: +- Safe for work topics +- General conversation +- Professional inquiries +- Creative requests (non-explicit) +- Technical questions + +Examples: +- "How to make a bomb" → NSFW +- "Write a love poem" → FINE +- "Explicit sex scene" → NSFW +- "Python tutorial" → FINE + +Return ONLY "NSFW" or "FINE", nothing else."""), + ("human", "{prompt}") + ]) + + self.chain = self.moderation_prompt | self.llm + + async def classify_async(self, prompt: str) -> ModerationLabel: + """Asynchronous classification""" + try: + response = (await self.chain.ainvoke({"prompt": prompt})).strip().upper() + return self._parse_response(response) + except Exception as e: + print(f"Moderation error: {e}") + return ModerationLabel.NSFW # Fail-safe to NSFW + + def classify(self, prompt: str) -> ModerationLabel: + """Synchronous classification""" + try: + response = self.chain.invoke({"prompt": prompt}).strip().upper() + return self._parse_response(response) + except Exception as e: + print(f"Moderation error: {e}") + return ModerationLabel.NSFW # Fail-safe to NSFW + + def _parse_response(self, response: str) -> ModerationLabel: + """Convert string response to ModerationLabel enum""" + if "NSFW" in response: + return ModerationLabel.NSFW + return ModerationLabel.FINE # Default to FINE if unclear + + +# Singleton instance +moderation_classifier = ModerationClassifier() \ No newline at end of file diff --git a/llm_be/chat_backend/services/prompt_classifier.py b/llm_be/chat_backend/services/prompt_classifier.py new file mode 100644 index 0000000..4548f46 --- /dev/null +++ b/llm_be/chat_backend/services/prompt_classifier.py @@ -0,0 +1,100 @@ +from enum import Enum, auto +from typing import Dict, Any +from langchain_core.prompts import ChatPromptTemplate +from langchain_community.llms import Ollama + +class PromptType(Enum): + GENERAL_CHAT = auto() + RAG = auto() + IMAGE_GENERATION = auto() + UNKNOWN = auto() + +class PromptClassifier: + """ + Classifies user prompts to determine which service should handle them. + """ + + def __init__(self): + self.llm = Ollama( + model="llama3", + temperature=0.3, # Lower temp for more deterministic classification + top_k=20, + top_p=0.9, + num_ctx=4096 + ) + + self.classification_prompt = ChatPromptTemplate.from_messages([ + ("system", + """You are a precision prompt classifier. Strictly categorize prompts into: +1. GENERAL_CHAT - Casual conversation, personal questions, or non-specific inquiries +2. RAG - ONLY when explicitly requesting document/search-based knowledge +3. IMAGE_GENERATION - Specific requests to create/modify images +4. UNKNOWN - If none of the above fit + +1. IMAGE_GENERATION - ONLY if: + - Explicitly contains: "generate/create/draw/make an image/picture/photo/art/illustration" + - Requests visual content creation + - Example: "Make a picture of a castle" → IMAGE_GENERATION + +2. RAG - ONLY if: + - Explicitly mentions documents/files/data + - Uses search terms: "find/search/lookup in [source]" + - Example: "What does contracts.pdf say?" → RAG + +3. GENERAL_CHAT - DEFAULT category when: + - Doesn't meet above criteria + - Conversational/general knowledge + - Uncertain cases + - Example: "Tell me a joke" → GENERAL_CHAT + +Examples: +[Definitely RAG] +- "What does the uploaded PDF say about quarterly results?" +- "Search our documents for the 2023 marketing strategy" +- "Find the contract clause about termination" + +[Definitely GENERAL_CHAT] +- "How does photosynthesis work?" (General knowledge) +- "Tell me a joke" +- "What's your opinion on AI?" + +[Borderline → GENERAL_CHAT] +- "What's our company policy on X?" (No doc reference → general) +- "Explain quantum computing" (General knowledge) +- "Summarize the meeting" (No doc reference) + +Return ONLY the label, no explanations."""), + ("human", "{prompt}") + ]) + + self.chain = self.classification_prompt | self.llm + + async def classify_async(self, prompt: str) -> PromptType: + """Asynchronously classify the prompt""" + try: + response = await self.chain.ainvoke({"prompt": prompt}) + return self._parse_response(response.strip()) + except Exception as e: + print(f"Classification error: {e}") + return PromptType.UNKNOWN + + def classify(self, prompt: str) -> PromptType: + """Synchronously classify the prompt""" + try: + response = self.chain.invoke({"prompt": prompt}) + return self._parse_response(response.strip()) + except Exception as e: + print(f"Classification error: {e}") + return PromptType.UNKNOWN + + def _parse_response(self, response: str) -> PromptType: + """Convert string response to PromptType enum""" + response = response.upper() + for prompt_type in PromptType: + if prompt_type.name in response: + return prompt_type + return PromptType.UNKNOWN + + +# Singleton instance for easy access +prompt_classifier = PromptClassifier() \ No newline at end of file diff --git a/llm_be/chat_backend/services/rag_services.py b/llm_be/chat_backend/services/rag_services.py new file mode 100644 index 0000000..4608822 --- /dev/null +++ b/llm_be/chat_backend/services/rag_services.py @@ -0,0 +1,378 @@ +import os +from abc import ABC, abstractmethod +from typing import List, Dict, Any, AsyncGenerator, Generator, Optional +from channels.db import database_sync_to_async +from langchain_community.embeddings import OllamaEmbeddings +from langchain_community.llms import Ollama +from langchain_community.vectorstores import Chroma +from langchain_core.documents import Document as LangDocument +from langchain_core.output_parsers import StrOutputParser +from langchain_core.prompts import ChatPromptTemplate +from langchain_core.runnables import RunnablePassthrough +from langchain_text_splitters import RecursiveCharacterTextSplitter +from langchain_community.document_loaders import ( + PyPDFLoader, + Docx2txtLoader, + TextLoader, + UnstructuredFileLoader +) +from django.core.files.uploadedfile import UploadedFile +from chat_backend.models import Conversation, Prompt, DocumentWorkspace, Document +from pathlib import Path + +@database_sync_to_async +def get_documents(workspace: DocumentWorkspace | None = None): + if workspace: + return [doc for doc in Document.objects.filter(workspace=workspace)] + else: + return [doc for doc in Document.objects.all()] + + + +class RAGService(ABC): + """Abstract base class for RAG services.""" + _instance = None + + def __new__(cls): + if cls._instance is None: + cls._instance = super().__new__(cls) + cls._instance.__init__() + return cls._instance + + def __init__(self): + self.embedding_model = OllamaEmbeddings(model="llama3.2") + self.llm = Ollama( + model="llama3.2", + temperature=0.7, + top_k=50, + top_p=0.9, + repeat_penalty=1.1, + num_ctx=4096 + ) + self.text_splitter = RecursiveCharacterTextSplitter( + chunk_size=1000, + chunk_overlap=200 + ) + self.vector_store = self._initialize_vector_store() + + # Supported file types and their loaders + self.loader_mapping = { + '.pdf': PyPDFLoader, + '.docx': Docx2txtLoader, + '.txt': TextLoader, + # Fallback for other file types + '*': UnstructuredFileLoader, + } + + def _initialize_vector_store(self) -> Chroma: + """Initialize and return the Chroma vector store.""" + persist_directory=f"./chroma_db/" + vector_store = Chroma( + embedding_function=self.embedding_model, + persist_directory=persist_directory + ) + return vector_store + + def clear_vector_store(self): + """Clear all vectors from the store""" + self.vector_store.delete_collection() + self.vector_store = self._initialize_vector_store() + + def _prepare_documents(self, documents: List[Document]) -> List[Document]: + """Process documents for ingestion into vector store.""" + docs = [] + + for doc in documents: + print(f"Processing: {doc.file.name}") + loader_class = self._get_file_loader( doc.file.name) + loader = loader_class(doc.file) + + + chunks = self._load_and_split_documents(doc.file.path) + if chunks: + self.vector_store.add_documents(chunks) + self.vector_store.persist() + + + def ingest_documents(self, workspace: DocumentWorkspace | None = None) -> None: + """Ingest documents from a workspace into the vector store.""" + print(f"Getting the Document via the workspace: {workspace}") + if workspace: + documents = [doc for doc in Document.objects.filter(workspace=workspace)] + else: + documents = [doc for doc in Document.objects.all()] + + print(f"Processing the documents : {documents}") + self._prepare_documents(documents) + + + @abstractmethod + def generate_response(self, conversation: Conversation, query: str, **kwargs): + """Generate a response using RAG.""" + pass + + @abstractmethod + def search_documents(self, query: str, workspace: Optional[DocumentWorkspace] = None, k: int = 4) -> List[Document]: + """Search relevant documents from the vector store.""" + pass + + def _get_file_loader(self, file_path: str): + """Get appropriate loader for file type""" + ext = Path(file_path).suffix.lower() + return self.loader_mapping.get(ext, self.loader_mapping['*']) + + def _sanitize_filename(self, filename: str) -> str: + """Sanitize filename for safe storage""" + return re.sub(r'[^\w\-_. ]', '_', filename) + + def _save_uploaded_file(self, uploaded_file: UploadedFile, save_dir: str) -> str: + """Save uploaded file to disk""" + os.makedirs(save_dir, exist_ok=True) + sanitized_name = self._sanitize_filename(uploaded_file.name) + file_path = os.path.join(save_dir, sanitized_name) + + with open(file_path, 'wb+') as destination: + for chunk in uploaded_file.chunks(): + destination.write(chunk) + + return file_path + + def _load_and_split_documents(self, file_path: str, metadata: dict = None) -> List[Document]: + """Load and split documents from file""" + loader_class = self._get_file_loader(file_path) + loader = loader_class(file_path) + + docs = loader.load() + if metadata: + for doc in docs: + doc.metadata.update(metadata) + + return self.text_splitter.split_documents(docs) + + def add_files_to_store( + self, + file_tupls: List[UploadedFile], # (file_path, name,workspace_id) + workspace_id: str, + source: str = "upload", + save_dir: str = "data/uploads" + ) -> Dict[str, Any]: + """ + Process and add uploaded files to vector store + + Args: + files: List of Django UploadedFile objects + workspace_id: ID of the workspace these belong to + source: Source identifier for documents + save_dir: Directory to save uploaded files + + Returns: + Dictionary with processing results + """ + results = { + 'total_added': 0, + 'failed_files': [], + 'processed_files': [] + } + + for file_tuple in file_tupls: + try: + # Save file to disk + + + # Prepare metadata + metadata = { + 'source': file_tuple[1], + 'workspace_id': file_tuple[2], + 'original_filename': file_tuple[1], + 'file_path': file_tuple[0], + } + + # Load and split documents + docs = self._load_and_split_documents(file_path, metadata) + + # Add to vector store + if docs: + self.vector_store.add_documents(docs) + results['total_added'] += len(docs) + results['processed_files'].append({ + 'filename': file_tuple[1], + 'document_count': len(docs) + }) + + except Exception as e: + results['failed_files'].append({ + 'filename': file_tuple[1], + 'error': str(e) + }) + continue + + # Persist changes + self.vector_store.persist() + return results + + +class SyncRAGService(RAGService): + """Synchronous RAG service implementation.""" + + def __init__(self): + super().__init__() + self._setup_chain() + + def _setup_chain(self): + """Setup the RAG chain.""" + template = """Answer the question based only on the following context: + {context} + + Conversation history: + {history} + + Question: {question} + """ + self.prompt = ChatPromptTemplate.from_template(template) + + self.rag_chain = ( + { + "context": self._retriever_with_history, + "history": lambda x: self._format_history(x["conversation"]), + "question": lambda x: x["query"] + } + | self.prompt + | self.llm + | StrOutputParser() + ) + + def _format_history(self, conversation: Conversation) -> str: + """Format conversation history for the prompt.""" + prompts = Prompt.objects.filter(conversation=conversation).order_by('created_at') + return "\n".join( + f"{'User' if prompt.is_user else 'AI'}: {prompt.text}" + for prompt in prompts + ) + + def _retriever_with_history(self, input_dict: Dict[str, Any]) -> str: + """Retrieve documents considering conversation history.""" + query = input_dict["query"] + conversation = input_dict["conversation"] + + # You could enhance this to consider historical context in retrieval + relevant_docs = self.search_documents(query, conversation.workspace) + if not relevant_docs: + print("didn't find any relevant docs") + return relevant_docs + else: + return relevant_docs + + + def search_documents(self, query: str, workspace: Optional[DocumentWorkspace] = None, k: int = 4) -> List[Document]: + """Search relevant documents from the vector store.""" + filter_dict = {} + if workspace: + filter_dict["workspace_id"] = workspace.id + print(f"search_kwargs: {search_kwargs}") + retriever = self.vector_store.as_retriever( + search_type="similarity", + search_kwargs={ + "k": k, + "filter": filter_dict if filter_dict else None + } + ) + return retriever.get_relevant_documents(query) + + def generate_response(self, conversation: Conversation, query: str, **kwargs) -> Generator[str, None, None]: + """Generate response with streaming support.""" + chain_input = { + "query": query, + "conversation": conversation + } + + for chunk in self.rag_chain.stream(chain_input): + yield chunk + + +class AsyncRAGService(RAGService): + """Asynchronous RAG service implementation.""" + + def __init__(self): + super().__init__() + self._setup_chain() + + def _setup_chain(self): + """Setup the RAG chain.""" + template = """Answer the question based only on the following context: + {context} + + Conversation history: + {history} + + Question: {question} + """ + self.prompt = ChatPromptTemplate.from_template(template) + + self.rag_chain = ( + { + "context": self._retriever_with_history, + "history": lambda x: self._format_history(x["conversation"]), + "question": lambda x: x["query"] + } + | self.prompt + | self.llm + | StrOutputParser() + ) + + async def _format_history(self, conversation: Conversation) -> str: + """Format conversation history for the prompt.""" + prompts = await Prompt.objects.filter(conversation=conversation).order_by('created_at').alist() + print(f"prompts that we are seeding with are: {prompts}") + return "\n".join( + f"{'User' if prompt.is_user else 'AI'}: {prompt.text}" + for prompt in prompts + ) + + async def _retriever_with_history(self, input_dict: Dict[str, Any]) -> str: + """Retrieve documents considering conversation history.""" + print(f"Retrieving history with input: {input_dict}") + query = input_dict["query"] + conversation = input_dict["conversation"] + workspace = input_dict["workspace"] + + # You could enhance this to consider historical context in retrieval + docs= await self.search_documents(query, workspace) + + if not docs: + print("Didn't find any relevant docs") + + print("\n\n".join(doc.page_content for doc in docs)) + return "\n\n".join(doc.page_content for doc in docs) + + + async def search_documents(self, query: str, workspace: Optional[DocumentWorkspace] = None, k: int = 4) -> List[Document]: + """Search relevant documents from the vector store.""" + filter_dict = {} + print(f"Do we have a workspace: {workspace}") + if workspace: + filter_dict["workspace_id"] = workspace.id + search_kwargs={ + "k": k, + "filter": filter_dict if filter_dict else None + } + print(f"search_kwargs: {search_kwargs}") + + retriever = self.vector_store.as_retriever( + search_type="mmr", + search_kwargs={ + "k": k, + "filter": filter_dict if filter_dict else None + } + ) + return await retriever.aget_relevant_documents(query) + + async def generate_response(self, conversation: Conversation, query: str, workspace: DocumentWorkspace, **kwargs) -> AsyncGenerator[str, None]: + """Generate response with streaming support.""" + chain_input = { + "query": query, + "conversation": conversation, + "workspace": workspace, + } + + async for chunk in self.rag_chain.astream(chain_input): + yield chunk diff --git a/llm_be/chat_backend/services/tests.py b/llm_be/chat_backend/services/tests.py new file mode 100644 index 0000000..8da5303 --- /dev/null +++ b/llm_be/chat_backend/services/tests.py @@ -0,0 +1,219 @@ +import os +from unittest import TestCase, mock +from unittest.mock import MagicMock, patch, AsyncMock +from typing import List, Dict, Any + +from django.test import TestCase as DjangoTestCase + +from chat_backend.services.rag_services import RAGService, SyncRAGService, AsyncRAGService +from chat_backend.models import Conversation, Prompt, DocumentWorkspace, Document + +class TestRAGService(TestCase): + def setUp(self): + self.rag_service = RAGService() + self.rag_service.vector_store = MagicMock() + self.rag_service.embedding_model = MagicMock() + self.rag_service.text_splitter = MagicMock() + + def test_initialize_vector_store(self): + with patch('os.path.exists', return_value=False), \ + patch('os.makedirs') as mock_makedirs, \ + patch('langchain_community.vectorstores.Chroma') as mock_chroma: + + # Reset the vector store to test initialization + self.rag_service.vector_store = None + result = self.rag_service._initialize_vector_store() + + mock_makedirs.assert_called_once_with("chroma_db") + mock_chroma.assert_called_once_with( + embedding_function=self.rag_service.embedding_model, + persist_directory="chroma_db" + ) + self.assertIsNotNone(result) + + def test_prepare_documents(self): + mock_doc1 = MagicMock(spec=Document) + mock_doc1.content = "Test content" + mock_doc1.source = "test_source" + mock_doc1.workspace = MagicMock() + mock_doc1.workspace.id = 1 + mock_doc1.id = 1 + + self.rag_service.text_splitter.split_text.return_value = ["chunk1", "chunk2"] + + result = self.rag_service._prepare_documents([mock_doc1]) + + self.assertEqual(len(result), 2) + self.rag_service.text_splitter.split_text.assert_called_once_with("Test content") + self.assertEqual(result[0].page_content, "chunk1") + self.assertEqual(result[0].metadata["source"], "test_source") + + def test_ingest_documents(self): + mock_workspace = MagicMock() + mock_document = MagicMock() + mock_documents = [mock_document] + + with patch('services.rag_services.Document.objects.filter', return_value=mock_documents): + self.rag_service._prepare_documents = MagicMock(return_value=["processed_doc"]) + + self.rag_service.ingest_documents(mock_workspace) + + self.rag_service.vector_store.add_documents.assert_called_once_with(["processed_doc"]) + self.rag_service.vector_store.persist.assert_called_once() + + +class TestSyncRAGService(DjangoTestCase): + def setUp(self): + self.sync_service = SyncRAGService() + self.sync_service.vector_store = MagicMock() + self.sync_service.llm = MagicMock() + self.sync_service.rag_chain = MagicMock() + + self.mock_conversation = MagicMock(spec=Conversation) + self.mock_conversation.workspace = MagicMock() + + self.mock_prompt1 = MagicMock(spec=Prompt) + self.mock_prompt1.is_user = True + self.mock_prompt1.text = "User question" + self.mock_prompt1.created_at = "2023-01-01" + + self.mock_prompt2 = MagicMock(spec=Prompt) + self.mock_prompt2.is_user = False + self.mock_prompt2.text = "AI response" + self.mock_prompt2.created_at = "2023-01-02" + + def test_format_history(self): + with patch('services.rag_services.Prompt.objects.filter') as mock_filter: + mock_filter.return_value.order_by.return_value = [self.mock_prompt1, self.mock_prompt2] + + result = self.sync_service._format_history(self.mock_conversation) + + expected = "User: User question\nAI: AI response" + self.assertEqual(result, expected) + mock_filter.assert_called_once_with(conversation=self.mock_conversation) + + def test_retriever_with_history(self): + input_dict = { + "query": "test query", + "conversation": self.mock_conversation + } + + self.sync_service.search_documents = MagicMock(return_value=["doc1", "doc2"]) + + result = self.sync_service._retriever_with_history(input_dict) + + self.sync_service.search_documents.assert_called_once_with( + "test query", + self.mock_conversation.workspace + ) + self.assertEqual(result, ["doc1", "doc2"]) + + def test_search_documents(self): + mock_retriever = MagicMock() + mock_retriever.get_relevant_documents.return_value = ["doc1", "doc2"] + self.sync_service.vector_store.as_retriever.return_value = mock_retriever + + result = self.sync_service.search_documents("test query", self.mock_conversation.workspace) + + self.sync_service.vector_store.as_retriever.assert_called_once_with( + search_type="similarity", + search_kwargs={ + "k": 4, + "filter": {"workspace_id": self.mock_conversation.workspace.id} + } + ) + self.assertEqual(result, ["doc1", "doc2"]) + + def test_generate_response(self): + chain_input = { + "query": "test query", + "conversation": self.mock_conversation + } + + mock_stream = ["chunk1", "chunk2", "chunk3"] + self.sync_service.rag_chain.stream.return_value = mock_stream + + result = list(self.sync_service.generate_response(self.mock_conversation, "test query")) + + self.sync_service.rag_chain.stream.assert_called_once_with(chain_input) + self.assertEqual(result, mock_stream) + + +class TestAsyncRAGService(DjangoTestCase): + def setUp(self): + self.async_service = AsyncRAGService() + self.async_service.vector_store = MagicMock() + self.async_service.llm = MagicMock() + self.async_service.rag_chain = AsyncMock() + + self.mock_conversation = MagicMock(spec=Conversation) + self.mock_conversation.workspace = MagicMock() + + self.mock_prompt1 = MagicMock(spec=Prompt) + self.mock_prompt1.is_user = True + self.mock_prompt1.text = "User question" + self.mock_prompt1.created_at = "2023-01-01" + + self.mock_prompt2 = MagicMock(spec=Prompt) + self.mock_prompt2.is_user = False + self.mock_prompt2.text = "AI response" + self.mock_prompt2.created_at = "2023-01-02" + + async def test_format_history(self): + mock_manager = AsyncMock() + mock_manager.order_by.return_value.alist.return_value = [self.mock_prompt1, self.mock_prompt2] + + with patch('services.rag_services.Prompt.objects.filter', return_value=mock_manager): + result = await self.async_service._format_history(self.mock_conversation) + + expected = "User: User question\nAI: AI response" + self.assertEqual(result, expected) + mock_manager.order_by.assert_called_once_with('created_at') + + async def test_retriever_with_history(self): + input_dict = { + "query": "test query", + "conversation": self.mock_conversation + } + + self.async_service.search_documents = AsyncMock(return_value=["doc1", "doc2"]) + + result = await self.async_service._retriever_with_history(input_dict) + + self.async_service.search_documents.assert_awaited_once_with( + "test query", + self.mock_conversation.workspace + ) + self.assertEqual(result, ["doc1", "doc2"]) + + async def test_search_documents(self): + mock_retriever = AsyncMock() + mock_retriever.aget_relevant_documents.return_value = ["doc1", "doc2"] + self.async_service.vector_store.as_retriever.return_value = mock_retriever + + result = await self.async_service.search_documents("test query", self.mock_conversation.workspace) + + self.async_service.vector_store.as_retriever.assert_called_once_with( + search_type="similarity", + search_kwargs={ + "k": 4, + "filter": {"workspace_id": self.mock_conversation.workspace.id} + } + ) + self.assertEqual(result, ["doc1", "doc2"]) + + async def test_generate_response(self): + chain_input = { + "query": "test query", + "conversation": self.mock_conversation + } + + mock_stream = ["chunk1", "chunk2", "chunk3"] + self.async_service.rag_chain.astream.return_value = mock_stream + + chunks = [] + async for chunk in self.async_service.generate_response(self.mock_conversation, "test query"): + chunks.append(chunk) + + self.async_service.rag_chain.astream.assert_awaited_once_with(chain_input) + self.assertEqual(chunks, mock_stream) \ No newline at end of file diff --git a/llm_be/chat_backend/services/title_generator.py b/llm_be/chat_backend/services/title_generator.py new file mode 100644 index 0000000..f8209d1 --- /dev/null +++ b/llm_be/chat_backend/services/title_generator.py @@ -0,0 +1,67 @@ +from langchain_core.prompts import ChatPromptTemplate +from langchain_community.llms import Ollama +from typing import Optional + +class TitleGenerator: + """ + Generates short, descriptive titles for conversations based on the first prompt. + """ + + def __init__(self): + self.llm = Ollama( + model="llama3", + temperature=0.5, # Slightly creative but not too random + top_k=20, + num_ctx=2048 # Shorter context needed for titles + ) + + self.title_prompt = ChatPromptTemplate.from_messages([ + ("system", """You are a conversation title generator. Create a very short (2-5 word) title based on the user's first message. + +Rules: +1. Keep it extremely concise +2. Capture the main topic or intent +3. Use title case +4. No quotes or punctuation +5. Never exceed 5 words + +Examples: +- "What's the weather today?" → "Weather Inquiry" +- "Explain quantum computing" → "Quantum Computing Explanation" +- "Generate an image of a dragon" → "Dragon Image Generation" +- "Find our company's privacy policy" → "Privacy Policy Search" + +Return ONLY the title, nothing else."""), + ("human", "{prompt}") + ]) + + self.chain = self.title_prompt | self.llm + + async def generate_async(self, prompt: str) -> str: + """Generate title asynchronously""" + try: + response = await self.chain.ainvoke({"prompt": prompt}) + return self._clean_response(response) + except Exception as e: + print(f"Title generation error: {e}") + return "Conversation" + + def generate(self, prompt: str) -> str: + """Generate title synchronously""" + try: + response = self.chain.invoke({"prompt": prompt}) + return self._clean_response(response) + except Exception as e: + print(f"Title generation error: {e}") + return "Conversation" + + def _clean_response(self, response: str) -> str: + """Clean and format the LLM response""" + # Remove any quotes or punctuation + response = response.strip('"\'.!? \n\t') + # Ensure title case and trim + return response.title()[:50] # Hard limit for safety + + +# Singleton instance +title_generator = TitleGenerator() \ No newline at end of file diff --git a/llm_be/chat_backend/signals.py b/llm_be/chat_backend/signals.py new file mode 100644 index 0000000..c08bab1 --- /dev/null +++ b/llm_be/chat_backend/signals.py @@ -0,0 +1,18 @@ +from django.db.models.signals import post_save, post_delete +from django.dispatch import receiver +from chat_backend.models import Document +from .services.rag_services import AsyncRAGService + +@receiver(post_save, sender=Document) +def update_vector_on_save(sender, instance, **kwargs): + """Update vector store when documents are saved""" + + if kwargs.get('created', False): + rag_service = AsyncRAGService() + rag_service.ingest_documents() + +@receiver(post_delete, sender=Document) +def delete_vector_on_remove(sender, instance, **kwargs): + """Handle document deletion by re-indexing the whole workspace""" + rag_service = AsyncRAGService() + rag_service.ingest_documents() \ No newline at end of file diff --git a/llm_be/chat_backend/templates/emails/reset_email.html b/llm_be/chat_backend/templates/emails/reset_email.html new file mode 100644 index 0000000..1a127d7 --- /dev/null +++ b/llm_be/chat_backend/templates/emails/reset_email.html @@ -0,0 +1,97 @@ + + + + + + Reset Password for Chat by AI ML Operations, LLC + + + + + + + +
+ + +
+ + \ No newline at end of file diff --git a/llm_be/chat_backend/templates/emails/reset_email.txt b/llm_be/chat_backend/templates/emails/reset_email.txt new file mode 100644 index 0000000..f805092 --- /dev/null +++ b/llm_be/chat_backend/templates/emails/reset_email.txt @@ -0,0 +1,3 @@ +Password Reset for AI ML Operations, LLC Chat Services + +"Password reset for chat.aimloperations.com. Please use {{ url }} to set your password" \ No newline at end of file diff --git a/llm_be/chat_backend/tests.py b/llm_be/chat_backend/tests.py index 7ce503c..392daa8 100644 --- a/llm_be/chat_backend/tests.py +++ b/llm_be/chat_backend/tests.py @@ -1,3 +1,210 @@ from django.test import TestCase # Create your tests here. +from django.test import TestCase, Client +from django.urls import reverse +from django.contrib.auth.models import User +from rest_framework.test import APIClient, APITestCase +from rest_framework import status +from .models import DocumentWorkspace, Document, Company +from django.contrib.auth import get_user_model +import tempfile +from django.core.files.uploadedfile import SimpleUploadedFile + +# Minimal valid PDF bytes +VALID_PDF_BYTES = ( + b'%PDF-1.3\n' + b'1 0 obj\n' + b'<< /Type /Catalog /Pages 2 0 R >>\n' + b'endobj\n' + b'2 0 obj\n' + b'<< /Type /Pages /Kids [3 0 R] /Count 1 >>\n' + b'endobj\n' + b'3 0 obj\n' + b'<< /Type /Page /Parent 2 0 R /Resources << >> /MediaBox [0 0 612 792] /Contents 4 0 R >>\n' + b'endobj\n' + b'4 0 obj\n' + b'<< /Length 44 >>\n' + b'stream\n' + b'BT /F1 12 Tf 72 720 Td (Test PDF) Tj ET\n' + b'endstream\n' + b'endobj\n' + b'xref\n' + b'0 5\n' + b'0000000000 65535 f \n' + b'0000000009 00000 n \n' + b'0000000058 00000 n \n' + b'0000000117 00000 n \n' + b'0000000223 00000 n \n' + b'trailer\n' + b'<< /Size 5 /Root 1 0 R >>\n' + b'startxref\n' + b'317\n' + b'%%EOF' +) + +class DocumentWorkspaceViewsTestCase(APITestCase): + def setUp(self): + self.company = Company.objects.create( + name="test", + state="IL", + zipcode="60189", + address="1968 Greensboro Dr" + ) + self.user = get_user_model().objects.create_user( + company=self.company, + username='testuser', + password='testpass123', + email="test@test.com", + ) + + self.client = APIClient() + self.client.force_authenticate(user=self.user) + + self.workspace = DocumentWorkspace.objects.create( + company = self.user.company, + name='Test Workspace' + ) + + def test_list_workspaces(self): + url = reverse('document_workspaces') + response = self.client.get(url) + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertEqual(len(response.data), 1) + self.assertEqual(response.data[0]['name'], 'Test Workspace') + + def test_create_workspace(self): + url = reverse('document_workspaces') + data = { + 'name': 'New Workspace' + } + response = self.client.post(url, data, format='json') + self.assertEqual(response.status_code, status.HTTP_201_CREATED) + self.assertEqual(DocumentWorkspace.objects.count(), 2) + + def test_retrieve_workspace(self): + url = reverse('document_workspaces') + response = self.client.get(url) + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertEqual(response.data[0]['name'], 'Test Workspace') + + # def test_update_workspace(self): + # url = reverse('document_workspaces') + # data = { + # 'name': 'Updated Workspace' + # } + # response = self.client.post(url, data, format='json') + # self.assertEqual(response.status_code, status.HTTP_201_CREATED) + # self.workspace.refresh_from_db() + # self.assertEqual(self.workspace.name, 'Updated Workspace') + + # def test_delete_workspace(self): + # url = reverse('document_workspaces', args=[self.workspace.id]) + # response = self.client.delete(url) + # self.assertEqual(response.status_code, status.HTTP_204_NO_CONTENT) + # self.assertEqual(DocumentWorkspace.objects.count(), 0) + +class DocumentViewsTestCase(APITestCase): + def setUp(self): + self.company = Company.objects.create( + name="test", + state="IL", + zipcode="60189", + address="1968 Greensboro Dr" + ) + self.user = get_user_model().objects.create_user( + company=self.company, + username='testuser', + password='testpass123', + email="test@test.com", + ) + + self.client = APIClient() + self.client.force_authenticate(user=self.user) + + self.workspace = DocumentWorkspace.objects.create( + company=self.user.company, + name='Test Workspace' + ) + + # Create a test file + self.test_file = SimpleUploadedFile( + "test.pdf", + VALID_PDF_BYTES, + content_type="application/pdf" + ) + + def test_upload_document(self): + url = reverse('documents') + data = { + 'file': self.test_file + } + response = self.client.post(url, data, format='multipart') + self.assertEqual(response.status_code, status.HTTP_201_CREATED) + self.assertEqual(Document.objects.count(), 1) + + document = Document.objects.first() + self.assertEqual(document.workspace.id, self.workspace.id) + self.assertTrue(document.processed) # Should be False initially + + def test_list_documents(self): + # First create a document + Document.objects.create( + workspace=self.workspace, + file=self.test_file + ) + + url = reverse('documents') + response = self.client.get(url) + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertEqual(len(response.data), 1) + self.assertIn('test', response.data[0]['file']) + self.assertIn('pdf', response.data[0]['file']) + + # def test_delete_document(self): + # document = Document.objects.create( + # workspace=self.workspace, + # file=self.test_file + # ) + + # url = reverse('document-detail', args=[document.id]) + # response = self.client.delete(url) + # self.assertEqual(response.status_code, status.HTTP_204_NO_CONTENT) + # self.assertEqual(Document.objects.count(), 0) + + def test_upload_invalid_file(self): + url = reverse('documents') + data = { + 'file': 'not a file' + } + response = self.client.post(url, data, format='multipart') + self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) + + def test_access_other_users_documents(self): + # Create another user + other_company = Company.objects.create( + name="test2", + state="IL", + zipcode="60189", + address="1968 Greensboro Dr" + ) + other_user = get_user_model().objects.create_user( + company=other_company, + username='otheruser', + password='otherpass123', + email="testing2@test.com" + ) + other_workspace = DocumentWorkspace.objects.create( + company = other_user.company, + name='Other Workspace' + ) + other_document = Document.objects.create( + workspace=other_workspace, + file=self.test_file + ) + + # Try to access the other user's document + url = reverse('documents_details', kwargs={"document_id":other_document.id}) + response = self.client.get(url) + self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND) + \ No newline at end of file diff --git a/llm_be/chat_backend/urls.py b/llm_be/chat_backend/urls.py index d32c9c5..83e281a 100644 --- a/llm_be/chat_backend/urls.py +++ b/llm_be/chat_backend/urls.py @@ -14,27 +14,42 @@ from .views import ( ConversationDetailView, CompanyUsersView, SetUserPassword, + ResetUserPassword, ConversationPreferences, UserPromptAnalytics, UserConversationAnalytics, CompanyUsageAnalytics, - AdminAnalytics + AdminAnalytics, + reset_password, + DocumentWorkspaceView, + DocumentUploadView, + DocumentDetailView + ) +from rest_framework.routers import DefaultRouter + urlpatterns = [ path("token/obtain/", CustomObtainTokenView.as_view(), name="token_create"), path("token/refresh/", jwt_views.TokenRefreshView.as_view(), name="token_refresh"), path("user/create/", CustomUserCreate.as_view(), name="create_user"), path("user/invite/", CustomUserInvite.as_view(), name="invite_user"), - path("user/set_password//", SetUserPassword.as_view(), name="set_password"), + path("user/reset_password/", reset_password, name="reset_password"), + path( + "user/set_password//", SetUserPassword.as_view(), name="set_password" + ), path( "blacklist/", LogoutAndBlacklistRefreshTokenForUserView.as_view(), name="blacklist", ), path("user/get/", CustomUserGet.as_view(), name="get_user"), - path("user/acknowledge_tos/", AcknowledgeTermsOfService.as_view(), name="acknowledge_tos"), - path("company_users",CompanyUsersView.as_view(), name="company_users"), + path( + "user/acknowledge_tos/", + AcknowledgeTermsOfService.as_view(), + name="acknowledge_tos", + ), + path("company_users", CompanyUsersView.as_view(), name="company_users"), path("user/is_authenticated/", is_authenticated, name="is_authenticated"), path("announcment/get/", AnnouncmentView.as_view(), name="get_announcments"), path("conversations", ConversationsView.as_view(), name="conversations"), @@ -44,9 +59,32 @@ urlpatterns = [ ConversationDetailView.as_view(), name="conversation_details", ), - path("conversation_preferences", ConversationPreferences.as_view(), name="conversation_preferences"), - path("analytics/user_prompts/", UserPromptAnalytics.as_view(), name="analytics_user_prompts"), - path("analytics/user_conversations/", UserConversationAnalytics.as_view(), name="analytics_user_conversations"), - path("analytics/company_usage/", CompanyUsageAnalytics.as_view(), name="analytics_company_usage"), + path( + "conversation_preferences", + ConversationPreferences.as_view(), + name="conversation_preferences", + ), + path( + "analytics/user_prompts/", + UserPromptAnalytics.as_view(), + name="analytics_user_prompts", + ), + path( + "analytics/user_conversations/", + UserConversationAnalytics.as_view(), + name="analytics_user_conversations", + ), + path( + "analytics/company_usage/", + CompanyUsageAnalytics.as_view(), + name="analytics_company_usage", + ), path("analytics/admin/", AdminAnalytics.as_view(), name="analytics_admin"), + + # document urls + path("document_workspaces/", DocumentWorkspaceView.as_view(), name="document_workspaces"), + path("documents/", DocumentUploadView.as_view(), name="documents"), + path("documents_details/", DocumentDetailView.as_view(), name="documents_details"), + ] + diff --git a/llm_be/chat_backend/utils.py b/llm_be/chat_backend/utils.py index ba6a9f6..2d33454 100644 --- a/llm_be/chat_backend/utils.py +++ b/llm_be/chat_backend/utils.py @@ -1,7 +1,8 @@ import datetime + def last_day_of_month(any_day): - # The day 28 exists in every month. 4 days later, it's always next month - next_month = any_day.replace(day=28) + datetime.timedelta(days=4) - # subtracting the number of the current day brings us back one month - return next_month - datetime.timedelta(days=next_month.day) \ No newline at end of file + # The day 28 exists in every month. 4 days later, it's always next month + next_month = any_day.replace(day=28) + datetime.timedelta(days=4) + # subtracting the number of the current day brings us back one month + return next_month - datetime.timedelta(days=next_month.day) diff --git a/llm_be/chat_backend/views.py b/llm_be/chat_backend/views.py index 7a87583..a6a41b1 100644 --- a/llm_be/chat_backend/views.py +++ b/llm_be/chat_backend/views.py @@ -11,11 +11,22 @@ from .serializers import ( CompanySerializer, ConversationSerializer, PromptSerializer, - FeedbackSerializer + FeedbackSerializer, + DocumentWorkspaceSerializer, + DocumentSerializer ) from rest_framework.views import APIView from rest_framework.response import Response -from .models import CustomUser, Announcement, Conversation, Prompt, Feedback,PromptMetric +from .models import ( + CustomUser, + Announcement, + Conversation, + Prompt, + Feedback, + PromptMetric, + DocumentWorkspace, + Document +) from django.views.decorators.cache import never_cache from django.http import JsonResponse from datetime import datetime @@ -25,7 +36,9 @@ from channels.generic.websocket import AsyncWebsocketConsumer from langchain_ollama.llms import OllamaLLM from langchain_core.prompts import ChatPromptTemplate from langchain_core.messages import HumanMessage, SystemMessage +from langchain.chains import RetrievalQA import re +import os from django.conf import settings import json import base64 @@ -43,15 +56,27 @@ from django.core.files.base import ContentFile import math import datetime import pytz +from langchain_community.embeddings import OllamaEmbeddings from dateutil.relativedelta import relativedelta +from django.views.decorators.csrf import csrf_exempt from .utils import last_day_of_month +from .services.llm_service import AsyncLLMService +from .services.rag_services import AsyncRAGService +from .services.title_generator import title_generator +from .services.moderation_classifier import moderation_classifier, ModerationLabel +from .services.prompt_classifier import prompt_classifier, PromptType -CHANNEL_NAME: str = 'llm_messages' +from langchain.chains import create_retrieval_chain +from langchain.chains.combine_documents import create_stuff_documents_chain +from langchain_ollama import ChatOllama + +CHANNEL_NAME: str = "llm_messages" MODEL_NAME: str = "llama3" + # Create your views here. class CustomObtainTokenView(TokenObtainPairView): permission_classes = (permissions.AllowAny,) @@ -71,80 +96,167 @@ class CustomUserCreate(APIView): return Response(json, status=status.HTTP_201_CREATED) return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST) + def send_invite_email(slug, email_to_invite): + print("Sending invite email") print(f"url : https://www.chat.aimloperations.com/set_password?slug={slug}") url = f"https://www.chat.aimloperations.com/set_password?slug={slug}" subject = "Welcome to AI ML Operations, LLC Chat Services" from_email = "ryan@aimloperations.com" - to=email_to_invite + to = email_to_invite d = {"url": url} - html_content = get_template(r'emails/invite_email.html').render(d) - text_content = get_template(r'emails/invite_email.txt').render(d) - + html_content = get_template(r"emails/invite_email.html").render(d) + text_content = get_template(r"emails/invite_email.txt").render(d) + msg = EmailMultiAlternatives(subject, text_content, from_email, [to]) msg.attach_alternative(html_content, "text/html") msg.send(fail_silently=True) + def send_feedback_email(feedback_obj): + print("Sending feedback email") subject = "New Feedback for Chat by AI ML Operations, LLC" from_email = "ryan@aimloperations.com" - to="ryan@aimloperations.com" + to = "ryan@aimloperations.com" d = {"title": feedback_obj.title, "feedback_text": feedback_obj.text} - html_content = get_template(r'emails/feedback_email.html').render(d) - text_content = get_template(r'emails/feedback_email.txt').render(d) - + html_content = get_template(r"emails/feedback_email.html").render(d) + text_content = get_template(r"emails/feedback_email.txt").render(d) + msg = EmailMultiAlternatives(subject, text_content, from_email, [to]) msg.attach_alternative(html_content, "text/html") msg.send(fail_silently=True) + +def send_password_reset_email(slug, email_to_invite): + print("Sending Password reset email") + url = f"https://www.chat.aimloperations.com/set_password?slug={slug}" + subject = "Password reset for Chat by AI ML Operations, LLC" + from_email = "ryan@aimloperations.com" + to = email_to_invite + d = {"url": url} + html_content = get_template(r"emails/reset_email.html").render(d) + text_content = get_template(r"emails/reset_email.txt").render(d) + msg = EmailMultiAlternatives(subject, text_content, from_email, [to]) + msg.attach_alternative(html_content, "text/html") + msg.send(fail_silently=True) + + class CustomUserInvite(APIView): - http_method_names = ['post'] + http_method_names = ["post"] + def post(self, request, format="json"): def valid_email(email_string): - regex = r'^[a-z0-9]+[\._]?[a-z0-9]+[@]\w+[.]\w+$' - if re.match(regex,email_string): + regex = r"^[a-z0-9]+[\._]?[a-z0-9]+[@]\w+[.]\w+$" + if re.match(regex, email_string): return True else: return False - email_to_invite = request.data['email'] - - if len(email_to_invite) == 0 or not valid_email(email_to_invite) or not request.user.is_company_manager: + email_to_invite = request.data["email"] + + if ( + len(email_to_invite) == 0 + or not valid_email(email_to_invite) + or not request.user.is_company_manager + ): return Response(status=status.HTTP_400_BAD_REQUEST) # make sure there isn't a user with this email already existing_users = CustomUser.objects.filter(email=email_to_invite) if len(existing_users) > 0: return Response(status=status.HTTP_400_BAD_REQUEST) # create the object and send the email - user = CustomUser.objects.create(email=email_to_invite, username=email_to_invite, company=request.user.company) + user = CustomUser.objects.create( + email=email_to_invite, + username=email_to_invite, + company=request.user.company, + ) # send an email send_invite_email(user.slug, email_to_invite) - - return Response(status=status.HTTP_201_CREATED) -class SetUserPassword(APIView): - http_method_names = ['post','get'] +@csrf_exempt +def reset_password(request): + if request.method == "POST": + data = json.loads(request.body) + token = data.get('recaptchaToken') + payload = { + 'secret': settings.CAPTCHA_SECRET_KEY, + 'response': token, + } + response = requests.post('https://www.google.com/recaptcha/api/siteverify', data=payload) + result = response.json() + if result.get('success') and result.get('score') >= 0.5: + email = data.get('email') + user = CustomUser.objects.filter(email=email).first() + if user: + user.set_unusable_password() + user.save() + + # send the email + send_password_reset_email(user.slug, email) + JsonResponse(status=200) + + + JsonResponse(status=400) + +class ResetUserPassword(APIView): + http_method_names = [ + "post", + ] permission_classes = (permissions.AllowAny,) authentication_classes = () + + def post(self, request, format="json"): + """ + Send an email with a set password link to the set password page + Also disable the account + """ + print(f"Password reset for requests. {request.data}") + token = request.data.get('recaptchaToken') + payload = { + 'secret': settings.CAPTCHA_SECRET_KEY, + 'response': recaptchaToken, + } + response = requests.post('https://www.google.com/recaptcha/api/siteverify', data=payload) + result = response.json() + if result.get('success') and result.get('score') >= 0.5: + user = CustomUser.objects.filter(email=email).first() + if user: + user.set_unusable_password() + user.save() + + # send the email + send_password_reset_email(user.slug, email) + else: + print('Captcha secret failed') + + return Response(status=status.HTTP_200_OK) + + +class SetUserPassword(APIView): + http_method_names = ["post", "get"] + permission_classes = (permissions.AllowAny,) + authentication_classes = () + def get(self, request, slug): user = CustomUser.objects.get(slug=slug) - if user.last_login: + if user.has_usable_password(): return Response(status=status.HTTP_401_UNAUTHORIZED) else: return Response(status=status.HTTP_200_OK) + def post(self, request, slug, format="json"): user = CustomUser.objects.get(slug=slug) - user.set_password(request.data['password']) + user.set_password(request.data["password"]) user.save() return Response(status=status.HTTP_200_OK) - class CustomUserGet(APIView): - http_method_names = ['get', 'head', 'post'] + http_method_names = ["get", "head", "post"] + def get(self, request, format="json"): email = request.user.email @@ -154,16 +266,18 @@ class CustomUserGet(APIView): return Response(serializer.data, status=status.HTTP_200_OK) + class FeedbackView(APIView): - http_method_names = ['post','get'] + http_method_names = ["post", "get"] + def post(self, request, format="json"): serializer = FeedbackSerializer(data=request.data) print(request.data) if serializer.is_valid(): - + feedback_obj = serializer.save() feedback_obj.user = request.user - + feedback_obj.save() send_feedback_email(feedback_obj) return Response(serializer.data, status=status.HTTP_201_CREATED) @@ -177,14 +291,15 @@ class FeedbackView(APIView): return Response(serializer.data, status=status.HTTP_200_OK) - class AcknowledgeTermsOfService(APIView): - http_method_names = ['post'] + http_method_names = ["post"] + def post(self, request, format="json"): request.user.has_signed_tos = True request.user.save() return Response(status=status.HTTP_200_OK) + class CompanyUsersView(APIView): def get(self, request, format="json"): # TODO: make sure you are a manager of that company @@ -194,8 +309,7 @@ class CompanyUsersView(APIView): return Response(serializer.data, status=status.HTTP_200_OK) else: return Response(status=status.HTTP_401_UNAUTHORIZED) - - + def post(self, request, format="json"): if request.user.is_company_manager: user = CustomUser.objects.get(email=request.data.get("email")) @@ -215,7 +329,7 @@ class CompanyUsersView(APIView): return Response(status=status.HTTP_200_OK) return Response(status=status.HTTP_400_BAD_REQUEST) return Response(status=status.HTTP_401_UNAUTHORIZED) - + def delete(self, request, format="json"): if request.user.is_company_manager: user = CustomUser.objects.get(email=request.data.get("email")) @@ -224,6 +338,7 @@ class CompanyUsersView(APIView): return Response(status=status.HTTP_200_OK) return Response(status=status.HTTP_401_UNAUTHORIZED) + class AnnouncmentView(APIView): permission_classes = (permissions.AllowAny,) serializer_class = AnnouncmentSerializer @@ -259,7 +374,9 @@ def is_authenticated(request): class ConversationsView(APIView): def get(self, request, format="json"): order = "created" if request.user.conversation_order else "-created" - conversations = Conversation.objects.filter(user=request.user, deleted=False).order_by(order) + conversations = Conversation.objects.filter( + user=request.user, deleted=False + ).order_by(order) serializer = ConversationSerializer(conversations, many=True) return Response(serializer.data, status=status.HTTP_200_OK) @@ -283,7 +400,9 @@ class ConversationsView(APIView): # conversation.user_id = request.user.id # conversation.save() - return Response({"title": title, "id": conversation.id}, status=status.HTTP_201_CREATED) + return Response( + {"title": title, "id": conversation.id}, status=status.HTTP_201_CREATED + ) class ConversationPreferences(APIView): @@ -298,7 +417,6 @@ class ConversationPreferences(APIView): return Response({"order": user.conversation_order}, status=status.HTTP_200_OK) - class ConversationDetailView(APIView): def get(self, request, format="json"): conversation_id = request.query_params.get("conversation_id") @@ -306,9 +424,8 @@ class ConversationDetailView(APIView): serailzer = PromptSerializer(prompts, many=True) return Response(serailzer.data, status=status.HTTP_200_OK) - def post(self, request, format="json"): - print('In the post') + print("In the post") # Add the prompt to the database # make sure there is a conversation for it # if there is not a conversation create a title for it @@ -336,28 +453,30 @@ class ConversationDetailView(APIView): prompt_instance = serializer.save() # set up the streaming response if it is from the user - print(f'Do we have a valid user? {is_user}') + print(f"Do we have a valid user? {is_user}") if is_user: messages = [] - for prompt_obj in Prompt.objects.filter(conversation__id=conversation_id): - messages.append({ - 'content':prompt_obj.message, - 'role': 'user' if prompt_obj.user_created else 'assistant' - }) + for prompt_obj in Prompt.objects.filter( + conversation__id=conversation_id + ): + messages.append( + { + "content": prompt_obj.message, + "role": "user" if prompt_obj.user_created else "assistant", + } + ) channel_layer = get_channel_layer() - print(f'Sending to the channel: {CHANNEL_NAME}') + print(f"Sending to the channel: {CHANNEL_NAME}") async_to_sync(channel_layer.group_send)( - CHANNEL_NAME, { - 'type':'receive', - 'content': messages - } + CHANNEL_NAME, {"type": "receive", "content": messages} ) except: - print(f"Error trying to submit to conversation_id: {conversation_id} with request.data: {request.data}") + print( + f"Error trying to submit to conversation_id: {conversation_id} with request.data: {request.data}" + ) pass - return Response(status=status.HTTP_200_OK) def delete(self, request, format="json"): @@ -367,13 +486,16 @@ class ConversationDetailView(APIView): conversation.save() return Response(status=status.HTTP_202_ACCEPTED) + class UserPromptAnalytics(APIView): def get(self, request, format="json"): now = timezone.now() result = [] - number_of_months = 3 - company_user_ids = CustomUser.objects.filter(company=request.user.company).values_list('id', flat=True) + number_of_months = 3 + company_user_ids = CustomUser.objects.filter( + company=request.user.company + ).values_list("id", flat=True) for i in range(number_of_months): next_year = now.year next_month = now.month - i @@ -383,30 +505,51 @@ class UserPromptAnalytics(APIView): start_date = datetime.datetime(next_year, next_month, 1) end_date = last_day_of_month(start_date) - total_conversations = Conversation.objects.filter(created__gte=start_date, created__lte=end_date) - total_prompts = Prompt.objects.filter(conversation__id__in=total_conversations, created__gte=start_date, created__lte=end_date) + total_conversations = Conversation.objects.filter( + created__gte=start_date, created__lte=end_date + ) + total_prompts = Prompt.objects.filter( + conversation__id__in=total_conversations, + created__gte=start_date, + created__lte=end_date, + ) total_users = len(CustomUser.objects.all()) my_conversations = Conversation.objects.filter(user=request.user) - my_prompts = Prompt.objects.filter(conversation__in=my_conversations, created__gte=start_date, created__lte=end_date) - company_conversations = Conversation.objects.filter(user__id__in=company_user_ids) - company_prompts = Prompt.objects.filter(conversation__in=company_conversations, created__gte=start_date, created__lte=end_date) + my_prompts = Prompt.objects.filter( + conversation__in=my_conversations, + created__gte=start_date, + created__lte=end_date, + ) + company_conversations = Conversation.objects.filter( + user__id__in=company_user_ids + ) + company_prompts = Prompt.objects.filter( + conversation__in=company_conversations, + created__gte=start_date, + created__lte=end_date, + ) + + result.append( + { + "month": start_date.strftime("%B"), + "you": len(my_prompts), + "others": len(company_prompts) / len(company_user_ids), + "all": len(total_prompts) / total_users, + } + ) - result.append({ - "month":start_date.strftime("%B"), - "you": len(my_prompts), - "others": len(company_prompts)/len(company_user_ids), - "all":len(total_prompts)/total_users - }) - return Response(result[::-1], status=status.HTTP_200_OK) - + + class UserConversationAnalytics(APIView): def get(self, request, format="json"): now = timezone.now() result = [] - number_of_months = 3 - company_user_ids = CustomUser.objects.filter(company=request.user.company).values_list('id', flat=True) + number_of_months = 3 + company_user_ids = CustomUser.objects.filter( + company=request.user.company + ).values_list("id", flat=True) for i in range(number_of_months): next_year = now.year next_month = now.month - i @@ -416,28 +559,48 @@ class UserConversationAnalytics(APIView): start_date = datetime.datetime(next_year, next_month, 1) end_date = last_day_of_month(start_date) - total_conversations = len(Conversation.objects.filter(created__gte=start_date, created__lte=end_date)) + total_conversations = len( + Conversation.objects.filter( + created__gte=start_date, created__lte=end_date + ) + ) total_users = len(CustomUser.objects.all()) - company_conversations = len(Conversation.objects.filter(user__id__in=company_user_ids, created__gte=start_date, created__lte=end_date)) + company_conversations = len( + Conversation.objects.filter( + user__id__in=company_user_ids, + created__gte=start_date, + created__lte=end_date, + ) + ) - result.append({ - "month":start_date.strftime("%B"), - "you": len(Conversation.objects.filter(user=request.user, created__gte=start_date, created__lte=end_date)), - "others": company_conversations/len(company_user_ids), - "all":total_conversations/total_users - }) + result.append( + { + "month": start_date.strftime("%B"), + "you": len( + Conversation.objects.filter( + user=request.user, + created__gte=start_date, + created__lte=end_date, + ) + ), + "others": company_conversations / len(company_user_ids), + "all": total_conversations / total_users, + } + ) - return Response(result[::-1], status=status.HTTP_200_OK) + class CompanyUsageAnalytics(APIView): def get(self, request, format="json"): now = timezone.now() result = [] - number_of_months = 3 - company_user_ids = CustomUser.objects.filter(company=request.user.company).values_list('id', flat=True) - + number_of_months = 3 + company_user_ids = CustomUser.objects.filter( + company=request.user.company + ).values_list("id", flat=True) + for i in range(number_of_months): next_year = now.year next_month = now.month - i @@ -447,19 +610,28 @@ class CompanyUsageAnalytics(APIView): start_date = datetime.datetime(next_year, next_month, 1) end_date = last_day_of_month(start_date) - conversations = Conversation.objects.filter(user__id__in=company_user_ids, created__gte=start_date, created__lte=end_date) - - conversation_user_ids = conversations.values_list("user__id", flat=True).distinct() - result.append({ - "month":start_date.strftime("%B"), - "used":len(conversation_user_ids), - "not_used":len(company_user_ids) - len(conversation_user_ids) - }) + conversations = Conversation.objects.filter( + user__id__in=company_user_ids, + created__gte=start_date, + created__lte=end_date, + ) + + conversation_user_ids = conversations.values_list( + "user__id", flat=True + ).distinct() + result.append( + { + "month": start_date.strftime("%B"), + "used": len(conversation_user_ids), + "not_used": len(company_user_ids) - len(conversation_user_ids), + } + ) return Response(result[::-1], status=status.HTTP_200_OK) - + + class AdminAnalytics(APIView): def get(self, request, format="json"): - number_of_months = 3 + number_of_months = 3 result = [] now = timezone.now() @@ -472,37 +644,43 @@ class AdminAnalytics(APIView): start_date = datetime.datetime(next_year, next_month, 1) end_date = last_day_of_month(start_date) - durations = [item.get_duration() for item in PromptMetric.objects.filter(created__gte=start_date, created__lte=end_date)] + durations = [ + item.get_duration() + for item in PromptMetric.objects.filter( + created__gte=start_date, created__lte=end_date + ) + ] if len(durations) == 0: - result.append({ - "month":start_date.strftime("%B"), - "range":[0,0], - "avg": 0, - "median":0, - }) + result.append( + { + "month": start_date.strftime("%B"), + "range": [0, 0], + "avg": 0, + "median": 0, + } + ) continue - - average = sum(durations)/len(durations) + + average = sum(durations) / len(durations) min_value = min(durations) max_value = max(durations) durations.sort() - median = durations[len(durations)//2] - result.append({ - "month":start_date.strftime("%B"), - "range":[min_value,max_value], - "avg": average, - "median":median, - }) + median = durations[len(durations) // 2] + result.append( + { + "month": start_date.strftime("%B"), + "range": [min_value, max_value], + "avg": average, + "median": median, + } + ) - - - return Response(result[::-1], status=status.HTTP_200_OK) -prompt = ChatPromptTemplate.from_messages([ - ("system", "You are a helpful assistant."), - ("user", "{input}") -]) + +prompt = ChatPromptTemplate.from_messages( + [("system", "You are a helpful assistant."), ("user", "{input}")] +) llm = OllamaLLM(model=MODEL_NAME) @@ -510,19 +688,11 @@ llm = OllamaLLM(model=MODEL_NAME) # # Chain # chain = prompt | llm.with_config({"run_name": "model"}) | output_parser.with_config({"run_name": "Assistant"}) -@database_sync_to_async -def create_conversation(prompt, email): - # return the conversation id - - response = llm.invoke("Summarise the phrase in one to for words\"%s\"" % prompt) - print(f"Response: {response}") - print(dir(response)) - title = response.replace("\"","") - title = " ".join(title.split(" ")[:4]) - - - conversation = Conversation.objects.create(title = title) +@database_sync_to_async +def create_conversation(prompt, email, title): + # return the conversation id + conversation = Conversation.objects.create(title=title) conversation.save() user = CustomUser.objects.get(email=email) @@ -530,6 +700,12 @@ def create_conversation(prompt, email): conversation.save() return conversation.id + +@database_sync_to_async +def get_workspace(conversation_id): + conversation = Conversation.objects.get(id=conversation_id) + return DocumentWorkspace.objects.get(company=conversation.user.company) + @database_sync_to_async def get_messages(conversation_id, prompt, file_string: str = None, file_type: str = ""): messages = [] @@ -542,7 +718,7 @@ def get_messages(conversation_id, prompt, file_string: str = None, file_type: st data={ "message": prompt, "user_created": True, - "created": datetime.now(), + "created": timezone.now(), } ) if serializer.is_valid(raise_exception=True): @@ -550,42 +726,46 @@ def get_messages(conversation_id, prompt, file_string: str = None, file_type: st prompt_instance.conversation_id = conversation.id prompt_instance.save() if file_string: - file_name = f"prompt_{prompt_instance.id}_data.{file_type}" + file_name = f"prompt_{prompt_instance.id}_data.{file_type}" f = ContentFile(file_string, name=file_name) prompt_instance.file.save(file_name, f) prompt_instance.file_type = file_type prompt_instance.save() - - for prompt_obj in Prompt.objects.filter(conversation__id=conversation_id): - messages.append({ - 'content': prompt_obj.message, - 'role': 'user' if prompt_obj.user_created else 'assistant', - 'has_file': prompt_obj.file_exists(), - 'file': prompt_obj.file if prompt_obj.file_exists() else None, - 'file_type': prompt_obj.file_type if prompt_obj.file_exists() else None, - }) + messages.append( + { + "content": prompt_obj.message, + "role": "user" if prompt_obj.user_created else "assistant", + "has_file": prompt_obj.file_exists(), + "file": prompt_obj.file if prompt_obj.file_exists() else None, + "file_type": prompt_obj.file_type if prompt_obj.file_exists() else None, + } + ) # now transform the messages transformed_messages = [] for message in messages: - - if message['has_file'] and message['file_type'] != None: - if 'csv' in message['file_type']: - file_type = 'csv' + + if message["has_file"] and message["file_type"] != None: + if "csv" in message["file_type"]: + file_type = "csv" altered_message = f"{message['content']}\n The file type is csv and the file contents are: {message['file'].read()}" - elif 'xlsx' in message['file_type']: - file_type = 'xlsx' - df = pd.read_excel(message['file'].read()) + elif "xlsx" in message["file_type"]: + file_type = "xlsx" + df = pd.read_excel(message["file"].read()) altered_message = f"{message['content']}\n The file type is xlsx and the file contents are: {df}" - elif 'txt' in message['file_type']: - file_type = 'txt' + elif "txt" in message["file_type"]: + file_type = "txt" altered_message = f"{message['content']}\n The file type is csv and the file contents are: {message['file'].read()}" else: - altered_message = message['content'] - - transformed_message = SystemMessage(content=altered_message) if message['role'] == 'assistant' else HumanMessage(content=altered_message) + altered_message = message["content"] + + transformed_message = ( + SystemMessage(content=altered_message) + if message["role"] == "assistant" + else HumanMessage(content=altered_message) + ) transformed_messages.append(transformed_message) return transformed_messages, prompt_instance @@ -600,7 +780,7 @@ def save_generated_message(conversation_id, message): data={ "message": message, "user_created": False, - "created": datetime.now(), + "created": timezone.now(), } ) if serializer.is_valid(): @@ -608,34 +788,52 @@ def save_generated_message(conversation_id, message): prompt_instance.conversation_id = conversation.id prompt_instance = serializer.save() + @database_sync_to_async -def create_prompt_metric(prompt_id, prompt, has_file, file_type, model_name, conversation_id): +def create_prompt_metric( + prompt_id, prompt, has_file, file_type, model_name, conversation_id +): prompt_metric = PromptMetric.objects.create( prompt_id=prompt_id, - start_time = timezone.now(), - prompt_length = len(prompt), - has_file = has_file, - file_type = file_type, + start_time=timezone.now(), + prompt_length=len(prompt), + has_file=has_file, + file_type=file_type, model_name=model_name, conversation_id=conversation_id, ) prompt_metric.save() return prompt_metric + @database_sync_to_async def update_prompt_metric(prompt_metric, status): prompt_metric.event = status prompt_metric.save() + @database_sync_to_async def finish_prompt_metric(prompt_metric, response_length): - print(f'finish_prompt_metric: {response_length}') + print(f"finish_prompt_metric: {response_length}") prompt_metric.end_time = timezone.now() prompt_metric.reponse_length = response_length - prompt_metric.event = 'FINISHED' - prompt_metric.save(update_fields=["end_time", "reponse_length","event"]) + prompt_metric.event = "FINISHED" + prompt_metric.save(update_fields=["end_time", "reponse_length", "event"]) print("finish_prompt_metric saved") +@database_sync_to_async +def get_retriever(conversation_id): + print(f'getting workspace from conversation: {conversation_id}') + conversation = Conversation.objects.get(id=conversation_id) + print(f'Got conversation: {conversation}') + workspace = DocumentWorkspace.objects.get(company=conversation.user.company) + print(f'Got workspace: {conversation}') + vectorstore = Chroma( + persist_directory=f"./chroma_db/", + embedding=OllamaEmbeddings(model="llama3.2"), + ) + return vectorstore.as_retriever() + class ChatConsumerAgain(AsyncWebsocketConsumer): async def connect(self): await self.accept() @@ -648,47 +846,74 @@ class ChatConsumerAgain(AsyncWebsocketConsumer): print(f"Bytes Data: {bytes_data}") if text_data: data = json.loads(text_data) - message = data.get('message',None) - conversation_id = data.get('conversation_id',None) + message = data.get("message", None) + conversation_id = data.get("conversation_id", None) email = data.get("email", None) file = data.get("file", None) file_type = data.get("fileType", "") + model = data.get("modelName", "Turbo") if not conversation_id: # we need to create a new conversation # we will generate a name for it too - conversation_id = await create_conversation(message, email) - + title = await title_generator.generate_async(message) + conversation_id = await create_conversation(message, email, title) + if conversation_id: decoded_file = None if file: decoded_file = base64.b64decode(file) print(decoded_file) - if 'csv' in file_type: - file_type = 'csv' + if "csv" in file_type: + file_type = "csv" altered_message = f"{message}\n The file type is csv and the file contents are: {decoded_file}" - elif 'xmlformats-officedocument' in file_type: - file_type = 'xlsx' + elif "xmlformats-officedocument" in file_type: + file_type = "xlsx" df = pd.read_excel(decoded_file) altered_message = f"{message}\n The file type is xlsx and the file contents are: {df}" - elif 'text' in file_type: - file_type = 'txt' + elif "text" in file_type: + file_type = "txt" altered_message = f"{message}\n The file type is txt and the file contents are: {decoded_file}" else: - file_type = 'Not Sure' - - - - + file_type = "Not Sure" + print(f'received: "{message}" for conversation {conversation_id}') + + # check the moderation here + if await moderation_classifier.classify_async(message) == ModerationLabel.NSFW: + response = "Prompt has been marked as NSFW. If this is in error, submit a feedback with the prompt text." + print("this prompt has been marked as NSFW") + await self.send("CONVERSATION_ID") + await self.send(str(conversation_id)) + await self.send("START_OF_THE_STREAM_ENDER_GAME_42") + await self.send(response) + await self.send("END_OF_THE_STREAM_ENDER_GAME_42") + await save_generated_message(conversation_id, response) + return + # TODO: add the message to the database # get the new conversation # TODO: get the messages here + + messages, prompt = await get_messages( + conversation_id, message, decoded_file, file_type + ) + + prompt_type = await prompt_classifier.classify_async(message) + print(f"prompt_type: {prompt_type} for {message}") + + - messages, prompt = await get_messages(conversation_id, message, decoded_file, file_type) - - prompt_metric = await create_prompt_metric(prompt.id, prompt.message, True if file else False, file_type, MODEL_NAME, conversation_id) + + prompt_metric = await create_prompt_metric( + prompt.id, + prompt.message, + True if file else False, + file_type, + MODEL_NAME, + conversation_id, + ) if file: # udpate with the altered_message messages = messages[:-1] + [HumanMessage(content=altered_message)] @@ -698,17 +923,117 @@ class ChatConsumerAgain(AsyncWebsocketConsumer): # stream the response back response = "" # start of the message - await self.send('CONVERSATION_ID') + await self.send("CONVERSATION_ID") await self.send(str(conversation_id)) - await self.send('START_OF_THE_STREAM_ENDER_GAME_42') - async for chunk in llm.astream(messages): - print(f"chunk: {chunk}") - response += chunk - await self.send(chunk) - await self.send('END_OF_THE_STREAM_ENDER_GAME_42') + await self.send("START_OF_THE_STREAM_ENDER_GAME_42") + if prompt_type == PromptType.RAG: + service = AsyncRAGService() + #await service.ingest_documents() + workspace = await get_workspace(conversation_id) + print('Time to get the rag response') + + async for chunk in service.generate_response(messages, prompt.message, workspace): + print(f"chunk: {chunk}") + response += chunk + await self.send(chunk) + elif prompt_type == PromptType.IMAGE_GENERATION: + response = "Image Generation is not supported at this time, but it will be soon." + await self.send(response) + + else: + service = AsyncLLMService() + async for chunk in service.generate_response(messages, prompt.message): + print(f"chunk: {chunk}") + response += chunk + await self.send(chunk) + await self.send("END_OF_THE_STREAM_ENDER_GAME_42") await save_generated_message(conversation_id, response) await finish_prompt_metric(prompt_metric, len(response)) - + if bytes_data: print("we have byte data") + +# Document Views +class DocumentWorkspaceView(APIView): + #permission_classes = [permissions.IsAuthenticated] + + def get(self, request): + workspaces = DocumentWorkspace.objects.filter(company=request.user.company) + serializer = DocumentWorkspaceSerializer(workspaces, many=True) + return Response(serializer.data) + + def post(self, request): + serializer = DocumentWorkspaceSerializer(data=request.data) + if serializer.is_valid(): + serializer.save(company=request.user.company) + return Response(serializer.data, status=status.HTTP_201_CREATED) + return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST) + +class DocumentUploadView(APIView): + #permission_classes = [permissions.IsAuthenticated]Z + + def get(self, request): + print(f'request_3: {request}') + try: + workspace = DocumentWorkspace.objects.get(company=request.user.company) + serializer = DocumentSerializer(Document.objects.filter(workspace=workspace), many=True) + return Response(serializer.data, status=status.HTTP_200_OK) + + except: + return Response({'error': "Workspace not found"}, status=status.HTTP_404_NOT_FOUND) + + def post(self, request): + print(f'request: {request}') + + try: + workspace = DocumentWorkspace.objects.get(company=request.user.company) + + except: + return Response({'error': "Workspace not found"}, status=status.HTTP_404_NOT_FOUND) + + print(request.FILES) + file = request.FILES.get('file') + if not file: + return Response({"error":"No file provided"}, status=status.HTTP_400_BAD_REQUEST) + + print("have the workspace and the file") + + document = Document.objects.create( + workspace=workspace, + file=file + ) + + # process the document inthe background + self.process_document(document) + + serializer = DocumentSerializer(document) + return Response(serializer.data, status=status.HTTP_201_CREATED) + + + def process_document(self, document): + file_path = os.path.join(settings.MEDIA_ROOT, document.file.name) + + document.processed = True + document.active = True + document.save() + service = AsyncRAGService() + service.add_files_to_store([(file_path, document.file.name, document.workspace_id)], workspace_id=document.workspace_id) + +class DocumentDetailView(APIView): + #permission_classes = [permissions.IsAuthenticated] + + def get(self, request, document_id): + print(f'request: {request}') + try: + workspace = DocumentWorkspace.objects.get(company=request.user.company) + + document = Document.objects.get( + workspace=workspace, + id=document_id + ) + except: + return Response({'error': "Document not found"}, status=status.HTTP_404_NOT_FOUND) + + serializer = DocumentWorkspaceSerializer(workspaces, many=True) + return Response(serializer.data) \ No newline at end of file diff --git a/llm_be/llm_be/asgi.py b/llm_be/llm_be/asgi.py index f8be107..cc9b0a1 100644 --- a/llm_be/llm_be/asgi.py +++ b/llm_be/llm_be/asgi.py @@ -13,13 +13,14 @@ from django.core.asgi import get_asgi_application from channels.routing import ProtocolTypeRouter, URLRouter from channels.auth import AuthMiddlewareStack import chat_backend.routing -os.environ.setdefault('DJANGO_SETTINGS_MODULE', 'llm_be.settings') -application = ProtocolTypeRouter({ - "http": get_asgi_application(), - "websocket": AuthMiddlewareStack( - URLRouter( - chat_backend.routing.websocket_urlpatterns - ) - ), - }) +os.environ.setdefault("DJANGO_SETTINGS_MODULE", "llm_be.settings") + +application = ProtocolTypeRouter( + { + "http": get_asgi_application(), + "websocket": AuthMiddlewareStack( + URLRouter(chat_backend.routing.websocket_urlpatterns) + ), + } +) diff --git a/llm_be/llm_be/settings.py b/llm_be/llm_be/settings.py index ab0118b..47b6d6d 100644 --- a/llm_be/llm_be/settings.py +++ b/llm_be/llm_be/settings.py @@ -22,80 +22,86 @@ BASE_DIR = Path(__file__).resolve().parent.parent # See https://docs.djangoproject.com/en/3.2/howto/deployment/checklist/ # SECURITY WARNING: keep the secret key used in production secret! -SECRET_KEY = 'django-insecure-6suk6fj5q2)1tj%)f(wgw1smnliv5-#&@zvgvj1wp#(#@h#31x' +SECRET_KEY = "django-insecure-6suk6fj5q2)1tj%)f(wgw1smnliv5-#&@zvgvj1wp#(#@h#31x" # SECURITY WARNING: don't run with debug turned on in production! DEBUG = True - -ALLOWED_HOSTS = ['*.aimloperations.com','localhost','127.0.0.1','chat.aimloperations.com','chatbackend.aimloperations.com'] +CORS_ALLOW_CREDENTIALS = False +ALLOWED_HOSTS = [ + "*.aimloperations.com", + "localhost", + "127.0.0.1", + "localhost:3000", + "127.0.0.1:3000", + "chat.aimloperations.com", + "chatbackend.aimloperations.com", +] CORS_ORIGIN_ALLOW_ALL = True +CSRF_TRUSTED_ORIGINS = ["http://localhost", "http://127.0.0.1", "http://localhost:3000"] # Application definition INSTALLED_APPS = [ - 'daphne', - 'django.contrib.admin', - 'django.contrib.auth', - 'django.contrib.contenttypes', - 'django.contrib.sessions', - 'django.contrib.messages', - 'django.contrib.staticfiles', - 'chat_backend', - 'rest_framework', - 'corsheaders', - 'rest_framework_simplejwt.token_blacklist', - + "daphne", + "django.contrib.admin", + "django.contrib.auth", + "django.contrib.contenttypes", + "django.contrib.sessions", + "django.contrib.messages", + "django.contrib.staticfiles", + "chat_backend", + "rest_framework", + "corsheaders", + "rest_framework_simplejwt.token_blacklist", ] MIDDLEWARE = [ - 'django.middleware.security.SecurityMiddleware', - 'django.contrib.sessions.middleware.SessionMiddleware', - 'django.middleware.common.CommonMiddleware', - 'django.middleware.csrf.CsrfViewMiddleware', - 'django.contrib.auth.middleware.AuthenticationMiddleware', - 'django.contrib.messages.middleware.MessageMiddleware', - 'django.middleware.clickjacking.XFrameOptionsMiddleware', + "django.middleware.security.SecurityMiddleware", + "django.contrib.sessions.middleware.SessionMiddleware", + "django.middleware.common.CommonMiddleware", + "django.middleware.csrf.CsrfViewMiddleware", + "django.contrib.auth.middleware.AuthenticationMiddleware", + "django.contrib.messages.middleware.MessageMiddleware", + "django.middleware.clickjacking.XFrameOptionsMiddleware", "corsheaders.middleware.CorsMiddleware", "django.middleware.common.CommonMiddleware", ] -ROOT_URLCONF = 'llm_be.urls' +ROOT_URLCONF = "llm_be.urls" # SETTINGS_PATH = os.path.dirname(os.path.dirname(__file__)) # TEMPLATE_DIRS = ( # os.path.join(SETTINGS_PATH, 'templates'), # ) -print(os.path.join(BASE_DIR, 'templates')) - TEMPLATES = [ { - 'BACKEND': 'django.template.backends.django.DjangoTemplates', - 'DIRS': [os.path.join(BASE_DIR, 'templates')], - 'APP_DIRS': True, - 'OPTIONS': { - 'context_processors': [ - 'django.template.context_processors.debug', - 'django.template.context_processors.request', - 'django.contrib.auth.context_processors.auth', - 'django.contrib.messages.context_processors.messages', + "BACKEND": "django.template.backends.django.DjangoTemplates", + "DIRS": [os.path.join(BASE_DIR, "templates")], + "APP_DIRS": True, + "OPTIONS": { + "context_processors": [ + "django.template.context_processors.debug", + "django.template.context_processors.request", + "django.contrib.auth.context_processors.auth", + "django.contrib.messages.context_processors.messages", ], }, }, ] -WSGI_APPLICATION = 'llm_be.wsgi.application' -ASGI_APPLICATION = 'llm_be.asgi.application' +WSGI_APPLICATION = "llm_be.wsgi.application" +ASGI_APPLICATION = "llm_be.asgi.application" # Database # https://docs.djangoproject.com/en/3.2/ref/settings/#databases DATABASES = { - 'default': { - 'ENGINE': 'django.db.backends.sqlite3', - 'NAME': BASE_DIR / 'db.sqlite3', + "default": { + "ENGINE": "django.db.backends.sqlite3", + "NAME": BASE_DIR / "db.sqlite3", } } @@ -105,28 +111,26 @@ DATABASES = { AUTH_PASSWORD_VALIDATORS = [ { - 'NAME': 'django.contrib.auth.password_validation.UserAttributeSimilarityValidator', + "NAME": "django.contrib.auth.password_validation.UserAttributeSimilarityValidator", }, { - 'NAME': 'django.contrib.auth.password_validation.MinimumLengthValidator', + "NAME": "django.contrib.auth.password_validation.MinimumLengthValidator", }, { - 'NAME': 'django.contrib.auth.password_validation.CommonPasswordValidator', + "NAME": "django.contrib.auth.password_validation.CommonPasswordValidator", }, { - 'NAME': 'django.contrib.auth.password_validation.NumericPasswordValidator', + "NAME": "django.contrib.auth.password_validation.NumericPasswordValidator", }, ] - - # Internationalization # https://docs.djangoproject.com/en/3.2/topics/i18n/ -LANGUAGE_CODE = 'en-us' +LANGUAGE_CODE = "en-us" -TIME_ZONE = 'UTC' +TIME_ZONE = "UTC" USE_I18N = True @@ -138,39 +142,37 @@ USE_TZ = True # Static files (CSS, JavaScript, Images) # https://docs.djangoproject.com/en/3.2/howto/static-files/ -STATIC_URL = '/static/' +STATIC_URL = "/static/" # Default primary key field type # https://docs.djangoproject.com/en/3.2/ref/settings/#default-auto-field -DEFAULT_AUTO_FIELD = 'django.db.models.BigAutoField' +DEFAULT_AUTO_FIELD = "django.db.models.BigAutoField" # custom user model -AUTH_USER_MODEL = 'chat_backend.CustomUser' +AUTH_USER_MODEL = "chat_backend.CustomUser" # rest framework jwt stuff REST_FRAMEWORK = { - 'DEFAULT_PERMISSION_CLASSES': ( - 'rest_framework.permissions.IsAuthenticated', - ), - 'DEFAULT_AUTHENTICATION_CLASSES': ( -'rest_framework_simplejwt.authentication.JWTAuthentication', - ), # + "DEFAULT_PERMISSION_CLASSES": ("rest_framework.permissions.IsAuthenticated",), + "DEFAULT_AUTHENTICATION_CLASSES": ( + "rest_framework_simplejwt.authentication.JWTAuthentication", + ), # } SIMPLE_JWT = { - 'ACCESS_TOKEN_LIFETIME':timedelta(hours=5), - 'REFRESH_TOKEN_LIFETIME':timedelta(days=14), - 'ROTATE_REFRESH_TOKENS':True, - 'BLACKLIST_AFTER_ROTATION':True, - 'ALGORITHM':"HS256", - "SIGNING_KEY":SECRET_KEY, - 'VERIFYING_KEY':None, - "AUTH_HEADER_TYPES":('JWT',), - 'USER_ID_FIELD':'id', - 'USER_ID_CLAIM':'user_id', - 'AUTH_TOKEN_CLASSES':('rest_framework_simplejwt.tokens.AccessToken',), - 'TOKEN_TYPE_CLAIM':'token_type', + "ACCESS_TOKEN_LIFETIME": timedelta(hours=24), + "REFRESH_TOKEN_LIFETIME": timedelta(days=14), + "ROTATE_REFRESH_TOKENS": True, + "BLACKLIST_AFTER_ROTATION": True, + "ALGORITHM": "HS256", + "SIGNING_KEY": SECRET_KEY, + "VERIFYING_KEY": None, + "AUTH_HEADER_TYPES": ("JWT",), + "USER_ID_FIELD": "id", + "USER_ID_CLAIM": "user_id", + "AUTH_TOKEN_CLASSES": ("rest_framework_simplejwt.tokens.AccessToken",), + "TOKEN_TYPE_CLAIM": "token_type", } # CORS settings @@ -181,8 +183,8 @@ CORS_ALLOWED_ORIGINS = [ # channel settings CHANNEL_LAYERS = { - 'default': { - 'BACKEND': 'channels.layers.InMemoryChannelLayer', + "default": { + "BACKEND": "channels.layers.InMemoryChannelLayer", }, } @@ -198,8 +200,11 @@ CHANNEL_LAYERS = { # EMAIL_TIMEOUT = os.getenv("APP_EMAIL_TIMEOUT", 60) # SMTP2GO -EMAIL_HOST = 'mail.smtp2go.com' -EMAIL_HOST_USER = 'info.aimloperations.com' -EMAIL_HOST_PASSWORD = 'ZDErIII2sipNNVMz' +EMAIL_HOST = "mail.smtp2go.com" +EMAIL_HOST_USER = "info.aimloperations.com" +EMAIL_HOST_PASSWORD = "ZDErIII2sipNNVMz" EMAIL_PORT = 2525 -EMAIL_USE_TLS = True \ No newline at end of file +EMAIL_USE_TLS = True + +# Captcha +CAPTCHA_SECRET_KEY = "6LfENu4qAAAAABdrj6JTviq-LfdPP5imhE-Os7h9" \ No newline at end of file diff --git a/llm_be/llm_be/urls.py b/llm_be/llm_be/urls.py index 16d76aa..ba910f3 100644 --- a/llm_be/llm_be/urls.py +++ b/llm_be/llm_be/urls.py @@ -13,12 +13,17 @@ Including another URLconf 1. Import the include() function: from django.urls import include, path 2. Add a URL to urlpatterns: path('blog/', include('blog.urls')) """ + from django.contrib import admin from django.urls import path, include from django.conf import settings from django.conf.urls.static import static -urlpatterns = [ - path('admin/', admin.site.urls), - path('api/', include('chat_backend.urls')), -] + static(settings.STATIC_URL, document_root=settings.STATIC_ROOT) + static(settings.MEDIA_URL, document_root=settings.MEDIA_ROOT) +urlpatterns = ( + [ + path("admin/", admin.site.urls), + path("api/", include("chat_backend.urls")), + ] + + static(settings.STATIC_URL, document_root=settings.STATIC_ROOT) + + static(settings.MEDIA_URL, document_root=settings.MEDIA_ROOT) +) diff --git a/llm_be/llm_be/wsgi.py b/llm_be/llm_be/wsgi.py index 1edc130..35b7dc6 100644 --- a/llm_be/llm_be/wsgi.py +++ b/llm_be/llm_be/wsgi.py @@ -11,6 +11,6 @@ import os from django.core.wsgi import get_wsgi_application -os.environ.setdefault('DJANGO_SETTINGS_MODULE', 'llm_be.settings') +os.environ.setdefault("DJANGO_SETTINGS_MODULE", "llm_be.settings") application = get_wsgi_application() diff --git a/llm_be/manage.py b/llm_be/manage.py index 96e59ea..05354a5 100755 --- a/llm_be/manage.py +++ b/llm_be/manage.py @@ -6,7 +6,7 @@ import sys def main(): """Run administrative tasks.""" - os.environ.setdefault('DJANGO_SETTINGS_MODULE', 'llm_be.settings') + os.environ.setdefault("DJANGO_SETTINGS_MODULE", "llm_be.settings") try: from django.core.management import execute_from_command_line except ImportError as exc: @@ -18,5 +18,5 @@ def main(): execute_from_command_line(sys.argv) -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/requirements.dev b/requirements.dev index 60d210d..0029b6c 100644 --- a/requirements.dev +++ b/requirements.dev @@ -30,16 +30,16 @@ djangorestframework-simplejwt==5.3.1 duckdb==1.1.3 et_xmlfile==2.0.0 exceptiongroup==1.2.2 -Faker==33.1.0 +Faker filelock==3.16.1 fonttools==4.55.3 frozenlist==1.5.0 fsspec==2024.12.0 greenlet==3.1.1 h11==0.14.0 -httpcore==1.0.7 -httpx==0.27.2 -httpx-sse==0.4.0 +httpcore +httpx +httpx-sse hyperlink==21.0.0 idna==3.10 importlib_resources==6.4.5 @@ -48,14 +48,14 @@ Jinja2==3.1.5 jiter==0.8.2 jsonpatch==1.33 jsonpointer==3.0.0 -kiwisolver==1.4.7 -langchain==0.3.13 -langchain-community==0.3.13 -langchain-core==0.3.28 -langchain-ollama==0.2.2 -langchain-openai==0.2.14 -langchain-text-splitters==0.3.4 -langsmith==0.2.7 +kiwisolver +langchain +langchain-community +langchain-core +langchain-ollama +langchain-openai +langchain-text-splitters +langsmith lxml==5.3.0 MarkupSafe==3.0.2 marshmallow==3.23.2 @@ -77,14 +77,14 @@ nvidia-cusparse-cu12==12.3.1.170 nvidia-nccl-cu12==2.21.5 nvidia-nvjitlink-cu12==12.4.127 nvidia-nvtx-cu12==12.4.127 -ollama==0.4.5 -ollama-python==0.1.2 -openai==1.58.1 +ollama +ollama-python +openai openpyxl==3.1.5 orjson==3.10.13 packaging==24.2 pandas==2.2.3 -pandasai==2.4.1 +pandasai pathspec==0.12.1 pillow==11.0.0 platformdirs==4.3.6 diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..a209ddd --- /dev/null +++ b/requirements.txt @@ -0,0 +1,208 @@ +aiofiles==24.1.0 +aiohappyeyeballs==2.6.1 +aiohttp==3.11.18 +aiosignal==1.3.2 +annotated-types==0.7.0 +anyio==4.8.0 +asgiref==3.8.1 +astor==0.8.1 +attrs==25.1.0 +autobahn==24.4.2 +Automat==24.8.1 +backoff==2.2.1 +bcrypt==4.3.0 +beautifulsoup4==4.13.4 +black==25.1.0 +build==1.2.2.post1 +cachetools==5.5.2 +certifi==2025.1.31 +cffi==1.17.1 +channels==4.2.0 +chardet==5.2.0 +charset-normalizer==3.4.1 +chroma-hnswlib==0.7.6 +chromadb==1.0.7 +click==8.1.8 +coloredlogs==15.0.1 +constantly==23.10.4 +contourpy==1.3.1 +cryptography==44.0.2 +cycler==0.12.1 +daphne==4.1.2 +dataclasses-json==0.6.7 +Deprecated==1.2.18 +distro==1.9.0 +Django==5.1.7 +django-autoslug==1.9.9 +django-cors-headers==4.7.0 +django-filter==25.1 +djangorestframework==3.15.2 +djangorestframework_simplejwt==5.5.0 +duckdb==1.2.1 +durationpy==0.9 +emoji==2.14.1 +eval_type_backport==0.2.2 +Faker==37.0.0 +fastapi==0.115.9 +filelock==3.17.0 +filetype==1.2.0 +flatbuffers==25.2.10 +fonttools==4.56.0 +frozenlist==1.6.0 +fsspec==2025.2.0 +google-auth==2.39.0 +googleapis-common-protos==1.70.0 +greenlet==3.1.1 +grpcio==1.71.0 +h11==0.14.0 +html5lib==1.1 +httpcore==1.0.7 +httptools==0.6.4 +httpx==0.28.1 +httpx-sse==0.4.0 +huggingface-hub==0.30.2 +humanfriendly==10.0 +hyperlink==21.0.0 +idna==3.10 +importlib_metadata==8.6.1 +importlib_resources==6.5.2 +incremental==24.7.2 +Jinja2==3.1.6 +jiter==0.8.2 +joblib==1.4.2 +jsonpatch==1.33 +jsonpointer==3.0.0 +jsonschema==4.23.0 +jsonschema-specifications==2025.4.1 +kiwisolver==1.4.8 +kubernetes==32.0.1 +langchain==0.3.24 +langchain-community==0.3.23 +langchain-core==0.3.56 +langchain-ollama==0.2.3 +langchain-text-splitters==0.3.8 +langdetect==1.0.9 +langsmith==0.3.13 +lxml==5.4.0 +Markdown==3.7 +markdown-it-py==3.0.0 +MarkupSafe==3.0.2 +marshmallow==3.26.1 +matplotlib==3.10.1 +mdurl==0.1.2 +mmh3==5.1.0 +mpmath==1.3.0 +multidict==6.4.3 +mypy-extensions==1.0.0 +nest-asyncio==1.6.0 +networkx==3.4.2 +nltk==3.9.1 +numpy==2.2.3 +nvidia-cublas-cu12==12.4.5.8 +nvidia-cuda-cupti-cu12==12.4.127 +nvidia-cuda-nvrtc-cu12==12.4.127 +nvidia-cuda-runtime-cu12==12.4.127 +nvidia-cudnn-cu12==9.1.0.70 +nvidia-cufft-cu12==11.2.1.3 +nvidia-curand-cu12==10.3.5.147 +nvidia-cusolver-cu12==11.6.1.9 +nvidia-cusparse-cu12==12.3.1.170 +nvidia-cusparselt-cu12==0.6.2 +nvidia-nccl-cu12==2.21.5 +nvidia-nvjitlink-cu12==12.4.127 +nvidia-nvtx-cu12==12.4.127 +oauthlib==3.2.2 +olefile==0.47 +ollama==0.4.7 +onnxruntime==1.21.1 +openai==1.65.4 +opentelemetry-api==1.32.1 +opentelemetry-exporter-otlp-proto-common==1.32.1 +opentelemetry-exporter-otlp-proto-grpc==1.32.1 +opentelemetry-instrumentation==0.53b1 +opentelemetry-instrumentation-asgi==0.53b1 +opentelemetry-instrumentation-fastapi==0.53b1 +opentelemetry-proto==1.32.1 +opentelemetry-sdk==1.32.1 +opentelemetry-semantic-conventions==0.53b1 +opentelemetry-util-http==0.53b1 +orjson==3.10.15 +overrides==7.7.0 +packaging==24.2 +pandas==2.2.3 +pandasai==2.4.2 +pathspec==0.12.1 +pillow==11.1.0 +platformdirs==4.3.6 +posthog==4.0.1 +propcache==0.3.1 +protobuf==5.29.4 +psutil==7.0.0 +pyasn1==0.6.1 +pyasn1_modules==0.4.1 +pycparser==2.22 +pydantic==2.11.4 +pydantic-settings==2.9.1 +pydantic_core==2.33.2 +Pygments==2.19.1 +PyJWT==2.10.1 +pyOpenSSL==25.0.0 +pyparsing==3.2.1 +pypdf==5.4.0 +PyPika==0.48.9 +pyproject_hooks==1.2.0 +python-dateutil==2.9.0.post0 +python-dotenv==1.0.1 +python-iso639==2025.2.18 +python-magic==0.4.27 +python-oxmsg==0.0.2 +pytz==2025.1 +PyYAML==6.0.2 +RapidFuzz==3.13.0 +referencing==0.36.2 +regex==2024.11.6 +requests==2.32.3 +requests-oauthlib==2.0.0 +requests-toolbelt==1.0.0 +rich==14.0.0 +rpds-py==0.24.0 +rsa==4.9.1 +scipy==1.15.2 +service-identity==24.2.0 +setuptools==75.8.2 +shellingham==1.5.4 +six==1.17.0 +sniffio==1.3.1 +soupsieve==2.7 +SQLAlchemy==2.0.38 +sqlglot==26.9.0 +sqlglotrs==0.4.0 +sqlparse==0.5.3 +starlette==0.45.3 +sympy==1.13.1 +tenacity==9.0.0 +tokenizers==0.21.1 +torch==2.6.0 +tqdm==4.67.1 +triton==3.2.0 +Twisted==24.11.0 +txaio==23.1.1 +typer==0.15.3 +typing-inspect==0.9.0 +typing-inspection==0.4.0 +typing_extensions==4.12.2 +tzdata==2025.1 +unstructured==0.17.2 +unstructured-client==0.34.0 +urllib3==2.3.0 +uvicorn==0.34.2 +uvloop==0.21.0 +watchfiles==1.0.5 +webencodings==0.5.1 +websocket-client==1.8.0 +websockets==15.0.1 +wrapt==1.17.2 +yarl==1.20.0 +zipp==3.21.0 +zope.interface==7.2 +zstandard==0.23.0 diff --git a/strip_and_upgrade.py b/strip_and_upgrade.py new file mode 100644 index 0000000..a700b81 --- /dev/null +++ b/strip_and_upgrade.py @@ -0,0 +1,9 @@ +outfile = open("requirements.txt",'w') +for line in open('requirements.dev','r'): + line = line.strip() + if line: + values = line.split('==') + print(values[0]) + outfile.write(values[0] + '\n') + +