diff --git a/.gitignore b/.gitignore
index 0dbf2f2..b741129 100644
--- a/.gitignore
+++ b/.gitignore
@@ -167,4 +167,5 @@ cython_debug/
# and can be added to the global gitignore or merged into this file. For a more nuclear
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
#.idea/
-
+chroma_db/
+documents/
diff --git a/llm_be/chat_backend/admin.py b/llm_be/chat_backend/admin.py
index b985b0b..e7ae303 100644
--- a/llm_be/chat_backend/admin.py
+++ b/llm_be/chat_backend/admin.py
@@ -1,5 +1,16 @@
from django.contrib import admin
-from .models import CustomUser, Announcement, Company, LLMModels, Conversation, Prompt, Feedback, PromptMetric
+from .models import (
+ CustomUser,
+ Announcement,
+ Company,
+ LLMModels,
+ Conversation,
+ Prompt,
+ Feedback,
+ PromptMetric,
+ DocumentWorkspace,
+ Document
+)
# Register your models here.
@@ -27,16 +38,16 @@ class CustomUserAdmin(admin.ModelAdmin):
"has_signed_tos",
"last_login",
"slug",
- "get_set_password_url"
+ "get_set_password_url",
)
search_fields = ("fields", "username", "first_name", "last_name", "slug")
+
class FeedbackAdmin(admin.ModelAdmin):
model = Feedback
search_fields = ("status", "text", "get_user_email")
- list_display= (
- "status", "get_user_email", "title", "category"
- )
+ list_display = ("status", "get_user_email", "title", "category")
+
class LLMModelsAdmin(admin.ModelAdmin):
model = LLMModels
@@ -46,7 +57,7 @@ class LLMModelsAdmin(admin.ModelAdmin):
class ConversationAdmin(admin.ModelAdmin):
model = Conversation
- list_display = ("title", "get_user_email","deleted")
+ list_display = ("title", "get_user_email", "deleted")
search_fields = ("title",)
@@ -55,9 +66,35 @@ class PromptAdmin(admin.ModelAdmin):
list_display = ("message", "user_created", "get_conversation_title")
search_fields = ("message",)
+
class PromptMetricAdmin(admin.ModelAdmin):
model = PromptMetric
- list_display = ("event", "model_name", "prompt_length","reponse_length",'has_file','file_type', "get_duration")
+ list_display = (
+ "event",
+ "model_name",
+ "prompt_length",
+ "reponse_length",
+ "has_file",
+ "file_type",
+ "get_duration",
+ )
+
+class DocumentWorkspaceAdmin(admin.ModelAdmin):
+ model = DocumentWorkspace
+ list_display = (
+ "name",
+ "company",
+
+ )
+
+class DocumentAdmin(admin.ModelAdmin):
+ model = Document
+ list_display = (
+ "file",
+ "active",
+ "created",
+ "processed",
+ )
admin.site.register(Announcement, AnnouncmentAdmin)
@@ -69,3 +106,6 @@ admin.site.register(Conversation, ConversationAdmin)
admin.site.register(Prompt, PromptAdmin)
admin.site.register(PromptMetric, PromptMetricAdmin)
admin.site.register(Feedback, FeedbackAdmin)
+
+admin.site.register(DocumentWorkspace, DocumentWorkspaceAdmin)
+admin.site.register(Document, DocumentAdmin)
diff --git a/llm_be/chat_backend/apps.py b/llm_be/chat_backend/apps.py
index 24f2d5c..524ac53 100644
--- a/llm_be/chat_backend/apps.py
+++ b/llm_be/chat_backend/apps.py
@@ -1,6 +1,31 @@
from django.apps import AppConfig
+from django.conf import settings
+from django.db import OperationalError
class ChatBackendConfig(AppConfig):
default_auto_field = "django.db.models.BigAutoField"
name = "chat_backend"
+
+ def ready(self):
+ import chat_backend.signals
+ FORCE_RELOAD = False
+
+ if True: #not settings.TESTING: # Don't run during tests
+ try:
+ from .services.rag_services import AsyncRAGService
+ from chat_backend.models import Document
+
+ # Check if Chroma needs initialization
+ if Document.objects.exists():
+ rag_service = AsyncRAGService()
+
+ if rag_service.vector_store._collection.count() == 0:
+ print("Initializing ChromaDB with existing documents...")
+ rag_service.ingest_documents()
+ if FORCE_RELOAD:
+ print("Force Reload ChromaDB with existing documents...")
+ rag_service.clear_vector_store()
+ except OperationalError:
+ # Database tables might not exist yet during migration
+ pass
diff --git a/llm_be/chat_backend/client.py b/llm_be/chat_backend/client.py
index d550dbb..4425565 100644
--- a/llm_be/chat_backend/client.py
+++ b/llm_be/chat_backend/client.py
@@ -1,30 +1,32 @@
"""
llama client - Abstract this in the future
"""
+
import ollama
from typing import List, Dict
+
class LlamaClient(object):
- def __init__(self, model: str='llama3'):
+ def __init__(self, model: str = "llama3"):
self.client = ollama.Client(host="http://127.0.0.1:11434")
self.model = model
def check_if_model_exists(self) -> bool:
raise NotImplementedError
- def generate_conversation_title(self, message:str):
- response = self.generate_single_message("Summarise the phrase in one to for words\"%s\"" % message)
-
- raw_response = response['response'].replace("\"","")
+ def generate_conversation_title(self, message: str):
+ response = self.generate_single_message(
+ 'Summarise the phrase in one to for words"%s"' % message
+ )
+
+ raw_response = response["response"].replace('"', "")
return " ".join(raw_response.split()[:4])
def generate_single_message(self, message: str):
return ollama.generate(model=self.model, prompt=message)
def get_chat_response(self, messages: List[str]):
- return self.client.chat(model = self.model, messages=messages, stream=False)
-
-
+ return self.client.chat(model=self.model, messages=messages, stream=False)
+
def get_streamed_chat_response(self, messages: List[str]):
- return self.client.chat(model = self.model, messages=messages, stream=True)
-
+ return self.client.chat(model=self.model, messages=messages, stream=True)
diff --git a/llm_be/chat_backend/migrations/0020_documentworkspace_document.py b/llm_be/chat_backend/migrations/0020_documentworkspace_document.py
new file mode 100644
index 0000000..d255761
--- /dev/null
+++ b/llm_be/chat_backend/migrations/0020_documentworkspace_document.py
@@ -0,0 +1,78 @@
+# Generated by Django 5.1.7 on 2025-04-30 18:58
+
+import django.db.models.deletion
+import django.utils.timezone
+from django.db import migrations, models
+
+
+class Migration(migrations.Migration):
+
+ dependencies = [
+ ("chat_backend", "0019_customuser_conversation_order_and_more"),
+ ]
+
+ operations = [
+ migrations.CreateModel(
+ name="DocumentWorkspace",
+ fields=[
+ (
+ "id",
+ models.BigAutoField(
+ auto_created=True,
+ primary_key=True,
+ serialize=False,
+ verbose_name="ID",
+ ),
+ ),
+ ("created", models.DateTimeField(default=django.utils.timezone.now)),
+ (
+ "last_modified",
+ models.DateTimeField(default=django.utils.timezone.now),
+ ),
+ ("name", models.CharField(max_length=255)),
+ (
+ "company",
+ models.ForeignKey(
+ on_delete=django.db.models.deletion.CASCADE,
+ to="chat_backend.company",
+ ),
+ ),
+ ],
+ options={
+ "abstract": False,
+ },
+ ),
+ migrations.CreateModel(
+ name="Document",
+ fields=[
+ (
+ "id",
+ models.BigAutoField(
+ auto_created=True,
+ primary_key=True,
+ serialize=False,
+ verbose_name="ID",
+ ),
+ ),
+ ("created", models.DateTimeField(default=django.utils.timezone.now)),
+ (
+ "last_modified",
+ models.DateTimeField(default=django.utils.timezone.now),
+ ),
+ ("file", models.FileField(upload_to="documents/")),
+ ("uploaded_at", models.DateTimeField(auto_now_add=True)),
+ ("processed", models.BooleanField(default=False)),
+ ("active", models.BooleanField(default=False)),
+ (
+ "workspace",
+ models.ForeignKey(
+ on_delete=django.db.models.deletion.CASCADE,
+ to="chat_backend.documentworkspace",
+ ),
+ ),
+ ],
+ options={
+ "abstract": False,
+ },
+ ),
+ ]
diff --git a/llm_be/chat_backend/models.py b/llm_be/chat_backend/models.py
index 7505470..d2b5ef2 100644
--- a/llm_be/chat_backend/models.py
+++ b/llm_be/chat_backend/models.py
@@ -3,9 +3,11 @@ from django.contrib.auth.models import AbstractUser
from django.utils import timezone
from autoslug import AutoSlugField
from django.core.files.storage import FileSystemStorage
+
# Create your models here.
-FILE_STORAGE = FileSystemStorage(location='prompt_files')
+FILE_STORAGE = FileSystemStorage(location="prompt_files")
+
class TimeInfoBase(models.Model):
@@ -60,12 +62,18 @@ class CustomUser(AbstractUser):
help_text="Allows the edit/add/remove of users for a company", default=False
)
deleted = models.BooleanField(help_text="This is to hid accounts", default=False)
- has_signed_tos = models.BooleanField(default=False, help_text="If the user has signed the TOS")
- slug = AutoSlugField(populate_from='email')
- conversation_order = models.BooleanField(default=True, help_text='How the conversations should display')
+ has_signed_tos = models.BooleanField(
+ default=False, help_text="If the user has signed the TOS"
+ )
+ slug = AutoSlugField(populate_from="email")
+ conversation_order = models.BooleanField(
+ default=True, help_text="How the conversations should display"
+ )
+
def get_set_password_url(self):
return f"https://www.chat.aimloperations.com/set_password?slug={self.slug}"
+
FEEDBACK_CHOICE = (
("SUBMITTED", "Submitted"),
("RESOLVED", "Resolved"),
@@ -74,21 +82,26 @@ FEEDBACK_CHOICE = (
)
FEEDBACK_CATEGORIES = (
- ('NOT_DEFINED', 'Not defined'),
- ('BUG', 'Bug'),
- ('ENHANCEMENT', 'Enhancement'),
- ('OTHER', 'Other'),
- ('MAX_CATEGORIES', 'Max Categories'),
+ ("NOT_DEFINED", "Not defined"),
+ ("BUG", "Bug"),
+ ("ENHANCEMENT", "Enhancement"),
+ ("OTHER", "Other"),
+ ("MAX_CATEGORIES", "Max Categories"),
)
+
class Feedback(TimeInfoBase):
- title = models.TextField(max_length=64, default='')
+ title = models.TextField(max_length=64, default="")
user = models.ForeignKey(
CustomUser, on_delete=models.CASCADE, blank=True, null=True
)
text = models.TextField(max_length=512)
- status = models.CharField(max_length=24, choices=FEEDBACK_CHOICE, default="SUBMITTED")
- category = models.CharField(max_length=24, choices=FEEDBACK_CATEGORIES, default="NOT_DEFINED")
+ status = models.CharField(
+ max_length=24, choices=FEEDBACK_CHOICE, default="SUBMITTED"
+ )
+ category = models.CharField(
+ max_length=24, choices=FEEDBACK_CATEGORIES, default="NOT_DEFINED"
+ )
def get_user_email(self):
if self.user:
@@ -105,9 +118,8 @@ MONTH_CHOICES = (
("DECEMBER", "December"),
)
-month = models.CharField(max_length=9,
- choices=MONTH_CHOICES,
- default="JANUARY")
+month = models.CharField(max_length=9, choices=MONTH_CHOICES, default="JANUARY")
+
class Announcement(TimeInfoBase):
class Status(models.TextChoices):
@@ -131,7 +143,9 @@ class Conversation(TimeInfoBase):
title = models.CharField(
max_length=64, help_text="The title for the conversation", default=""
)
- deleted = models.BooleanField(help_text="This is to hide conversations", default=False)
+ deleted = models.BooleanField(
+ help_text="This is to hide conversations", default=False
+ )
def get_user_email(self):
if self.user:
@@ -151,20 +165,26 @@ class Prompt(TimeInfoBase):
conversation = models.ForeignKey(
"Conversation", on_delete=models.CASCADE, blank=True, null=True
)
- file =models.FileField(upload_to=FILE_STORAGE, blank=True, null=True, help_text="file for the prompt")
- file_type=models.CharField(max_length=16, blank=True, null=True, help_text='file type of the file for the prompt')
+ file = models.FileField(
+ upload_to=FILE_STORAGE, blank=True, null=True, help_text="file for the prompt"
+ )
+ file_type = models.CharField(
+ max_length=16,
+ blank=True,
+ null=True,
+ help_text="file type of the file for the prompt",
+ )
def get_conversation_title(self):
if self.conversation:
return self.conversation.title
else:
return ""
-
+
def file_exists(self):
return self.file != None and self.file.storage.exists(self.file.name)
-
class PromptMetric(TimeInfoBase):
PROMPT_METRIC_CHOICES = (
("CREATED", "Created"),
@@ -174,20 +194,40 @@ class PromptMetric(TimeInfoBase):
("MAX_PROMPT_METRIC_CHOICES", "Max Prompt Metric Choices"),
)
prompt_id = models.IntegerField(help_text="The id of the prompt this matches to")
- conversation_id = models.IntegerField(help_text="The id of the conversation this matches to")
+ conversation_id = models.IntegerField(
+ help_text="The id of the conversation this matches to"
+ )
event = models.CharField(
- max_length=26, choices=PROMPT_METRIC_CHOICES, default='CREATED'
+ max_length=26, choices=PROMPT_METRIC_CHOICES, default="CREATED"
)
model_name = models.CharField(max_length=215, help_text="The name of the model")
start_time = models.DateTimeField()
end_time = models.DateTimeField(blank=True, null=True)
- prompt_length = models.IntegerField( help_text="How many characters are in the prompt")
- reponse_length = models.IntegerField(blank=True, null=True, help_text="How many characters are in the response")
+ prompt_length = models.IntegerField(
+ help_text="How many characters are in the prompt"
+ )
+ reponse_length = models.IntegerField(
+ blank=True, null=True, help_text="How many characters are in the response"
+ )
has_file = models.BooleanField(help_text="Is there a file")
- file_type = models.CharField(max_length=16, help_text='The file type, if any', blank=True, null=True)
+ file_type = models.CharField(
+ max_length=16, help_text="The file type, if any", blank=True, null=True
+ )
def get_duration(self):
- if(self.start_time and self.end_time):
- difference =self.end_time - self.start_time
+ if self.start_time and self.end_time:
+ difference = self.end_time - self.start_time
return difference.seconds
return 0
+
+# Document Models
+class DocumentWorkspace(TimeInfoBase):
+ name = models.CharField(max_length=255)
+ company = models.ForeignKey(Company, on_delete=models.CASCADE)
+
+class Document(TimeInfoBase):
+ workspace = models.ForeignKey(DocumentWorkspace, on_delete=models.CASCADE)
+ file = models.FileField(upload_to='documents/')
+ uploaded_at = models.DateTimeField(auto_now_add=True)
+ processed = models.BooleanField(default=False)
+ active = models.BooleanField(default=False)
diff --git a/llm_be/chat_backend/renderers.py b/llm_be/chat_backend/renderers.py
index f611bb0..e7f30ae 100644
--- a/llm_be/chat_backend/renderers.py
+++ b/llm_be/chat_backend/renderers.py
@@ -1,8 +1,9 @@
from rest_framework.renderers import BaseRenderer
+
class ServerSentEventRenderer(BaseRenderer):
- media_type = 'text/event-stream'
- format = 'txt'
+ media_type = "text/event-stream"
+ format = "txt"
def render(self, data, accepted_media_type=None, renderer_context=None):
- return data
\ No newline at end of file
+ return data
diff --git a/llm_be/chat_backend/routing.py b/llm_be/chat_backend/routing.py
index 70829b5..dd6acd3 100644
--- a/llm_be/chat_backend/routing.py
+++ b/llm_be/chat_backend/routing.py
@@ -1,7 +1,6 @@
-from django.urls import re_path
+from django.urls import re_path
from .views import ChatConsumerAgain
-
-websocket_urlpatterns = [
- re_path(r'ws/chat_again/$', ChatConsumerAgain.as_asgi()),
-]
\ No newline at end of file
+websocket_urlpatterns = [
+ re_path(r"ws/chat_again/$", ChatConsumerAgain.as_asgi()),
+]
diff --git a/llm_be/chat_backend/serializers.py b/llm_be/chat_backend/serializers.py
index 07c0dcc..c4906ec 100644
--- a/llm_be/chat_backend/serializers.py
+++ b/llm_be/chat_backend/serializers.py
@@ -1,6 +1,16 @@
from rest_framework_simplejwt.serializers import TokenObtainPairSerializer
from rest_framework import serializers
-from .models import CustomUser, Announcement, Company, Conversation, Prompt, Feedback, FEEDBACK_CATEGORIES
+from .models import (
+ CustomUser,
+ Announcement,
+ Company,
+ Conversation,
+ Prompt,
+ Feedback,
+ FEEDBACK_CATEGORIES,
+ DocumentWorkspace,
+ Document
+)
class MyTokenObtainPairSerializer(TokenObtainPairSerializer):
@@ -25,11 +35,13 @@ class AnnouncmentSerializer(serializers.ModelSerializer):
model = Announcement
fields = "__all__"
+
class FeedbackSerializer(serializers.ModelSerializer):
class Meta:
model = Feedback
fields = "__all__"
+
class CustomUserSerializer(serializers.ModelSerializer):
email = serializers.EmailField(required=True)
username = serializers.CharField()
@@ -58,12 +70,40 @@ class ConversationSerializer(serializers.ModelSerializer):
class PromptSerializer(serializers.ModelSerializer):
-
+
class Meta:
model = Prompt
- fields = ("message", "user_created", "created", "id", )
+ fields = (
+ "message",
+ "user_created",
+ "created",
+ "id",
+ )
+
class BasicUserSerializer(serializers.ModelSerializer):
class Meta:
model = CustomUser
- fields = ("email", "first_name", "last_name", "is_active","has_usable_password","is_company_manager",'has_signed_tos')
\ No newline at end of file
+ fields = (
+ "email",
+ "first_name",
+ "last_name",
+ "is_active",
+ "has_usable_password",
+ "is_company_manager",
+ "has_signed_tos",
+ )
+
+
+# document serializers
+class DocumentWorkspaceSerializer(serializers.ModelSerializer):
+ class Meta:
+ model = DocumentWorkspace
+ fields = ['id', 'name', 'created']
+ read_only_fields = ['id', 'created']
+
+class DocumentSerializer(serializers.ModelSerializer):
+ class Meta:
+ model = Document
+ fields = ['id', 'workspace', 'file', 'uploaded_at', 'processed', 'created', 'active']
+ read_only_fields = ['id', 'uploaded_at', 'processed', 'created']
\ No newline at end of file
diff --git a/llm_be/chat_backend/services/__init__.py b/llm_be/chat_backend/services/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/llm_be/chat_backend/services/image_generation.py b/llm_be/chat_backend/services/image_generation.py
new file mode 100644
index 0000000..f6d95d6
--- /dev/null
+++ b/llm_be/chat_backend/services/image_generation.py
@@ -0,0 +1,145 @@
+import os
+import logging
+from typing import Optional, Tuple
+from PIL import Image
+import torch
+from diffusers import StableDiffusionPipeline, DPMSolverSinglestepScheduler
+
+logger = logging.getLogger(__name__)
+
+class ImageGenerationService:
+ """
+ Service for text-to-image generation using Stable Diffusion.
+ Uses singleton pattern to maintain loaded model in memory.
+ """
+
+ _instance = None
+ _model_loaded = False
+
+ def __new__(cls):
+ if cls._instance is None:
+ cls._instance = super().__new__(cls)
+ cls._instance._initialize()
+ return cls._instance
+
+ def _initialize(self):
+ """Initialize the service with default settings"""
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
+ self.model_id = "stabilityai/stable-diffusion-2-1"
+ self.pipeline = None
+ self.default_params = {
+ "num_inference_steps": 25,
+ "guidance_scale": 7.5,
+ "width": 512,
+ "height": 512,
+ }
+
+ def load_model(self):
+ """Load the Stable Diffusion model"""
+ if self._model_loaded:
+ return
+
+ try:
+ logger.info(f"Loading Stable Diffusion model on {self.device}...")
+
+ # Use DPMSolver for faster inference
+ self.pipeline = StableDiffusionPipeline.from_pretrained(
+ self.model_id,
+ torch_dtype=torch.float16 if self.device == "cuda" else torch.float32,
+ )
+ self.pipeline.scheduler = DPMSolverSinglestepScheduler.from_config(
+ self.pipeline.scheduler.config
+ )
+ self.pipeline = self.pipeline.to(self.device)
+
+ # Optimizations
+ if self.device == "cuda":
+ self.pipeline.enable_attention_slicing()
+ self.pipeline.enable_xformers_memory_efficient_attention()
+
+ self._model_loaded = True
+ logger.info("Model loaded successfully")
+
+ except Exception as e:
+ logger.error(f"Failed to load model: {str(e)}")
+ raise RuntimeError(f"Model loading failed: {str(e)}")
+
+ def generate_image(
+ self,
+ prompt: str,
+ negative_prompt: Optional[str] = None,
+ output_path: Optional[str] = None,
+ **kwargs
+ ) -> Tuple[Image.Image, dict]:
+ """
+ Generate image from text prompt.
+
+ Args:
+ prompt: Text prompt for image generation
+ negative_prompt: Text for things to avoid in generation
+ output_path: Optional path to save the image
+ **kwargs: Generation parameters (overrides defaults)
+
+ Returns:
+ Tuple of (PIL.Image, generation_parameters)
+ """
+ if not self._model_loaded:
+ self.load_model()
+
+ # Merge default params with overrides
+ params = {**self.default_params, **kwargs}
+
+ try:
+ logger.info(f"Generating image with prompt: {prompt[:50]}...")
+
+ with torch.inference_mode():
+ result = self.pipeline(
+ prompt=prompt,
+ negative_prompt=negative_prompt,
+ **params
+ )
+
+ image = result.images[0]
+
+ if output_path:
+ os.makedirs(os.path.dirname(output_path), exist_ok=True)
+ image.save(output_path)
+ logger.info(f"Image saved to {output_path}")
+
+ return image, params
+
+ except Exception as e:
+ logger.error(f"Image generation failed: {str(e)}")
+ raise RuntimeError(f"Image generation failed: {str(e)}")
+
+
+class AsyncImageGenerationService:
+ """
+ Asynchronous wrapper for image generation service.
+ Runs the synchronous service in a thread pool.
+ """
+
+ def __init__(self):
+ self.sync_service = ImageGenerationService()
+
+ async def generate_image(
+ self,
+ prompt: str,
+ negative_prompt: Optional[str] = None,
+ output_path: Optional[str] = None,
+ **kwargs
+ ) -> Tuple[Image.Image, dict]:
+ """Async version of generate_image"""
+ import asyncio
+ from functools import partial
+
+ loop = asyncio.get_event_loop()
+ func = partial(
+ self.sync_service.generate_image,
+ prompt=prompt,
+ negative_prompt=negative_prompt,
+ output_path=output_path,
+ **kwargs
+ )
+
+ return await loop.run_in_executor(None, func)
\ No newline at end of file
diff --git a/llm_be/chat_backend/services/llm_service.py b/llm_be/chat_backend/services/llm_service.py
new file mode 100644
index 0000000..1d24f08
--- /dev/null
+++ b/llm_be/chat_backend/services/llm_service.py
@@ -0,0 +1,138 @@
+from abc import ABC, abstractmethod
+from typing import AsyncGenerator, Generator, Optional
+
+from langchain_community.llms import Ollama
+from langchain_core.output_parsers import StrOutputParser
+from langchain_core.prompts import ChatPromptTemplate
+
+from chat_backend.models import Conversation, Prompt
+
+class LLMService(ABC):
+ """Abstract base class for LLM conversation services."""
+
+ def __init__(self):
+ self.llm = Ollama(
+ model="llama3.2",
+ temperature=0.7,
+ top_k=50,
+ top_p=0.9,
+ repeat_penalty=1.1,
+ num_ctx=4096
+ )
+ self.output_parser = StrOutputParser()
+
+ @abstractmethod
+ def generate_response(self, conversation: Conversation, query: str, **kwargs):
+ """Generate a response to a query within a conversation context."""
+ pass
+
+ def _format_history(self, conversation: Conversation) -> str:
+ """Format conversation history for the prompt."""
+ prompts = Prompt.objects.filter(conversation=conversation).order_by('created_at')
+ return "\n".join(
+ f"{'User' if prompt.is_user else 'AI'}: {prompt.text}"
+ for prompt in prompts
+ )
+
+
+class SyncLLMService(LLMService):
+ """Synchronous LLM conversation service."""
+
+ def __init__(self):
+ super().__init__()
+ self._setup_chain()
+
+ def _setup_chain(self):
+ """Setup the conversation chain."""
+ template = """Continue the conversation based on the following history:
+
+ {history}
+
+ Latest message: {query}
+
+ Response:"""
+ self.prompt = ChatPromptTemplate.from_template(template)
+
+ self.conversation_chain = (
+ {
+ "history": lambda x: self._format_history(x["conversation"]),
+ "query": lambda x: x["query"]
+ }
+ | self.prompt
+ | self.llm
+ | self.output_parser
+ )
+
+ def generate_response(self, conversation: Conversation, query: str, **kwargs) -> Generator[str, None, None]:
+ """Generate response with streaming support."""
+ chain_input = {
+ "query": query,
+ "conversation": conversation
+ }
+
+ for chunk in self.conversation_chain.stream(chain_input):
+ yield chunk
+
+
+class AsyncLLMService(LLMService):
+ """Asynchronous LLM conversation service."""
+
+ def __init__(self):
+ super().__init__()
+ self._setup_chain()
+
+ def _setup_chain(self):
+ """Setup the conversation chain."""
+ template = """Continue this conversation while maintaining context by providing a single helpful response.
+ Current context: {context}
+
+ Last 3 messages:
+ {recent_history}
+
+ Latest message: {query}
+
+ Instructions:
+ - Carefully maintain all established context
+ - If referencing previous elements (like stories), preserve all details
+ - When asked to modify something, identify what's being modified
+
+ Response:"""
+
+ self.prompt = ChatPromptTemplate.from_template(template)
+
+ self.conversation_chain = (
+ {
+ "context": lambda x: self._format_history(x["conversation"]),
+ "recent_history": lambda x: self._get_recent_messages(x["conversation"]),
+ "query": lambda x: x["query"]
+ }
+ | self.prompt
+ | self.llm
+ | self.output_parser
+ )
+
+ async def _format_history(self, conversation: Conversation) -> str:
+ """Async version of format conversation history."""
+ prompts = await Prompt.objects.filter(conversation=conversation).order_by('created_at').alist()
+ return "\n".join(
+ f"{'User' if prompt.is_user else 'AI'}: {prompt.text}"
+ for prompt in prompts
+ )
+
+ async def _get_recent_messages(self, conversation: Conversation) -> str:
+ """Async version of format conversation history."""
+ prompts = await Prompt.objects.filter(conversation=conversation).order_by('created_at').alist()[-3:]
+ return "\n".join(
+ f"{'User' if prompt.is_user else 'AI'}: {prompt.text}"
+ for prompt in prompts
+ )
+
+ async def generate_response(self, conversation: Conversation, query: str, **kwargs) -> AsyncGenerator[str, None]:
+ """Generate response with async streaming support."""
+ chain_input = {
+ "query": query,
+ "conversation": conversation
+ }
+
+ async for chunk in self.conversation_chain.astream(chain_input):
+ yield chunk
\ No newline at end of file
diff --git a/llm_be/chat_backend/services/moderation_classifier.py b/llm_be/chat_backend/services/moderation_classifier.py
new file mode 100644
index 0000000..d66968b
--- /dev/null
+++ b/llm_be/chat_backend/services/moderation_classifier.py
@@ -0,0 +1,79 @@
+from enum import Enum, auto
+from typing import Dict, Any
+from langchain_core.prompts import ChatPromptTemplate
+from langchain_community.llms import Ollama
+
+class ModerationLabel(Enum):
+ NSFW = auto()
+ FINE = auto()
+
+class ModerationClassifier:
+ """
+ Classifies prompts as NSFW or FINE (safe) content.
+ """
+
+ def __init__(self):
+ self.llm = Ollama(
+ model="llama3.2",
+ temperature=0.1, # Very low for strict moderation
+ top_k=10,
+ num_ctx=2048
+ )
+
+ self.moderation_prompt = ChatPromptTemplate.from_messages([
+ ("system", """You are a strict content moderator. Classify the following prompt as either NSFW or FINE.
+
+NSFW includes:
+- Sexual content
+- Violence/gore
+- Hate speech
+- Illegal activities
+- Harassment
+- Graphic/disturbing content
+
+FINE includes:
+- Safe for work topics
+- General conversation
+- Professional inquiries
+- Creative requests (non-explicit)
+- Technical questions
+
+Examples:
+- "How to make a bomb" → NSFW
+- "Write a love poem" → FINE
+- "Explicit sex scene" → NSFW
+- "Python tutorial" → FINE
+
+Return ONLY "NSFW" or "FINE", nothing else."""),
+ ("human", "{prompt}")
+ ])
+
+ self.chain = self.moderation_prompt | self.llm
+
+ async def classify_async(self, prompt: str) -> ModerationLabel:
+ """Asynchronous classification"""
+ try:
+ response = (await self.chain.ainvoke({"prompt": prompt})).strip().upper()
+ return self._parse_response(response)
+ except Exception as e:
+ print(f"Moderation error: {e}")
+ return ModerationLabel.NSFW # Fail-safe to NSFW
+
+ def classify(self, prompt: str) -> ModerationLabel:
+ """Synchronous classification"""
+ try:
+ response = self.chain.invoke({"prompt": prompt}).strip().upper()
+ return self._parse_response(response)
+ except Exception as e:
+ print(f"Moderation error: {e}")
+ return ModerationLabel.NSFW # Fail-safe to NSFW
+
+ def _parse_response(self, response: str) -> ModerationLabel:
+ """Convert string response to ModerationLabel enum"""
+ if "NSFW" in response:
+ return ModerationLabel.NSFW
+ return ModerationLabel.FINE # Default to FINE if unclear
+
+
+# Singleton instance
+moderation_classifier = ModerationClassifier()
\ No newline at end of file
diff --git a/llm_be/chat_backend/services/prompt_classifier.py b/llm_be/chat_backend/services/prompt_classifier.py
new file mode 100644
index 0000000..4548f46
--- /dev/null
+++ b/llm_be/chat_backend/services/prompt_classifier.py
@@ -0,0 +1,100 @@
+from enum import Enum, auto
+from typing import Dict, Any
+from langchain_core.prompts import ChatPromptTemplate
+from langchain_community.llms import Ollama
+
+class PromptType(Enum):
+ GENERAL_CHAT = auto()
+ RAG = auto()
+ IMAGE_GENERATION = auto()
+ UNKNOWN = auto()
+
+class PromptClassifier:
+ """
+ Classifies user prompts to determine which service should handle them.
+ """
+
+ def __init__(self):
+ self.llm = Ollama(
+ model="llama3",
+ temperature=0.3, # Lower temp for more deterministic classification
+ top_k=20,
+ top_p=0.9,
+ num_ctx=4096
+ )
+
+ self.classification_prompt = ChatPromptTemplate.from_messages([
+ ("system",
+ """You are a precision prompt classifier. Strictly categorize prompts into:
+1. GENERAL_CHAT - Casual conversation, personal questions, or non-specific inquiries
+2. RAG - ONLY when explicitly requesting document/search-based knowledge
+3. IMAGE_GENERATION - Specific requests to create/modify images
+4. UNKNOWN - If none of the above fit
+
+1. IMAGE_GENERATION - ONLY if:
+ - Explicitly contains: "generate/create/draw/make an image/picture/photo/art/illustration"
+ - Requests visual content creation
+ - Example: "Make a picture of a castle" → IMAGE_GENERATION
+
+2. RAG - ONLY if:
+ - Explicitly mentions documents/files/data
+ - Uses search terms: "find/search/lookup in [source]"
+ - Example: "What does contracts.pdf say?" → RAG
+
+3. GENERAL_CHAT - DEFAULT category when:
+ - Doesn't meet above criteria
+ - Conversational/general knowledge
+ - Uncertain cases
+ - Example: "Tell me a joke" → GENERAL_CHAT
+
+Examples:
+[Definitely RAG]
+- "What does the uploaded PDF say about quarterly results?"
+- "Search our documents for the 2023 marketing strategy"
+- "Find the contract clause about termination"
+
+[Definitely GENERAL_CHAT]
+- "How does photosynthesis work?" (General knowledge)
+- "Tell me a joke"
+- "What's your opinion on AI?"
+
+[Borderline → GENERAL_CHAT]
+- "What's our company policy on X?" (No doc reference → general)
+- "Explain quantum computing" (General knowledge)
+- "Summarize the meeting" (No doc reference)
+
+Return ONLY the label, no explanations."""),
+ ("human", "{prompt}")
+ ])
+
+ self.chain = self.classification_prompt | self.llm
+
+ async def classify_async(self, prompt: str) -> PromptType:
+ """Asynchronously classify the prompt"""
+ try:
+ response = await self.chain.ainvoke({"prompt": prompt})
+ return self._parse_response(response.strip())
+ except Exception as e:
+ print(f"Classification error: {e}")
+ return PromptType.UNKNOWN
+
+ def classify(self, prompt: str) -> PromptType:
+ """Synchronously classify the prompt"""
+ try:
+ response = self.chain.invoke({"prompt": prompt})
+ return self._parse_response(response.strip())
+ except Exception as e:
+ print(f"Classification error: {e}")
+ return PromptType.UNKNOWN
+
+ def _parse_response(self, response: str) -> PromptType:
+ """Convert string response to PromptType enum"""
+ response = response.upper()
+ for prompt_type in PromptType:
+ if prompt_type.name in response:
+ return prompt_type
+ return PromptType.UNKNOWN
+
+
+# Singleton instance for easy access
+prompt_classifier = PromptClassifier()
\ No newline at end of file
diff --git a/llm_be/chat_backend/services/rag_services.py b/llm_be/chat_backend/services/rag_services.py
new file mode 100644
index 0000000..4608822
--- /dev/null
+++ b/llm_be/chat_backend/services/rag_services.py
@@ -0,0 +1,378 @@
+import os
+from abc import ABC, abstractmethod
+from typing import List, Dict, Any, AsyncGenerator, Generator, Optional
+from channels.db import database_sync_to_async
+from langchain_community.embeddings import OllamaEmbeddings
+from langchain_community.llms import Ollama
+from langchain_community.vectorstores import Chroma
+from langchain_core.documents import Document as LangDocument
+from langchain_core.output_parsers import StrOutputParser
+from langchain_core.prompts import ChatPromptTemplate
+from langchain_core.runnables import RunnablePassthrough
+from langchain_text_splitters import RecursiveCharacterTextSplitter
+from langchain_community.document_loaders import (
+ PyPDFLoader,
+ Docx2txtLoader,
+ TextLoader,
+ UnstructuredFileLoader
+)
+from django.core.files.uploadedfile import UploadedFile
+from chat_backend.models import Conversation, Prompt, DocumentWorkspace, Document
+from pathlib import Path
+
+@database_sync_to_async
+def get_documents(workspace: DocumentWorkspace | None = None):
+ if workspace:
+ return [doc for doc in Document.objects.filter(workspace=workspace)]
+ else:
+ return [doc for doc in Document.objects.all()]
+
+
+
+class RAGService(ABC):
+ """Abstract base class for RAG services."""
+ _instance = None
+
+ def __new__(cls):
+ if cls._instance is None:
+ cls._instance = super().__new__(cls)
+ cls._instance.__init__()
+ return cls._instance
+
+ def __init__(self):
+ self.embedding_model = OllamaEmbeddings(model="llama3.2")
+ self.llm = Ollama(
+ model="llama3.2",
+ temperature=0.7,
+ top_k=50,
+ top_p=0.9,
+ repeat_penalty=1.1,
+ num_ctx=4096
+ )
+ self.text_splitter = RecursiveCharacterTextSplitter(
+ chunk_size=1000,
+ chunk_overlap=200
+ )
+ self.vector_store = self._initialize_vector_store()
+
+ # Supported file types and their loaders
+ self.loader_mapping = {
+ '.pdf': PyPDFLoader,
+ '.docx': Docx2txtLoader,
+ '.txt': TextLoader,
+ # Fallback for other file types
+ '*': UnstructuredFileLoader,
+ }
+
+ def _initialize_vector_store(self) -> Chroma:
+ """Initialize and return the Chroma vector store."""
+ persist_directory=f"./chroma_db/"
+ vector_store = Chroma(
+ embedding_function=self.embedding_model,
+ persist_directory=persist_directory
+ )
+ return vector_store
+
+ def clear_vector_store(self):
+ """Clear all vectors from the store"""
+ self.vector_store.delete_collection()
+ self.vector_store = self._initialize_vector_store()
+
+ def _prepare_documents(self, documents: List[Document]) -> List[Document]:
+ """Process documents for ingestion into vector store."""
+ docs = []
+
+ for doc in documents:
+ print(f"Processing: {doc.file.name}")
+ loader_class = self._get_file_loader( doc.file.name)
+ loader = loader_class(doc.file)
+
+
+ chunks = self._load_and_split_documents(doc.file.path)
+ if chunks:
+ self.vector_store.add_documents(chunks)
+ self.vector_store.persist()
+
+
+ def ingest_documents(self, workspace: DocumentWorkspace | None = None) -> None:
+ """Ingest documents from a workspace into the vector store."""
+ print(f"Getting the Document via the workspace: {workspace}")
+ if workspace:
+ documents = [doc for doc in Document.objects.filter(workspace=workspace)]
+ else:
+ documents = [doc for doc in Document.objects.all()]
+
+ print(f"Processing the documents : {documents}")
+ self._prepare_documents(documents)
+
+
+ @abstractmethod
+ def generate_response(self, conversation: Conversation, query: str, **kwargs):
+ """Generate a response using RAG."""
+ pass
+
+ @abstractmethod
+ def search_documents(self, query: str, workspace: Optional[DocumentWorkspace] = None, k: int = 4) -> List[Document]:
+ """Search relevant documents from the vector store."""
+ pass
+
+ def _get_file_loader(self, file_path: str):
+ """Get appropriate loader for file type"""
+ ext = Path(file_path).suffix.lower()
+ return self.loader_mapping.get(ext, self.loader_mapping['*'])
+
+ def _sanitize_filename(self, filename: str) -> str:
+ """Sanitize filename for safe storage"""
+ return re.sub(r'[^\w\-_. ]', '_', filename)
+
+ def _save_uploaded_file(self, uploaded_file: UploadedFile, save_dir: str) -> str:
+ """Save uploaded file to disk"""
+ os.makedirs(save_dir, exist_ok=True)
+ sanitized_name = self._sanitize_filename(uploaded_file.name)
+ file_path = os.path.join(save_dir, sanitized_name)
+
+ with open(file_path, 'wb+') as destination:
+ for chunk in uploaded_file.chunks():
+ destination.write(chunk)
+
+ return file_path
+
+ def _load_and_split_documents(self, file_path: str, metadata: dict = None) -> List[Document]:
+ """Load and split documents from file"""
+ loader_class = self._get_file_loader(file_path)
+ loader = loader_class(file_path)
+
+ docs = loader.load()
+ if metadata:
+ for doc in docs:
+ doc.metadata.update(metadata)
+
+ return self.text_splitter.split_documents(docs)
+
+ def add_files_to_store(
+ self,
+ file_tupls: List[UploadedFile], # (file_path, name,workspace_id)
+ workspace_id: str,
+ source: str = "upload",
+ save_dir: str = "data/uploads"
+ ) -> Dict[str, Any]:
+ """
+ Process and add uploaded files to vector store
+
+ Args:
+ files: List of Django UploadedFile objects
+ workspace_id: ID of the workspace these belong to
+ source: Source identifier for documents
+ save_dir: Directory to save uploaded files
+
+ Returns:
+ Dictionary with processing results
+ """
+ results = {
+ 'total_added': 0,
+ 'failed_files': [],
+ 'processed_files': []
+ }
+
+ for file_tuple in file_tupls:
+ try:
+ # Save file to disk
+
+
+ # Prepare metadata
+ metadata = {
+ 'source': file_tuple[1],
+ 'workspace_id': file_tuple[2],
+ 'original_filename': file_tuple[1],
+ 'file_path': file_tuple[0],
+ }
+
+ # Load and split documents
+ docs = self._load_and_split_documents(file_path, metadata)
+
+ # Add to vector store
+ if docs:
+ self.vector_store.add_documents(docs)
+ results['total_added'] += len(docs)
+ results['processed_files'].append({
+ 'filename': file_tuple[1],
+ 'document_count': len(docs)
+ })
+
+ except Exception as e:
+ results['failed_files'].append({
+ 'filename': file_tuple[1],
+ 'error': str(e)
+ })
+ continue
+
+ # Persist changes
+ self.vector_store.persist()
+ return results
+
+
+class SyncRAGService(RAGService):
+ """Synchronous RAG service implementation."""
+
+ def __init__(self):
+ super().__init__()
+ self._setup_chain()
+
+ def _setup_chain(self):
+ """Setup the RAG chain."""
+ template = """Answer the question based only on the following context:
+ {context}
+
+ Conversation history:
+ {history}
+
+ Question: {question}
+ """
+ self.prompt = ChatPromptTemplate.from_template(template)
+
+ self.rag_chain = (
+ {
+ "context": self._retriever_with_history,
+ "history": lambda x: self._format_history(x["conversation"]),
+ "question": lambda x: x["query"]
+ }
+ | self.prompt
+ | self.llm
+ | StrOutputParser()
+ )
+
+ def _format_history(self, conversation: Conversation) -> str:
+ """Format conversation history for the prompt."""
+ prompts = Prompt.objects.filter(conversation=conversation).order_by('created_at')
+ return "\n".join(
+ f"{'User' if prompt.is_user else 'AI'}: {prompt.text}"
+ for prompt in prompts
+ )
+
+ def _retriever_with_history(self, input_dict: Dict[str, Any]) -> str:
+ """Retrieve documents considering conversation history."""
+ query = input_dict["query"]
+ conversation = input_dict["conversation"]
+
+ # You could enhance this to consider historical context in retrieval
+ relevant_docs = self.search_documents(query, conversation.workspace)
+ if not relevant_docs:
+ print("didn't find any relevant docs")
+ return relevant_docs
+ else:
+ return relevant_docs
+
+
+ def search_documents(self, query: str, workspace: Optional[DocumentWorkspace] = None, k: int = 4) -> List[Document]:
+ """Search relevant documents from the vector store."""
+ filter_dict = {}
+ if workspace:
+ filter_dict["workspace_id"] = workspace.id
+ print(f"search_kwargs: {search_kwargs}")
+ retriever = self.vector_store.as_retriever(
+ search_type="similarity",
+ search_kwargs={
+ "k": k,
+ "filter": filter_dict if filter_dict else None
+ }
+ )
+ return retriever.get_relevant_documents(query)
+
+ def generate_response(self, conversation: Conversation, query: str, **kwargs) -> Generator[str, None, None]:
+ """Generate response with streaming support."""
+ chain_input = {
+ "query": query,
+ "conversation": conversation
+ }
+
+ for chunk in self.rag_chain.stream(chain_input):
+ yield chunk
+
+
+class AsyncRAGService(RAGService):
+ """Asynchronous RAG service implementation."""
+
+ def __init__(self):
+ super().__init__()
+ self._setup_chain()
+
+ def _setup_chain(self):
+ """Setup the RAG chain."""
+ template = """Answer the question based only on the following context:
+ {context}
+
+ Conversation history:
+ {history}
+
+ Question: {question}
+ """
+ self.prompt = ChatPromptTemplate.from_template(template)
+
+ self.rag_chain = (
+ {
+ "context": self._retriever_with_history,
+ "history": lambda x: self._format_history(x["conversation"]),
+ "question": lambda x: x["query"]
+ }
+ | self.prompt
+ | self.llm
+ | StrOutputParser()
+ )
+
+ async def _format_history(self, conversation: Conversation) -> str:
+ """Format conversation history for the prompt."""
+ prompts = await Prompt.objects.filter(conversation=conversation).order_by('created_at').alist()
+ print(f"prompts that we are seeding with are: {prompts}")
+ return "\n".join(
+ f"{'User' if prompt.is_user else 'AI'}: {prompt.text}"
+ for prompt in prompts
+ )
+
+ async def _retriever_with_history(self, input_dict: Dict[str, Any]) -> str:
+ """Retrieve documents considering conversation history."""
+ print(f"Retrieving history with input: {input_dict}")
+ query = input_dict["query"]
+ conversation = input_dict["conversation"]
+ workspace = input_dict["workspace"]
+
+ # You could enhance this to consider historical context in retrieval
+ docs= await self.search_documents(query, workspace)
+
+ if not docs:
+ print("Didn't find any relevant docs")
+
+ print("\n\n".join(doc.page_content for doc in docs))
+ return "\n\n".join(doc.page_content for doc in docs)
+
+
+ async def search_documents(self, query: str, workspace: Optional[DocumentWorkspace] = None, k: int = 4) -> List[Document]:
+ """Search relevant documents from the vector store."""
+ filter_dict = {}
+ print(f"Do we have a workspace: {workspace}")
+ if workspace:
+ filter_dict["workspace_id"] = workspace.id
+ search_kwargs={
+ "k": k,
+ "filter": filter_dict if filter_dict else None
+ }
+ print(f"search_kwargs: {search_kwargs}")
+
+ retriever = self.vector_store.as_retriever(
+ search_type="mmr",
+ search_kwargs={
+ "k": k,
+ "filter": filter_dict if filter_dict else None
+ }
+ )
+ return await retriever.aget_relevant_documents(query)
+
+ async def generate_response(self, conversation: Conversation, query: str, workspace: DocumentWorkspace, **kwargs) -> AsyncGenerator[str, None]:
+ """Generate response with streaming support."""
+ chain_input = {
+ "query": query,
+ "conversation": conversation,
+ "workspace": workspace,
+ }
+
+ async for chunk in self.rag_chain.astream(chain_input):
+ yield chunk
diff --git a/llm_be/chat_backend/services/tests.py b/llm_be/chat_backend/services/tests.py
new file mode 100644
index 0000000..8da5303
--- /dev/null
+++ b/llm_be/chat_backend/services/tests.py
@@ -0,0 +1,219 @@
+import os
+from unittest import TestCase, mock
+from unittest.mock import MagicMock, patch, AsyncMock
+from typing import List, Dict, Any
+
+from django.test import TestCase as DjangoTestCase
+
+from chat_backend.services.rag_services import RAGService, SyncRAGService, AsyncRAGService
+from chat_backend.models import Conversation, Prompt, DocumentWorkspace, Document
+
+class TestRAGService(TestCase):
+ def setUp(self):
+ self.rag_service = RAGService()
+ self.rag_service.vector_store = MagicMock()
+ self.rag_service.embedding_model = MagicMock()
+ self.rag_service.text_splitter = MagicMock()
+
+ def test_initialize_vector_store(self):
+ with patch('os.path.exists', return_value=False), \
+ patch('os.makedirs') as mock_makedirs, \
+ patch('langchain_community.vectorstores.Chroma') as mock_chroma:
+
+ # Reset the vector store to test initialization
+ self.rag_service.vector_store = None
+ result = self.rag_service._initialize_vector_store()
+
+ mock_makedirs.assert_called_once_with("chroma_db")
+ mock_chroma.assert_called_once_with(
+ embedding_function=self.rag_service.embedding_model,
+ persist_directory="chroma_db"
+ )
+ self.assertIsNotNone(result)
+
+ def test_prepare_documents(self):
+ mock_doc1 = MagicMock(spec=Document)
+ mock_doc1.content = "Test content"
+ mock_doc1.source = "test_source"
+ mock_doc1.workspace = MagicMock()
+ mock_doc1.workspace.id = 1
+ mock_doc1.id = 1
+
+ self.rag_service.text_splitter.split_text.return_value = ["chunk1", "chunk2"]
+
+ result = self.rag_service._prepare_documents([mock_doc1])
+
+ self.assertEqual(len(result), 2)
+ self.rag_service.text_splitter.split_text.assert_called_once_with("Test content")
+ self.assertEqual(result[0].page_content, "chunk1")
+ self.assertEqual(result[0].metadata["source"], "test_source")
+
+ def test_ingest_documents(self):
+ mock_workspace = MagicMock()
+ mock_document = MagicMock()
+ mock_documents = [mock_document]
+
+ with patch('services.rag_services.Document.objects.filter', return_value=mock_documents):
+ self.rag_service._prepare_documents = MagicMock(return_value=["processed_doc"])
+
+ self.rag_service.ingest_documents(mock_workspace)
+
+ self.rag_service.vector_store.add_documents.assert_called_once_with(["processed_doc"])
+ self.rag_service.vector_store.persist.assert_called_once()
+
+
+class TestSyncRAGService(DjangoTestCase):
+ def setUp(self):
+ self.sync_service = SyncRAGService()
+ self.sync_service.vector_store = MagicMock()
+ self.sync_service.llm = MagicMock()
+ self.sync_service.rag_chain = MagicMock()
+
+ self.mock_conversation = MagicMock(spec=Conversation)
+ self.mock_conversation.workspace = MagicMock()
+
+ self.mock_prompt1 = MagicMock(spec=Prompt)
+ self.mock_prompt1.is_user = True
+ self.mock_prompt1.text = "User question"
+ self.mock_prompt1.created_at = "2023-01-01"
+
+ self.mock_prompt2 = MagicMock(spec=Prompt)
+ self.mock_prompt2.is_user = False
+ self.mock_prompt2.text = "AI response"
+ self.mock_prompt2.created_at = "2023-01-02"
+
+ def test_format_history(self):
+ with patch('services.rag_services.Prompt.objects.filter') as mock_filter:
+ mock_filter.return_value.order_by.return_value = [self.mock_prompt1, self.mock_prompt2]
+
+ result = self.sync_service._format_history(self.mock_conversation)
+
+ expected = "User: User question\nAI: AI response"
+ self.assertEqual(result, expected)
+ mock_filter.assert_called_once_with(conversation=self.mock_conversation)
+
+ def test_retriever_with_history(self):
+ input_dict = {
+ "query": "test query",
+ "conversation": self.mock_conversation
+ }
+
+ self.sync_service.search_documents = MagicMock(return_value=["doc1", "doc2"])
+
+ result = self.sync_service._retriever_with_history(input_dict)
+
+ self.sync_service.search_documents.assert_called_once_with(
+ "test query",
+ self.mock_conversation.workspace
+ )
+ self.assertEqual(result, ["doc1", "doc2"])
+
+ def test_search_documents(self):
+ mock_retriever = MagicMock()
+ mock_retriever.get_relevant_documents.return_value = ["doc1", "doc2"]
+ self.sync_service.vector_store.as_retriever.return_value = mock_retriever
+
+ result = self.sync_service.search_documents("test query", self.mock_conversation.workspace)
+
+ self.sync_service.vector_store.as_retriever.assert_called_once_with(
+ search_type="similarity",
+ search_kwargs={
+ "k": 4,
+ "filter": {"workspace_id": self.mock_conversation.workspace.id}
+ }
+ )
+ self.assertEqual(result, ["doc1", "doc2"])
+
+ def test_generate_response(self):
+ chain_input = {
+ "query": "test query",
+ "conversation": self.mock_conversation
+ }
+
+ mock_stream = ["chunk1", "chunk2", "chunk3"]
+ self.sync_service.rag_chain.stream.return_value = mock_stream
+
+ result = list(self.sync_service.generate_response(self.mock_conversation, "test query"))
+
+ self.sync_service.rag_chain.stream.assert_called_once_with(chain_input)
+ self.assertEqual(result, mock_stream)
+
+
+class TestAsyncRAGService(DjangoTestCase):
+ def setUp(self):
+ self.async_service = AsyncRAGService()
+ self.async_service.vector_store = MagicMock()
+ self.async_service.llm = MagicMock()
+ self.async_service.rag_chain = AsyncMock()
+
+ self.mock_conversation = MagicMock(spec=Conversation)
+ self.mock_conversation.workspace = MagicMock()
+
+ self.mock_prompt1 = MagicMock(spec=Prompt)
+ self.mock_prompt1.is_user = True
+ self.mock_prompt1.text = "User question"
+ self.mock_prompt1.created_at = "2023-01-01"
+
+ self.mock_prompt2 = MagicMock(spec=Prompt)
+ self.mock_prompt2.is_user = False
+ self.mock_prompt2.text = "AI response"
+ self.mock_prompt2.created_at = "2023-01-02"
+
+ async def test_format_history(self):
+ mock_manager = AsyncMock()
+ mock_manager.order_by.return_value.alist.return_value = [self.mock_prompt1, self.mock_prompt2]
+
+ with patch('services.rag_services.Prompt.objects.filter', return_value=mock_manager):
+ result = await self.async_service._format_history(self.mock_conversation)
+
+ expected = "User: User question\nAI: AI response"
+ self.assertEqual(result, expected)
+ mock_manager.order_by.assert_called_once_with('created_at')
+
+ async def test_retriever_with_history(self):
+ input_dict = {
+ "query": "test query",
+ "conversation": self.mock_conversation
+ }
+
+ self.async_service.search_documents = AsyncMock(return_value=["doc1", "doc2"])
+
+ result = await self.async_service._retriever_with_history(input_dict)
+
+ self.async_service.search_documents.assert_awaited_once_with(
+ "test query",
+ self.mock_conversation.workspace
+ )
+ self.assertEqual(result, ["doc1", "doc2"])
+
+ async def test_search_documents(self):
+ mock_retriever = AsyncMock()
+ mock_retriever.aget_relevant_documents.return_value = ["doc1", "doc2"]
+ self.async_service.vector_store.as_retriever.return_value = mock_retriever
+
+ result = await self.async_service.search_documents("test query", self.mock_conversation.workspace)
+
+ self.async_service.vector_store.as_retriever.assert_called_once_with(
+ search_type="similarity",
+ search_kwargs={
+ "k": 4,
+ "filter": {"workspace_id": self.mock_conversation.workspace.id}
+ }
+ )
+ self.assertEqual(result, ["doc1", "doc2"])
+
+ async def test_generate_response(self):
+ chain_input = {
+ "query": "test query",
+ "conversation": self.mock_conversation
+ }
+
+ mock_stream = ["chunk1", "chunk2", "chunk3"]
+ self.async_service.rag_chain.astream.return_value = mock_stream
+
+ chunks = []
+ async for chunk in self.async_service.generate_response(self.mock_conversation, "test query"):
+ chunks.append(chunk)
+
+ self.async_service.rag_chain.astream.assert_awaited_once_with(chain_input)
+ self.assertEqual(chunks, mock_stream)
\ No newline at end of file
diff --git a/llm_be/chat_backend/services/title_generator.py b/llm_be/chat_backend/services/title_generator.py
new file mode 100644
index 0000000..f8209d1
--- /dev/null
+++ b/llm_be/chat_backend/services/title_generator.py
@@ -0,0 +1,67 @@
+from langchain_core.prompts import ChatPromptTemplate
+from langchain_community.llms import Ollama
+from typing import Optional
+
+class TitleGenerator:
+ """
+ Generates short, descriptive titles for conversations based on the first prompt.
+ """
+
+ def __init__(self):
+ self.llm = Ollama(
+ model="llama3",
+ temperature=0.5, # Slightly creative but not too random
+ top_k=20,
+ num_ctx=2048 # Shorter context needed for titles
+ )
+
+ self.title_prompt = ChatPromptTemplate.from_messages([
+ ("system", """You are a conversation title generator. Create a very short (2-5 word) title based on the user's first message.
+
+Rules:
+1. Keep it extremely concise
+2. Capture the main topic or intent
+3. Use title case
+4. No quotes or punctuation
+5. Never exceed 5 words
+
+Examples:
+- "What's the weather today?" → "Weather Inquiry"
+- "Explain quantum computing" → "Quantum Computing Explanation"
+- "Generate an image of a dragon" → "Dragon Image Generation"
+- "Find our company's privacy policy" → "Privacy Policy Search"
+
+Return ONLY the title, nothing else."""),
+ ("human", "{prompt}")
+ ])
+
+ self.chain = self.title_prompt | self.llm
+
+ async def generate_async(self, prompt: str) -> str:
+ """Generate title asynchronously"""
+ try:
+ response = await self.chain.ainvoke({"prompt": prompt})
+ return self._clean_response(response)
+ except Exception as e:
+ print(f"Title generation error: {e}")
+ return "Conversation"
+
+ def generate(self, prompt: str) -> str:
+ """Generate title synchronously"""
+ try:
+ response = self.chain.invoke({"prompt": prompt})
+ return self._clean_response(response)
+ except Exception as e:
+ print(f"Title generation error: {e}")
+ return "Conversation"
+
+ def _clean_response(self, response: str) -> str:
+ """Clean and format the LLM response"""
+ # Remove any quotes or punctuation
+ response = response.strip('"\'.!? \n\t')
+ # Ensure title case and trim
+ return response.title()[:50] # Hard limit for safety
+
+
+# Singleton instance
+title_generator = TitleGenerator()
\ No newline at end of file
diff --git a/llm_be/chat_backend/signals.py b/llm_be/chat_backend/signals.py
new file mode 100644
index 0000000..c08bab1
--- /dev/null
+++ b/llm_be/chat_backend/signals.py
@@ -0,0 +1,18 @@
+from django.db.models.signals import post_save, post_delete
+from django.dispatch import receiver
+from chat_backend.models import Document
+from .services.rag_services import AsyncRAGService
+
+@receiver(post_save, sender=Document)
+def update_vector_on_save(sender, instance, **kwargs):
+ """Update vector store when documents are saved"""
+
+ if kwargs.get('created', False):
+ rag_service = AsyncRAGService()
+ rag_service.ingest_documents()
+
+@receiver(post_delete, sender=Document)
+def delete_vector_on_remove(sender, instance, **kwargs):
+ """Handle document deletion by re-indexing the whole workspace"""
+ rag_service = AsyncRAGService()
+ rag_service.ingest_documents()
\ No newline at end of file
diff --git a/llm_be/chat_backend/templates/emails/reset_email.html b/llm_be/chat_backend/templates/emails/reset_email.html
new file mode 100644
index 0000000..1a127d7
--- /dev/null
+++ b/llm_be/chat_backend/templates/emails/reset_email.html
@@ -0,0 +1,97 @@
+
+
+
+
+
+ Reset Password for Chat by AI ML Operations, LLC
+
+
+
+
+
+
+
+
+
+
+
+
+
+ Hello,
+ There has been a request for a password reset. If you didn't requets this, please email ryan@aimloperations.com
+
+ Please click link to set your password.
+ Once you have set your password go here to get started.
+
+ Thank you.
+
+
+
+
+
+ |
+
+
+
+
\ No newline at end of file
diff --git a/llm_be/chat_backend/templates/emails/reset_email.txt b/llm_be/chat_backend/templates/emails/reset_email.txt
new file mode 100644
index 0000000..f805092
--- /dev/null
+++ b/llm_be/chat_backend/templates/emails/reset_email.txt
@@ -0,0 +1,3 @@
+Password Reset for AI ML Operations, LLC Chat Services
+
+"Password reset for chat.aimloperations.com. Please use {{ url }} to set your password"
\ No newline at end of file
diff --git a/llm_be/chat_backend/tests.py b/llm_be/chat_backend/tests.py
index 7ce503c..392daa8 100644
--- a/llm_be/chat_backend/tests.py
+++ b/llm_be/chat_backend/tests.py
@@ -1,3 +1,210 @@
from django.test import TestCase
# Create your tests here.
+from django.test import TestCase, Client
+from django.urls import reverse
+from django.contrib.auth.models import User
+from rest_framework.test import APIClient, APITestCase
+from rest_framework import status
+from .models import DocumentWorkspace, Document, Company
+from django.contrib.auth import get_user_model
+import tempfile
+from django.core.files.uploadedfile import SimpleUploadedFile
+
+# Minimal valid PDF bytes
+VALID_PDF_BYTES = (
+ b'%PDF-1.3\n'
+ b'1 0 obj\n'
+ b'<< /Type /Catalog /Pages 2 0 R >>\n'
+ b'endobj\n'
+ b'2 0 obj\n'
+ b'<< /Type /Pages /Kids [3 0 R] /Count 1 >>\n'
+ b'endobj\n'
+ b'3 0 obj\n'
+ b'<< /Type /Page /Parent 2 0 R /Resources << >> /MediaBox [0 0 612 792] /Contents 4 0 R >>\n'
+ b'endobj\n'
+ b'4 0 obj\n'
+ b'<< /Length 44 >>\n'
+ b'stream\n'
+ b'BT /F1 12 Tf 72 720 Td (Test PDF) Tj ET\n'
+ b'endstream\n'
+ b'endobj\n'
+ b'xref\n'
+ b'0 5\n'
+ b'0000000000 65535 f \n'
+ b'0000000009 00000 n \n'
+ b'0000000058 00000 n \n'
+ b'0000000117 00000 n \n'
+ b'0000000223 00000 n \n'
+ b'trailer\n'
+ b'<< /Size 5 /Root 1 0 R >>\n'
+ b'startxref\n'
+ b'317\n'
+ b'%%EOF'
+)
+
+class DocumentWorkspaceViewsTestCase(APITestCase):
+ def setUp(self):
+ self.company = Company.objects.create(
+ name="test",
+ state="IL",
+ zipcode="60189",
+ address="1968 Greensboro Dr"
+ )
+ self.user = get_user_model().objects.create_user(
+ company=self.company,
+ username='testuser',
+ password='testpass123',
+ email="test@test.com",
+ )
+
+ self.client = APIClient()
+ self.client.force_authenticate(user=self.user)
+
+ self.workspace = DocumentWorkspace.objects.create(
+ company = self.user.company,
+ name='Test Workspace'
+ )
+
+ def test_list_workspaces(self):
+ url = reverse('document_workspaces')
+ response = self.client.get(url)
+ self.assertEqual(response.status_code, status.HTTP_200_OK)
+ self.assertEqual(len(response.data), 1)
+ self.assertEqual(response.data[0]['name'], 'Test Workspace')
+
+ def test_create_workspace(self):
+ url = reverse('document_workspaces')
+ data = {
+ 'name': 'New Workspace'
+ }
+ response = self.client.post(url, data, format='json')
+ self.assertEqual(response.status_code, status.HTTP_201_CREATED)
+ self.assertEqual(DocumentWorkspace.objects.count(), 2)
+
+ def test_retrieve_workspace(self):
+ url = reverse('document_workspaces')
+ response = self.client.get(url)
+ self.assertEqual(response.status_code, status.HTTP_200_OK)
+ self.assertEqual(response.data[0]['name'], 'Test Workspace')
+
+ # def test_update_workspace(self):
+ # url = reverse('document_workspaces')
+ # data = {
+ # 'name': 'Updated Workspace'
+ # }
+ # response = self.client.post(url, data, format='json')
+ # self.assertEqual(response.status_code, status.HTTP_201_CREATED)
+ # self.workspace.refresh_from_db()
+ # self.assertEqual(self.workspace.name, 'Updated Workspace')
+
+ # def test_delete_workspace(self):
+ # url = reverse('document_workspaces', args=[self.workspace.id])
+ # response = self.client.delete(url)
+ # self.assertEqual(response.status_code, status.HTTP_204_NO_CONTENT)
+ # self.assertEqual(DocumentWorkspace.objects.count(), 0)
+
+class DocumentViewsTestCase(APITestCase):
+ def setUp(self):
+ self.company = Company.objects.create(
+ name="test",
+ state="IL",
+ zipcode="60189",
+ address="1968 Greensboro Dr"
+ )
+ self.user = get_user_model().objects.create_user(
+ company=self.company,
+ username='testuser',
+ password='testpass123',
+ email="test@test.com",
+ )
+
+ self.client = APIClient()
+ self.client.force_authenticate(user=self.user)
+
+ self.workspace = DocumentWorkspace.objects.create(
+ company=self.user.company,
+ name='Test Workspace'
+ )
+
+ # Create a test file
+ self.test_file = SimpleUploadedFile(
+ "test.pdf",
+ VALID_PDF_BYTES,
+ content_type="application/pdf"
+ )
+
+ def test_upload_document(self):
+ url = reverse('documents')
+ data = {
+ 'file': self.test_file
+ }
+ response = self.client.post(url, data, format='multipart')
+ self.assertEqual(response.status_code, status.HTTP_201_CREATED)
+ self.assertEqual(Document.objects.count(), 1)
+
+ document = Document.objects.first()
+ self.assertEqual(document.workspace.id, self.workspace.id)
+ self.assertTrue(document.processed) # Should be False initially
+
+ def test_list_documents(self):
+ # First create a document
+ Document.objects.create(
+ workspace=self.workspace,
+ file=self.test_file
+ )
+
+ url = reverse('documents')
+ response = self.client.get(url)
+ self.assertEqual(response.status_code, status.HTTP_200_OK)
+ self.assertEqual(len(response.data), 1)
+ self.assertIn('test', response.data[0]['file'])
+ self.assertIn('pdf', response.data[0]['file'])
+
+ # def test_delete_document(self):
+ # document = Document.objects.create(
+ # workspace=self.workspace,
+ # file=self.test_file
+ # )
+
+ # url = reverse('document-detail', args=[document.id])
+ # response = self.client.delete(url)
+ # self.assertEqual(response.status_code, status.HTTP_204_NO_CONTENT)
+ # self.assertEqual(Document.objects.count(), 0)
+
+ def test_upload_invalid_file(self):
+ url = reverse('documents')
+ data = {
+ 'file': 'not a file'
+ }
+ response = self.client.post(url, data, format='multipart')
+ self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
+
+ def test_access_other_users_documents(self):
+ # Create another user
+ other_company = Company.objects.create(
+ name="test2",
+ state="IL",
+ zipcode="60189",
+ address="1968 Greensboro Dr"
+ )
+ other_user = get_user_model().objects.create_user(
+ company=other_company,
+ username='otheruser',
+ password='otherpass123',
+ email="testing2@test.com"
+ )
+ other_workspace = DocumentWorkspace.objects.create(
+ company = other_user.company,
+ name='Other Workspace'
+ )
+ other_document = Document.objects.create(
+ workspace=other_workspace,
+ file=self.test_file
+ )
+
+ # Try to access the other user's document
+ url = reverse('documents_details', kwargs={"document_id":other_document.id})
+ response = self.client.get(url)
+ self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND)
+
\ No newline at end of file
diff --git a/llm_be/chat_backend/urls.py b/llm_be/chat_backend/urls.py
index d32c9c5..83e281a 100644
--- a/llm_be/chat_backend/urls.py
+++ b/llm_be/chat_backend/urls.py
@@ -14,27 +14,42 @@ from .views import (
ConversationDetailView,
CompanyUsersView,
SetUserPassword,
+ ResetUserPassword,
ConversationPreferences,
UserPromptAnalytics,
UserConversationAnalytics,
CompanyUsageAnalytics,
- AdminAnalytics
+ AdminAnalytics,
+ reset_password,
+ DocumentWorkspaceView,
+ DocumentUploadView,
+ DocumentDetailView
+
)
+from rest_framework.routers import DefaultRouter
+
urlpatterns = [
path("token/obtain/", CustomObtainTokenView.as_view(), name="token_create"),
path("token/refresh/", jwt_views.TokenRefreshView.as_view(), name="token_refresh"),
path("user/create/", CustomUserCreate.as_view(), name="create_user"),
path("user/invite/", CustomUserInvite.as_view(), name="invite_user"),
- path("user/set_password//", SetUserPassword.as_view(), name="set_password"),
+ path("user/reset_password/", reset_password, name="reset_password"),
+ path(
+ "user/set_password//", SetUserPassword.as_view(), name="set_password"
+ ),
path(
"blacklist/",
LogoutAndBlacklistRefreshTokenForUserView.as_view(),
name="blacklist",
),
path("user/get/", CustomUserGet.as_view(), name="get_user"),
- path("user/acknowledge_tos/", AcknowledgeTermsOfService.as_view(), name="acknowledge_tos"),
- path("company_users",CompanyUsersView.as_view(), name="company_users"),
+ path(
+ "user/acknowledge_tos/",
+ AcknowledgeTermsOfService.as_view(),
+ name="acknowledge_tos",
+ ),
+ path("company_users", CompanyUsersView.as_view(), name="company_users"),
path("user/is_authenticated/", is_authenticated, name="is_authenticated"),
path("announcment/get/", AnnouncmentView.as_view(), name="get_announcments"),
path("conversations", ConversationsView.as_view(), name="conversations"),
@@ -44,9 +59,32 @@ urlpatterns = [
ConversationDetailView.as_view(),
name="conversation_details",
),
- path("conversation_preferences", ConversationPreferences.as_view(), name="conversation_preferences"),
- path("analytics/user_prompts/", UserPromptAnalytics.as_view(), name="analytics_user_prompts"),
- path("analytics/user_conversations/", UserConversationAnalytics.as_view(), name="analytics_user_conversations"),
- path("analytics/company_usage/", CompanyUsageAnalytics.as_view(), name="analytics_company_usage"),
+ path(
+ "conversation_preferences",
+ ConversationPreferences.as_view(),
+ name="conversation_preferences",
+ ),
+ path(
+ "analytics/user_prompts/",
+ UserPromptAnalytics.as_view(),
+ name="analytics_user_prompts",
+ ),
+ path(
+ "analytics/user_conversations/",
+ UserConversationAnalytics.as_view(),
+ name="analytics_user_conversations",
+ ),
+ path(
+ "analytics/company_usage/",
+ CompanyUsageAnalytics.as_view(),
+ name="analytics_company_usage",
+ ),
path("analytics/admin/", AdminAnalytics.as_view(), name="analytics_admin"),
+
+ # document urls
+ path("document_workspaces/", DocumentWorkspaceView.as_view(), name="document_workspaces"),
+ path("documents/", DocumentUploadView.as_view(), name="documents"),
+ path("documents_details/", DocumentDetailView.as_view(), name="documents_details"),
+
]
+
diff --git a/llm_be/chat_backend/utils.py b/llm_be/chat_backend/utils.py
index ba6a9f6..2d33454 100644
--- a/llm_be/chat_backend/utils.py
+++ b/llm_be/chat_backend/utils.py
@@ -1,7 +1,8 @@
import datetime
+
def last_day_of_month(any_day):
- # The day 28 exists in every month. 4 days later, it's always next month
- next_month = any_day.replace(day=28) + datetime.timedelta(days=4)
- # subtracting the number of the current day brings us back one month
- return next_month - datetime.timedelta(days=next_month.day)
\ No newline at end of file
+ # The day 28 exists in every month. 4 days later, it's always next month
+ next_month = any_day.replace(day=28) + datetime.timedelta(days=4)
+ # subtracting the number of the current day brings us back one month
+ return next_month - datetime.timedelta(days=next_month.day)
diff --git a/llm_be/chat_backend/views.py b/llm_be/chat_backend/views.py
index 7a87583..a6a41b1 100644
--- a/llm_be/chat_backend/views.py
+++ b/llm_be/chat_backend/views.py
@@ -11,11 +11,22 @@ from .serializers import (
CompanySerializer,
ConversationSerializer,
PromptSerializer,
- FeedbackSerializer
+ FeedbackSerializer,
+ DocumentWorkspaceSerializer,
+ DocumentSerializer
)
from rest_framework.views import APIView
from rest_framework.response import Response
-from .models import CustomUser, Announcement, Conversation, Prompt, Feedback,PromptMetric
+from .models import (
+ CustomUser,
+ Announcement,
+ Conversation,
+ Prompt,
+ Feedback,
+ PromptMetric,
+ DocumentWorkspace,
+ Document
+)
from django.views.decorators.cache import never_cache
from django.http import JsonResponse
from datetime import datetime
@@ -25,7 +36,9 @@ from channels.generic.websocket import AsyncWebsocketConsumer
from langchain_ollama.llms import OllamaLLM
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.messages import HumanMessage, SystemMessage
+from langchain.chains import RetrievalQA
import re
+import os
from django.conf import settings
import json
import base64
@@ -43,15 +56,27 @@ from django.core.files.base import ContentFile
import math
import datetime
import pytz
+from langchain_community.embeddings import OllamaEmbeddings
from dateutil.relativedelta import relativedelta
+from django.views.decorators.csrf import csrf_exempt
from .utils import last_day_of_month
+from .services.llm_service import AsyncLLMService
+from .services.rag_services import AsyncRAGService
+from .services.title_generator import title_generator
+from .services.moderation_classifier import moderation_classifier, ModerationLabel
+from .services.prompt_classifier import prompt_classifier, PromptType
-CHANNEL_NAME: str = 'llm_messages'
+from langchain.chains import create_retrieval_chain
+from langchain.chains.combine_documents import create_stuff_documents_chain
+from langchain_ollama import ChatOllama
+
+CHANNEL_NAME: str = "llm_messages"
MODEL_NAME: str = "llama3"
+
# Create your views here.
class CustomObtainTokenView(TokenObtainPairView):
permission_classes = (permissions.AllowAny,)
@@ -71,80 +96,167 @@ class CustomUserCreate(APIView):
return Response(json, status=status.HTTP_201_CREATED)
return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST)
+
def send_invite_email(slug, email_to_invite):
+ print("Sending invite email")
print(f"url : https://www.chat.aimloperations.com/set_password?slug={slug}")
url = f"https://www.chat.aimloperations.com/set_password?slug={slug}"
subject = "Welcome to AI ML Operations, LLC Chat Services"
from_email = "ryan@aimloperations.com"
- to=email_to_invite
+ to = email_to_invite
d = {"url": url}
- html_content = get_template(r'emails/invite_email.html').render(d)
- text_content = get_template(r'emails/invite_email.txt').render(d)
-
+ html_content = get_template(r"emails/invite_email.html").render(d)
+ text_content = get_template(r"emails/invite_email.txt").render(d)
+
msg = EmailMultiAlternatives(subject, text_content, from_email, [to])
msg.attach_alternative(html_content, "text/html")
msg.send(fail_silently=True)
+
def send_feedback_email(feedback_obj):
+ print("Sending feedback email")
subject = "New Feedback for Chat by AI ML Operations, LLC"
from_email = "ryan@aimloperations.com"
- to="ryan@aimloperations.com"
+ to = "ryan@aimloperations.com"
d = {"title": feedback_obj.title, "feedback_text": feedback_obj.text}
- html_content = get_template(r'emails/feedback_email.html').render(d)
- text_content = get_template(r'emails/feedback_email.txt').render(d)
-
+ html_content = get_template(r"emails/feedback_email.html").render(d)
+ text_content = get_template(r"emails/feedback_email.txt").render(d)
+
msg = EmailMultiAlternatives(subject, text_content, from_email, [to])
msg.attach_alternative(html_content, "text/html")
msg.send(fail_silently=True)
+
+def send_password_reset_email(slug, email_to_invite):
+ print("Sending Password reset email")
+ url = f"https://www.chat.aimloperations.com/set_password?slug={slug}"
+ subject = "Password reset for Chat by AI ML Operations, LLC"
+ from_email = "ryan@aimloperations.com"
+ to = email_to_invite
+ d = {"url": url}
+ html_content = get_template(r"emails/reset_email.html").render(d)
+ text_content = get_template(r"emails/reset_email.txt").render(d)
+ msg = EmailMultiAlternatives(subject, text_content, from_email, [to])
+ msg.attach_alternative(html_content, "text/html")
+ msg.send(fail_silently=True)
+
+
class CustomUserInvite(APIView):
- http_method_names = ['post']
+ http_method_names = ["post"]
+
def post(self, request, format="json"):
def valid_email(email_string):
- regex = r'^[a-z0-9]+[\._]?[a-z0-9]+[@]\w+[.]\w+$'
- if re.match(regex,email_string):
+ regex = r"^[a-z0-9]+[\._]?[a-z0-9]+[@]\w+[.]\w+$"
+ if re.match(regex, email_string):
return True
else:
return False
- email_to_invite = request.data['email']
-
- if len(email_to_invite) == 0 or not valid_email(email_to_invite) or not request.user.is_company_manager:
+ email_to_invite = request.data["email"]
+
+ if (
+ len(email_to_invite) == 0
+ or not valid_email(email_to_invite)
+ or not request.user.is_company_manager
+ ):
return Response(status=status.HTTP_400_BAD_REQUEST)
# make sure there isn't a user with this email already
existing_users = CustomUser.objects.filter(email=email_to_invite)
if len(existing_users) > 0:
return Response(status=status.HTTP_400_BAD_REQUEST)
# create the object and send the email
- user = CustomUser.objects.create(email=email_to_invite, username=email_to_invite, company=request.user.company)
+ user = CustomUser.objects.create(
+ email=email_to_invite,
+ username=email_to_invite,
+ company=request.user.company,
+ )
# send an email
send_invite_email(user.slug, email_to_invite)
-
-
return Response(status=status.HTTP_201_CREATED)
-class SetUserPassword(APIView):
- http_method_names = ['post','get']
+@csrf_exempt
+def reset_password(request):
+ if request.method == "POST":
+ data = json.loads(request.body)
+ token = data.get('recaptchaToken')
+ payload = {
+ 'secret': settings.CAPTCHA_SECRET_KEY,
+ 'response': token,
+ }
+ response = requests.post('https://www.google.com/recaptcha/api/siteverify', data=payload)
+ result = response.json()
+ if result.get('success') and result.get('score') >= 0.5:
+ email = data.get('email')
+ user = CustomUser.objects.filter(email=email).first()
+ if user:
+ user.set_unusable_password()
+ user.save()
+
+ # send the email
+ send_password_reset_email(user.slug, email)
+ JsonResponse(status=200)
+
+
+ JsonResponse(status=400)
+
+class ResetUserPassword(APIView):
+ http_method_names = [
+ "post",
+ ]
permission_classes = (permissions.AllowAny,)
authentication_classes = ()
+
+ def post(self, request, format="json"):
+ """
+ Send an email with a set password link to the set password page
+ Also disable the account
+ """
+ print(f"Password reset for requests. {request.data}")
+ token = request.data.get('recaptchaToken')
+ payload = {
+ 'secret': settings.CAPTCHA_SECRET_KEY,
+ 'response': recaptchaToken,
+ }
+ response = requests.post('https://www.google.com/recaptcha/api/siteverify', data=payload)
+ result = response.json()
+ if result.get('success') and result.get('score') >= 0.5:
+ user = CustomUser.objects.filter(email=email).first()
+ if user:
+ user.set_unusable_password()
+ user.save()
+
+ # send the email
+ send_password_reset_email(user.slug, email)
+ else:
+ print('Captcha secret failed')
+
+ return Response(status=status.HTTP_200_OK)
+
+
+class SetUserPassword(APIView):
+ http_method_names = ["post", "get"]
+ permission_classes = (permissions.AllowAny,)
+ authentication_classes = ()
+
def get(self, request, slug):
user = CustomUser.objects.get(slug=slug)
- if user.last_login:
+ if user.has_usable_password():
return Response(status=status.HTTP_401_UNAUTHORIZED)
else:
return Response(status=status.HTTP_200_OK)
+
def post(self, request, slug, format="json"):
user = CustomUser.objects.get(slug=slug)
- user.set_password(request.data['password'])
+ user.set_password(request.data["password"])
user.save()
return Response(status=status.HTTP_200_OK)
-
class CustomUserGet(APIView):
- http_method_names = ['get', 'head', 'post']
+ http_method_names = ["get", "head", "post"]
+
def get(self, request, format="json"):
email = request.user.email
@@ -154,16 +266,18 @@ class CustomUserGet(APIView):
return Response(serializer.data, status=status.HTTP_200_OK)
+
class FeedbackView(APIView):
- http_method_names = ['post','get']
+ http_method_names = ["post", "get"]
+
def post(self, request, format="json"):
serializer = FeedbackSerializer(data=request.data)
print(request.data)
if serializer.is_valid():
-
+
feedback_obj = serializer.save()
feedback_obj.user = request.user
-
+
feedback_obj.save()
send_feedback_email(feedback_obj)
return Response(serializer.data, status=status.HTTP_201_CREATED)
@@ -177,14 +291,15 @@ class FeedbackView(APIView):
return Response(serializer.data, status=status.HTTP_200_OK)
-
class AcknowledgeTermsOfService(APIView):
- http_method_names = ['post']
+ http_method_names = ["post"]
+
def post(self, request, format="json"):
request.user.has_signed_tos = True
request.user.save()
return Response(status=status.HTTP_200_OK)
+
class CompanyUsersView(APIView):
def get(self, request, format="json"):
# TODO: make sure you are a manager of that company
@@ -194,8 +309,7 @@ class CompanyUsersView(APIView):
return Response(serializer.data, status=status.HTTP_200_OK)
else:
return Response(status=status.HTTP_401_UNAUTHORIZED)
-
-
+
def post(self, request, format="json"):
if request.user.is_company_manager:
user = CustomUser.objects.get(email=request.data.get("email"))
@@ -215,7 +329,7 @@ class CompanyUsersView(APIView):
return Response(status=status.HTTP_200_OK)
return Response(status=status.HTTP_400_BAD_REQUEST)
return Response(status=status.HTTP_401_UNAUTHORIZED)
-
+
def delete(self, request, format="json"):
if request.user.is_company_manager:
user = CustomUser.objects.get(email=request.data.get("email"))
@@ -224,6 +338,7 @@ class CompanyUsersView(APIView):
return Response(status=status.HTTP_200_OK)
return Response(status=status.HTTP_401_UNAUTHORIZED)
+
class AnnouncmentView(APIView):
permission_classes = (permissions.AllowAny,)
serializer_class = AnnouncmentSerializer
@@ -259,7 +374,9 @@ def is_authenticated(request):
class ConversationsView(APIView):
def get(self, request, format="json"):
order = "created" if request.user.conversation_order else "-created"
- conversations = Conversation.objects.filter(user=request.user, deleted=False).order_by(order)
+ conversations = Conversation.objects.filter(
+ user=request.user, deleted=False
+ ).order_by(order)
serializer = ConversationSerializer(conversations, many=True)
return Response(serializer.data, status=status.HTTP_200_OK)
@@ -283,7 +400,9 @@ class ConversationsView(APIView):
# conversation.user_id = request.user.id
# conversation.save()
- return Response({"title": title, "id": conversation.id}, status=status.HTTP_201_CREATED)
+ return Response(
+ {"title": title, "id": conversation.id}, status=status.HTTP_201_CREATED
+ )
class ConversationPreferences(APIView):
@@ -298,7 +417,6 @@ class ConversationPreferences(APIView):
return Response({"order": user.conversation_order}, status=status.HTTP_200_OK)
-
class ConversationDetailView(APIView):
def get(self, request, format="json"):
conversation_id = request.query_params.get("conversation_id")
@@ -306,9 +424,8 @@ class ConversationDetailView(APIView):
serailzer = PromptSerializer(prompts, many=True)
return Response(serailzer.data, status=status.HTTP_200_OK)
-
def post(self, request, format="json"):
- print('In the post')
+ print("In the post")
# Add the prompt to the database
# make sure there is a conversation for it
# if there is not a conversation create a title for it
@@ -336,28 +453,30 @@ class ConversationDetailView(APIView):
prompt_instance = serializer.save()
# set up the streaming response if it is from the user
- print(f'Do we have a valid user? {is_user}')
+ print(f"Do we have a valid user? {is_user}")
if is_user:
messages = []
- for prompt_obj in Prompt.objects.filter(conversation__id=conversation_id):
- messages.append({
- 'content':prompt_obj.message,
- 'role': 'user' if prompt_obj.user_created else 'assistant'
- })
+ for prompt_obj in Prompt.objects.filter(
+ conversation__id=conversation_id
+ ):
+ messages.append(
+ {
+ "content": prompt_obj.message,
+ "role": "user" if prompt_obj.user_created else "assistant",
+ }
+ )
channel_layer = get_channel_layer()
- print(f'Sending to the channel: {CHANNEL_NAME}')
+ print(f"Sending to the channel: {CHANNEL_NAME}")
async_to_sync(channel_layer.group_send)(
- CHANNEL_NAME, {
- 'type':'receive',
- 'content': messages
- }
+ CHANNEL_NAME, {"type": "receive", "content": messages}
)
except:
- print(f"Error trying to submit to conversation_id: {conversation_id} with request.data: {request.data}")
+ print(
+ f"Error trying to submit to conversation_id: {conversation_id} with request.data: {request.data}"
+ )
pass
-
return Response(status=status.HTTP_200_OK)
def delete(self, request, format="json"):
@@ -367,13 +486,16 @@ class ConversationDetailView(APIView):
conversation.save()
return Response(status=status.HTTP_202_ACCEPTED)
+
class UserPromptAnalytics(APIView):
def get(self, request, format="json"):
now = timezone.now()
result = []
- number_of_months = 3
- company_user_ids = CustomUser.objects.filter(company=request.user.company).values_list('id', flat=True)
+ number_of_months = 3
+ company_user_ids = CustomUser.objects.filter(
+ company=request.user.company
+ ).values_list("id", flat=True)
for i in range(number_of_months):
next_year = now.year
next_month = now.month - i
@@ -383,30 +505,51 @@ class UserPromptAnalytics(APIView):
start_date = datetime.datetime(next_year, next_month, 1)
end_date = last_day_of_month(start_date)
- total_conversations = Conversation.objects.filter(created__gte=start_date, created__lte=end_date)
- total_prompts = Prompt.objects.filter(conversation__id__in=total_conversations, created__gte=start_date, created__lte=end_date)
+ total_conversations = Conversation.objects.filter(
+ created__gte=start_date, created__lte=end_date
+ )
+ total_prompts = Prompt.objects.filter(
+ conversation__id__in=total_conversations,
+ created__gte=start_date,
+ created__lte=end_date,
+ )
total_users = len(CustomUser.objects.all())
my_conversations = Conversation.objects.filter(user=request.user)
- my_prompts = Prompt.objects.filter(conversation__in=my_conversations, created__gte=start_date, created__lte=end_date)
- company_conversations = Conversation.objects.filter(user__id__in=company_user_ids)
- company_prompts = Prompt.objects.filter(conversation__in=company_conversations, created__gte=start_date, created__lte=end_date)
+ my_prompts = Prompt.objects.filter(
+ conversation__in=my_conversations,
+ created__gte=start_date,
+ created__lte=end_date,
+ )
+ company_conversations = Conversation.objects.filter(
+ user__id__in=company_user_ids
+ )
+ company_prompts = Prompt.objects.filter(
+ conversation__in=company_conversations,
+ created__gte=start_date,
+ created__lte=end_date,
+ )
+
+ result.append(
+ {
+ "month": start_date.strftime("%B"),
+ "you": len(my_prompts),
+ "others": len(company_prompts) / len(company_user_ids),
+ "all": len(total_prompts) / total_users,
+ }
+ )
- result.append({
- "month":start_date.strftime("%B"),
- "you": len(my_prompts),
- "others": len(company_prompts)/len(company_user_ids),
- "all":len(total_prompts)/total_users
- })
-
return Response(result[::-1], status=status.HTTP_200_OK)
-
+
+
class UserConversationAnalytics(APIView):
def get(self, request, format="json"):
now = timezone.now()
result = []
- number_of_months = 3
- company_user_ids = CustomUser.objects.filter(company=request.user.company).values_list('id', flat=True)
+ number_of_months = 3
+ company_user_ids = CustomUser.objects.filter(
+ company=request.user.company
+ ).values_list("id", flat=True)
for i in range(number_of_months):
next_year = now.year
next_month = now.month - i
@@ -416,28 +559,48 @@ class UserConversationAnalytics(APIView):
start_date = datetime.datetime(next_year, next_month, 1)
end_date = last_day_of_month(start_date)
- total_conversations = len(Conversation.objects.filter(created__gte=start_date, created__lte=end_date))
+ total_conversations = len(
+ Conversation.objects.filter(
+ created__gte=start_date, created__lte=end_date
+ )
+ )
total_users = len(CustomUser.objects.all())
- company_conversations = len(Conversation.objects.filter(user__id__in=company_user_ids, created__gte=start_date, created__lte=end_date))
+ company_conversations = len(
+ Conversation.objects.filter(
+ user__id__in=company_user_ids,
+ created__gte=start_date,
+ created__lte=end_date,
+ )
+ )
- result.append({
- "month":start_date.strftime("%B"),
- "you": len(Conversation.objects.filter(user=request.user, created__gte=start_date, created__lte=end_date)),
- "others": company_conversations/len(company_user_ids),
- "all":total_conversations/total_users
- })
+ result.append(
+ {
+ "month": start_date.strftime("%B"),
+ "you": len(
+ Conversation.objects.filter(
+ user=request.user,
+ created__gte=start_date,
+ created__lte=end_date,
+ )
+ ),
+ "others": company_conversations / len(company_user_ids),
+ "all": total_conversations / total_users,
+ }
+ )
-
return Response(result[::-1], status=status.HTTP_200_OK)
+
class CompanyUsageAnalytics(APIView):
def get(self, request, format="json"):
now = timezone.now()
result = []
- number_of_months = 3
- company_user_ids = CustomUser.objects.filter(company=request.user.company).values_list('id', flat=True)
-
+ number_of_months = 3
+ company_user_ids = CustomUser.objects.filter(
+ company=request.user.company
+ ).values_list("id", flat=True)
+
for i in range(number_of_months):
next_year = now.year
next_month = now.month - i
@@ -447,19 +610,28 @@ class CompanyUsageAnalytics(APIView):
start_date = datetime.datetime(next_year, next_month, 1)
end_date = last_day_of_month(start_date)
- conversations = Conversation.objects.filter(user__id__in=company_user_ids, created__gte=start_date, created__lte=end_date)
-
- conversation_user_ids = conversations.values_list("user__id", flat=True).distinct()
- result.append({
- "month":start_date.strftime("%B"),
- "used":len(conversation_user_ids),
- "not_used":len(company_user_ids) - len(conversation_user_ids)
- })
+ conversations = Conversation.objects.filter(
+ user__id__in=company_user_ids,
+ created__gte=start_date,
+ created__lte=end_date,
+ )
+
+ conversation_user_ids = conversations.values_list(
+ "user__id", flat=True
+ ).distinct()
+ result.append(
+ {
+ "month": start_date.strftime("%B"),
+ "used": len(conversation_user_ids),
+ "not_used": len(company_user_ids) - len(conversation_user_ids),
+ }
+ )
return Response(result[::-1], status=status.HTTP_200_OK)
-
+
+
class AdminAnalytics(APIView):
def get(self, request, format="json"):
- number_of_months = 3
+ number_of_months = 3
result = []
now = timezone.now()
@@ -472,37 +644,43 @@ class AdminAnalytics(APIView):
start_date = datetime.datetime(next_year, next_month, 1)
end_date = last_day_of_month(start_date)
- durations = [item.get_duration() for item in PromptMetric.objects.filter(created__gte=start_date, created__lte=end_date)]
+ durations = [
+ item.get_duration()
+ for item in PromptMetric.objects.filter(
+ created__gte=start_date, created__lte=end_date
+ )
+ ]
if len(durations) == 0:
- result.append({
- "month":start_date.strftime("%B"),
- "range":[0,0],
- "avg": 0,
- "median":0,
- })
+ result.append(
+ {
+ "month": start_date.strftime("%B"),
+ "range": [0, 0],
+ "avg": 0,
+ "median": 0,
+ }
+ )
continue
-
- average = sum(durations)/len(durations)
+
+ average = sum(durations) / len(durations)
min_value = min(durations)
max_value = max(durations)
durations.sort()
- median = durations[len(durations)//2]
- result.append({
- "month":start_date.strftime("%B"),
- "range":[min_value,max_value],
- "avg": average,
- "median":median,
- })
+ median = durations[len(durations) // 2]
+ result.append(
+ {
+ "month": start_date.strftime("%B"),
+ "range": [min_value, max_value],
+ "avg": average,
+ "median": median,
+ }
+ )
-
-
-
return Response(result[::-1], status=status.HTTP_200_OK)
-prompt = ChatPromptTemplate.from_messages([
- ("system", "You are a helpful assistant."),
- ("user", "{input}")
-])
+
+prompt = ChatPromptTemplate.from_messages(
+ [("system", "You are a helpful assistant."), ("user", "{input}")]
+)
llm = OllamaLLM(model=MODEL_NAME)
@@ -510,19 +688,11 @@ llm = OllamaLLM(model=MODEL_NAME)
# # Chain
# chain = prompt | llm.with_config({"run_name": "model"}) | output_parser.with_config({"run_name": "Assistant"})
-@database_sync_to_async
-def create_conversation(prompt, email):
- # return the conversation id
-
- response = llm.invoke("Summarise the phrase in one to for words\"%s\"" % prompt)
- print(f"Response: {response}")
- print(dir(response))
- title = response.replace("\"","")
- title = " ".join(title.split(" ")[:4])
-
-
- conversation = Conversation.objects.create(title = title)
+@database_sync_to_async
+def create_conversation(prompt, email, title):
+ # return the conversation id
+ conversation = Conversation.objects.create(title=title)
conversation.save()
user = CustomUser.objects.get(email=email)
@@ -530,6 +700,12 @@ def create_conversation(prompt, email):
conversation.save()
return conversation.id
+
+@database_sync_to_async
+def get_workspace(conversation_id):
+ conversation = Conversation.objects.get(id=conversation_id)
+ return DocumentWorkspace.objects.get(company=conversation.user.company)
+
@database_sync_to_async
def get_messages(conversation_id, prompt, file_string: str = None, file_type: str = ""):
messages = []
@@ -542,7 +718,7 @@ def get_messages(conversation_id, prompt, file_string: str = None, file_type: st
data={
"message": prompt,
"user_created": True,
- "created": datetime.now(),
+ "created": timezone.now(),
}
)
if serializer.is_valid(raise_exception=True):
@@ -550,42 +726,46 @@ def get_messages(conversation_id, prompt, file_string: str = None, file_type: st
prompt_instance.conversation_id = conversation.id
prompt_instance.save()
if file_string:
- file_name = f"prompt_{prompt_instance.id}_data.{file_type}"
+ file_name = f"prompt_{prompt_instance.id}_data.{file_type}"
f = ContentFile(file_string, name=file_name)
prompt_instance.file.save(file_name, f)
prompt_instance.file_type = file_type
prompt_instance.save()
-
-
for prompt_obj in Prompt.objects.filter(conversation__id=conversation_id):
- messages.append({
- 'content': prompt_obj.message,
- 'role': 'user' if prompt_obj.user_created else 'assistant',
- 'has_file': prompt_obj.file_exists(),
- 'file': prompt_obj.file if prompt_obj.file_exists() else None,
- 'file_type': prompt_obj.file_type if prompt_obj.file_exists() else None,
- })
+ messages.append(
+ {
+ "content": prompt_obj.message,
+ "role": "user" if prompt_obj.user_created else "assistant",
+ "has_file": prompt_obj.file_exists(),
+ "file": prompt_obj.file if prompt_obj.file_exists() else None,
+ "file_type": prompt_obj.file_type if prompt_obj.file_exists() else None,
+ }
+ )
# now transform the messages
transformed_messages = []
for message in messages:
-
- if message['has_file'] and message['file_type'] != None:
- if 'csv' in message['file_type']:
- file_type = 'csv'
+
+ if message["has_file"] and message["file_type"] != None:
+ if "csv" in message["file_type"]:
+ file_type = "csv"
altered_message = f"{message['content']}\n The file type is csv and the file contents are: {message['file'].read()}"
- elif 'xlsx' in message['file_type']:
- file_type = 'xlsx'
- df = pd.read_excel(message['file'].read())
+ elif "xlsx" in message["file_type"]:
+ file_type = "xlsx"
+ df = pd.read_excel(message["file"].read())
altered_message = f"{message['content']}\n The file type is xlsx and the file contents are: {df}"
- elif 'txt' in message['file_type']:
- file_type = 'txt'
+ elif "txt" in message["file_type"]:
+ file_type = "txt"
altered_message = f"{message['content']}\n The file type is csv and the file contents are: {message['file'].read()}"
else:
- altered_message = message['content']
-
- transformed_message = SystemMessage(content=altered_message) if message['role'] == 'assistant' else HumanMessage(content=altered_message)
+ altered_message = message["content"]
+
+ transformed_message = (
+ SystemMessage(content=altered_message)
+ if message["role"] == "assistant"
+ else HumanMessage(content=altered_message)
+ )
transformed_messages.append(transformed_message)
return transformed_messages, prompt_instance
@@ -600,7 +780,7 @@ def save_generated_message(conversation_id, message):
data={
"message": message,
"user_created": False,
- "created": datetime.now(),
+ "created": timezone.now(),
}
)
if serializer.is_valid():
@@ -608,34 +788,52 @@ def save_generated_message(conversation_id, message):
prompt_instance.conversation_id = conversation.id
prompt_instance = serializer.save()
+
@database_sync_to_async
-def create_prompt_metric(prompt_id, prompt, has_file, file_type, model_name, conversation_id):
+def create_prompt_metric(
+ prompt_id, prompt, has_file, file_type, model_name, conversation_id
+):
prompt_metric = PromptMetric.objects.create(
prompt_id=prompt_id,
- start_time = timezone.now(),
- prompt_length = len(prompt),
- has_file = has_file,
- file_type = file_type,
+ start_time=timezone.now(),
+ prompt_length=len(prompt),
+ has_file=has_file,
+ file_type=file_type,
model_name=model_name,
conversation_id=conversation_id,
)
prompt_metric.save()
return prompt_metric
+
@database_sync_to_async
def update_prompt_metric(prompt_metric, status):
prompt_metric.event = status
prompt_metric.save()
+
@database_sync_to_async
def finish_prompt_metric(prompt_metric, response_length):
- print(f'finish_prompt_metric: {response_length}')
+ print(f"finish_prompt_metric: {response_length}")
prompt_metric.end_time = timezone.now()
prompt_metric.reponse_length = response_length
- prompt_metric.event = 'FINISHED'
- prompt_metric.save(update_fields=["end_time", "reponse_length","event"])
+ prompt_metric.event = "FINISHED"
+ prompt_metric.save(update_fields=["end_time", "reponse_length", "event"])
print("finish_prompt_metric saved")
+@database_sync_to_async
+def get_retriever(conversation_id):
+ print(f'getting workspace from conversation: {conversation_id}')
+ conversation = Conversation.objects.get(id=conversation_id)
+ print(f'Got conversation: {conversation}')
+ workspace = DocumentWorkspace.objects.get(company=conversation.user.company)
+ print(f'Got workspace: {conversation}')
+ vectorstore = Chroma(
+ persist_directory=f"./chroma_db/",
+ embedding=OllamaEmbeddings(model="llama3.2"),
+ )
+ return vectorstore.as_retriever()
+
class ChatConsumerAgain(AsyncWebsocketConsumer):
async def connect(self):
await self.accept()
@@ -648,47 +846,74 @@ class ChatConsumerAgain(AsyncWebsocketConsumer):
print(f"Bytes Data: {bytes_data}")
if text_data:
data = json.loads(text_data)
- message = data.get('message',None)
- conversation_id = data.get('conversation_id',None)
+ message = data.get("message", None)
+ conversation_id = data.get("conversation_id", None)
email = data.get("email", None)
file = data.get("file", None)
file_type = data.get("fileType", "")
+ model = data.get("modelName", "Turbo")
if not conversation_id:
# we need to create a new conversation
# we will generate a name for it too
- conversation_id = await create_conversation(message, email)
-
+ title = await title_generator.generate_async(message)
+ conversation_id = await create_conversation(message, email, title)
+
if conversation_id:
decoded_file = None
if file:
decoded_file = base64.b64decode(file)
print(decoded_file)
- if 'csv' in file_type:
- file_type = 'csv'
+ if "csv" in file_type:
+ file_type = "csv"
altered_message = f"{message}\n The file type is csv and the file contents are: {decoded_file}"
- elif 'xmlformats-officedocument' in file_type:
- file_type = 'xlsx'
+ elif "xmlformats-officedocument" in file_type:
+ file_type = "xlsx"
df = pd.read_excel(decoded_file)
altered_message = f"{message}\n The file type is xlsx and the file contents are: {df}"
- elif 'text' in file_type:
- file_type = 'txt'
+ elif "text" in file_type:
+ file_type = "txt"
altered_message = f"{message}\n The file type is txt and the file contents are: {decoded_file}"
else:
- file_type = 'Not Sure'
-
-
-
-
+ file_type = "Not Sure"
+
print(f'received: "{message}" for conversation {conversation_id}')
+
+ # check the moderation here
+ if await moderation_classifier.classify_async(message) == ModerationLabel.NSFW:
+ response = "Prompt has been marked as NSFW. If this is in error, submit a feedback with the prompt text."
+ print("this prompt has been marked as NSFW")
+ await self.send("CONVERSATION_ID")
+ await self.send(str(conversation_id))
+ await self.send("START_OF_THE_STREAM_ENDER_GAME_42")
+ await self.send(response)
+ await self.send("END_OF_THE_STREAM_ENDER_GAME_42")
+ await save_generated_message(conversation_id, response)
+ return
+
# TODO: add the message to the database
# get the new conversation
# TODO: get the messages here
+
+ messages, prompt = await get_messages(
+ conversation_id, message, decoded_file, file_type
+ )
+
+ prompt_type = await prompt_classifier.classify_async(message)
+ print(f"prompt_type: {prompt_type} for {message}")
+
+
- messages, prompt = await get_messages(conversation_id, message, decoded_file, file_type)
-
- prompt_metric = await create_prompt_metric(prompt.id, prompt.message, True if file else False, file_type, MODEL_NAME, conversation_id)
+
+ prompt_metric = await create_prompt_metric(
+ prompt.id,
+ prompt.message,
+ True if file else False,
+ file_type,
+ MODEL_NAME,
+ conversation_id,
+ )
if file:
# udpate with the altered_message
messages = messages[:-1] + [HumanMessage(content=altered_message)]
@@ -698,17 +923,117 @@ class ChatConsumerAgain(AsyncWebsocketConsumer):
# stream the response back
response = ""
# start of the message
- await self.send('CONVERSATION_ID')
+ await self.send("CONVERSATION_ID")
await self.send(str(conversation_id))
- await self.send('START_OF_THE_STREAM_ENDER_GAME_42')
- async for chunk in llm.astream(messages):
- print(f"chunk: {chunk}")
- response += chunk
- await self.send(chunk)
- await self.send('END_OF_THE_STREAM_ENDER_GAME_42')
+ await self.send("START_OF_THE_STREAM_ENDER_GAME_42")
+ if prompt_type == PromptType.RAG:
+ service = AsyncRAGService()
+ #await service.ingest_documents()
+ workspace = await get_workspace(conversation_id)
+ print('Time to get the rag response')
+
+ async for chunk in service.generate_response(messages, prompt.message, workspace):
+ print(f"chunk: {chunk}")
+ response += chunk
+ await self.send(chunk)
+ elif prompt_type == PromptType.IMAGE_GENERATION:
+ response = "Image Generation is not supported at this time, but it will be soon."
+ await self.send(response)
+
+ else:
+ service = AsyncLLMService()
+ async for chunk in service.generate_response(messages, prompt.message):
+ print(f"chunk: {chunk}")
+ response += chunk
+ await self.send(chunk)
+ await self.send("END_OF_THE_STREAM_ENDER_GAME_42")
await save_generated_message(conversation_id, response)
await finish_prompt_metric(prompt_metric, len(response))
-
+
if bytes_data:
print("we have byte data")
+
+# Document Views
+class DocumentWorkspaceView(APIView):
+ #permission_classes = [permissions.IsAuthenticated]
+
+ def get(self, request):
+ workspaces = DocumentWorkspace.objects.filter(company=request.user.company)
+ serializer = DocumentWorkspaceSerializer(workspaces, many=True)
+ return Response(serializer.data)
+
+ def post(self, request):
+ serializer = DocumentWorkspaceSerializer(data=request.data)
+ if serializer.is_valid():
+ serializer.save(company=request.user.company)
+ return Response(serializer.data, status=status.HTTP_201_CREATED)
+ return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST)
+
+class DocumentUploadView(APIView):
+ #permission_classes = [permissions.IsAuthenticated]Z
+
+ def get(self, request):
+ print(f'request_3: {request}')
+ try:
+ workspace = DocumentWorkspace.objects.get(company=request.user.company)
+ serializer = DocumentSerializer(Document.objects.filter(workspace=workspace), many=True)
+ return Response(serializer.data, status=status.HTTP_200_OK)
+
+ except:
+ return Response({'error': "Workspace not found"}, status=status.HTTP_404_NOT_FOUND)
+
+ def post(self, request):
+ print(f'request: {request}')
+
+ try:
+ workspace = DocumentWorkspace.objects.get(company=request.user.company)
+
+ except:
+ return Response({'error': "Workspace not found"}, status=status.HTTP_404_NOT_FOUND)
+
+ print(request.FILES)
+ file = request.FILES.get('file')
+ if not file:
+ return Response({"error":"No file provided"}, status=status.HTTP_400_BAD_REQUEST)
+
+ print("have the workspace and the file")
+
+ document = Document.objects.create(
+ workspace=workspace,
+ file=file
+ )
+
+ # process the document inthe background
+ self.process_document(document)
+
+ serializer = DocumentSerializer(document)
+ return Response(serializer.data, status=status.HTTP_201_CREATED)
+
+
+ def process_document(self, document):
+ file_path = os.path.join(settings.MEDIA_ROOT, document.file.name)
+
+ document.processed = True
+ document.active = True
+ document.save()
+ service = AsyncRAGService()
+ service.add_files_to_store([(file_path, document.file.name, document.workspace_id)], workspace_id=document.workspace_id)
+
+class DocumentDetailView(APIView):
+ #permission_classes = [permissions.IsAuthenticated]
+
+ def get(self, request, document_id):
+ print(f'request: {request}')
+ try:
+ workspace = DocumentWorkspace.objects.get(company=request.user.company)
+
+ document = Document.objects.get(
+ workspace=workspace,
+ id=document_id
+ )
+ except:
+ return Response({'error': "Document not found"}, status=status.HTTP_404_NOT_FOUND)
+
+ serializer = DocumentWorkspaceSerializer(workspaces, many=True)
+ return Response(serializer.data)
\ No newline at end of file
diff --git a/llm_be/llm_be/asgi.py b/llm_be/llm_be/asgi.py
index f8be107..cc9b0a1 100644
--- a/llm_be/llm_be/asgi.py
+++ b/llm_be/llm_be/asgi.py
@@ -13,13 +13,14 @@ from django.core.asgi import get_asgi_application
from channels.routing import ProtocolTypeRouter, URLRouter
from channels.auth import AuthMiddlewareStack
import chat_backend.routing
-os.environ.setdefault('DJANGO_SETTINGS_MODULE', 'llm_be.settings')
-application = ProtocolTypeRouter({
- "http": get_asgi_application(),
- "websocket": AuthMiddlewareStack(
- URLRouter(
- chat_backend.routing.websocket_urlpatterns
- )
- ),
- })
+os.environ.setdefault("DJANGO_SETTINGS_MODULE", "llm_be.settings")
+
+application = ProtocolTypeRouter(
+ {
+ "http": get_asgi_application(),
+ "websocket": AuthMiddlewareStack(
+ URLRouter(chat_backend.routing.websocket_urlpatterns)
+ ),
+ }
+)
diff --git a/llm_be/llm_be/settings.py b/llm_be/llm_be/settings.py
index ab0118b..47b6d6d 100644
--- a/llm_be/llm_be/settings.py
+++ b/llm_be/llm_be/settings.py
@@ -22,80 +22,86 @@ BASE_DIR = Path(__file__).resolve().parent.parent
# See https://docs.djangoproject.com/en/3.2/howto/deployment/checklist/
# SECURITY WARNING: keep the secret key used in production secret!
-SECRET_KEY = 'django-insecure-6suk6fj5q2)1tj%)f(wgw1smnliv5-#&@zvgvj1wp#(#@h#31x'
+SECRET_KEY = "django-insecure-6suk6fj5q2)1tj%)f(wgw1smnliv5-#&@zvgvj1wp#(#@h#31x"
# SECURITY WARNING: don't run with debug turned on in production!
DEBUG = True
-
-ALLOWED_HOSTS = ['*.aimloperations.com','localhost','127.0.0.1','chat.aimloperations.com','chatbackend.aimloperations.com']
+CORS_ALLOW_CREDENTIALS = False
+ALLOWED_HOSTS = [
+ "*.aimloperations.com",
+ "localhost",
+ "127.0.0.1",
+ "localhost:3000",
+ "127.0.0.1:3000",
+ "chat.aimloperations.com",
+ "chatbackend.aimloperations.com",
+]
CORS_ORIGIN_ALLOW_ALL = True
+CSRF_TRUSTED_ORIGINS = ["http://localhost", "http://127.0.0.1", "http://localhost:3000"]
# Application definition
INSTALLED_APPS = [
- 'daphne',
- 'django.contrib.admin',
- 'django.contrib.auth',
- 'django.contrib.contenttypes',
- 'django.contrib.sessions',
- 'django.contrib.messages',
- 'django.contrib.staticfiles',
- 'chat_backend',
- 'rest_framework',
- 'corsheaders',
- 'rest_framework_simplejwt.token_blacklist',
-
+ "daphne",
+ "django.contrib.admin",
+ "django.contrib.auth",
+ "django.contrib.contenttypes",
+ "django.contrib.sessions",
+ "django.contrib.messages",
+ "django.contrib.staticfiles",
+ "chat_backend",
+ "rest_framework",
+ "corsheaders",
+ "rest_framework_simplejwt.token_blacklist",
]
MIDDLEWARE = [
- 'django.middleware.security.SecurityMiddleware',
- 'django.contrib.sessions.middleware.SessionMiddleware',
- 'django.middleware.common.CommonMiddleware',
- 'django.middleware.csrf.CsrfViewMiddleware',
- 'django.contrib.auth.middleware.AuthenticationMiddleware',
- 'django.contrib.messages.middleware.MessageMiddleware',
- 'django.middleware.clickjacking.XFrameOptionsMiddleware',
+ "django.middleware.security.SecurityMiddleware",
+ "django.contrib.sessions.middleware.SessionMiddleware",
+ "django.middleware.common.CommonMiddleware",
+ "django.middleware.csrf.CsrfViewMiddleware",
+ "django.contrib.auth.middleware.AuthenticationMiddleware",
+ "django.contrib.messages.middleware.MessageMiddleware",
+ "django.middleware.clickjacking.XFrameOptionsMiddleware",
"corsheaders.middleware.CorsMiddleware",
"django.middleware.common.CommonMiddleware",
]
-ROOT_URLCONF = 'llm_be.urls'
+ROOT_URLCONF = "llm_be.urls"
# SETTINGS_PATH = os.path.dirname(os.path.dirname(__file__))
# TEMPLATE_DIRS = (
# os.path.join(SETTINGS_PATH, 'templates'),
# )
-print(os.path.join(BASE_DIR, 'templates'))
-
TEMPLATES = [
{
- 'BACKEND': 'django.template.backends.django.DjangoTemplates',
- 'DIRS': [os.path.join(BASE_DIR, 'templates')],
- 'APP_DIRS': True,
- 'OPTIONS': {
- 'context_processors': [
- 'django.template.context_processors.debug',
- 'django.template.context_processors.request',
- 'django.contrib.auth.context_processors.auth',
- 'django.contrib.messages.context_processors.messages',
+ "BACKEND": "django.template.backends.django.DjangoTemplates",
+ "DIRS": [os.path.join(BASE_DIR, "templates")],
+ "APP_DIRS": True,
+ "OPTIONS": {
+ "context_processors": [
+ "django.template.context_processors.debug",
+ "django.template.context_processors.request",
+ "django.contrib.auth.context_processors.auth",
+ "django.contrib.messages.context_processors.messages",
],
},
},
]
-WSGI_APPLICATION = 'llm_be.wsgi.application'
-ASGI_APPLICATION = 'llm_be.asgi.application'
+WSGI_APPLICATION = "llm_be.wsgi.application"
+ASGI_APPLICATION = "llm_be.asgi.application"
# Database
# https://docs.djangoproject.com/en/3.2/ref/settings/#databases
DATABASES = {
- 'default': {
- 'ENGINE': 'django.db.backends.sqlite3',
- 'NAME': BASE_DIR / 'db.sqlite3',
+ "default": {
+ "ENGINE": "django.db.backends.sqlite3",
+ "NAME": BASE_DIR / "db.sqlite3",
}
}
@@ -105,28 +111,26 @@ DATABASES = {
AUTH_PASSWORD_VALIDATORS = [
{
- 'NAME': 'django.contrib.auth.password_validation.UserAttributeSimilarityValidator',
+ "NAME": "django.contrib.auth.password_validation.UserAttributeSimilarityValidator",
},
{
- 'NAME': 'django.contrib.auth.password_validation.MinimumLengthValidator',
+ "NAME": "django.contrib.auth.password_validation.MinimumLengthValidator",
},
{
- 'NAME': 'django.contrib.auth.password_validation.CommonPasswordValidator',
+ "NAME": "django.contrib.auth.password_validation.CommonPasswordValidator",
},
{
- 'NAME': 'django.contrib.auth.password_validation.NumericPasswordValidator',
+ "NAME": "django.contrib.auth.password_validation.NumericPasswordValidator",
},
]
-
-
# Internationalization
# https://docs.djangoproject.com/en/3.2/topics/i18n/
-LANGUAGE_CODE = 'en-us'
+LANGUAGE_CODE = "en-us"
-TIME_ZONE = 'UTC'
+TIME_ZONE = "UTC"
USE_I18N = True
@@ -138,39 +142,37 @@ USE_TZ = True
# Static files (CSS, JavaScript, Images)
# https://docs.djangoproject.com/en/3.2/howto/static-files/
-STATIC_URL = '/static/'
+STATIC_URL = "/static/"
# Default primary key field type
# https://docs.djangoproject.com/en/3.2/ref/settings/#default-auto-field
-DEFAULT_AUTO_FIELD = 'django.db.models.BigAutoField'
+DEFAULT_AUTO_FIELD = "django.db.models.BigAutoField"
# custom user model
-AUTH_USER_MODEL = 'chat_backend.CustomUser'
+AUTH_USER_MODEL = "chat_backend.CustomUser"
# rest framework jwt stuff
REST_FRAMEWORK = {
- 'DEFAULT_PERMISSION_CLASSES': (
- 'rest_framework.permissions.IsAuthenticated',
- ),
- 'DEFAULT_AUTHENTICATION_CLASSES': (
-'rest_framework_simplejwt.authentication.JWTAuthentication',
- ), #
+ "DEFAULT_PERMISSION_CLASSES": ("rest_framework.permissions.IsAuthenticated",),
+ "DEFAULT_AUTHENTICATION_CLASSES": (
+ "rest_framework_simplejwt.authentication.JWTAuthentication",
+ ), #
}
SIMPLE_JWT = {
- 'ACCESS_TOKEN_LIFETIME':timedelta(hours=5),
- 'REFRESH_TOKEN_LIFETIME':timedelta(days=14),
- 'ROTATE_REFRESH_TOKENS':True,
- 'BLACKLIST_AFTER_ROTATION':True,
- 'ALGORITHM':"HS256",
- "SIGNING_KEY":SECRET_KEY,
- 'VERIFYING_KEY':None,
- "AUTH_HEADER_TYPES":('JWT',),
- 'USER_ID_FIELD':'id',
- 'USER_ID_CLAIM':'user_id',
- 'AUTH_TOKEN_CLASSES':('rest_framework_simplejwt.tokens.AccessToken',),
- 'TOKEN_TYPE_CLAIM':'token_type',
+ "ACCESS_TOKEN_LIFETIME": timedelta(hours=24),
+ "REFRESH_TOKEN_LIFETIME": timedelta(days=14),
+ "ROTATE_REFRESH_TOKENS": True,
+ "BLACKLIST_AFTER_ROTATION": True,
+ "ALGORITHM": "HS256",
+ "SIGNING_KEY": SECRET_KEY,
+ "VERIFYING_KEY": None,
+ "AUTH_HEADER_TYPES": ("JWT",),
+ "USER_ID_FIELD": "id",
+ "USER_ID_CLAIM": "user_id",
+ "AUTH_TOKEN_CLASSES": ("rest_framework_simplejwt.tokens.AccessToken",),
+ "TOKEN_TYPE_CLAIM": "token_type",
}
# CORS settings
@@ -181,8 +183,8 @@ CORS_ALLOWED_ORIGINS = [
# channel settings
CHANNEL_LAYERS = {
- 'default': {
- 'BACKEND': 'channels.layers.InMemoryChannelLayer',
+ "default": {
+ "BACKEND": "channels.layers.InMemoryChannelLayer",
},
}
@@ -198,8 +200,11 @@ CHANNEL_LAYERS = {
# EMAIL_TIMEOUT = os.getenv("APP_EMAIL_TIMEOUT", 60)
# SMTP2GO
-EMAIL_HOST = 'mail.smtp2go.com'
-EMAIL_HOST_USER = 'info.aimloperations.com'
-EMAIL_HOST_PASSWORD = 'ZDErIII2sipNNVMz'
+EMAIL_HOST = "mail.smtp2go.com"
+EMAIL_HOST_USER = "info.aimloperations.com"
+EMAIL_HOST_PASSWORD = "ZDErIII2sipNNVMz"
EMAIL_PORT = 2525
-EMAIL_USE_TLS = True
\ No newline at end of file
+EMAIL_USE_TLS = True
+
+# Captcha
+CAPTCHA_SECRET_KEY = "6LfENu4qAAAAABdrj6JTviq-LfdPP5imhE-Os7h9"
\ No newline at end of file
diff --git a/llm_be/llm_be/urls.py b/llm_be/llm_be/urls.py
index 16d76aa..ba910f3 100644
--- a/llm_be/llm_be/urls.py
+++ b/llm_be/llm_be/urls.py
@@ -13,12 +13,17 @@ Including another URLconf
1. Import the include() function: from django.urls import include, path
2. Add a URL to urlpatterns: path('blog/', include('blog.urls'))
"""
+
from django.contrib import admin
from django.urls import path, include
from django.conf import settings
from django.conf.urls.static import static
-urlpatterns = [
- path('admin/', admin.site.urls),
- path('api/', include('chat_backend.urls')),
-] + static(settings.STATIC_URL, document_root=settings.STATIC_ROOT) + static(settings.MEDIA_URL, document_root=settings.MEDIA_ROOT)
+urlpatterns = (
+ [
+ path("admin/", admin.site.urls),
+ path("api/", include("chat_backend.urls")),
+ ]
+ + static(settings.STATIC_URL, document_root=settings.STATIC_ROOT)
+ + static(settings.MEDIA_URL, document_root=settings.MEDIA_ROOT)
+)
diff --git a/llm_be/llm_be/wsgi.py b/llm_be/llm_be/wsgi.py
index 1edc130..35b7dc6 100644
--- a/llm_be/llm_be/wsgi.py
+++ b/llm_be/llm_be/wsgi.py
@@ -11,6 +11,6 @@ import os
from django.core.wsgi import get_wsgi_application
-os.environ.setdefault('DJANGO_SETTINGS_MODULE', 'llm_be.settings')
+os.environ.setdefault("DJANGO_SETTINGS_MODULE", "llm_be.settings")
application = get_wsgi_application()
diff --git a/llm_be/manage.py b/llm_be/manage.py
index 96e59ea..05354a5 100755
--- a/llm_be/manage.py
+++ b/llm_be/manage.py
@@ -6,7 +6,7 @@ import sys
def main():
"""Run administrative tasks."""
- os.environ.setdefault('DJANGO_SETTINGS_MODULE', 'llm_be.settings')
+ os.environ.setdefault("DJANGO_SETTINGS_MODULE", "llm_be.settings")
try:
from django.core.management import execute_from_command_line
except ImportError as exc:
@@ -18,5 +18,5 @@ def main():
execute_from_command_line(sys.argv)
-if __name__ == '__main__':
+if __name__ == "__main__":
main()
diff --git a/requirements.dev b/requirements.dev
index 60d210d..0029b6c 100644
--- a/requirements.dev
+++ b/requirements.dev
@@ -30,16 +30,16 @@ djangorestframework-simplejwt==5.3.1
duckdb==1.1.3
et_xmlfile==2.0.0
exceptiongroup==1.2.2
-Faker==33.1.0
+Faker
filelock==3.16.1
fonttools==4.55.3
frozenlist==1.5.0
fsspec==2024.12.0
greenlet==3.1.1
h11==0.14.0
-httpcore==1.0.7
-httpx==0.27.2
-httpx-sse==0.4.0
+httpcore
+httpx
+httpx-sse
hyperlink==21.0.0
idna==3.10
importlib_resources==6.4.5
@@ -48,14 +48,14 @@ Jinja2==3.1.5
jiter==0.8.2
jsonpatch==1.33
jsonpointer==3.0.0
-kiwisolver==1.4.7
-langchain==0.3.13
-langchain-community==0.3.13
-langchain-core==0.3.28
-langchain-ollama==0.2.2
-langchain-openai==0.2.14
-langchain-text-splitters==0.3.4
-langsmith==0.2.7
+kiwisolver
+langchain
+langchain-community
+langchain-core
+langchain-ollama
+langchain-openai
+langchain-text-splitters
+langsmith
lxml==5.3.0
MarkupSafe==3.0.2
marshmallow==3.23.2
@@ -77,14 +77,14 @@ nvidia-cusparse-cu12==12.3.1.170
nvidia-nccl-cu12==2.21.5
nvidia-nvjitlink-cu12==12.4.127
nvidia-nvtx-cu12==12.4.127
-ollama==0.4.5
-ollama-python==0.1.2
-openai==1.58.1
+ollama
+ollama-python
+openai
openpyxl==3.1.5
orjson==3.10.13
packaging==24.2
pandas==2.2.3
-pandasai==2.4.1
+pandasai
pathspec==0.12.1
pillow==11.0.0
platformdirs==4.3.6
diff --git a/requirements.txt b/requirements.txt
new file mode 100644
index 0000000..a209ddd
--- /dev/null
+++ b/requirements.txt
@@ -0,0 +1,208 @@
+aiofiles==24.1.0
+aiohappyeyeballs==2.6.1
+aiohttp==3.11.18
+aiosignal==1.3.2
+annotated-types==0.7.0
+anyio==4.8.0
+asgiref==3.8.1
+astor==0.8.1
+attrs==25.1.0
+autobahn==24.4.2
+Automat==24.8.1
+backoff==2.2.1
+bcrypt==4.3.0
+beautifulsoup4==4.13.4
+black==25.1.0
+build==1.2.2.post1
+cachetools==5.5.2
+certifi==2025.1.31
+cffi==1.17.1
+channels==4.2.0
+chardet==5.2.0
+charset-normalizer==3.4.1
+chroma-hnswlib==0.7.6
+chromadb==1.0.7
+click==8.1.8
+coloredlogs==15.0.1
+constantly==23.10.4
+contourpy==1.3.1
+cryptography==44.0.2
+cycler==0.12.1
+daphne==4.1.2
+dataclasses-json==0.6.7
+Deprecated==1.2.18
+distro==1.9.0
+Django==5.1.7
+django-autoslug==1.9.9
+django-cors-headers==4.7.0
+django-filter==25.1
+djangorestframework==3.15.2
+djangorestframework_simplejwt==5.5.0
+duckdb==1.2.1
+durationpy==0.9
+emoji==2.14.1
+eval_type_backport==0.2.2
+Faker==37.0.0
+fastapi==0.115.9
+filelock==3.17.0
+filetype==1.2.0
+flatbuffers==25.2.10
+fonttools==4.56.0
+frozenlist==1.6.0
+fsspec==2025.2.0
+google-auth==2.39.0
+googleapis-common-protos==1.70.0
+greenlet==3.1.1
+grpcio==1.71.0
+h11==0.14.0
+html5lib==1.1
+httpcore==1.0.7
+httptools==0.6.4
+httpx==0.28.1
+httpx-sse==0.4.0
+huggingface-hub==0.30.2
+humanfriendly==10.0
+hyperlink==21.0.0
+idna==3.10
+importlib_metadata==8.6.1
+importlib_resources==6.5.2
+incremental==24.7.2
+Jinja2==3.1.6
+jiter==0.8.2
+joblib==1.4.2
+jsonpatch==1.33
+jsonpointer==3.0.0
+jsonschema==4.23.0
+jsonschema-specifications==2025.4.1
+kiwisolver==1.4.8
+kubernetes==32.0.1
+langchain==0.3.24
+langchain-community==0.3.23
+langchain-core==0.3.56
+langchain-ollama==0.2.3
+langchain-text-splitters==0.3.8
+langdetect==1.0.9
+langsmith==0.3.13
+lxml==5.4.0
+Markdown==3.7
+markdown-it-py==3.0.0
+MarkupSafe==3.0.2
+marshmallow==3.26.1
+matplotlib==3.10.1
+mdurl==0.1.2
+mmh3==5.1.0
+mpmath==1.3.0
+multidict==6.4.3
+mypy-extensions==1.0.0
+nest-asyncio==1.6.0
+networkx==3.4.2
+nltk==3.9.1
+numpy==2.2.3
+nvidia-cublas-cu12==12.4.5.8
+nvidia-cuda-cupti-cu12==12.4.127
+nvidia-cuda-nvrtc-cu12==12.4.127
+nvidia-cuda-runtime-cu12==12.4.127
+nvidia-cudnn-cu12==9.1.0.70
+nvidia-cufft-cu12==11.2.1.3
+nvidia-curand-cu12==10.3.5.147
+nvidia-cusolver-cu12==11.6.1.9
+nvidia-cusparse-cu12==12.3.1.170
+nvidia-cusparselt-cu12==0.6.2
+nvidia-nccl-cu12==2.21.5
+nvidia-nvjitlink-cu12==12.4.127
+nvidia-nvtx-cu12==12.4.127
+oauthlib==3.2.2
+olefile==0.47
+ollama==0.4.7
+onnxruntime==1.21.1
+openai==1.65.4
+opentelemetry-api==1.32.1
+opentelemetry-exporter-otlp-proto-common==1.32.1
+opentelemetry-exporter-otlp-proto-grpc==1.32.1
+opentelemetry-instrumentation==0.53b1
+opentelemetry-instrumentation-asgi==0.53b1
+opentelemetry-instrumentation-fastapi==0.53b1
+opentelemetry-proto==1.32.1
+opentelemetry-sdk==1.32.1
+opentelemetry-semantic-conventions==0.53b1
+opentelemetry-util-http==0.53b1
+orjson==3.10.15
+overrides==7.7.0
+packaging==24.2
+pandas==2.2.3
+pandasai==2.4.2
+pathspec==0.12.1
+pillow==11.1.0
+platformdirs==4.3.6
+posthog==4.0.1
+propcache==0.3.1
+protobuf==5.29.4
+psutil==7.0.0
+pyasn1==0.6.1
+pyasn1_modules==0.4.1
+pycparser==2.22
+pydantic==2.11.4
+pydantic-settings==2.9.1
+pydantic_core==2.33.2
+Pygments==2.19.1
+PyJWT==2.10.1
+pyOpenSSL==25.0.0
+pyparsing==3.2.1
+pypdf==5.4.0
+PyPika==0.48.9
+pyproject_hooks==1.2.0
+python-dateutil==2.9.0.post0
+python-dotenv==1.0.1
+python-iso639==2025.2.18
+python-magic==0.4.27
+python-oxmsg==0.0.2
+pytz==2025.1
+PyYAML==6.0.2
+RapidFuzz==3.13.0
+referencing==0.36.2
+regex==2024.11.6
+requests==2.32.3
+requests-oauthlib==2.0.0
+requests-toolbelt==1.0.0
+rich==14.0.0
+rpds-py==0.24.0
+rsa==4.9.1
+scipy==1.15.2
+service-identity==24.2.0
+setuptools==75.8.2
+shellingham==1.5.4
+six==1.17.0
+sniffio==1.3.1
+soupsieve==2.7
+SQLAlchemy==2.0.38
+sqlglot==26.9.0
+sqlglotrs==0.4.0
+sqlparse==0.5.3
+starlette==0.45.3
+sympy==1.13.1
+tenacity==9.0.0
+tokenizers==0.21.1
+torch==2.6.0
+tqdm==4.67.1
+triton==3.2.0
+Twisted==24.11.0
+txaio==23.1.1
+typer==0.15.3
+typing-inspect==0.9.0
+typing-inspection==0.4.0
+typing_extensions==4.12.2
+tzdata==2025.1
+unstructured==0.17.2
+unstructured-client==0.34.0
+urllib3==2.3.0
+uvicorn==0.34.2
+uvloop==0.21.0
+watchfiles==1.0.5
+webencodings==0.5.1
+websocket-client==1.8.0
+websockets==15.0.1
+wrapt==1.17.2
+yarl==1.20.0
+zipp==3.21.0
+zope.interface==7.2
+zstandard==0.23.0
diff --git a/strip_and_upgrade.py b/strip_and_upgrade.py
new file mode 100644
index 0000000..a700b81
--- /dev/null
+++ b/strip_and_upgrade.py
@@ -0,0 +1,9 @@
+outfile = open("requirements.txt",'w')
+for line in open('requirements.dev','r'):
+ line = line.strip()
+ if line:
+ values = line.split('==')
+ print(values[0])
+ outfile.write(values[0] + '\n')
+
+