RAG implementation, content moderation, prompt classification, new LLM chain, document storage
This commit is contained in:
3
.gitignore
vendored
3
.gitignore
vendored
@@ -167,4 +167,5 @@ cython_debug/
|
|||||||
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
||||||
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
||||||
#.idea/
|
#.idea/
|
||||||
|
chroma_db/
|
||||||
|
documents/
|
||||||
|
|||||||
@@ -1,5 +1,16 @@
|
|||||||
from django.contrib import admin
|
from django.contrib import admin
|
||||||
from .models import CustomUser, Announcement, Company, LLMModels, Conversation, Prompt, Feedback, PromptMetric
|
from .models import (
|
||||||
|
CustomUser,
|
||||||
|
Announcement,
|
||||||
|
Company,
|
||||||
|
LLMModels,
|
||||||
|
Conversation,
|
||||||
|
Prompt,
|
||||||
|
Feedback,
|
||||||
|
PromptMetric,
|
||||||
|
DocumentWorkspace,
|
||||||
|
Document
|
||||||
|
)
|
||||||
|
|
||||||
# Register your models here.
|
# Register your models here.
|
||||||
|
|
||||||
@@ -27,16 +38,16 @@ class CustomUserAdmin(admin.ModelAdmin):
|
|||||||
"has_signed_tos",
|
"has_signed_tos",
|
||||||
"last_login",
|
"last_login",
|
||||||
"slug",
|
"slug",
|
||||||
"get_set_password_url"
|
"get_set_password_url",
|
||||||
)
|
)
|
||||||
search_fields = ("fields", "username", "first_name", "last_name", "slug")
|
search_fields = ("fields", "username", "first_name", "last_name", "slug")
|
||||||
|
|
||||||
|
|
||||||
class FeedbackAdmin(admin.ModelAdmin):
|
class FeedbackAdmin(admin.ModelAdmin):
|
||||||
model = Feedback
|
model = Feedback
|
||||||
search_fields = ("status", "text", "get_user_email")
|
search_fields = ("status", "text", "get_user_email")
|
||||||
list_display= (
|
list_display = ("status", "get_user_email", "title", "category")
|
||||||
"status", "get_user_email", "title", "category"
|
|
||||||
)
|
|
||||||
|
|
||||||
class LLMModelsAdmin(admin.ModelAdmin):
|
class LLMModelsAdmin(admin.ModelAdmin):
|
||||||
model = LLMModels
|
model = LLMModels
|
||||||
@@ -46,7 +57,7 @@ class LLMModelsAdmin(admin.ModelAdmin):
|
|||||||
|
|
||||||
class ConversationAdmin(admin.ModelAdmin):
|
class ConversationAdmin(admin.ModelAdmin):
|
||||||
model = Conversation
|
model = Conversation
|
||||||
list_display = ("title", "get_user_email","deleted")
|
list_display = ("title", "get_user_email", "deleted")
|
||||||
search_fields = ("title",)
|
search_fields = ("title",)
|
||||||
|
|
||||||
|
|
||||||
@@ -55,9 +66,35 @@ class PromptAdmin(admin.ModelAdmin):
|
|||||||
list_display = ("message", "user_created", "get_conversation_title")
|
list_display = ("message", "user_created", "get_conversation_title")
|
||||||
search_fields = ("message",)
|
search_fields = ("message",)
|
||||||
|
|
||||||
|
|
||||||
class PromptMetricAdmin(admin.ModelAdmin):
|
class PromptMetricAdmin(admin.ModelAdmin):
|
||||||
model = PromptMetric
|
model = PromptMetric
|
||||||
list_display = ("event", "model_name", "prompt_length","reponse_length",'has_file','file_type', "get_duration")
|
list_display = (
|
||||||
|
"event",
|
||||||
|
"model_name",
|
||||||
|
"prompt_length",
|
||||||
|
"reponse_length",
|
||||||
|
"has_file",
|
||||||
|
"file_type",
|
||||||
|
"get_duration",
|
||||||
|
)
|
||||||
|
|
||||||
|
class DocumentWorkspaceAdmin(admin.ModelAdmin):
|
||||||
|
model = DocumentWorkspace
|
||||||
|
list_display = (
|
||||||
|
"name",
|
||||||
|
"company",
|
||||||
|
|
||||||
|
)
|
||||||
|
|
||||||
|
class DocumentAdmin(admin.ModelAdmin):
|
||||||
|
model = Document
|
||||||
|
list_display = (
|
||||||
|
"file",
|
||||||
|
"active",
|
||||||
|
"created",
|
||||||
|
"processed",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
admin.site.register(Announcement, AnnouncmentAdmin)
|
admin.site.register(Announcement, AnnouncmentAdmin)
|
||||||
@@ -69,3 +106,6 @@ admin.site.register(Conversation, ConversationAdmin)
|
|||||||
admin.site.register(Prompt, PromptAdmin)
|
admin.site.register(Prompt, PromptAdmin)
|
||||||
admin.site.register(PromptMetric, PromptMetricAdmin)
|
admin.site.register(PromptMetric, PromptMetricAdmin)
|
||||||
admin.site.register(Feedback, FeedbackAdmin)
|
admin.site.register(Feedback, FeedbackAdmin)
|
||||||
|
|
||||||
|
admin.site.register(DocumentWorkspace, DocumentWorkspaceAdmin)
|
||||||
|
admin.site.register(Document, DocumentAdmin)
|
||||||
|
|||||||
@@ -1,6 +1,31 @@
|
|||||||
from django.apps import AppConfig
|
from django.apps import AppConfig
|
||||||
|
from django.conf import settings
|
||||||
|
from django.db import OperationalError
|
||||||
|
|
||||||
|
|
||||||
class ChatBackendConfig(AppConfig):
|
class ChatBackendConfig(AppConfig):
|
||||||
default_auto_field = "django.db.models.BigAutoField"
|
default_auto_field = "django.db.models.BigAutoField"
|
||||||
name = "chat_backend"
|
name = "chat_backend"
|
||||||
|
|
||||||
|
def ready(self):
|
||||||
|
import chat_backend.signals
|
||||||
|
FORCE_RELOAD = False
|
||||||
|
|
||||||
|
if True: #not settings.TESTING: # Don't run during tests
|
||||||
|
try:
|
||||||
|
from .services.rag_services import AsyncRAGService
|
||||||
|
from chat_backend.models import Document
|
||||||
|
|
||||||
|
# Check if Chroma needs initialization
|
||||||
|
if Document.objects.exists():
|
||||||
|
rag_service = AsyncRAGService()
|
||||||
|
|
||||||
|
if rag_service.vector_store._collection.count() == 0:
|
||||||
|
print("Initializing ChromaDB with existing documents...")
|
||||||
|
rag_service.ingest_documents()
|
||||||
|
if FORCE_RELOAD:
|
||||||
|
print("Force Reload ChromaDB with existing documents...")
|
||||||
|
rag_service.clear_vector_store()
|
||||||
|
except OperationalError:
|
||||||
|
# Database tables might not exist yet during migration
|
||||||
|
pass
|
||||||
|
|||||||
@@ -1,30 +1,32 @@
|
|||||||
"""
|
"""
|
||||||
llama client - Abstract this in the future
|
llama client - Abstract this in the future
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import ollama
|
import ollama
|
||||||
from typing import List, Dict
|
from typing import List, Dict
|
||||||
|
|
||||||
|
|
||||||
class LlamaClient(object):
|
class LlamaClient(object):
|
||||||
def __init__(self, model: str='llama3'):
|
def __init__(self, model: str = "llama3"):
|
||||||
self.client = ollama.Client(host="http://127.0.0.1:11434")
|
self.client = ollama.Client(host="http://127.0.0.1:11434")
|
||||||
self.model = model
|
self.model = model
|
||||||
|
|
||||||
def check_if_model_exists(self) -> bool:
|
def check_if_model_exists(self) -> bool:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
def generate_conversation_title(self, message:str):
|
def generate_conversation_title(self, message: str):
|
||||||
response = self.generate_single_message("Summarise the phrase in one to for words\"%s\"" % message)
|
response = self.generate_single_message(
|
||||||
|
'Summarise the phrase in one to for words"%s"' % message
|
||||||
raw_response = response['response'].replace("\"","")
|
)
|
||||||
|
|
||||||
|
raw_response = response["response"].replace('"', "")
|
||||||
return " ".join(raw_response.split()[:4])
|
return " ".join(raw_response.split()[:4])
|
||||||
|
|
||||||
def generate_single_message(self, message: str):
|
def generate_single_message(self, message: str):
|
||||||
return ollama.generate(model=self.model, prompt=message)
|
return ollama.generate(model=self.model, prompt=message)
|
||||||
|
|
||||||
def get_chat_response(self, messages: List[str]):
|
def get_chat_response(self, messages: List[str]):
|
||||||
return self.client.chat(model = self.model, messages=messages, stream=False)
|
return self.client.chat(model=self.model, messages=messages, stream=False)
|
||||||
|
|
||||||
|
|
||||||
def get_streamed_chat_response(self, messages: List[str]):
|
def get_streamed_chat_response(self, messages: List[str]):
|
||||||
return self.client.chat(model = self.model, messages=messages, stream=True)
|
return self.client.chat(model=self.model, messages=messages, stream=True)
|
||||||
|
|
||||||
|
|||||||
@@ -0,0 +1,78 @@
|
|||||||
|
# Generated by Django 5.1.7 on 2025-04-30 18:58
|
||||||
|
|
||||||
|
import django.db.models.deletion
|
||||||
|
import django.utils.timezone
|
||||||
|
from django.db import migrations, models
|
||||||
|
|
||||||
|
|
||||||
|
class Migration(migrations.Migration):
|
||||||
|
|
||||||
|
dependencies = [
|
||||||
|
("chat_backend", "0019_customuser_conversation_order_and_more"),
|
||||||
|
]
|
||||||
|
|
||||||
|
operations = [
|
||||||
|
migrations.CreateModel(
|
||||||
|
name="DocumentWorkspace",
|
||||||
|
fields=[
|
||||||
|
(
|
||||||
|
"id",
|
||||||
|
models.BigAutoField(
|
||||||
|
auto_created=True,
|
||||||
|
primary_key=True,
|
||||||
|
serialize=False,
|
||||||
|
verbose_name="ID",
|
||||||
|
),
|
||||||
|
),
|
||||||
|
("created", models.DateTimeField(default=django.utils.timezone.now)),
|
||||||
|
(
|
||||||
|
"last_modified",
|
||||||
|
models.DateTimeField(default=django.utils.timezone.now),
|
||||||
|
),
|
||||||
|
("name", models.CharField(max_length=255)),
|
||||||
|
(
|
||||||
|
"company",
|
||||||
|
models.ForeignKey(
|
||||||
|
on_delete=django.db.models.deletion.CASCADE,
|
||||||
|
to="chat_backend.company",
|
||||||
|
),
|
||||||
|
),
|
||||||
|
],
|
||||||
|
options={
|
||||||
|
"abstract": False,
|
||||||
|
},
|
||||||
|
),
|
||||||
|
migrations.CreateModel(
|
||||||
|
name="Document",
|
||||||
|
fields=[
|
||||||
|
(
|
||||||
|
"id",
|
||||||
|
models.BigAutoField(
|
||||||
|
auto_created=True,
|
||||||
|
primary_key=True,
|
||||||
|
serialize=False,
|
||||||
|
verbose_name="ID",
|
||||||
|
),
|
||||||
|
),
|
||||||
|
("created", models.DateTimeField(default=django.utils.timezone.now)),
|
||||||
|
(
|
||||||
|
"last_modified",
|
||||||
|
models.DateTimeField(default=django.utils.timezone.now),
|
||||||
|
),
|
||||||
|
("file", models.FileField(upload_to="documents/")),
|
||||||
|
("uploaded_at", models.DateTimeField(auto_now_add=True)),
|
||||||
|
("processed", models.BooleanField(default=False)),
|
||||||
|
("active", models.BooleanField(default=False)),
|
||||||
|
(
|
||||||
|
"workspace",
|
||||||
|
models.ForeignKey(
|
||||||
|
on_delete=django.db.models.deletion.CASCADE,
|
||||||
|
to="chat_backend.documentworkspace",
|
||||||
|
),
|
||||||
|
),
|
||||||
|
],
|
||||||
|
options={
|
||||||
|
"abstract": False,
|
||||||
|
},
|
||||||
|
),
|
||||||
|
]
|
||||||
@@ -3,9 +3,11 @@ from django.contrib.auth.models import AbstractUser
|
|||||||
from django.utils import timezone
|
from django.utils import timezone
|
||||||
from autoslug import AutoSlugField
|
from autoslug import AutoSlugField
|
||||||
from django.core.files.storage import FileSystemStorage
|
from django.core.files.storage import FileSystemStorage
|
||||||
|
|
||||||
# Create your models here.
|
# Create your models here.
|
||||||
|
|
||||||
FILE_STORAGE = FileSystemStorage(location='prompt_files')
|
FILE_STORAGE = FileSystemStorage(location="prompt_files")
|
||||||
|
|
||||||
|
|
||||||
class TimeInfoBase(models.Model):
|
class TimeInfoBase(models.Model):
|
||||||
|
|
||||||
@@ -60,12 +62,18 @@ class CustomUser(AbstractUser):
|
|||||||
help_text="Allows the edit/add/remove of users for a company", default=False
|
help_text="Allows the edit/add/remove of users for a company", default=False
|
||||||
)
|
)
|
||||||
deleted = models.BooleanField(help_text="This is to hid accounts", default=False)
|
deleted = models.BooleanField(help_text="This is to hid accounts", default=False)
|
||||||
has_signed_tos = models.BooleanField(default=False, help_text="If the user has signed the TOS")
|
has_signed_tos = models.BooleanField(
|
||||||
slug = AutoSlugField(populate_from='email')
|
default=False, help_text="If the user has signed the TOS"
|
||||||
conversation_order = models.BooleanField(default=True, help_text='How the conversations should display')
|
)
|
||||||
|
slug = AutoSlugField(populate_from="email")
|
||||||
|
conversation_order = models.BooleanField(
|
||||||
|
default=True, help_text="How the conversations should display"
|
||||||
|
)
|
||||||
|
|
||||||
def get_set_password_url(self):
|
def get_set_password_url(self):
|
||||||
return f"https://www.chat.aimloperations.com/set_password?slug={self.slug}"
|
return f"https://www.chat.aimloperations.com/set_password?slug={self.slug}"
|
||||||
|
|
||||||
|
|
||||||
FEEDBACK_CHOICE = (
|
FEEDBACK_CHOICE = (
|
||||||
("SUBMITTED", "Submitted"),
|
("SUBMITTED", "Submitted"),
|
||||||
("RESOLVED", "Resolved"),
|
("RESOLVED", "Resolved"),
|
||||||
@@ -74,21 +82,26 @@ FEEDBACK_CHOICE = (
|
|||||||
)
|
)
|
||||||
|
|
||||||
FEEDBACK_CATEGORIES = (
|
FEEDBACK_CATEGORIES = (
|
||||||
('NOT_DEFINED', 'Not defined'),
|
("NOT_DEFINED", "Not defined"),
|
||||||
('BUG', 'Bug'),
|
("BUG", "Bug"),
|
||||||
('ENHANCEMENT', 'Enhancement'),
|
("ENHANCEMENT", "Enhancement"),
|
||||||
('OTHER', 'Other'),
|
("OTHER", "Other"),
|
||||||
('MAX_CATEGORIES', 'Max Categories'),
|
("MAX_CATEGORIES", "Max Categories"),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class Feedback(TimeInfoBase):
|
class Feedback(TimeInfoBase):
|
||||||
title = models.TextField(max_length=64, default='')
|
title = models.TextField(max_length=64, default="")
|
||||||
user = models.ForeignKey(
|
user = models.ForeignKey(
|
||||||
CustomUser, on_delete=models.CASCADE, blank=True, null=True
|
CustomUser, on_delete=models.CASCADE, blank=True, null=True
|
||||||
)
|
)
|
||||||
text = models.TextField(max_length=512)
|
text = models.TextField(max_length=512)
|
||||||
status = models.CharField(max_length=24, choices=FEEDBACK_CHOICE, default="SUBMITTED")
|
status = models.CharField(
|
||||||
category = models.CharField(max_length=24, choices=FEEDBACK_CATEGORIES, default="NOT_DEFINED")
|
max_length=24, choices=FEEDBACK_CHOICE, default="SUBMITTED"
|
||||||
|
)
|
||||||
|
category = models.CharField(
|
||||||
|
max_length=24, choices=FEEDBACK_CATEGORIES, default="NOT_DEFINED"
|
||||||
|
)
|
||||||
|
|
||||||
def get_user_email(self):
|
def get_user_email(self):
|
||||||
if self.user:
|
if self.user:
|
||||||
@@ -105,9 +118,8 @@ MONTH_CHOICES = (
|
|||||||
("DECEMBER", "December"),
|
("DECEMBER", "December"),
|
||||||
)
|
)
|
||||||
|
|
||||||
month = models.CharField(max_length=9,
|
month = models.CharField(max_length=9, choices=MONTH_CHOICES, default="JANUARY")
|
||||||
choices=MONTH_CHOICES,
|
|
||||||
default="JANUARY")
|
|
||||||
|
|
||||||
class Announcement(TimeInfoBase):
|
class Announcement(TimeInfoBase):
|
||||||
class Status(models.TextChoices):
|
class Status(models.TextChoices):
|
||||||
@@ -131,7 +143,9 @@ class Conversation(TimeInfoBase):
|
|||||||
title = models.CharField(
|
title = models.CharField(
|
||||||
max_length=64, help_text="The title for the conversation", default=""
|
max_length=64, help_text="The title for the conversation", default=""
|
||||||
)
|
)
|
||||||
deleted = models.BooleanField(help_text="This is to hide conversations", default=False)
|
deleted = models.BooleanField(
|
||||||
|
help_text="This is to hide conversations", default=False
|
||||||
|
)
|
||||||
|
|
||||||
def get_user_email(self):
|
def get_user_email(self):
|
||||||
if self.user:
|
if self.user:
|
||||||
@@ -151,20 +165,26 @@ class Prompt(TimeInfoBase):
|
|||||||
conversation = models.ForeignKey(
|
conversation = models.ForeignKey(
|
||||||
"Conversation", on_delete=models.CASCADE, blank=True, null=True
|
"Conversation", on_delete=models.CASCADE, blank=True, null=True
|
||||||
)
|
)
|
||||||
file =models.FileField(upload_to=FILE_STORAGE, blank=True, null=True, help_text="file for the prompt")
|
file = models.FileField(
|
||||||
file_type=models.CharField(max_length=16, blank=True, null=True, help_text='file type of the file for the prompt')
|
upload_to=FILE_STORAGE, blank=True, null=True, help_text="file for the prompt"
|
||||||
|
)
|
||||||
|
file_type = models.CharField(
|
||||||
|
max_length=16,
|
||||||
|
blank=True,
|
||||||
|
null=True,
|
||||||
|
help_text="file type of the file for the prompt",
|
||||||
|
)
|
||||||
|
|
||||||
def get_conversation_title(self):
|
def get_conversation_title(self):
|
||||||
if self.conversation:
|
if self.conversation:
|
||||||
return self.conversation.title
|
return self.conversation.title
|
||||||
else:
|
else:
|
||||||
return ""
|
return ""
|
||||||
|
|
||||||
def file_exists(self):
|
def file_exists(self):
|
||||||
return self.file != None and self.file.storage.exists(self.file.name)
|
return self.file != None and self.file.storage.exists(self.file.name)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class PromptMetric(TimeInfoBase):
|
class PromptMetric(TimeInfoBase):
|
||||||
PROMPT_METRIC_CHOICES = (
|
PROMPT_METRIC_CHOICES = (
|
||||||
("CREATED", "Created"),
|
("CREATED", "Created"),
|
||||||
@@ -174,20 +194,40 @@ class PromptMetric(TimeInfoBase):
|
|||||||
("MAX_PROMPT_METRIC_CHOICES", "Max Prompt Metric Choices"),
|
("MAX_PROMPT_METRIC_CHOICES", "Max Prompt Metric Choices"),
|
||||||
)
|
)
|
||||||
prompt_id = models.IntegerField(help_text="The id of the prompt this matches to")
|
prompt_id = models.IntegerField(help_text="The id of the prompt this matches to")
|
||||||
conversation_id = models.IntegerField(help_text="The id of the conversation this matches to")
|
conversation_id = models.IntegerField(
|
||||||
|
help_text="The id of the conversation this matches to"
|
||||||
|
)
|
||||||
event = models.CharField(
|
event = models.CharField(
|
||||||
max_length=26, choices=PROMPT_METRIC_CHOICES, default='CREATED'
|
max_length=26, choices=PROMPT_METRIC_CHOICES, default="CREATED"
|
||||||
)
|
)
|
||||||
model_name = models.CharField(max_length=215, help_text="The name of the model")
|
model_name = models.CharField(max_length=215, help_text="The name of the model")
|
||||||
start_time = models.DateTimeField()
|
start_time = models.DateTimeField()
|
||||||
end_time = models.DateTimeField(blank=True, null=True)
|
end_time = models.DateTimeField(blank=True, null=True)
|
||||||
prompt_length = models.IntegerField( help_text="How many characters are in the prompt")
|
prompt_length = models.IntegerField(
|
||||||
reponse_length = models.IntegerField(blank=True, null=True, help_text="How many characters are in the response")
|
help_text="How many characters are in the prompt"
|
||||||
|
)
|
||||||
|
reponse_length = models.IntegerField(
|
||||||
|
blank=True, null=True, help_text="How many characters are in the response"
|
||||||
|
)
|
||||||
has_file = models.BooleanField(help_text="Is there a file")
|
has_file = models.BooleanField(help_text="Is there a file")
|
||||||
file_type = models.CharField(max_length=16, help_text='The file type, if any', blank=True, null=True)
|
file_type = models.CharField(
|
||||||
|
max_length=16, help_text="The file type, if any", blank=True, null=True
|
||||||
|
)
|
||||||
|
|
||||||
def get_duration(self):
|
def get_duration(self):
|
||||||
if(self.start_time and self.end_time):
|
if self.start_time and self.end_time:
|
||||||
difference =self.end_time - self.start_time
|
difference = self.end_time - self.start_time
|
||||||
return difference.seconds
|
return difference.seconds
|
||||||
return 0
|
return 0
|
||||||
|
|
||||||
|
# Document Models
|
||||||
|
class DocumentWorkspace(TimeInfoBase):
|
||||||
|
name = models.CharField(max_length=255)
|
||||||
|
company = models.ForeignKey(Company, on_delete=models.CASCADE)
|
||||||
|
|
||||||
|
class Document(TimeInfoBase):
|
||||||
|
workspace = models.ForeignKey(DocumentWorkspace, on_delete=models.CASCADE)
|
||||||
|
file = models.FileField(upload_to='documents/')
|
||||||
|
uploaded_at = models.DateTimeField(auto_now_add=True)
|
||||||
|
processed = models.BooleanField(default=False)
|
||||||
|
active = models.BooleanField(default=False)
|
||||||
|
|||||||
@@ -1,8 +1,9 @@
|
|||||||
from rest_framework.renderers import BaseRenderer
|
from rest_framework.renderers import BaseRenderer
|
||||||
|
|
||||||
|
|
||||||
class ServerSentEventRenderer(BaseRenderer):
|
class ServerSentEventRenderer(BaseRenderer):
|
||||||
media_type = 'text/event-stream'
|
media_type = "text/event-stream"
|
||||||
format = 'txt'
|
format = "txt"
|
||||||
|
|
||||||
def render(self, data, accepted_media_type=None, renderer_context=None):
|
def render(self, data, accepted_media_type=None, renderer_context=None):
|
||||||
return data
|
return data
|
||||||
|
|||||||
@@ -1,7 +1,6 @@
|
|||||||
from django.urls import re_path
|
from django.urls import re_path
|
||||||
from .views import ChatConsumerAgain
|
from .views import ChatConsumerAgain
|
||||||
|
|
||||||
websocket_urlpatterns = [
|
|
||||||
re_path(r'ws/chat_again/$', ChatConsumerAgain.as_asgi()),
|
|
||||||
|
|
||||||
]
|
websocket_urlpatterns = [
|
||||||
|
re_path(r"ws/chat_again/$", ChatConsumerAgain.as_asgi()),
|
||||||
|
]
|
||||||
|
|||||||
@@ -1,6 +1,16 @@
|
|||||||
from rest_framework_simplejwt.serializers import TokenObtainPairSerializer
|
from rest_framework_simplejwt.serializers import TokenObtainPairSerializer
|
||||||
from rest_framework import serializers
|
from rest_framework import serializers
|
||||||
from .models import CustomUser, Announcement, Company, Conversation, Prompt, Feedback, FEEDBACK_CATEGORIES
|
from .models import (
|
||||||
|
CustomUser,
|
||||||
|
Announcement,
|
||||||
|
Company,
|
||||||
|
Conversation,
|
||||||
|
Prompt,
|
||||||
|
Feedback,
|
||||||
|
FEEDBACK_CATEGORIES,
|
||||||
|
DocumentWorkspace,
|
||||||
|
Document
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class MyTokenObtainPairSerializer(TokenObtainPairSerializer):
|
class MyTokenObtainPairSerializer(TokenObtainPairSerializer):
|
||||||
@@ -25,11 +35,13 @@ class AnnouncmentSerializer(serializers.ModelSerializer):
|
|||||||
model = Announcement
|
model = Announcement
|
||||||
fields = "__all__"
|
fields = "__all__"
|
||||||
|
|
||||||
|
|
||||||
class FeedbackSerializer(serializers.ModelSerializer):
|
class FeedbackSerializer(serializers.ModelSerializer):
|
||||||
class Meta:
|
class Meta:
|
||||||
model = Feedback
|
model = Feedback
|
||||||
fields = "__all__"
|
fields = "__all__"
|
||||||
|
|
||||||
|
|
||||||
class CustomUserSerializer(serializers.ModelSerializer):
|
class CustomUserSerializer(serializers.ModelSerializer):
|
||||||
email = serializers.EmailField(required=True)
|
email = serializers.EmailField(required=True)
|
||||||
username = serializers.CharField()
|
username = serializers.CharField()
|
||||||
@@ -58,12 +70,40 @@ class ConversationSerializer(serializers.ModelSerializer):
|
|||||||
|
|
||||||
|
|
||||||
class PromptSerializer(serializers.ModelSerializer):
|
class PromptSerializer(serializers.ModelSerializer):
|
||||||
|
|
||||||
class Meta:
|
class Meta:
|
||||||
model = Prompt
|
model = Prompt
|
||||||
fields = ("message", "user_created", "created", "id", )
|
fields = (
|
||||||
|
"message",
|
||||||
|
"user_created",
|
||||||
|
"created",
|
||||||
|
"id",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class BasicUserSerializer(serializers.ModelSerializer):
|
class BasicUserSerializer(serializers.ModelSerializer):
|
||||||
class Meta:
|
class Meta:
|
||||||
model = CustomUser
|
model = CustomUser
|
||||||
fields = ("email", "first_name", "last_name", "is_active","has_usable_password","is_company_manager",'has_signed_tos')
|
fields = (
|
||||||
|
"email",
|
||||||
|
"first_name",
|
||||||
|
"last_name",
|
||||||
|
"is_active",
|
||||||
|
"has_usable_password",
|
||||||
|
"is_company_manager",
|
||||||
|
"has_signed_tos",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# document serializers
|
||||||
|
class DocumentWorkspaceSerializer(serializers.ModelSerializer):
|
||||||
|
class Meta:
|
||||||
|
model = DocumentWorkspace
|
||||||
|
fields = ['id', 'name', 'created']
|
||||||
|
read_only_fields = ['id', 'created']
|
||||||
|
|
||||||
|
class DocumentSerializer(serializers.ModelSerializer):
|
||||||
|
class Meta:
|
||||||
|
model = Document
|
||||||
|
fields = ['id', 'workspace', 'file', 'uploaded_at', 'processed', 'created', 'active']
|
||||||
|
read_only_fields = ['id', 'uploaded_at', 'processed', 'created']
|
||||||
0
llm_be/chat_backend/services/__init__.py
Normal file
0
llm_be/chat_backend/services/__init__.py
Normal file
145
llm_be/chat_backend/services/image_generation.py
Normal file
145
llm_be/chat_backend/services/image_generation.py
Normal file
@@ -0,0 +1,145 @@
|
|||||||
|
import os
|
||||||
|
import logging
|
||||||
|
from typing import Optional, Tuple
|
||||||
|
from PIL import Image
|
||||||
|
import torch
|
||||||
|
from diffusers import StableDiffusionPipeline, DPMSolverSinglestepScheduler
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
class ImageGenerationService:
|
||||||
|
"""
|
||||||
|
Service for text-to-image generation using Stable Diffusion.
|
||||||
|
Uses singleton pattern to maintain loaded model in memory.
|
||||||
|
"""
|
||||||
|
|
||||||
|
_instance = None
|
||||||
|
_model_loaded = False
|
||||||
|
|
||||||
|
def __new__(cls):
|
||||||
|
if cls._instance is None:
|
||||||
|
cls._instance = super().__new__(cls)
|
||||||
|
cls._instance._initialize()
|
||||||
|
return cls._instance
|
||||||
|
|
||||||
|
def _initialize(self):
|
||||||
|
"""Initialize the service with default settings"""
|
||||||
|
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||||
|
self.model_id = "stabilityai/stable-diffusion-2-1"
|
||||||
|
self.pipeline = None
|
||||||
|
self.default_params = {
|
||||||
|
"num_inference_steps": 25,
|
||||||
|
"guidance_scale": 7.5,
|
||||||
|
"width": 512,
|
||||||
|
"height": 512,
|
||||||
|
}
|
||||||
|
|
||||||
|
def load_model(self):
|
||||||
|
"""Load the Stable Diffusion model"""
|
||||||
|
if self._model_loaded:
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
logger.info(f"Loading Stable Diffusion model on {self.device}...")
|
||||||
|
|
||||||
|
# Use DPMSolver for faster inference
|
||||||
|
self.pipeline = StableDiffusionPipeline.from_pretrained(
|
||||||
|
self.model_id,
|
||||||
|
torch_dtype=torch.float16 if self.device == "cuda" else torch.float32,
|
||||||
|
)
|
||||||
|
self.pipeline.scheduler = DPMSolverSinglestepScheduler.from_config(
|
||||||
|
self.pipeline.scheduler.config
|
||||||
|
)
|
||||||
|
self.pipeline = self.pipeline.to(self.device)
|
||||||
|
|
||||||
|
# Optimizations
|
||||||
|
if self.device == "cuda":
|
||||||
|
self.pipeline.enable_attention_slicing()
|
||||||
|
self.pipeline.enable_xformers_memory_efficient_attention()
|
||||||
|
|
||||||
|
self._model_loaded = True
|
||||||
|
logger.info("Model loaded successfully")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to load model: {str(e)}")
|
||||||
|
raise RuntimeError(f"Model loading failed: {str(e)}")
|
||||||
|
|
||||||
|
def generate_image(
|
||||||
|
self,
|
||||||
|
prompt: str,
|
||||||
|
negative_prompt: Optional[str] = None,
|
||||||
|
output_path: Optional[str] = None,
|
||||||
|
**kwargs
|
||||||
|
) -> Tuple[Image.Image, dict]:
|
||||||
|
"""
|
||||||
|
Generate image from text prompt.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
prompt: Text prompt for image generation
|
||||||
|
negative_prompt: Text for things to avoid in generation
|
||||||
|
output_path: Optional path to save the image
|
||||||
|
**kwargs: Generation parameters (overrides defaults)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of (PIL.Image, generation_parameters)
|
||||||
|
"""
|
||||||
|
if not self._model_loaded:
|
||||||
|
self.load_model()
|
||||||
|
|
||||||
|
# Merge default params with overrides
|
||||||
|
params = {**self.default_params, **kwargs}
|
||||||
|
|
||||||
|
try:
|
||||||
|
logger.info(f"Generating image with prompt: {prompt[:50]}...")
|
||||||
|
|
||||||
|
with torch.inference_mode():
|
||||||
|
result = self.pipeline(
|
||||||
|
prompt=prompt,
|
||||||
|
negative_prompt=negative_prompt,
|
||||||
|
**params
|
||||||
|
)
|
||||||
|
|
||||||
|
image = result.images[0]
|
||||||
|
|
||||||
|
if output_path:
|
||||||
|
os.makedirs(os.path.dirname(output_path), exist_ok=True)
|
||||||
|
image.save(output_path)
|
||||||
|
logger.info(f"Image saved to {output_path}")
|
||||||
|
|
||||||
|
return image, params
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Image generation failed: {str(e)}")
|
||||||
|
raise RuntimeError(f"Image generation failed: {str(e)}")
|
||||||
|
|
||||||
|
|
||||||
|
class AsyncImageGenerationService:
|
||||||
|
"""
|
||||||
|
Asynchronous wrapper for image generation service.
|
||||||
|
Runs the synchronous service in a thread pool.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.sync_service = ImageGenerationService()
|
||||||
|
|
||||||
|
async def generate_image(
|
||||||
|
self,
|
||||||
|
prompt: str,
|
||||||
|
negative_prompt: Optional[str] = None,
|
||||||
|
output_path: Optional[str] = None,
|
||||||
|
**kwargs
|
||||||
|
) -> Tuple[Image.Image, dict]:
|
||||||
|
"""Async version of generate_image"""
|
||||||
|
import asyncio
|
||||||
|
from functools import partial
|
||||||
|
|
||||||
|
loop = asyncio.get_event_loop()
|
||||||
|
func = partial(
|
||||||
|
self.sync_service.generate_image,
|
||||||
|
prompt=prompt,
|
||||||
|
negative_prompt=negative_prompt,
|
||||||
|
output_path=output_path,
|
||||||
|
**kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
return await loop.run_in_executor(None, func)
|
||||||
138
llm_be/chat_backend/services/llm_service.py
Normal file
138
llm_be/chat_backend/services/llm_service.py
Normal file
@@ -0,0 +1,138 @@
|
|||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from typing import AsyncGenerator, Generator, Optional
|
||||||
|
|
||||||
|
from langchain_community.llms import Ollama
|
||||||
|
from langchain_core.output_parsers import StrOutputParser
|
||||||
|
from langchain_core.prompts import ChatPromptTemplate
|
||||||
|
|
||||||
|
from chat_backend.models import Conversation, Prompt
|
||||||
|
|
||||||
|
class LLMService(ABC):
|
||||||
|
"""Abstract base class for LLM conversation services."""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.llm = Ollama(
|
||||||
|
model="llama3.2",
|
||||||
|
temperature=0.7,
|
||||||
|
top_k=50,
|
||||||
|
top_p=0.9,
|
||||||
|
repeat_penalty=1.1,
|
||||||
|
num_ctx=4096
|
||||||
|
)
|
||||||
|
self.output_parser = StrOutputParser()
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def generate_response(self, conversation: Conversation, query: str, **kwargs):
|
||||||
|
"""Generate a response to a query within a conversation context."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
def _format_history(self, conversation: Conversation) -> str:
|
||||||
|
"""Format conversation history for the prompt."""
|
||||||
|
prompts = Prompt.objects.filter(conversation=conversation).order_by('created_at')
|
||||||
|
return "\n".join(
|
||||||
|
f"{'User' if prompt.is_user else 'AI'}: {prompt.text}"
|
||||||
|
for prompt in prompts
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class SyncLLMService(LLMService):
|
||||||
|
"""Synchronous LLM conversation service."""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self._setup_chain()
|
||||||
|
|
||||||
|
def _setup_chain(self):
|
||||||
|
"""Setup the conversation chain."""
|
||||||
|
template = """Continue the conversation based on the following history:
|
||||||
|
|
||||||
|
{history}
|
||||||
|
|
||||||
|
Latest message: {query}
|
||||||
|
|
||||||
|
Response:"""
|
||||||
|
self.prompt = ChatPromptTemplate.from_template(template)
|
||||||
|
|
||||||
|
self.conversation_chain = (
|
||||||
|
{
|
||||||
|
"history": lambda x: self._format_history(x["conversation"]),
|
||||||
|
"query": lambda x: x["query"]
|
||||||
|
}
|
||||||
|
| self.prompt
|
||||||
|
| self.llm
|
||||||
|
| self.output_parser
|
||||||
|
)
|
||||||
|
|
||||||
|
def generate_response(self, conversation: Conversation, query: str, **kwargs) -> Generator[str, None, None]:
|
||||||
|
"""Generate response with streaming support."""
|
||||||
|
chain_input = {
|
||||||
|
"query": query,
|
||||||
|
"conversation": conversation
|
||||||
|
}
|
||||||
|
|
||||||
|
for chunk in self.conversation_chain.stream(chain_input):
|
||||||
|
yield chunk
|
||||||
|
|
||||||
|
|
||||||
|
class AsyncLLMService(LLMService):
|
||||||
|
"""Asynchronous LLM conversation service."""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self._setup_chain()
|
||||||
|
|
||||||
|
def _setup_chain(self):
|
||||||
|
"""Setup the conversation chain."""
|
||||||
|
template = """Continue this conversation while maintaining context by providing a single helpful response.
|
||||||
|
Current context: {context}
|
||||||
|
|
||||||
|
Last 3 messages:
|
||||||
|
{recent_history}
|
||||||
|
|
||||||
|
Latest message: {query}
|
||||||
|
|
||||||
|
Instructions:
|
||||||
|
- Carefully maintain all established context
|
||||||
|
- If referencing previous elements (like stories), preserve all details
|
||||||
|
- When asked to modify something, identify what's being modified
|
||||||
|
|
||||||
|
Response:"""
|
||||||
|
|
||||||
|
self.prompt = ChatPromptTemplate.from_template(template)
|
||||||
|
|
||||||
|
self.conversation_chain = (
|
||||||
|
{
|
||||||
|
"context": lambda x: self._format_history(x["conversation"]),
|
||||||
|
"recent_history": lambda x: self._get_recent_messages(x["conversation"]),
|
||||||
|
"query": lambda x: x["query"]
|
||||||
|
}
|
||||||
|
| self.prompt
|
||||||
|
| self.llm
|
||||||
|
| self.output_parser
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _format_history(self, conversation: Conversation) -> str:
|
||||||
|
"""Async version of format conversation history."""
|
||||||
|
prompts = await Prompt.objects.filter(conversation=conversation).order_by('created_at').alist()
|
||||||
|
return "\n".join(
|
||||||
|
f"{'User' if prompt.is_user else 'AI'}: {prompt.text}"
|
||||||
|
for prompt in prompts
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _get_recent_messages(self, conversation: Conversation) -> str:
|
||||||
|
"""Async version of format conversation history."""
|
||||||
|
prompts = await Prompt.objects.filter(conversation=conversation).order_by('created_at').alist()[-3:]
|
||||||
|
return "\n".join(
|
||||||
|
f"{'User' if prompt.is_user else 'AI'}: {prompt.text}"
|
||||||
|
for prompt in prompts
|
||||||
|
)
|
||||||
|
|
||||||
|
async def generate_response(self, conversation: Conversation, query: str, **kwargs) -> AsyncGenerator[str, None]:
|
||||||
|
"""Generate response with async streaming support."""
|
||||||
|
chain_input = {
|
||||||
|
"query": query,
|
||||||
|
"conversation": conversation
|
||||||
|
}
|
||||||
|
|
||||||
|
async for chunk in self.conversation_chain.astream(chain_input):
|
||||||
|
yield chunk
|
||||||
79
llm_be/chat_backend/services/moderation_classifier.py
Normal file
79
llm_be/chat_backend/services/moderation_classifier.py
Normal file
@@ -0,0 +1,79 @@
|
|||||||
|
from enum import Enum, auto
|
||||||
|
from typing import Dict, Any
|
||||||
|
from langchain_core.prompts import ChatPromptTemplate
|
||||||
|
from langchain_community.llms import Ollama
|
||||||
|
|
||||||
|
class ModerationLabel(Enum):
|
||||||
|
NSFW = auto()
|
||||||
|
FINE = auto()
|
||||||
|
|
||||||
|
class ModerationClassifier:
|
||||||
|
"""
|
||||||
|
Classifies prompts as NSFW or FINE (safe) content.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.llm = Ollama(
|
||||||
|
model="llama3.2",
|
||||||
|
temperature=0.1, # Very low for strict moderation
|
||||||
|
top_k=10,
|
||||||
|
num_ctx=2048
|
||||||
|
)
|
||||||
|
|
||||||
|
self.moderation_prompt = ChatPromptTemplate.from_messages([
|
||||||
|
("system", """You are a strict content moderator. Classify the following prompt as either NSFW or FINE.
|
||||||
|
|
||||||
|
NSFW includes:
|
||||||
|
- Sexual content
|
||||||
|
- Violence/gore
|
||||||
|
- Hate speech
|
||||||
|
- Illegal activities
|
||||||
|
- Harassment
|
||||||
|
- Graphic/disturbing content
|
||||||
|
|
||||||
|
FINE includes:
|
||||||
|
- Safe for work topics
|
||||||
|
- General conversation
|
||||||
|
- Professional inquiries
|
||||||
|
- Creative requests (non-explicit)
|
||||||
|
- Technical questions
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
- "How to make a bomb" → NSFW
|
||||||
|
- "Write a love poem" → FINE
|
||||||
|
- "Explicit sex scene" → NSFW
|
||||||
|
- "Python tutorial" → FINE
|
||||||
|
|
||||||
|
Return ONLY "NSFW" or "FINE", nothing else."""),
|
||||||
|
("human", "{prompt}")
|
||||||
|
])
|
||||||
|
|
||||||
|
self.chain = self.moderation_prompt | self.llm
|
||||||
|
|
||||||
|
async def classify_async(self, prompt: str) -> ModerationLabel:
|
||||||
|
"""Asynchronous classification"""
|
||||||
|
try:
|
||||||
|
response = (await self.chain.ainvoke({"prompt": prompt})).strip().upper()
|
||||||
|
return self._parse_response(response)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Moderation error: {e}")
|
||||||
|
return ModerationLabel.NSFW # Fail-safe to NSFW
|
||||||
|
|
||||||
|
def classify(self, prompt: str) -> ModerationLabel:
|
||||||
|
"""Synchronous classification"""
|
||||||
|
try:
|
||||||
|
response = self.chain.invoke({"prompt": prompt}).strip().upper()
|
||||||
|
return self._parse_response(response)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Moderation error: {e}")
|
||||||
|
return ModerationLabel.NSFW # Fail-safe to NSFW
|
||||||
|
|
||||||
|
def _parse_response(self, response: str) -> ModerationLabel:
|
||||||
|
"""Convert string response to ModerationLabel enum"""
|
||||||
|
if "NSFW" in response:
|
||||||
|
return ModerationLabel.NSFW
|
||||||
|
return ModerationLabel.FINE # Default to FINE if unclear
|
||||||
|
|
||||||
|
|
||||||
|
# Singleton instance
|
||||||
|
moderation_classifier = ModerationClassifier()
|
||||||
100
llm_be/chat_backend/services/prompt_classifier.py
Normal file
100
llm_be/chat_backend/services/prompt_classifier.py
Normal file
@@ -0,0 +1,100 @@
|
|||||||
|
from enum import Enum, auto
|
||||||
|
from typing import Dict, Any
|
||||||
|
from langchain_core.prompts import ChatPromptTemplate
|
||||||
|
from langchain_community.llms import Ollama
|
||||||
|
|
||||||
|
class PromptType(Enum):
|
||||||
|
GENERAL_CHAT = auto()
|
||||||
|
RAG = auto()
|
||||||
|
IMAGE_GENERATION = auto()
|
||||||
|
UNKNOWN = auto()
|
||||||
|
|
||||||
|
class PromptClassifier:
|
||||||
|
"""
|
||||||
|
Classifies user prompts to determine which service should handle them.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.llm = Ollama(
|
||||||
|
model="llama3",
|
||||||
|
temperature=0.3, # Lower temp for more deterministic classification
|
||||||
|
top_k=20,
|
||||||
|
top_p=0.9,
|
||||||
|
num_ctx=4096
|
||||||
|
)
|
||||||
|
|
||||||
|
self.classification_prompt = ChatPromptTemplate.from_messages([
|
||||||
|
("system",
|
||||||
|
"""You are a precision prompt classifier. Strictly categorize prompts into:
|
||||||
|
1. GENERAL_CHAT - Casual conversation, personal questions, or non-specific inquiries
|
||||||
|
2. RAG - ONLY when explicitly requesting document/search-based knowledge
|
||||||
|
3. IMAGE_GENERATION - Specific requests to create/modify images
|
||||||
|
4. UNKNOWN - If none of the above fit
|
||||||
|
|
||||||
|
1. IMAGE_GENERATION - ONLY if:
|
||||||
|
- Explicitly contains: "generate/create/draw/make an image/picture/photo/art/illustration"
|
||||||
|
- Requests visual content creation
|
||||||
|
- Example: "Make a picture of a castle" → IMAGE_GENERATION
|
||||||
|
|
||||||
|
2. RAG - ONLY if:
|
||||||
|
- Explicitly mentions documents/files/data
|
||||||
|
- Uses search terms: "find/search/lookup in [source]"
|
||||||
|
- Example: "What does contracts.pdf say?" → RAG
|
||||||
|
|
||||||
|
3. GENERAL_CHAT - DEFAULT category when:
|
||||||
|
- Doesn't meet above criteria
|
||||||
|
- Conversational/general knowledge
|
||||||
|
- Uncertain cases
|
||||||
|
- Example: "Tell me a joke" → GENERAL_CHAT
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
[Definitely RAG]
|
||||||
|
- "What does the uploaded PDF say about quarterly results?"
|
||||||
|
- "Search our documents for the 2023 marketing strategy"
|
||||||
|
- "Find the contract clause about termination"
|
||||||
|
|
||||||
|
[Definitely GENERAL_CHAT]
|
||||||
|
- "How does photosynthesis work?" (General knowledge)
|
||||||
|
- "Tell me a joke"
|
||||||
|
- "What's your opinion on AI?"
|
||||||
|
|
||||||
|
[Borderline → GENERAL_CHAT]
|
||||||
|
- "What's our company policy on X?" (No doc reference → general)
|
||||||
|
- "Explain quantum computing" (General knowledge)
|
||||||
|
- "Summarize the meeting" (No doc reference)
|
||||||
|
|
||||||
|
Return ONLY the label, no explanations."""),
|
||||||
|
("human", "{prompt}")
|
||||||
|
])
|
||||||
|
|
||||||
|
self.chain = self.classification_prompt | self.llm
|
||||||
|
|
||||||
|
async def classify_async(self, prompt: str) -> PromptType:
|
||||||
|
"""Asynchronously classify the prompt"""
|
||||||
|
try:
|
||||||
|
response = await self.chain.ainvoke({"prompt": prompt})
|
||||||
|
return self._parse_response(response.strip())
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Classification error: {e}")
|
||||||
|
return PromptType.UNKNOWN
|
||||||
|
|
||||||
|
def classify(self, prompt: str) -> PromptType:
|
||||||
|
"""Synchronously classify the prompt"""
|
||||||
|
try:
|
||||||
|
response = self.chain.invoke({"prompt": prompt})
|
||||||
|
return self._parse_response(response.strip())
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Classification error: {e}")
|
||||||
|
return PromptType.UNKNOWN
|
||||||
|
|
||||||
|
def _parse_response(self, response: str) -> PromptType:
|
||||||
|
"""Convert string response to PromptType enum"""
|
||||||
|
response = response.upper()
|
||||||
|
for prompt_type in PromptType:
|
||||||
|
if prompt_type.name in response:
|
||||||
|
return prompt_type
|
||||||
|
return PromptType.UNKNOWN
|
||||||
|
|
||||||
|
|
||||||
|
# Singleton instance for easy access
|
||||||
|
prompt_classifier = PromptClassifier()
|
||||||
378
llm_be/chat_backend/services/rag_services.py
Normal file
378
llm_be/chat_backend/services/rag_services.py
Normal file
@@ -0,0 +1,378 @@
|
|||||||
|
import os
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from typing import List, Dict, Any, AsyncGenerator, Generator, Optional
|
||||||
|
from channels.db import database_sync_to_async
|
||||||
|
from langchain_community.embeddings import OllamaEmbeddings
|
||||||
|
from langchain_community.llms import Ollama
|
||||||
|
from langchain_community.vectorstores import Chroma
|
||||||
|
from langchain_core.documents import Document as LangDocument
|
||||||
|
from langchain_core.output_parsers import StrOutputParser
|
||||||
|
from langchain_core.prompts import ChatPromptTemplate
|
||||||
|
from langchain_core.runnables import RunnablePassthrough
|
||||||
|
from langchain_text_splitters import RecursiveCharacterTextSplitter
|
||||||
|
from langchain_community.document_loaders import (
|
||||||
|
PyPDFLoader,
|
||||||
|
Docx2txtLoader,
|
||||||
|
TextLoader,
|
||||||
|
UnstructuredFileLoader
|
||||||
|
)
|
||||||
|
from django.core.files.uploadedfile import UploadedFile
|
||||||
|
from chat_backend.models import Conversation, Prompt, DocumentWorkspace, Document
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
@database_sync_to_async
|
||||||
|
def get_documents(workspace: DocumentWorkspace | None = None):
|
||||||
|
if workspace:
|
||||||
|
return [doc for doc in Document.objects.filter(workspace=workspace)]
|
||||||
|
else:
|
||||||
|
return [doc for doc in Document.objects.all()]
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class RAGService(ABC):
|
||||||
|
"""Abstract base class for RAG services."""
|
||||||
|
_instance = None
|
||||||
|
|
||||||
|
def __new__(cls):
|
||||||
|
if cls._instance is None:
|
||||||
|
cls._instance = super().__new__(cls)
|
||||||
|
cls._instance.__init__()
|
||||||
|
return cls._instance
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.embedding_model = OllamaEmbeddings(model="llama3.2")
|
||||||
|
self.llm = Ollama(
|
||||||
|
model="llama3.2",
|
||||||
|
temperature=0.7,
|
||||||
|
top_k=50,
|
||||||
|
top_p=0.9,
|
||||||
|
repeat_penalty=1.1,
|
||||||
|
num_ctx=4096
|
||||||
|
)
|
||||||
|
self.text_splitter = RecursiveCharacterTextSplitter(
|
||||||
|
chunk_size=1000,
|
||||||
|
chunk_overlap=200
|
||||||
|
)
|
||||||
|
self.vector_store = self._initialize_vector_store()
|
||||||
|
|
||||||
|
# Supported file types and their loaders
|
||||||
|
self.loader_mapping = {
|
||||||
|
'.pdf': PyPDFLoader,
|
||||||
|
'.docx': Docx2txtLoader,
|
||||||
|
'.txt': TextLoader,
|
||||||
|
# Fallback for other file types
|
||||||
|
'*': UnstructuredFileLoader,
|
||||||
|
}
|
||||||
|
|
||||||
|
def _initialize_vector_store(self) -> Chroma:
|
||||||
|
"""Initialize and return the Chroma vector store."""
|
||||||
|
persist_directory=f"./chroma_db/"
|
||||||
|
vector_store = Chroma(
|
||||||
|
embedding_function=self.embedding_model,
|
||||||
|
persist_directory=persist_directory
|
||||||
|
)
|
||||||
|
return vector_store
|
||||||
|
|
||||||
|
def clear_vector_store(self):
|
||||||
|
"""Clear all vectors from the store"""
|
||||||
|
self.vector_store.delete_collection()
|
||||||
|
self.vector_store = self._initialize_vector_store()
|
||||||
|
|
||||||
|
def _prepare_documents(self, documents: List[Document]) -> List[Document]:
|
||||||
|
"""Process documents for ingestion into vector store."""
|
||||||
|
docs = []
|
||||||
|
|
||||||
|
for doc in documents:
|
||||||
|
print(f"Processing: {doc.file.name}")
|
||||||
|
loader_class = self._get_file_loader( doc.file.name)
|
||||||
|
loader = loader_class(doc.file)
|
||||||
|
|
||||||
|
|
||||||
|
chunks = self._load_and_split_documents(doc.file.path)
|
||||||
|
if chunks:
|
||||||
|
self.vector_store.add_documents(chunks)
|
||||||
|
self.vector_store.persist()
|
||||||
|
|
||||||
|
|
||||||
|
def ingest_documents(self, workspace: DocumentWorkspace | None = None) -> None:
|
||||||
|
"""Ingest documents from a workspace into the vector store."""
|
||||||
|
print(f"Getting the Document via the workspace: {workspace}")
|
||||||
|
if workspace:
|
||||||
|
documents = [doc for doc in Document.objects.filter(workspace=workspace)]
|
||||||
|
else:
|
||||||
|
documents = [doc for doc in Document.objects.all()]
|
||||||
|
|
||||||
|
print(f"Processing the documents : {documents}")
|
||||||
|
self._prepare_documents(documents)
|
||||||
|
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def generate_response(self, conversation: Conversation, query: str, **kwargs):
|
||||||
|
"""Generate a response using RAG."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def search_documents(self, query: str, workspace: Optional[DocumentWorkspace] = None, k: int = 4) -> List[Document]:
|
||||||
|
"""Search relevant documents from the vector store."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
def _get_file_loader(self, file_path: str):
|
||||||
|
"""Get appropriate loader for file type"""
|
||||||
|
ext = Path(file_path).suffix.lower()
|
||||||
|
return self.loader_mapping.get(ext, self.loader_mapping['*'])
|
||||||
|
|
||||||
|
def _sanitize_filename(self, filename: str) -> str:
|
||||||
|
"""Sanitize filename for safe storage"""
|
||||||
|
return re.sub(r'[^\w\-_. ]', '_', filename)
|
||||||
|
|
||||||
|
def _save_uploaded_file(self, uploaded_file: UploadedFile, save_dir: str) -> str:
|
||||||
|
"""Save uploaded file to disk"""
|
||||||
|
os.makedirs(save_dir, exist_ok=True)
|
||||||
|
sanitized_name = self._sanitize_filename(uploaded_file.name)
|
||||||
|
file_path = os.path.join(save_dir, sanitized_name)
|
||||||
|
|
||||||
|
with open(file_path, 'wb+') as destination:
|
||||||
|
for chunk in uploaded_file.chunks():
|
||||||
|
destination.write(chunk)
|
||||||
|
|
||||||
|
return file_path
|
||||||
|
|
||||||
|
def _load_and_split_documents(self, file_path: str, metadata: dict = None) -> List[Document]:
|
||||||
|
"""Load and split documents from file"""
|
||||||
|
loader_class = self._get_file_loader(file_path)
|
||||||
|
loader = loader_class(file_path)
|
||||||
|
|
||||||
|
docs = loader.load()
|
||||||
|
if metadata:
|
||||||
|
for doc in docs:
|
||||||
|
doc.metadata.update(metadata)
|
||||||
|
|
||||||
|
return self.text_splitter.split_documents(docs)
|
||||||
|
|
||||||
|
def add_files_to_store(
|
||||||
|
self,
|
||||||
|
file_tupls: List[UploadedFile], # (file_path, name,workspace_id)
|
||||||
|
workspace_id: str,
|
||||||
|
source: str = "upload",
|
||||||
|
save_dir: str = "data/uploads"
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Process and add uploaded files to vector store
|
||||||
|
|
||||||
|
Args:
|
||||||
|
files: List of Django UploadedFile objects
|
||||||
|
workspace_id: ID of the workspace these belong to
|
||||||
|
source: Source identifier for documents
|
||||||
|
save_dir: Directory to save uploaded files
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary with processing results
|
||||||
|
"""
|
||||||
|
results = {
|
||||||
|
'total_added': 0,
|
||||||
|
'failed_files': [],
|
||||||
|
'processed_files': []
|
||||||
|
}
|
||||||
|
|
||||||
|
for file_tuple in file_tupls:
|
||||||
|
try:
|
||||||
|
# Save file to disk
|
||||||
|
|
||||||
|
|
||||||
|
# Prepare metadata
|
||||||
|
metadata = {
|
||||||
|
'source': file_tuple[1],
|
||||||
|
'workspace_id': file_tuple[2],
|
||||||
|
'original_filename': file_tuple[1],
|
||||||
|
'file_path': file_tuple[0],
|
||||||
|
}
|
||||||
|
|
||||||
|
# Load and split documents
|
||||||
|
docs = self._load_and_split_documents(file_path, metadata)
|
||||||
|
|
||||||
|
# Add to vector store
|
||||||
|
if docs:
|
||||||
|
self.vector_store.add_documents(docs)
|
||||||
|
results['total_added'] += len(docs)
|
||||||
|
results['processed_files'].append({
|
||||||
|
'filename': file_tuple[1],
|
||||||
|
'document_count': len(docs)
|
||||||
|
})
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
results['failed_files'].append({
|
||||||
|
'filename': file_tuple[1],
|
||||||
|
'error': str(e)
|
||||||
|
})
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Persist changes
|
||||||
|
self.vector_store.persist()
|
||||||
|
return results
|
||||||
|
|
||||||
|
|
||||||
|
class SyncRAGService(RAGService):
|
||||||
|
"""Synchronous RAG service implementation."""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self._setup_chain()
|
||||||
|
|
||||||
|
def _setup_chain(self):
|
||||||
|
"""Setup the RAG chain."""
|
||||||
|
template = """Answer the question based only on the following context:
|
||||||
|
{context}
|
||||||
|
|
||||||
|
Conversation history:
|
||||||
|
{history}
|
||||||
|
|
||||||
|
Question: {question}
|
||||||
|
"""
|
||||||
|
self.prompt = ChatPromptTemplate.from_template(template)
|
||||||
|
|
||||||
|
self.rag_chain = (
|
||||||
|
{
|
||||||
|
"context": self._retriever_with_history,
|
||||||
|
"history": lambda x: self._format_history(x["conversation"]),
|
||||||
|
"question": lambda x: x["query"]
|
||||||
|
}
|
||||||
|
| self.prompt
|
||||||
|
| self.llm
|
||||||
|
| StrOutputParser()
|
||||||
|
)
|
||||||
|
|
||||||
|
def _format_history(self, conversation: Conversation) -> str:
|
||||||
|
"""Format conversation history for the prompt."""
|
||||||
|
prompts = Prompt.objects.filter(conversation=conversation).order_by('created_at')
|
||||||
|
return "\n".join(
|
||||||
|
f"{'User' if prompt.is_user else 'AI'}: {prompt.text}"
|
||||||
|
for prompt in prompts
|
||||||
|
)
|
||||||
|
|
||||||
|
def _retriever_with_history(self, input_dict: Dict[str, Any]) -> str:
|
||||||
|
"""Retrieve documents considering conversation history."""
|
||||||
|
query = input_dict["query"]
|
||||||
|
conversation = input_dict["conversation"]
|
||||||
|
|
||||||
|
# You could enhance this to consider historical context in retrieval
|
||||||
|
relevant_docs = self.search_documents(query, conversation.workspace)
|
||||||
|
if not relevant_docs:
|
||||||
|
print("didn't find any relevant docs")
|
||||||
|
return relevant_docs
|
||||||
|
else:
|
||||||
|
return relevant_docs
|
||||||
|
|
||||||
|
|
||||||
|
def search_documents(self, query: str, workspace: Optional[DocumentWorkspace] = None, k: int = 4) -> List[Document]:
|
||||||
|
"""Search relevant documents from the vector store."""
|
||||||
|
filter_dict = {}
|
||||||
|
if workspace:
|
||||||
|
filter_dict["workspace_id"] = workspace.id
|
||||||
|
print(f"search_kwargs: {search_kwargs}")
|
||||||
|
retriever = self.vector_store.as_retriever(
|
||||||
|
search_type="similarity",
|
||||||
|
search_kwargs={
|
||||||
|
"k": k,
|
||||||
|
"filter": filter_dict if filter_dict else None
|
||||||
|
}
|
||||||
|
)
|
||||||
|
return retriever.get_relevant_documents(query)
|
||||||
|
|
||||||
|
def generate_response(self, conversation: Conversation, query: str, **kwargs) -> Generator[str, None, None]:
|
||||||
|
"""Generate response with streaming support."""
|
||||||
|
chain_input = {
|
||||||
|
"query": query,
|
||||||
|
"conversation": conversation
|
||||||
|
}
|
||||||
|
|
||||||
|
for chunk in self.rag_chain.stream(chain_input):
|
||||||
|
yield chunk
|
||||||
|
|
||||||
|
|
||||||
|
class AsyncRAGService(RAGService):
|
||||||
|
"""Asynchronous RAG service implementation."""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self._setup_chain()
|
||||||
|
|
||||||
|
def _setup_chain(self):
|
||||||
|
"""Setup the RAG chain."""
|
||||||
|
template = """Answer the question based only on the following context:
|
||||||
|
{context}
|
||||||
|
|
||||||
|
Conversation history:
|
||||||
|
{history}
|
||||||
|
|
||||||
|
Question: {question}
|
||||||
|
"""
|
||||||
|
self.prompt = ChatPromptTemplate.from_template(template)
|
||||||
|
|
||||||
|
self.rag_chain = (
|
||||||
|
{
|
||||||
|
"context": self._retriever_with_history,
|
||||||
|
"history": lambda x: self._format_history(x["conversation"]),
|
||||||
|
"question": lambda x: x["query"]
|
||||||
|
}
|
||||||
|
| self.prompt
|
||||||
|
| self.llm
|
||||||
|
| StrOutputParser()
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _format_history(self, conversation: Conversation) -> str:
|
||||||
|
"""Format conversation history for the prompt."""
|
||||||
|
prompts = await Prompt.objects.filter(conversation=conversation).order_by('created_at').alist()
|
||||||
|
print(f"prompts that we are seeding with are: {prompts}")
|
||||||
|
return "\n".join(
|
||||||
|
f"{'User' if prompt.is_user else 'AI'}: {prompt.text}"
|
||||||
|
for prompt in prompts
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _retriever_with_history(self, input_dict: Dict[str, Any]) -> str:
|
||||||
|
"""Retrieve documents considering conversation history."""
|
||||||
|
print(f"Retrieving history with input: {input_dict}")
|
||||||
|
query = input_dict["query"]
|
||||||
|
conversation = input_dict["conversation"]
|
||||||
|
workspace = input_dict["workspace"]
|
||||||
|
|
||||||
|
# You could enhance this to consider historical context in retrieval
|
||||||
|
docs= await self.search_documents(query, workspace)
|
||||||
|
|
||||||
|
if not docs:
|
||||||
|
print("Didn't find any relevant docs")
|
||||||
|
|
||||||
|
print("\n\n".join(doc.page_content for doc in docs))
|
||||||
|
return "\n\n".join(doc.page_content for doc in docs)
|
||||||
|
|
||||||
|
|
||||||
|
async def search_documents(self, query: str, workspace: Optional[DocumentWorkspace] = None, k: int = 4) -> List[Document]:
|
||||||
|
"""Search relevant documents from the vector store."""
|
||||||
|
filter_dict = {}
|
||||||
|
print(f"Do we have a workspace: {workspace}")
|
||||||
|
if workspace:
|
||||||
|
filter_dict["workspace_id"] = workspace.id
|
||||||
|
search_kwargs={
|
||||||
|
"k": k,
|
||||||
|
"filter": filter_dict if filter_dict else None
|
||||||
|
}
|
||||||
|
print(f"search_kwargs: {search_kwargs}")
|
||||||
|
|
||||||
|
retriever = self.vector_store.as_retriever(
|
||||||
|
search_type="mmr",
|
||||||
|
search_kwargs={
|
||||||
|
"k": k,
|
||||||
|
"filter": filter_dict if filter_dict else None
|
||||||
|
}
|
||||||
|
)
|
||||||
|
return await retriever.aget_relevant_documents(query)
|
||||||
|
|
||||||
|
async def generate_response(self, conversation: Conversation, query: str, workspace: DocumentWorkspace, **kwargs) -> AsyncGenerator[str, None]:
|
||||||
|
"""Generate response with streaming support."""
|
||||||
|
chain_input = {
|
||||||
|
"query": query,
|
||||||
|
"conversation": conversation,
|
||||||
|
"workspace": workspace,
|
||||||
|
}
|
||||||
|
|
||||||
|
async for chunk in self.rag_chain.astream(chain_input):
|
||||||
|
yield chunk
|
||||||
219
llm_be/chat_backend/services/tests.py
Normal file
219
llm_be/chat_backend/services/tests.py
Normal file
@@ -0,0 +1,219 @@
|
|||||||
|
import os
|
||||||
|
from unittest import TestCase, mock
|
||||||
|
from unittest.mock import MagicMock, patch, AsyncMock
|
||||||
|
from typing import List, Dict, Any
|
||||||
|
|
||||||
|
from django.test import TestCase as DjangoTestCase
|
||||||
|
|
||||||
|
from chat_backend.services.rag_services import RAGService, SyncRAGService, AsyncRAGService
|
||||||
|
from chat_backend.models import Conversation, Prompt, DocumentWorkspace, Document
|
||||||
|
|
||||||
|
class TestRAGService(TestCase):
|
||||||
|
def setUp(self):
|
||||||
|
self.rag_service = RAGService()
|
||||||
|
self.rag_service.vector_store = MagicMock()
|
||||||
|
self.rag_service.embedding_model = MagicMock()
|
||||||
|
self.rag_service.text_splitter = MagicMock()
|
||||||
|
|
||||||
|
def test_initialize_vector_store(self):
|
||||||
|
with patch('os.path.exists', return_value=False), \
|
||||||
|
patch('os.makedirs') as mock_makedirs, \
|
||||||
|
patch('langchain_community.vectorstores.Chroma') as mock_chroma:
|
||||||
|
|
||||||
|
# Reset the vector store to test initialization
|
||||||
|
self.rag_service.vector_store = None
|
||||||
|
result = self.rag_service._initialize_vector_store()
|
||||||
|
|
||||||
|
mock_makedirs.assert_called_once_with("chroma_db")
|
||||||
|
mock_chroma.assert_called_once_with(
|
||||||
|
embedding_function=self.rag_service.embedding_model,
|
||||||
|
persist_directory="chroma_db"
|
||||||
|
)
|
||||||
|
self.assertIsNotNone(result)
|
||||||
|
|
||||||
|
def test_prepare_documents(self):
|
||||||
|
mock_doc1 = MagicMock(spec=Document)
|
||||||
|
mock_doc1.content = "Test content"
|
||||||
|
mock_doc1.source = "test_source"
|
||||||
|
mock_doc1.workspace = MagicMock()
|
||||||
|
mock_doc1.workspace.id = 1
|
||||||
|
mock_doc1.id = 1
|
||||||
|
|
||||||
|
self.rag_service.text_splitter.split_text.return_value = ["chunk1", "chunk2"]
|
||||||
|
|
||||||
|
result = self.rag_service._prepare_documents([mock_doc1])
|
||||||
|
|
||||||
|
self.assertEqual(len(result), 2)
|
||||||
|
self.rag_service.text_splitter.split_text.assert_called_once_with("Test content")
|
||||||
|
self.assertEqual(result[0].page_content, "chunk1")
|
||||||
|
self.assertEqual(result[0].metadata["source"], "test_source")
|
||||||
|
|
||||||
|
def test_ingest_documents(self):
|
||||||
|
mock_workspace = MagicMock()
|
||||||
|
mock_document = MagicMock()
|
||||||
|
mock_documents = [mock_document]
|
||||||
|
|
||||||
|
with patch('services.rag_services.Document.objects.filter', return_value=mock_documents):
|
||||||
|
self.rag_service._prepare_documents = MagicMock(return_value=["processed_doc"])
|
||||||
|
|
||||||
|
self.rag_service.ingest_documents(mock_workspace)
|
||||||
|
|
||||||
|
self.rag_service.vector_store.add_documents.assert_called_once_with(["processed_doc"])
|
||||||
|
self.rag_service.vector_store.persist.assert_called_once()
|
||||||
|
|
||||||
|
|
||||||
|
class TestSyncRAGService(DjangoTestCase):
|
||||||
|
def setUp(self):
|
||||||
|
self.sync_service = SyncRAGService()
|
||||||
|
self.sync_service.vector_store = MagicMock()
|
||||||
|
self.sync_service.llm = MagicMock()
|
||||||
|
self.sync_service.rag_chain = MagicMock()
|
||||||
|
|
||||||
|
self.mock_conversation = MagicMock(spec=Conversation)
|
||||||
|
self.mock_conversation.workspace = MagicMock()
|
||||||
|
|
||||||
|
self.mock_prompt1 = MagicMock(spec=Prompt)
|
||||||
|
self.mock_prompt1.is_user = True
|
||||||
|
self.mock_prompt1.text = "User question"
|
||||||
|
self.mock_prompt1.created_at = "2023-01-01"
|
||||||
|
|
||||||
|
self.mock_prompt2 = MagicMock(spec=Prompt)
|
||||||
|
self.mock_prompt2.is_user = False
|
||||||
|
self.mock_prompt2.text = "AI response"
|
||||||
|
self.mock_prompt2.created_at = "2023-01-02"
|
||||||
|
|
||||||
|
def test_format_history(self):
|
||||||
|
with patch('services.rag_services.Prompt.objects.filter') as mock_filter:
|
||||||
|
mock_filter.return_value.order_by.return_value = [self.mock_prompt1, self.mock_prompt2]
|
||||||
|
|
||||||
|
result = self.sync_service._format_history(self.mock_conversation)
|
||||||
|
|
||||||
|
expected = "User: User question\nAI: AI response"
|
||||||
|
self.assertEqual(result, expected)
|
||||||
|
mock_filter.assert_called_once_with(conversation=self.mock_conversation)
|
||||||
|
|
||||||
|
def test_retriever_with_history(self):
|
||||||
|
input_dict = {
|
||||||
|
"query": "test query",
|
||||||
|
"conversation": self.mock_conversation
|
||||||
|
}
|
||||||
|
|
||||||
|
self.sync_service.search_documents = MagicMock(return_value=["doc1", "doc2"])
|
||||||
|
|
||||||
|
result = self.sync_service._retriever_with_history(input_dict)
|
||||||
|
|
||||||
|
self.sync_service.search_documents.assert_called_once_with(
|
||||||
|
"test query",
|
||||||
|
self.mock_conversation.workspace
|
||||||
|
)
|
||||||
|
self.assertEqual(result, ["doc1", "doc2"])
|
||||||
|
|
||||||
|
def test_search_documents(self):
|
||||||
|
mock_retriever = MagicMock()
|
||||||
|
mock_retriever.get_relevant_documents.return_value = ["doc1", "doc2"]
|
||||||
|
self.sync_service.vector_store.as_retriever.return_value = mock_retriever
|
||||||
|
|
||||||
|
result = self.sync_service.search_documents("test query", self.mock_conversation.workspace)
|
||||||
|
|
||||||
|
self.sync_service.vector_store.as_retriever.assert_called_once_with(
|
||||||
|
search_type="similarity",
|
||||||
|
search_kwargs={
|
||||||
|
"k": 4,
|
||||||
|
"filter": {"workspace_id": self.mock_conversation.workspace.id}
|
||||||
|
}
|
||||||
|
)
|
||||||
|
self.assertEqual(result, ["doc1", "doc2"])
|
||||||
|
|
||||||
|
def test_generate_response(self):
|
||||||
|
chain_input = {
|
||||||
|
"query": "test query",
|
||||||
|
"conversation": self.mock_conversation
|
||||||
|
}
|
||||||
|
|
||||||
|
mock_stream = ["chunk1", "chunk2", "chunk3"]
|
||||||
|
self.sync_service.rag_chain.stream.return_value = mock_stream
|
||||||
|
|
||||||
|
result = list(self.sync_service.generate_response(self.mock_conversation, "test query"))
|
||||||
|
|
||||||
|
self.sync_service.rag_chain.stream.assert_called_once_with(chain_input)
|
||||||
|
self.assertEqual(result, mock_stream)
|
||||||
|
|
||||||
|
|
||||||
|
class TestAsyncRAGService(DjangoTestCase):
|
||||||
|
def setUp(self):
|
||||||
|
self.async_service = AsyncRAGService()
|
||||||
|
self.async_service.vector_store = MagicMock()
|
||||||
|
self.async_service.llm = MagicMock()
|
||||||
|
self.async_service.rag_chain = AsyncMock()
|
||||||
|
|
||||||
|
self.mock_conversation = MagicMock(spec=Conversation)
|
||||||
|
self.mock_conversation.workspace = MagicMock()
|
||||||
|
|
||||||
|
self.mock_prompt1 = MagicMock(spec=Prompt)
|
||||||
|
self.mock_prompt1.is_user = True
|
||||||
|
self.mock_prompt1.text = "User question"
|
||||||
|
self.mock_prompt1.created_at = "2023-01-01"
|
||||||
|
|
||||||
|
self.mock_prompt2 = MagicMock(spec=Prompt)
|
||||||
|
self.mock_prompt2.is_user = False
|
||||||
|
self.mock_prompt2.text = "AI response"
|
||||||
|
self.mock_prompt2.created_at = "2023-01-02"
|
||||||
|
|
||||||
|
async def test_format_history(self):
|
||||||
|
mock_manager = AsyncMock()
|
||||||
|
mock_manager.order_by.return_value.alist.return_value = [self.mock_prompt1, self.mock_prompt2]
|
||||||
|
|
||||||
|
with patch('services.rag_services.Prompt.objects.filter', return_value=mock_manager):
|
||||||
|
result = await self.async_service._format_history(self.mock_conversation)
|
||||||
|
|
||||||
|
expected = "User: User question\nAI: AI response"
|
||||||
|
self.assertEqual(result, expected)
|
||||||
|
mock_manager.order_by.assert_called_once_with('created_at')
|
||||||
|
|
||||||
|
async def test_retriever_with_history(self):
|
||||||
|
input_dict = {
|
||||||
|
"query": "test query",
|
||||||
|
"conversation": self.mock_conversation
|
||||||
|
}
|
||||||
|
|
||||||
|
self.async_service.search_documents = AsyncMock(return_value=["doc1", "doc2"])
|
||||||
|
|
||||||
|
result = await self.async_service._retriever_with_history(input_dict)
|
||||||
|
|
||||||
|
self.async_service.search_documents.assert_awaited_once_with(
|
||||||
|
"test query",
|
||||||
|
self.mock_conversation.workspace
|
||||||
|
)
|
||||||
|
self.assertEqual(result, ["doc1", "doc2"])
|
||||||
|
|
||||||
|
async def test_search_documents(self):
|
||||||
|
mock_retriever = AsyncMock()
|
||||||
|
mock_retriever.aget_relevant_documents.return_value = ["doc1", "doc2"]
|
||||||
|
self.async_service.vector_store.as_retriever.return_value = mock_retriever
|
||||||
|
|
||||||
|
result = await self.async_service.search_documents("test query", self.mock_conversation.workspace)
|
||||||
|
|
||||||
|
self.async_service.vector_store.as_retriever.assert_called_once_with(
|
||||||
|
search_type="similarity",
|
||||||
|
search_kwargs={
|
||||||
|
"k": 4,
|
||||||
|
"filter": {"workspace_id": self.mock_conversation.workspace.id}
|
||||||
|
}
|
||||||
|
)
|
||||||
|
self.assertEqual(result, ["doc1", "doc2"])
|
||||||
|
|
||||||
|
async def test_generate_response(self):
|
||||||
|
chain_input = {
|
||||||
|
"query": "test query",
|
||||||
|
"conversation": self.mock_conversation
|
||||||
|
}
|
||||||
|
|
||||||
|
mock_stream = ["chunk1", "chunk2", "chunk3"]
|
||||||
|
self.async_service.rag_chain.astream.return_value = mock_stream
|
||||||
|
|
||||||
|
chunks = []
|
||||||
|
async for chunk in self.async_service.generate_response(self.mock_conversation, "test query"):
|
||||||
|
chunks.append(chunk)
|
||||||
|
|
||||||
|
self.async_service.rag_chain.astream.assert_awaited_once_with(chain_input)
|
||||||
|
self.assertEqual(chunks, mock_stream)
|
||||||
67
llm_be/chat_backend/services/title_generator.py
Normal file
67
llm_be/chat_backend/services/title_generator.py
Normal file
@@ -0,0 +1,67 @@
|
|||||||
|
from langchain_core.prompts import ChatPromptTemplate
|
||||||
|
from langchain_community.llms import Ollama
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
class TitleGenerator:
|
||||||
|
"""
|
||||||
|
Generates short, descriptive titles for conversations based on the first prompt.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.llm = Ollama(
|
||||||
|
model="llama3",
|
||||||
|
temperature=0.5, # Slightly creative but not too random
|
||||||
|
top_k=20,
|
||||||
|
num_ctx=2048 # Shorter context needed for titles
|
||||||
|
)
|
||||||
|
|
||||||
|
self.title_prompt = ChatPromptTemplate.from_messages([
|
||||||
|
("system", """You are a conversation title generator. Create a very short (2-5 word) title based on the user's first message.
|
||||||
|
|
||||||
|
Rules:
|
||||||
|
1. Keep it extremely concise
|
||||||
|
2. Capture the main topic or intent
|
||||||
|
3. Use title case
|
||||||
|
4. No quotes or punctuation
|
||||||
|
5. Never exceed 5 words
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
- "What's the weather today?" → "Weather Inquiry"
|
||||||
|
- "Explain quantum computing" → "Quantum Computing Explanation"
|
||||||
|
- "Generate an image of a dragon" → "Dragon Image Generation"
|
||||||
|
- "Find our company's privacy policy" → "Privacy Policy Search"
|
||||||
|
|
||||||
|
Return ONLY the title, nothing else."""),
|
||||||
|
("human", "{prompt}")
|
||||||
|
])
|
||||||
|
|
||||||
|
self.chain = self.title_prompt | self.llm
|
||||||
|
|
||||||
|
async def generate_async(self, prompt: str) -> str:
|
||||||
|
"""Generate title asynchronously"""
|
||||||
|
try:
|
||||||
|
response = await self.chain.ainvoke({"prompt": prompt})
|
||||||
|
return self._clean_response(response)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Title generation error: {e}")
|
||||||
|
return "Conversation"
|
||||||
|
|
||||||
|
def generate(self, prompt: str) -> str:
|
||||||
|
"""Generate title synchronously"""
|
||||||
|
try:
|
||||||
|
response = self.chain.invoke({"prompt": prompt})
|
||||||
|
return self._clean_response(response)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Title generation error: {e}")
|
||||||
|
return "Conversation"
|
||||||
|
|
||||||
|
def _clean_response(self, response: str) -> str:
|
||||||
|
"""Clean and format the LLM response"""
|
||||||
|
# Remove any quotes or punctuation
|
||||||
|
response = response.strip('"\'.!? \n\t')
|
||||||
|
# Ensure title case and trim
|
||||||
|
return response.title()[:50] # Hard limit for safety
|
||||||
|
|
||||||
|
|
||||||
|
# Singleton instance
|
||||||
|
title_generator = TitleGenerator()
|
||||||
18
llm_be/chat_backend/signals.py
Normal file
18
llm_be/chat_backend/signals.py
Normal file
@@ -0,0 +1,18 @@
|
|||||||
|
from django.db.models.signals import post_save, post_delete
|
||||||
|
from django.dispatch import receiver
|
||||||
|
from chat_backend.models import Document
|
||||||
|
from .services.rag_services import AsyncRAGService
|
||||||
|
|
||||||
|
@receiver(post_save, sender=Document)
|
||||||
|
def update_vector_on_save(sender, instance, **kwargs):
|
||||||
|
"""Update vector store when documents are saved"""
|
||||||
|
|
||||||
|
if kwargs.get('created', False):
|
||||||
|
rag_service = AsyncRAGService()
|
||||||
|
rag_service.ingest_documents()
|
||||||
|
|
||||||
|
@receiver(post_delete, sender=Document)
|
||||||
|
def delete_vector_on_remove(sender, instance, **kwargs):
|
||||||
|
"""Handle document deletion by re-indexing the whole workspace"""
|
||||||
|
rag_service = AsyncRAGService()
|
||||||
|
rag_service.ingest_documents()
|
||||||
97
llm_be/chat_backend/templates/emails/reset_email.html
Normal file
97
llm_be/chat_backend/templates/emails/reset_email.html
Normal file
@@ -0,0 +1,97 @@
|
|||||||
|
<!DOCTYPE html>
|
||||||
|
<html lang="en">
|
||||||
|
<head>
|
||||||
|
<meta charset="UTF-8">
|
||||||
|
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
||||||
|
<title>Reset Password for Chat by AI ML Operations, LLC</title>
|
||||||
|
<style>
|
||||||
|
/* Basic reset for email clients */
|
||||||
|
body, table, td, a {
|
||||||
|
-webkit-text-size-adjust: 100%;
|
||||||
|
-ms-text-size-adjust: 100%;
|
||||||
|
}
|
||||||
|
table, td {
|
||||||
|
mso-table-lspace: 0pt;
|
||||||
|
mso-table-rspace: 0pt;
|
||||||
|
}
|
||||||
|
img {
|
||||||
|
border: 0;
|
||||||
|
height: auto;
|
||||||
|
line-height: 100%;
|
||||||
|
outline: none;
|
||||||
|
text-decoration: none;
|
||||||
|
-ms-interpolation-mode: bicubic;
|
||||||
|
}
|
||||||
|
body {
|
||||||
|
margin: 0;
|
||||||
|
padding: 0;
|
||||||
|
font-family: Arial, sans-serif;
|
||||||
|
background-color: #f4f4f4;
|
||||||
|
}
|
||||||
|
.email-container {
|
||||||
|
max-width: 600px;
|
||||||
|
margin: 0 auto;
|
||||||
|
background-color: #ffffff;
|
||||||
|
border: 1px solid #dddddd;
|
||||||
|
}
|
||||||
|
.header {
|
||||||
|
background-color: #007BFF;
|
||||||
|
color: #ffffff;
|
||||||
|
padding: 20px;
|
||||||
|
text-align: center;
|
||||||
|
}
|
||||||
|
.content {
|
||||||
|
padding: 20px;
|
||||||
|
color: #333333;
|
||||||
|
}
|
||||||
|
.footer {
|
||||||
|
background-color: #f4f4f4;
|
||||||
|
color: #777777;
|
||||||
|
text-align: center;
|
||||||
|
padding: 10px;
|
||||||
|
font-size: 12px;
|
||||||
|
}
|
||||||
|
.feedback-title {
|
||||||
|
font-size: 18px;
|
||||||
|
font-weight: bold;
|
||||||
|
margin-bottom: 10px;
|
||||||
|
}
|
||||||
|
.feedback-text {
|
||||||
|
font-size: 14px;
|
||||||
|
line-height: 1.5;
|
||||||
|
}
|
||||||
|
</style>
|
||||||
|
</head>
|
||||||
|
<body>
|
||||||
|
<table role="presentation" width="100%" cellspacing="0" cellpadding="0" border="0" align="center">
|
||||||
|
<tr>
|
||||||
|
<td>
|
||||||
|
<!-- Email Container -->
|
||||||
|
<div class="email-container">
|
||||||
|
<!-- Header -->
|
||||||
|
<div class="header">
|
||||||
|
<h1>Password Reset for AI ML Operations, LLC Chat Services</h1>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<!-- Content -->
|
||||||
|
<div class="content">
|
||||||
|
<p>Hello,</p>
|
||||||
|
<p>There has been a request for a password reset. If you didn't requets this, please email ryan@aimloperations.com</p>
|
||||||
|
|
||||||
|
<p>Please click <a href="{{ url }}">link</a> to set your password.</p>
|
||||||
|
<p>Once you have set your password go <a href="https://chat.aimloperations.com">here</a> to get started.</p>
|
||||||
|
|
||||||
|
<p>Thank you.</p>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<!-- Footer -->
|
||||||
|
<div class="footer">
|
||||||
|
<p>This is an automated message. Please do not reply to this email.</p>
|
||||||
|
<p>© 2023-2025 AI ML Operations, LLC. All rights reserved.</p>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
</td>
|
||||||
|
</tr>
|
||||||
|
</table>
|
||||||
|
</body>
|
||||||
|
</html>
|
||||||
3
llm_be/chat_backend/templates/emails/reset_email.txt
Normal file
3
llm_be/chat_backend/templates/emails/reset_email.txt
Normal file
@@ -0,0 +1,3 @@
|
|||||||
|
Password Reset for AI ML Operations, LLC Chat Services
|
||||||
|
|
||||||
|
"Password reset for chat.aimloperations.com. Please use {{ url }} to set your password"
|
||||||
@@ -1,3 +1,210 @@
|
|||||||
from django.test import TestCase
|
from django.test import TestCase
|
||||||
|
|
||||||
# Create your tests here.
|
# Create your tests here.
|
||||||
|
from django.test import TestCase, Client
|
||||||
|
from django.urls import reverse
|
||||||
|
from django.contrib.auth.models import User
|
||||||
|
from rest_framework.test import APIClient, APITestCase
|
||||||
|
from rest_framework import status
|
||||||
|
from .models import DocumentWorkspace, Document, Company
|
||||||
|
from django.contrib.auth import get_user_model
|
||||||
|
import tempfile
|
||||||
|
from django.core.files.uploadedfile import SimpleUploadedFile
|
||||||
|
|
||||||
|
# Minimal valid PDF bytes
|
||||||
|
VALID_PDF_BYTES = (
|
||||||
|
b'%PDF-1.3\n'
|
||||||
|
b'1 0 obj\n'
|
||||||
|
b'<< /Type /Catalog /Pages 2 0 R >>\n'
|
||||||
|
b'endobj\n'
|
||||||
|
b'2 0 obj\n'
|
||||||
|
b'<< /Type /Pages /Kids [3 0 R] /Count 1 >>\n'
|
||||||
|
b'endobj\n'
|
||||||
|
b'3 0 obj\n'
|
||||||
|
b'<< /Type /Page /Parent 2 0 R /Resources << >> /MediaBox [0 0 612 792] /Contents 4 0 R >>\n'
|
||||||
|
b'endobj\n'
|
||||||
|
b'4 0 obj\n'
|
||||||
|
b'<< /Length 44 >>\n'
|
||||||
|
b'stream\n'
|
||||||
|
b'BT /F1 12 Tf 72 720 Td (Test PDF) Tj ET\n'
|
||||||
|
b'endstream\n'
|
||||||
|
b'endobj\n'
|
||||||
|
b'xref\n'
|
||||||
|
b'0 5\n'
|
||||||
|
b'0000000000 65535 f \n'
|
||||||
|
b'0000000009 00000 n \n'
|
||||||
|
b'0000000058 00000 n \n'
|
||||||
|
b'0000000117 00000 n \n'
|
||||||
|
b'0000000223 00000 n \n'
|
||||||
|
b'trailer\n'
|
||||||
|
b'<< /Size 5 /Root 1 0 R >>\n'
|
||||||
|
b'startxref\n'
|
||||||
|
b'317\n'
|
||||||
|
b'%%EOF'
|
||||||
|
)
|
||||||
|
|
||||||
|
class DocumentWorkspaceViewsTestCase(APITestCase):
|
||||||
|
def setUp(self):
|
||||||
|
self.company = Company.objects.create(
|
||||||
|
name="test",
|
||||||
|
state="IL",
|
||||||
|
zipcode="60189",
|
||||||
|
address="1968 Greensboro Dr"
|
||||||
|
)
|
||||||
|
self.user = get_user_model().objects.create_user(
|
||||||
|
company=self.company,
|
||||||
|
username='testuser',
|
||||||
|
password='testpass123',
|
||||||
|
email="test@test.com",
|
||||||
|
)
|
||||||
|
|
||||||
|
self.client = APIClient()
|
||||||
|
self.client.force_authenticate(user=self.user)
|
||||||
|
|
||||||
|
self.workspace = DocumentWorkspace.objects.create(
|
||||||
|
company = self.user.company,
|
||||||
|
name='Test Workspace'
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_list_workspaces(self):
|
||||||
|
url = reverse('document_workspaces')
|
||||||
|
response = self.client.get(url)
|
||||||
|
self.assertEqual(response.status_code, status.HTTP_200_OK)
|
||||||
|
self.assertEqual(len(response.data), 1)
|
||||||
|
self.assertEqual(response.data[0]['name'], 'Test Workspace')
|
||||||
|
|
||||||
|
def test_create_workspace(self):
|
||||||
|
url = reverse('document_workspaces')
|
||||||
|
data = {
|
||||||
|
'name': 'New Workspace'
|
||||||
|
}
|
||||||
|
response = self.client.post(url, data, format='json')
|
||||||
|
self.assertEqual(response.status_code, status.HTTP_201_CREATED)
|
||||||
|
self.assertEqual(DocumentWorkspace.objects.count(), 2)
|
||||||
|
|
||||||
|
def test_retrieve_workspace(self):
|
||||||
|
url = reverse('document_workspaces')
|
||||||
|
response = self.client.get(url)
|
||||||
|
self.assertEqual(response.status_code, status.HTTP_200_OK)
|
||||||
|
self.assertEqual(response.data[0]['name'], 'Test Workspace')
|
||||||
|
|
||||||
|
# def test_update_workspace(self):
|
||||||
|
# url = reverse('document_workspaces')
|
||||||
|
# data = {
|
||||||
|
# 'name': 'Updated Workspace'
|
||||||
|
# }
|
||||||
|
# response = self.client.post(url, data, format='json')
|
||||||
|
# self.assertEqual(response.status_code, status.HTTP_201_CREATED)
|
||||||
|
# self.workspace.refresh_from_db()
|
||||||
|
# self.assertEqual(self.workspace.name, 'Updated Workspace')
|
||||||
|
|
||||||
|
# def test_delete_workspace(self):
|
||||||
|
# url = reverse('document_workspaces', args=[self.workspace.id])
|
||||||
|
# response = self.client.delete(url)
|
||||||
|
# self.assertEqual(response.status_code, status.HTTP_204_NO_CONTENT)
|
||||||
|
# self.assertEqual(DocumentWorkspace.objects.count(), 0)
|
||||||
|
|
||||||
|
class DocumentViewsTestCase(APITestCase):
|
||||||
|
def setUp(self):
|
||||||
|
self.company = Company.objects.create(
|
||||||
|
name="test",
|
||||||
|
state="IL",
|
||||||
|
zipcode="60189",
|
||||||
|
address="1968 Greensboro Dr"
|
||||||
|
)
|
||||||
|
self.user = get_user_model().objects.create_user(
|
||||||
|
company=self.company,
|
||||||
|
username='testuser',
|
||||||
|
password='testpass123',
|
||||||
|
email="test@test.com",
|
||||||
|
)
|
||||||
|
|
||||||
|
self.client = APIClient()
|
||||||
|
self.client.force_authenticate(user=self.user)
|
||||||
|
|
||||||
|
self.workspace = DocumentWorkspace.objects.create(
|
||||||
|
company=self.user.company,
|
||||||
|
name='Test Workspace'
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create a test file
|
||||||
|
self.test_file = SimpleUploadedFile(
|
||||||
|
"test.pdf",
|
||||||
|
VALID_PDF_BYTES,
|
||||||
|
content_type="application/pdf"
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_upload_document(self):
|
||||||
|
url = reverse('documents')
|
||||||
|
data = {
|
||||||
|
'file': self.test_file
|
||||||
|
}
|
||||||
|
response = self.client.post(url, data, format='multipart')
|
||||||
|
self.assertEqual(response.status_code, status.HTTP_201_CREATED)
|
||||||
|
self.assertEqual(Document.objects.count(), 1)
|
||||||
|
|
||||||
|
document = Document.objects.first()
|
||||||
|
self.assertEqual(document.workspace.id, self.workspace.id)
|
||||||
|
self.assertTrue(document.processed) # Should be False initially
|
||||||
|
|
||||||
|
def test_list_documents(self):
|
||||||
|
# First create a document
|
||||||
|
Document.objects.create(
|
||||||
|
workspace=self.workspace,
|
||||||
|
file=self.test_file
|
||||||
|
)
|
||||||
|
|
||||||
|
url = reverse('documents')
|
||||||
|
response = self.client.get(url)
|
||||||
|
self.assertEqual(response.status_code, status.HTTP_200_OK)
|
||||||
|
self.assertEqual(len(response.data), 1)
|
||||||
|
self.assertIn('test', response.data[0]['file'])
|
||||||
|
self.assertIn('pdf', response.data[0]['file'])
|
||||||
|
|
||||||
|
# def test_delete_document(self):
|
||||||
|
# document = Document.objects.create(
|
||||||
|
# workspace=self.workspace,
|
||||||
|
# file=self.test_file
|
||||||
|
# )
|
||||||
|
|
||||||
|
# url = reverse('document-detail', args=[document.id])
|
||||||
|
# response = self.client.delete(url)
|
||||||
|
# self.assertEqual(response.status_code, status.HTTP_204_NO_CONTENT)
|
||||||
|
# self.assertEqual(Document.objects.count(), 0)
|
||||||
|
|
||||||
|
def test_upload_invalid_file(self):
|
||||||
|
url = reverse('documents')
|
||||||
|
data = {
|
||||||
|
'file': 'not a file'
|
||||||
|
}
|
||||||
|
response = self.client.post(url, data, format='multipart')
|
||||||
|
self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
|
||||||
|
|
||||||
|
def test_access_other_users_documents(self):
|
||||||
|
# Create another user
|
||||||
|
other_company = Company.objects.create(
|
||||||
|
name="test2",
|
||||||
|
state="IL",
|
||||||
|
zipcode="60189",
|
||||||
|
address="1968 Greensboro Dr"
|
||||||
|
)
|
||||||
|
other_user = get_user_model().objects.create_user(
|
||||||
|
company=other_company,
|
||||||
|
username='otheruser',
|
||||||
|
password='otherpass123',
|
||||||
|
email="testing2@test.com"
|
||||||
|
)
|
||||||
|
other_workspace = DocumentWorkspace.objects.create(
|
||||||
|
company = other_user.company,
|
||||||
|
name='Other Workspace'
|
||||||
|
)
|
||||||
|
other_document = Document.objects.create(
|
||||||
|
workspace=other_workspace,
|
||||||
|
file=self.test_file
|
||||||
|
)
|
||||||
|
|
||||||
|
# Try to access the other user's document
|
||||||
|
url = reverse('documents_details', kwargs={"document_id":other_document.id})
|
||||||
|
response = self.client.get(url)
|
||||||
|
self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND)
|
||||||
|
|
||||||
@@ -14,27 +14,42 @@ from .views import (
|
|||||||
ConversationDetailView,
|
ConversationDetailView,
|
||||||
CompanyUsersView,
|
CompanyUsersView,
|
||||||
SetUserPassword,
|
SetUserPassword,
|
||||||
|
ResetUserPassword,
|
||||||
ConversationPreferences,
|
ConversationPreferences,
|
||||||
UserPromptAnalytics,
|
UserPromptAnalytics,
|
||||||
UserConversationAnalytics,
|
UserConversationAnalytics,
|
||||||
CompanyUsageAnalytics,
|
CompanyUsageAnalytics,
|
||||||
AdminAnalytics
|
AdminAnalytics,
|
||||||
|
reset_password,
|
||||||
|
DocumentWorkspaceView,
|
||||||
|
DocumentUploadView,
|
||||||
|
DocumentDetailView
|
||||||
|
|
||||||
)
|
)
|
||||||
|
from rest_framework.routers import DefaultRouter
|
||||||
|
|
||||||
|
|
||||||
urlpatterns = [
|
urlpatterns = [
|
||||||
path("token/obtain/", CustomObtainTokenView.as_view(), name="token_create"),
|
path("token/obtain/", CustomObtainTokenView.as_view(), name="token_create"),
|
||||||
path("token/refresh/", jwt_views.TokenRefreshView.as_view(), name="token_refresh"),
|
path("token/refresh/", jwt_views.TokenRefreshView.as_view(), name="token_refresh"),
|
||||||
path("user/create/", CustomUserCreate.as_view(), name="create_user"),
|
path("user/create/", CustomUserCreate.as_view(), name="create_user"),
|
||||||
path("user/invite/", CustomUserInvite.as_view(), name="invite_user"),
|
path("user/invite/", CustomUserInvite.as_view(), name="invite_user"),
|
||||||
path("user/set_password/<slug:slug>/", SetUserPassword.as_view(), name="set_password"),
|
path("user/reset_password/", reset_password, name="reset_password"),
|
||||||
|
path(
|
||||||
|
"user/set_password/<slug:slug>/", SetUserPassword.as_view(), name="set_password"
|
||||||
|
),
|
||||||
path(
|
path(
|
||||||
"blacklist/",
|
"blacklist/",
|
||||||
LogoutAndBlacklistRefreshTokenForUserView.as_view(),
|
LogoutAndBlacklistRefreshTokenForUserView.as_view(),
|
||||||
name="blacklist",
|
name="blacklist",
|
||||||
),
|
),
|
||||||
path("user/get/", CustomUserGet.as_view(), name="get_user"),
|
path("user/get/", CustomUserGet.as_view(), name="get_user"),
|
||||||
path("user/acknowledge_tos/", AcknowledgeTermsOfService.as_view(), name="acknowledge_tos"),
|
path(
|
||||||
path("company_users",CompanyUsersView.as_view(), name="company_users"),
|
"user/acknowledge_tos/",
|
||||||
|
AcknowledgeTermsOfService.as_view(),
|
||||||
|
name="acknowledge_tos",
|
||||||
|
),
|
||||||
|
path("company_users", CompanyUsersView.as_view(), name="company_users"),
|
||||||
path("user/is_authenticated/", is_authenticated, name="is_authenticated"),
|
path("user/is_authenticated/", is_authenticated, name="is_authenticated"),
|
||||||
path("announcment/get/", AnnouncmentView.as_view(), name="get_announcments"),
|
path("announcment/get/", AnnouncmentView.as_view(), name="get_announcments"),
|
||||||
path("conversations", ConversationsView.as_view(), name="conversations"),
|
path("conversations", ConversationsView.as_view(), name="conversations"),
|
||||||
@@ -44,9 +59,32 @@ urlpatterns = [
|
|||||||
ConversationDetailView.as_view(),
|
ConversationDetailView.as_view(),
|
||||||
name="conversation_details",
|
name="conversation_details",
|
||||||
),
|
),
|
||||||
path("conversation_preferences", ConversationPreferences.as_view(), name="conversation_preferences"),
|
path(
|
||||||
path("analytics/user_prompts/", UserPromptAnalytics.as_view(), name="analytics_user_prompts"),
|
"conversation_preferences",
|
||||||
path("analytics/user_conversations/", UserConversationAnalytics.as_view(), name="analytics_user_conversations"),
|
ConversationPreferences.as_view(),
|
||||||
path("analytics/company_usage/", CompanyUsageAnalytics.as_view(), name="analytics_company_usage"),
|
name="conversation_preferences",
|
||||||
|
),
|
||||||
|
path(
|
||||||
|
"analytics/user_prompts/",
|
||||||
|
UserPromptAnalytics.as_view(),
|
||||||
|
name="analytics_user_prompts",
|
||||||
|
),
|
||||||
|
path(
|
||||||
|
"analytics/user_conversations/",
|
||||||
|
UserConversationAnalytics.as_view(),
|
||||||
|
name="analytics_user_conversations",
|
||||||
|
),
|
||||||
|
path(
|
||||||
|
"analytics/company_usage/",
|
||||||
|
CompanyUsageAnalytics.as_view(),
|
||||||
|
name="analytics_company_usage",
|
||||||
|
),
|
||||||
path("analytics/admin/", AdminAnalytics.as_view(), name="analytics_admin"),
|
path("analytics/admin/", AdminAnalytics.as_view(), name="analytics_admin"),
|
||||||
|
|
||||||
|
# document urls
|
||||||
|
path("document_workspaces/", DocumentWorkspaceView.as_view(), name="document_workspaces"),
|
||||||
|
path("documents/", DocumentUploadView.as_view(), name="documents"),
|
||||||
|
path("documents_details/<int:document_id>", DocumentDetailView.as_view(), name="documents_details"),
|
||||||
|
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|||||||
@@ -1,7 +1,8 @@
|
|||||||
import datetime
|
import datetime
|
||||||
|
|
||||||
|
|
||||||
def last_day_of_month(any_day):
|
def last_day_of_month(any_day):
|
||||||
# The day 28 exists in every month. 4 days later, it's always next month
|
# The day 28 exists in every month. 4 days later, it's always next month
|
||||||
next_month = any_day.replace(day=28) + datetime.timedelta(days=4)
|
next_month = any_day.replace(day=28) + datetime.timedelta(days=4)
|
||||||
# subtracting the number of the current day brings us back one month
|
# subtracting the number of the current day brings us back one month
|
||||||
return next_month - datetime.timedelta(days=next_month.day)
|
return next_month - datetime.timedelta(days=next_month.day)
|
||||||
|
|||||||
File diff suppressed because it is too large
Load Diff
@@ -13,13 +13,14 @@ from django.core.asgi import get_asgi_application
|
|||||||
from channels.routing import ProtocolTypeRouter, URLRouter
|
from channels.routing import ProtocolTypeRouter, URLRouter
|
||||||
from channels.auth import AuthMiddlewareStack
|
from channels.auth import AuthMiddlewareStack
|
||||||
import chat_backend.routing
|
import chat_backend.routing
|
||||||
os.environ.setdefault('DJANGO_SETTINGS_MODULE', 'llm_be.settings')
|
|
||||||
|
|
||||||
application = ProtocolTypeRouter({
|
os.environ.setdefault("DJANGO_SETTINGS_MODULE", "llm_be.settings")
|
||||||
"http": get_asgi_application(),
|
|
||||||
"websocket": AuthMiddlewareStack(
|
application = ProtocolTypeRouter(
|
||||||
URLRouter(
|
{
|
||||||
chat_backend.routing.websocket_urlpatterns
|
"http": get_asgi_application(),
|
||||||
)
|
"websocket": AuthMiddlewareStack(
|
||||||
),
|
URLRouter(chat_backend.routing.websocket_urlpatterns)
|
||||||
})
|
),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|||||||
@@ -22,80 +22,86 @@ BASE_DIR = Path(__file__).resolve().parent.parent
|
|||||||
# See https://docs.djangoproject.com/en/3.2/howto/deployment/checklist/
|
# See https://docs.djangoproject.com/en/3.2/howto/deployment/checklist/
|
||||||
|
|
||||||
# SECURITY WARNING: keep the secret key used in production secret!
|
# SECURITY WARNING: keep the secret key used in production secret!
|
||||||
SECRET_KEY = 'django-insecure-6suk6fj5q2)1tj%)f(wgw1smnliv5-#&@zvgvj1wp#(#@h#31x'
|
SECRET_KEY = "django-insecure-6suk6fj5q2)1tj%)f(wgw1smnliv5-#&@zvgvj1wp#(#@h#31x"
|
||||||
|
|
||||||
# SECURITY WARNING: don't run with debug turned on in production!
|
# SECURITY WARNING: don't run with debug turned on in production!
|
||||||
DEBUG = True
|
DEBUG = True
|
||||||
|
CORS_ALLOW_CREDENTIALS = False
|
||||||
ALLOWED_HOSTS = ['*.aimloperations.com','localhost','127.0.0.1','chat.aimloperations.com','chatbackend.aimloperations.com']
|
ALLOWED_HOSTS = [
|
||||||
|
"*.aimloperations.com",
|
||||||
|
"localhost",
|
||||||
|
"127.0.0.1",
|
||||||
|
"localhost:3000",
|
||||||
|
"127.0.0.1:3000",
|
||||||
|
"chat.aimloperations.com",
|
||||||
|
"chatbackend.aimloperations.com",
|
||||||
|
]
|
||||||
CORS_ORIGIN_ALLOW_ALL = True
|
CORS_ORIGIN_ALLOW_ALL = True
|
||||||
|
CSRF_TRUSTED_ORIGINS = ["http://localhost", "http://127.0.0.1", "http://localhost:3000"]
|
||||||
|
|
||||||
|
|
||||||
# Application definition
|
# Application definition
|
||||||
|
|
||||||
INSTALLED_APPS = [
|
INSTALLED_APPS = [
|
||||||
'daphne',
|
"daphne",
|
||||||
'django.contrib.admin',
|
"django.contrib.admin",
|
||||||
'django.contrib.auth',
|
"django.contrib.auth",
|
||||||
'django.contrib.contenttypes',
|
"django.contrib.contenttypes",
|
||||||
'django.contrib.sessions',
|
"django.contrib.sessions",
|
||||||
'django.contrib.messages',
|
"django.contrib.messages",
|
||||||
'django.contrib.staticfiles',
|
"django.contrib.staticfiles",
|
||||||
'chat_backend',
|
"chat_backend",
|
||||||
'rest_framework',
|
"rest_framework",
|
||||||
'corsheaders',
|
"corsheaders",
|
||||||
'rest_framework_simplejwt.token_blacklist',
|
"rest_framework_simplejwt.token_blacklist",
|
||||||
|
|
||||||
]
|
]
|
||||||
|
|
||||||
MIDDLEWARE = [
|
MIDDLEWARE = [
|
||||||
'django.middleware.security.SecurityMiddleware',
|
"django.middleware.security.SecurityMiddleware",
|
||||||
'django.contrib.sessions.middleware.SessionMiddleware',
|
"django.contrib.sessions.middleware.SessionMiddleware",
|
||||||
'django.middleware.common.CommonMiddleware',
|
"django.middleware.common.CommonMiddleware",
|
||||||
'django.middleware.csrf.CsrfViewMiddleware',
|
"django.middleware.csrf.CsrfViewMiddleware",
|
||||||
'django.contrib.auth.middleware.AuthenticationMiddleware',
|
"django.contrib.auth.middleware.AuthenticationMiddleware",
|
||||||
'django.contrib.messages.middleware.MessageMiddleware',
|
"django.contrib.messages.middleware.MessageMiddleware",
|
||||||
'django.middleware.clickjacking.XFrameOptionsMiddleware',
|
"django.middleware.clickjacking.XFrameOptionsMiddleware",
|
||||||
"corsheaders.middleware.CorsMiddleware",
|
"corsheaders.middleware.CorsMiddleware",
|
||||||
"django.middleware.common.CommonMiddleware",
|
"django.middleware.common.CommonMiddleware",
|
||||||
]
|
]
|
||||||
|
|
||||||
ROOT_URLCONF = 'llm_be.urls'
|
ROOT_URLCONF = "llm_be.urls"
|
||||||
|
|
||||||
# SETTINGS_PATH = os.path.dirname(os.path.dirname(__file__))
|
# SETTINGS_PATH = os.path.dirname(os.path.dirname(__file__))
|
||||||
# TEMPLATE_DIRS = (
|
# TEMPLATE_DIRS = (
|
||||||
# os.path.join(SETTINGS_PATH, 'templates'),
|
# os.path.join(SETTINGS_PATH, 'templates'),
|
||||||
# )
|
# )
|
||||||
|
|
||||||
print(os.path.join(BASE_DIR, 'templates'))
|
|
||||||
|
|
||||||
TEMPLATES = [
|
TEMPLATES = [
|
||||||
{
|
{
|
||||||
'BACKEND': 'django.template.backends.django.DjangoTemplates',
|
"BACKEND": "django.template.backends.django.DjangoTemplates",
|
||||||
'DIRS': [os.path.join(BASE_DIR, 'templates')],
|
"DIRS": [os.path.join(BASE_DIR, "templates")],
|
||||||
'APP_DIRS': True,
|
"APP_DIRS": True,
|
||||||
'OPTIONS': {
|
"OPTIONS": {
|
||||||
'context_processors': [
|
"context_processors": [
|
||||||
'django.template.context_processors.debug',
|
"django.template.context_processors.debug",
|
||||||
'django.template.context_processors.request',
|
"django.template.context_processors.request",
|
||||||
'django.contrib.auth.context_processors.auth',
|
"django.contrib.auth.context_processors.auth",
|
||||||
'django.contrib.messages.context_processors.messages',
|
"django.contrib.messages.context_processors.messages",
|
||||||
],
|
],
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
]
|
]
|
||||||
|
|
||||||
WSGI_APPLICATION = 'llm_be.wsgi.application'
|
WSGI_APPLICATION = "llm_be.wsgi.application"
|
||||||
ASGI_APPLICATION = 'llm_be.asgi.application'
|
ASGI_APPLICATION = "llm_be.asgi.application"
|
||||||
|
|
||||||
|
|
||||||
# Database
|
# Database
|
||||||
# https://docs.djangoproject.com/en/3.2/ref/settings/#databases
|
# https://docs.djangoproject.com/en/3.2/ref/settings/#databases
|
||||||
|
|
||||||
DATABASES = {
|
DATABASES = {
|
||||||
'default': {
|
"default": {
|
||||||
'ENGINE': 'django.db.backends.sqlite3',
|
"ENGINE": "django.db.backends.sqlite3",
|
||||||
'NAME': BASE_DIR / 'db.sqlite3',
|
"NAME": BASE_DIR / "db.sqlite3",
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -105,28 +111,26 @@ DATABASES = {
|
|||||||
|
|
||||||
AUTH_PASSWORD_VALIDATORS = [
|
AUTH_PASSWORD_VALIDATORS = [
|
||||||
{
|
{
|
||||||
'NAME': 'django.contrib.auth.password_validation.UserAttributeSimilarityValidator',
|
"NAME": "django.contrib.auth.password_validation.UserAttributeSimilarityValidator",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
'NAME': 'django.contrib.auth.password_validation.MinimumLengthValidator',
|
"NAME": "django.contrib.auth.password_validation.MinimumLengthValidator",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
'NAME': 'django.contrib.auth.password_validation.CommonPasswordValidator',
|
"NAME": "django.contrib.auth.password_validation.CommonPasswordValidator",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
'NAME': 'django.contrib.auth.password_validation.NumericPasswordValidator',
|
"NAME": "django.contrib.auth.password_validation.NumericPasswordValidator",
|
||||||
},
|
},
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
# Internationalization
|
# Internationalization
|
||||||
# https://docs.djangoproject.com/en/3.2/topics/i18n/
|
# https://docs.djangoproject.com/en/3.2/topics/i18n/
|
||||||
|
|
||||||
LANGUAGE_CODE = 'en-us'
|
LANGUAGE_CODE = "en-us"
|
||||||
|
|
||||||
TIME_ZONE = 'UTC'
|
TIME_ZONE = "UTC"
|
||||||
|
|
||||||
USE_I18N = True
|
USE_I18N = True
|
||||||
|
|
||||||
@@ -138,39 +142,37 @@ USE_TZ = True
|
|||||||
# Static files (CSS, JavaScript, Images)
|
# Static files (CSS, JavaScript, Images)
|
||||||
# https://docs.djangoproject.com/en/3.2/howto/static-files/
|
# https://docs.djangoproject.com/en/3.2/howto/static-files/
|
||||||
|
|
||||||
STATIC_URL = '/static/'
|
STATIC_URL = "/static/"
|
||||||
|
|
||||||
# Default primary key field type
|
# Default primary key field type
|
||||||
# https://docs.djangoproject.com/en/3.2/ref/settings/#default-auto-field
|
# https://docs.djangoproject.com/en/3.2/ref/settings/#default-auto-field
|
||||||
|
|
||||||
DEFAULT_AUTO_FIELD = 'django.db.models.BigAutoField'
|
DEFAULT_AUTO_FIELD = "django.db.models.BigAutoField"
|
||||||
|
|
||||||
# custom user model
|
# custom user model
|
||||||
AUTH_USER_MODEL = 'chat_backend.CustomUser'
|
AUTH_USER_MODEL = "chat_backend.CustomUser"
|
||||||
|
|
||||||
# rest framework jwt stuff
|
# rest framework jwt stuff
|
||||||
REST_FRAMEWORK = {
|
REST_FRAMEWORK = {
|
||||||
'DEFAULT_PERMISSION_CLASSES': (
|
"DEFAULT_PERMISSION_CLASSES": ("rest_framework.permissions.IsAuthenticated",),
|
||||||
'rest_framework.permissions.IsAuthenticated',
|
"DEFAULT_AUTHENTICATION_CLASSES": (
|
||||||
),
|
"rest_framework_simplejwt.authentication.JWTAuthentication",
|
||||||
'DEFAULT_AUTHENTICATION_CLASSES': (
|
), #
|
||||||
'rest_framework_simplejwt.authentication.JWTAuthentication',
|
|
||||||
), #
|
|
||||||
}
|
}
|
||||||
|
|
||||||
SIMPLE_JWT = {
|
SIMPLE_JWT = {
|
||||||
'ACCESS_TOKEN_LIFETIME':timedelta(hours=5),
|
"ACCESS_TOKEN_LIFETIME": timedelta(hours=24),
|
||||||
'REFRESH_TOKEN_LIFETIME':timedelta(days=14),
|
"REFRESH_TOKEN_LIFETIME": timedelta(days=14),
|
||||||
'ROTATE_REFRESH_TOKENS':True,
|
"ROTATE_REFRESH_TOKENS": True,
|
||||||
'BLACKLIST_AFTER_ROTATION':True,
|
"BLACKLIST_AFTER_ROTATION": True,
|
||||||
'ALGORITHM':"HS256",
|
"ALGORITHM": "HS256",
|
||||||
"SIGNING_KEY":SECRET_KEY,
|
"SIGNING_KEY": SECRET_KEY,
|
||||||
'VERIFYING_KEY':None,
|
"VERIFYING_KEY": None,
|
||||||
"AUTH_HEADER_TYPES":('JWT',),
|
"AUTH_HEADER_TYPES": ("JWT",),
|
||||||
'USER_ID_FIELD':'id',
|
"USER_ID_FIELD": "id",
|
||||||
'USER_ID_CLAIM':'user_id',
|
"USER_ID_CLAIM": "user_id",
|
||||||
'AUTH_TOKEN_CLASSES':('rest_framework_simplejwt.tokens.AccessToken',),
|
"AUTH_TOKEN_CLASSES": ("rest_framework_simplejwt.tokens.AccessToken",),
|
||||||
'TOKEN_TYPE_CLAIM':'token_type',
|
"TOKEN_TYPE_CLAIM": "token_type",
|
||||||
}
|
}
|
||||||
|
|
||||||
# CORS settings
|
# CORS settings
|
||||||
@@ -181,8 +183,8 @@ CORS_ALLOWED_ORIGINS = [
|
|||||||
|
|
||||||
# channel settings
|
# channel settings
|
||||||
CHANNEL_LAYERS = {
|
CHANNEL_LAYERS = {
|
||||||
'default': {
|
"default": {
|
||||||
'BACKEND': 'channels.layers.InMemoryChannelLayer',
|
"BACKEND": "channels.layers.InMemoryChannelLayer",
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -198,8 +200,11 @@ CHANNEL_LAYERS = {
|
|||||||
# EMAIL_TIMEOUT = os.getenv("APP_EMAIL_TIMEOUT", 60)
|
# EMAIL_TIMEOUT = os.getenv("APP_EMAIL_TIMEOUT", 60)
|
||||||
|
|
||||||
# SMTP2GO
|
# SMTP2GO
|
||||||
EMAIL_HOST = 'mail.smtp2go.com'
|
EMAIL_HOST = "mail.smtp2go.com"
|
||||||
EMAIL_HOST_USER = 'info.aimloperations.com'
|
EMAIL_HOST_USER = "info.aimloperations.com"
|
||||||
EMAIL_HOST_PASSWORD = 'ZDErIII2sipNNVMz'
|
EMAIL_HOST_PASSWORD = "ZDErIII2sipNNVMz"
|
||||||
EMAIL_PORT = 2525
|
EMAIL_PORT = 2525
|
||||||
EMAIL_USE_TLS = True
|
EMAIL_USE_TLS = True
|
||||||
|
|
||||||
|
# Captcha
|
||||||
|
CAPTCHA_SECRET_KEY = "6LfENu4qAAAAABdrj6JTviq-LfdPP5imhE-Os7h9"
|
||||||
@@ -13,12 +13,17 @@ Including another URLconf
|
|||||||
1. Import the include() function: from django.urls import include, path
|
1. Import the include() function: from django.urls import include, path
|
||||||
2. Add a URL to urlpatterns: path('blog/', include('blog.urls'))
|
2. Add a URL to urlpatterns: path('blog/', include('blog.urls'))
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from django.contrib import admin
|
from django.contrib import admin
|
||||||
from django.urls import path, include
|
from django.urls import path, include
|
||||||
from django.conf import settings
|
from django.conf import settings
|
||||||
from django.conf.urls.static import static
|
from django.conf.urls.static import static
|
||||||
|
|
||||||
urlpatterns = [
|
urlpatterns = (
|
||||||
path('admin/', admin.site.urls),
|
[
|
||||||
path('api/', include('chat_backend.urls')),
|
path("admin/", admin.site.urls),
|
||||||
] + static(settings.STATIC_URL, document_root=settings.STATIC_ROOT) + static(settings.MEDIA_URL, document_root=settings.MEDIA_ROOT)
|
path("api/", include("chat_backend.urls")),
|
||||||
|
]
|
||||||
|
+ static(settings.STATIC_URL, document_root=settings.STATIC_ROOT)
|
||||||
|
+ static(settings.MEDIA_URL, document_root=settings.MEDIA_ROOT)
|
||||||
|
)
|
||||||
|
|||||||
@@ -11,6 +11,6 @@ import os
|
|||||||
|
|
||||||
from django.core.wsgi import get_wsgi_application
|
from django.core.wsgi import get_wsgi_application
|
||||||
|
|
||||||
os.environ.setdefault('DJANGO_SETTINGS_MODULE', 'llm_be.settings')
|
os.environ.setdefault("DJANGO_SETTINGS_MODULE", "llm_be.settings")
|
||||||
|
|
||||||
application = get_wsgi_application()
|
application = get_wsgi_application()
|
||||||
|
|||||||
@@ -6,7 +6,7 @@ import sys
|
|||||||
|
|
||||||
def main():
|
def main():
|
||||||
"""Run administrative tasks."""
|
"""Run administrative tasks."""
|
||||||
os.environ.setdefault('DJANGO_SETTINGS_MODULE', 'llm_be.settings')
|
os.environ.setdefault("DJANGO_SETTINGS_MODULE", "llm_be.settings")
|
||||||
try:
|
try:
|
||||||
from django.core.management import execute_from_command_line
|
from django.core.management import execute_from_command_line
|
||||||
except ImportError as exc:
|
except ImportError as exc:
|
||||||
@@ -18,5 +18,5 @@ def main():
|
|||||||
execute_from_command_line(sys.argv)
|
execute_from_command_line(sys.argv)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == "__main__":
|
||||||
main()
|
main()
|
||||||
|
|||||||
@@ -30,16 +30,16 @@ djangorestframework-simplejwt==5.3.1
|
|||||||
duckdb==1.1.3
|
duckdb==1.1.3
|
||||||
et_xmlfile==2.0.0
|
et_xmlfile==2.0.0
|
||||||
exceptiongroup==1.2.2
|
exceptiongroup==1.2.2
|
||||||
Faker==33.1.0
|
Faker
|
||||||
filelock==3.16.1
|
filelock==3.16.1
|
||||||
fonttools==4.55.3
|
fonttools==4.55.3
|
||||||
frozenlist==1.5.0
|
frozenlist==1.5.0
|
||||||
fsspec==2024.12.0
|
fsspec==2024.12.0
|
||||||
greenlet==3.1.1
|
greenlet==3.1.1
|
||||||
h11==0.14.0
|
h11==0.14.0
|
||||||
httpcore==1.0.7
|
httpcore
|
||||||
httpx==0.27.2
|
httpx
|
||||||
httpx-sse==0.4.0
|
httpx-sse
|
||||||
hyperlink==21.0.0
|
hyperlink==21.0.0
|
||||||
idna==3.10
|
idna==3.10
|
||||||
importlib_resources==6.4.5
|
importlib_resources==6.4.5
|
||||||
@@ -48,14 +48,14 @@ Jinja2==3.1.5
|
|||||||
jiter==0.8.2
|
jiter==0.8.2
|
||||||
jsonpatch==1.33
|
jsonpatch==1.33
|
||||||
jsonpointer==3.0.0
|
jsonpointer==3.0.0
|
||||||
kiwisolver==1.4.7
|
kiwisolver
|
||||||
langchain==0.3.13
|
langchain
|
||||||
langchain-community==0.3.13
|
langchain-community
|
||||||
langchain-core==0.3.28
|
langchain-core
|
||||||
langchain-ollama==0.2.2
|
langchain-ollama
|
||||||
langchain-openai==0.2.14
|
langchain-openai
|
||||||
langchain-text-splitters==0.3.4
|
langchain-text-splitters
|
||||||
langsmith==0.2.7
|
langsmith
|
||||||
lxml==5.3.0
|
lxml==5.3.0
|
||||||
MarkupSafe==3.0.2
|
MarkupSafe==3.0.2
|
||||||
marshmallow==3.23.2
|
marshmallow==3.23.2
|
||||||
@@ -77,14 +77,14 @@ nvidia-cusparse-cu12==12.3.1.170
|
|||||||
nvidia-nccl-cu12==2.21.5
|
nvidia-nccl-cu12==2.21.5
|
||||||
nvidia-nvjitlink-cu12==12.4.127
|
nvidia-nvjitlink-cu12==12.4.127
|
||||||
nvidia-nvtx-cu12==12.4.127
|
nvidia-nvtx-cu12==12.4.127
|
||||||
ollama==0.4.5
|
ollama
|
||||||
ollama-python==0.1.2
|
ollama-python
|
||||||
openai==1.58.1
|
openai
|
||||||
openpyxl==3.1.5
|
openpyxl==3.1.5
|
||||||
orjson==3.10.13
|
orjson==3.10.13
|
||||||
packaging==24.2
|
packaging==24.2
|
||||||
pandas==2.2.3
|
pandas==2.2.3
|
||||||
pandasai==2.4.1
|
pandasai
|
||||||
pathspec==0.12.1
|
pathspec==0.12.1
|
||||||
pillow==11.0.0
|
pillow==11.0.0
|
||||||
platformdirs==4.3.6
|
platformdirs==4.3.6
|
||||||
|
|||||||
208
requirements.txt
Normal file
208
requirements.txt
Normal file
@@ -0,0 +1,208 @@
|
|||||||
|
aiofiles==24.1.0
|
||||||
|
aiohappyeyeballs==2.6.1
|
||||||
|
aiohttp==3.11.18
|
||||||
|
aiosignal==1.3.2
|
||||||
|
annotated-types==0.7.0
|
||||||
|
anyio==4.8.0
|
||||||
|
asgiref==3.8.1
|
||||||
|
astor==0.8.1
|
||||||
|
attrs==25.1.0
|
||||||
|
autobahn==24.4.2
|
||||||
|
Automat==24.8.1
|
||||||
|
backoff==2.2.1
|
||||||
|
bcrypt==4.3.0
|
||||||
|
beautifulsoup4==4.13.4
|
||||||
|
black==25.1.0
|
||||||
|
build==1.2.2.post1
|
||||||
|
cachetools==5.5.2
|
||||||
|
certifi==2025.1.31
|
||||||
|
cffi==1.17.1
|
||||||
|
channels==4.2.0
|
||||||
|
chardet==5.2.0
|
||||||
|
charset-normalizer==3.4.1
|
||||||
|
chroma-hnswlib==0.7.6
|
||||||
|
chromadb==1.0.7
|
||||||
|
click==8.1.8
|
||||||
|
coloredlogs==15.0.1
|
||||||
|
constantly==23.10.4
|
||||||
|
contourpy==1.3.1
|
||||||
|
cryptography==44.0.2
|
||||||
|
cycler==0.12.1
|
||||||
|
daphne==4.1.2
|
||||||
|
dataclasses-json==0.6.7
|
||||||
|
Deprecated==1.2.18
|
||||||
|
distro==1.9.0
|
||||||
|
Django==5.1.7
|
||||||
|
django-autoslug==1.9.9
|
||||||
|
django-cors-headers==4.7.0
|
||||||
|
django-filter==25.1
|
||||||
|
djangorestframework==3.15.2
|
||||||
|
djangorestframework_simplejwt==5.5.0
|
||||||
|
duckdb==1.2.1
|
||||||
|
durationpy==0.9
|
||||||
|
emoji==2.14.1
|
||||||
|
eval_type_backport==0.2.2
|
||||||
|
Faker==37.0.0
|
||||||
|
fastapi==0.115.9
|
||||||
|
filelock==3.17.0
|
||||||
|
filetype==1.2.0
|
||||||
|
flatbuffers==25.2.10
|
||||||
|
fonttools==4.56.0
|
||||||
|
frozenlist==1.6.0
|
||||||
|
fsspec==2025.2.0
|
||||||
|
google-auth==2.39.0
|
||||||
|
googleapis-common-protos==1.70.0
|
||||||
|
greenlet==3.1.1
|
||||||
|
grpcio==1.71.0
|
||||||
|
h11==0.14.0
|
||||||
|
html5lib==1.1
|
||||||
|
httpcore==1.0.7
|
||||||
|
httptools==0.6.4
|
||||||
|
httpx==0.28.1
|
||||||
|
httpx-sse==0.4.0
|
||||||
|
huggingface-hub==0.30.2
|
||||||
|
humanfriendly==10.0
|
||||||
|
hyperlink==21.0.0
|
||||||
|
idna==3.10
|
||||||
|
importlib_metadata==8.6.1
|
||||||
|
importlib_resources==6.5.2
|
||||||
|
incremental==24.7.2
|
||||||
|
Jinja2==3.1.6
|
||||||
|
jiter==0.8.2
|
||||||
|
joblib==1.4.2
|
||||||
|
jsonpatch==1.33
|
||||||
|
jsonpointer==3.0.0
|
||||||
|
jsonschema==4.23.0
|
||||||
|
jsonschema-specifications==2025.4.1
|
||||||
|
kiwisolver==1.4.8
|
||||||
|
kubernetes==32.0.1
|
||||||
|
langchain==0.3.24
|
||||||
|
langchain-community==0.3.23
|
||||||
|
langchain-core==0.3.56
|
||||||
|
langchain-ollama==0.2.3
|
||||||
|
langchain-text-splitters==0.3.8
|
||||||
|
langdetect==1.0.9
|
||||||
|
langsmith==0.3.13
|
||||||
|
lxml==5.4.0
|
||||||
|
Markdown==3.7
|
||||||
|
markdown-it-py==3.0.0
|
||||||
|
MarkupSafe==3.0.2
|
||||||
|
marshmallow==3.26.1
|
||||||
|
matplotlib==3.10.1
|
||||||
|
mdurl==0.1.2
|
||||||
|
mmh3==5.1.0
|
||||||
|
mpmath==1.3.0
|
||||||
|
multidict==6.4.3
|
||||||
|
mypy-extensions==1.0.0
|
||||||
|
nest-asyncio==1.6.0
|
||||||
|
networkx==3.4.2
|
||||||
|
nltk==3.9.1
|
||||||
|
numpy==2.2.3
|
||||||
|
nvidia-cublas-cu12==12.4.5.8
|
||||||
|
nvidia-cuda-cupti-cu12==12.4.127
|
||||||
|
nvidia-cuda-nvrtc-cu12==12.4.127
|
||||||
|
nvidia-cuda-runtime-cu12==12.4.127
|
||||||
|
nvidia-cudnn-cu12==9.1.0.70
|
||||||
|
nvidia-cufft-cu12==11.2.1.3
|
||||||
|
nvidia-curand-cu12==10.3.5.147
|
||||||
|
nvidia-cusolver-cu12==11.6.1.9
|
||||||
|
nvidia-cusparse-cu12==12.3.1.170
|
||||||
|
nvidia-cusparselt-cu12==0.6.2
|
||||||
|
nvidia-nccl-cu12==2.21.5
|
||||||
|
nvidia-nvjitlink-cu12==12.4.127
|
||||||
|
nvidia-nvtx-cu12==12.4.127
|
||||||
|
oauthlib==3.2.2
|
||||||
|
olefile==0.47
|
||||||
|
ollama==0.4.7
|
||||||
|
onnxruntime==1.21.1
|
||||||
|
openai==1.65.4
|
||||||
|
opentelemetry-api==1.32.1
|
||||||
|
opentelemetry-exporter-otlp-proto-common==1.32.1
|
||||||
|
opentelemetry-exporter-otlp-proto-grpc==1.32.1
|
||||||
|
opentelemetry-instrumentation==0.53b1
|
||||||
|
opentelemetry-instrumentation-asgi==0.53b1
|
||||||
|
opentelemetry-instrumentation-fastapi==0.53b1
|
||||||
|
opentelemetry-proto==1.32.1
|
||||||
|
opentelemetry-sdk==1.32.1
|
||||||
|
opentelemetry-semantic-conventions==0.53b1
|
||||||
|
opentelemetry-util-http==0.53b1
|
||||||
|
orjson==3.10.15
|
||||||
|
overrides==7.7.0
|
||||||
|
packaging==24.2
|
||||||
|
pandas==2.2.3
|
||||||
|
pandasai==2.4.2
|
||||||
|
pathspec==0.12.1
|
||||||
|
pillow==11.1.0
|
||||||
|
platformdirs==4.3.6
|
||||||
|
posthog==4.0.1
|
||||||
|
propcache==0.3.1
|
||||||
|
protobuf==5.29.4
|
||||||
|
psutil==7.0.0
|
||||||
|
pyasn1==0.6.1
|
||||||
|
pyasn1_modules==0.4.1
|
||||||
|
pycparser==2.22
|
||||||
|
pydantic==2.11.4
|
||||||
|
pydantic-settings==2.9.1
|
||||||
|
pydantic_core==2.33.2
|
||||||
|
Pygments==2.19.1
|
||||||
|
PyJWT==2.10.1
|
||||||
|
pyOpenSSL==25.0.0
|
||||||
|
pyparsing==3.2.1
|
||||||
|
pypdf==5.4.0
|
||||||
|
PyPika==0.48.9
|
||||||
|
pyproject_hooks==1.2.0
|
||||||
|
python-dateutil==2.9.0.post0
|
||||||
|
python-dotenv==1.0.1
|
||||||
|
python-iso639==2025.2.18
|
||||||
|
python-magic==0.4.27
|
||||||
|
python-oxmsg==0.0.2
|
||||||
|
pytz==2025.1
|
||||||
|
PyYAML==6.0.2
|
||||||
|
RapidFuzz==3.13.0
|
||||||
|
referencing==0.36.2
|
||||||
|
regex==2024.11.6
|
||||||
|
requests==2.32.3
|
||||||
|
requests-oauthlib==2.0.0
|
||||||
|
requests-toolbelt==1.0.0
|
||||||
|
rich==14.0.0
|
||||||
|
rpds-py==0.24.0
|
||||||
|
rsa==4.9.1
|
||||||
|
scipy==1.15.2
|
||||||
|
service-identity==24.2.0
|
||||||
|
setuptools==75.8.2
|
||||||
|
shellingham==1.5.4
|
||||||
|
six==1.17.0
|
||||||
|
sniffio==1.3.1
|
||||||
|
soupsieve==2.7
|
||||||
|
SQLAlchemy==2.0.38
|
||||||
|
sqlglot==26.9.0
|
||||||
|
sqlglotrs==0.4.0
|
||||||
|
sqlparse==0.5.3
|
||||||
|
starlette==0.45.3
|
||||||
|
sympy==1.13.1
|
||||||
|
tenacity==9.0.0
|
||||||
|
tokenizers==0.21.1
|
||||||
|
torch==2.6.0
|
||||||
|
tqdm==4.67.1
|
||||||
|
triton==3.2.0
|
||||||
|
Twisted==24.11.0
|
||||||
|
txaio==23.1.1
|
||||||
|
typer==0.15.3
|
||||||
|
typing-inspect==0.9.0
|
||||||
|
typing-inspection==0.4.0
|
||||||
|
typing_extensions==4.12.2
|
||||||
|
tzdata==2025.1
|
||||||
|
unstructured==0.17.2
|
||||||
|
unstructured-client==0.34.0
|
||||||
|
urllib3==2.3.0
|
||||||
|
uvicorn==0.34.2
|
||||||
|
uvloop==0.21.0
|
||||||
|
watchfiles==1.0.5
|
||||||
|
webencodings==0.5.1
|
||||||
|
websocket-client==1.8.0
|
||||||
|
websockets==15.0.1
|
||||||
|
wrapt==1.17.2
|
||||||
|
yarl==1.20.0
|
||||||
|
zipp==3.21.0
|
||||||
|
zope.interface==7.2
|
||||||
|
zstandard==0.23.0
|
||||||
9
strip_and_upgrade.py
Normal file
9
strip_and_upgrade.py
Normal file
@@ -0,0 +1,9 @@
|
|||||||
|
outfile = open("requirements.txt",'w')
|
||||||
|
for line in open('requirements.dev','r'):
|
||||||
|
line = line.strip()
|
||||||
|
if line:
|
||||||
|
values = line.split('==')
|
||||||
|
print(values[0])
|
||||||
|
outfile.write(values[0] + '\n')
|
||||||
|
|
||||||
|
|
||||||
Reference in New Issue
Block a user