Syncing with updates from prod and formatted
This commit is contained in:
@@ -9,7 +9,7 @@ from .models import (
|
|||||||
Feedback,
|
Feedback,
|
||||||
PromptMetric,
|
PromptMetric,
|
||||||
DocumentWorkspace,
|
DocumentWorkspace,
|
||||||
Document
|
Document,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Register your models here.
|
# Register your models here.
|
||||||
@@ -79,14 +79,15 @@ class PromptMetricAdmin(admin.ModelAdmin):
|
|||||||
"get_duration",
|
"get_duration",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class DocumentWorkspaceAdmin(admin.ModelAdmin):
|
class DocumentWorkspaceAdmin(admin.ModelAdmin):
|
||||||
model = DocumentWorkspace
|
model = DocumentWorkspace
|
||||||
list_display = (
|
list_display = (
|
||||||
"name",
|
"name",
|
||||||
"company",
|
"company",
|
||||||
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class DocumentAdmin(admin.ModelAdmin):
|
class DocumentAdmin(admin.ModelAdmin):
|
||||||
model = Document
|
model = Document
|
||||||
list_display = (
|
list_display = (
|
||||||
|
|||||||
@@ -8,18 +8,19 @@ class ChatBackendConfig(AppConfig):
|
|||||||
name = "chat_backend"
|
name = "chat_backend"
|
||||||
|
|
||||||
def ready(self):
|
def ready(self):
|
||||||
import chat_backend.signals
|
import chat_backend.signals
|
||||||
|
|
||||||
FORCE_RELOAD = False
|
FORCE_RELOAD = False
|
||||||
|
|
||||||
if True: #not settings.TESTING: # Don't run during tests
|
if True: # not settings.TESTING: # Don't run during tests
|
||||||
try:
|
try:
|
||||||
from .services.rag_services import AsyncRAGService
|
from .services.rag_services import AsyncRAGService
|
||||||
from chat_backend.models import Document
|
from chat_backend.models import Document
|
||||||
|
|
||||||
# Check if Chroma needs initialization
|
# Check if Chroma needs initialization
|
||||||
if Document.objects.exists():
|
if Document.objects.exists():
|
||||||
rag_service = AsyncRAGService()
|
rag_service = AsyncRAGService()
|
||||||
|
|
||||||
if rag_service.vector_store._collection.count() == 0:
|
if rag_service.vector_store._collection.count() == 0:
|
||||||
print("Initializing ChromaDB with existing documents...")
|
print("Initializing ChromaDB with existing documents...")
|
||||||
rag_service.ingest_documents()
|
rag_service.ingest_documents()
|
||||||
|
|||||||
@@ -53,6 +53,9 @@ class Company(TimeInfoBase):
|
|||||||
help_text="A list of LLMs that company can use",
|
help_text="A list of LLMs that company can use",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def __str__(self):
|
||||||
|
return self.name
|
||||||
|
|
||||||
|
|
||||||
class CustomUser(AbstractUser):
|
class CustomUser(AbstractUser):
|
||||||
company = models.ForeignKey(
|
company = models.ForeignKey(
|
||||||
@@ -71,7 +74,7 @@ class CustomUser(AbstractUser):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def get_set_password_url(self):
|
def get_set_password_url(self):
|
||||||
return f"https://www.chat.aimloperations.com/set_password?slug={self.slug}"
|
return f"https://chat.aimloperations.com/set_password?slug={self.slug}"
|
||||||
|
|
||||||
|
|
||||||
FEEDBACK_CHOICE = (
|
FEEDBACK_CHOICE = (
|
||||||
@@ -220,14 +223,16 @@ class PromptMetric(TimeInfoBase):
|
|||||||
return difference.seconds
|
return difference.seconds
|
||||||
return 0
|
return 0
|
||||||
|
|
||||||
|
|
||||||
# Document Models
|
# Document Models
|
||||||
class DocumentWorkspace(TimeInfoBase):
|
class DocumentWorkspace(TimeInfoBase):
|
||||||
name = models.CharField(max_length=255)
|
name = models.CharField(max_length=255)
|
||||||
company = models.ForeignKey(Company, on_delete=models.CASCADE)
|
company = models.ForeignKey(Company, on_delete=models.CASCADE)
|
||||||
|
|
||||||
|
|
||||||
class Document(TimeInfoBase):
|
class Document(TimeInfoBase):
|
||||||
workspace = models.ForeignKey(DocumentWorkspace, on_delete=models.CASCADE)
|
workspace = models.ForeignKey(DocumentWorkspace, on_delete=models.CASCADE)
|
||||||
file = models.FileField(upload_to='documents/')
|
file = models.FileField(upload_to="documents/")
|
||||||
uploaded_at = models.DateTimeField(auto_now_add=True)
|
uploaded_at = models.DateTimeField(auto_now_add=True)
|
||||||
processed = models.BooleanField(default=False)
|
processed = models.BooleanField(default=False)
|
||||||
active = models.BooleanField(default=False)
|
active = models.BooleanField(default=False)
|
||||||
|
|||||||
@@ -9,7 +9,7 @@ from .models import (
|
|||||||
Feedback,
|
Feedback,
|
||||||
FEEDBACK_CATEGORIES,
|
FEEDBACK_CATEGORIES,
|
||||||
DocumentWorkspace,
|
DocumentWorkspace,
|
||||||
Document
|
Document,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -99,11 +99,20 @@ class BasicUserSerializer(serializers.ModelSerializer):
|
|||||||
class DocumentWorkspaceSerializer(serializers.ModelSerializer):
|
class DocumentWorkspaceSerializer(serializers.ModelSerializer):
|
||||||
class Meta:
|
class Meta:
|
||||||
model = DocumentWorkspace
|
model = DocumentWorkspace
|
||||||
fields = ['id', 'name', 'created']
|
fields = ["id", "name", "created"]
|
||||||
read_only_fields = ['id', 'created']
|
read_only_fields = ["id", "created"]
|
||||||
|
|
||||||
|
|
||||||
class DocumentSerializer(serializers.ModelSerializer):
|
class DocumentSerializer(serializers.ModelSerializer):
|
||||||
class Meta:
|
class Meta:
|
||||||
model = Document
|
model = Document
|
||||||
fields = ['id', 'workspace', 'file', 'uploaded_at', 'processed', 'created', 'active']
|
fields = [
|
||||||
read_only_fields = ['id', 'uploaded_at', 'processed', 'created']
|
"id",
|
||||||
|
"workspace",
|
||||||
|
"file",
|
||||||
|
"uploaded_at",
|
||||||
|
"processed",
|
||||||
|
"created",
|
||||||
|
"active",
|
||||||
|
]
|
||||||
|
read_only_fields = ["id", "uploaded_at", "processed", "created"]
|
||||||
|
|||||||
@@ -7,21 +7,22 @@ from diffusers import StableDiffusionPipeline, DPMSolverSinglestepScheduler
|
|||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class ImageGenerationService:
|
class ImageGenerationService:
|
||||||
"""
|
"""
|
||||||
Service for text-to-image generation using Stable Diffusion.
|
Service for text-to-image generation using Stable Diffusion.
|
||||||
Uses singleton pattern to maintain loaded model in memory.
|
Uses singleton pattern to maintain loaded model in memory.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
_instance = None
|
_instance = None
|
||||||
_model_loaded = False
|
_model_loaded = False
|
||||||
|
|
||||||
def __new__(cls):
|
def __new__(cls):
|
||||||
if cls._instance is None:
|
if cls._instance is None:
|
||||||
cls._instance = super().__new__(cls)
|
cls._instance = super().__new__(cls)
|
||||||
cls._instance._initialize()
|
cls._instance._initialize()
|
||||||
return cls._instance
|
return cls._instance
|
||||||
|
|
||||||
def _initialize(self):
|
def _initialize(self):
|
||||||
"""Initialize the service with default settings"""
|
"""Initialize the service with default settings"""
|
||||||
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||||
@@ -33,15 +34,15 @@ class ImageGenerationService:
|
|||||||
"width": 512,
|
"width": 512,
|
||||||
"height": 512,
|
"height": 512,
|
||||||
}
|
}
|
||||||
|
|
||||||
def load_model(self):
|
def load_model(self):
|
||||||
"""Load the Stable Diffusion model"""
|
"""Load the Stable Diffusion model"""
|
||||||
if self._model_loaded:
|
if self._model_loaded:
|
||||||
return
|
return
|
||||||
|
|
||||||
try:
|
try:
|
||||||
logger.info(f"Loading Stable Diffusion model on {self.device}...")
|
logger.info(f"Loading Stable Diffusion model on {self.device}...")
|
||||||
|
|
||||||
# Use DPMSolver for faster inference
|
# Use DPMSolver for faster inference
|
||||||
self.pipeline = StableDiffusionPipeline.from_pretrained(
|
self.pipeline = StableDiffusionPipeline.from_pretrained(
|
||||||
self.model_id,
|
self.model_id,
|
||||||
@@ -51,15 +52,15 @@ class ImageGenerationService:
|
|||||||
self.pipeline.scheduler.config
|
self.pipeline.scheduler.config
|
||||||
)
|
)
|
||||||
self.pipeline = self.pipeline.to(self.device)
|
self.pipeline = self.pipeline.to(self.device)
|
||||||
|
|
||||||
# Optimizations
|
# Optimizations
|
||||||
if self.device == "cuda":
|
if self.device == "cuda":
|
||||||
self.pipeline.enable_attention_slicing()
|
self.pipeline.enable_attention_slicing()
|
||||||
self.pipeline.enable_xformers_memory_efficient_attention()
|
self.pipeline.enable_xformers_memory_efficient_attention()
|
||||||
|
|
||||||
self._model_loaded = True
|
self._model_loaded = True
|
||||||
logger.info("Model loaded successfully")
|
logger.info("Model loaded successfully")
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Failed to load model: {str(e)}")
|
logger.error(f"Failed to load model: {str(e)}")
|
||||||
raise RuntimeError(f"Model loading failed: {str(e)}")
|
raise RuntimeError(f"Model loading failed: {str(e)}")
|
||||||
@@ -69,45 +70,43 @@ class ImageGenerationService:
|
|||||||
prompt: str,
|
prompt: str,
|
||||||
negative_prompt: Optional[str] = None,
|
negative_prompt: Optional[str] = None,
|
||||||
output_path: Optional[str] = None,
|
output_path: Optional[str] = None,
|
||||||
**kwargs
|
**kwargs,
|
||||||
) -> Tuple[Image.Image, dict]:
|
) -> Tuple[Image.Image, dict]:
|
||||||
"""
|
"""
|
||||||
Generate image from text prompt.
|
Generate image from text prompt.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
prompt: Text prompt for image generation
|
prompt: Text prompt for image generation
|
||||||
negative_prompt: Text for things to avoid in generation
|
negative_prompt: Text for things to avoid in generation
|
||||||
output_path: Optional path to save the image
|
output_path: Optional path to save the image
|
||||||
**kwargs: Generation parameters (overrides defaults)
|
**kwargs: Generation parameters (overrides defaults)
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Tuple of (PIL.Image, generation_parameters)
|
Tuple of (PIL.Image, generation_parameters)
|
||||||
"""
|
"""
|
||||||
if not self._model_loaded:
|
if not self._model_loaded:
|
||||||
self.load_model()
|
self.load_model()
|
||||||
|
|
||||||
# Merge default params with overrides
|
# Merge default params with overrides
|
||||||
params = {**self.default_params, **kwargs}
|
params = {**self.default_params, **kwargs}
|
||||||
|
|
||||||
try:
|
try:
|
||||||
logger.info(f"Generating image with prompt: {prompt[:50]}...")
|
logger.info(f"Generating image with prompt: {prompt[:50]}...")
|
||||||
|
|
||||||
with torch.inference_mode():
|
with torch.inference_mode():
|
||||||
result = self.pipeline(
|
result = self.pipeline(
|
||||||
prompt=prompt,
|
prompt=prompt, negative_prompt=negative_prompt, **params
|
||||||
negative_prompt=negative_prompt,
|
|
||||||
**params
|
|
||||||
)
|
)
|
||||||
|
|
||||||
image = result.images[0]
|
image = result.images[0]
|
||||||
|
|
||||||
if output_path:
|
if output_path:
|
||||||
os.makedirs(os.path.dirname(output_path), exist_ok=True)
|
os.makedirs(os.path.dirname(output_path), exist_ok=True)
|
||||||
image.save(output_path)
|
image.save(output_path)
|
||||||
logger.info(f"Image saved to {output_path}")
|
logger.info(f"Image saved to {output_path}")
|
||||||
|
|
||||||
return image, params
|
return image, params
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Image generation failed: {str(e)}")
|
logger.error(f"Image generation failed: {str(e)}")
|
||||||
raise RuntimeError(f"Image generation failed: {str(e)}")
|
raise RuntimeError(f"Image generation failed: {str(e)}")
|
||||||
@@ -118,28 +117,28 @@ class AsyncImageGenerationService:
|
|||||||
Asynchronous wrapper for image generation service.
|
Asynchronous wrapper for image generation service.
|
||||||
Runs the synchronous service in a thread pool.
|
Runs the synchronous service in a thread pool.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.sync_service = ImageGenerationService()
|
self.sync_service = ImageGenerationService()
|
||||||
|
|
||||||
async def generate_image(
|
async def generate_image(
|
||||||
self,
|
self,
|
||||||
prompt: str,
|
prompt: str,
|
||||||
negative_prompt: Optional[str] = None,
|
negative_prompt: Optional[str] = None,
|
||||||
output_path: Optional[str] = None,
|
output_path: Optional[str] = None,
|
||||||
**kwargs
|
**kwargs,
|
||||||
) -> Tuple[Image.Image, dict]:
|
) -> Tuple[Image.Image, dict]:
|
||||||
"""Async version of generate_image"""
|
"""Async version of generate_image"""
|
||||||
import asyncio
|
import asyncio
|
||||||
from functools import partial
|
from functools import partial
|
||||||
|
|
||||||
loop = asyncio.get_event_loop()
|
loop = asyncio.get_event_loop()
|
||||||
func = partial(
|
func = partial(
|
||||||
self.sync_service.generate_image,
|
self.sync_service.generate_image,
|
||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
negative_prompt=negative_prompt,
|
negative_prompt=negative_prompt,
|
||||||
output_path=output_path,
|
output_path=output_path,
|
||||||
**kwargs
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
return await loop.run_in_executor(None, func)
|
return await loop.run_in_executor(None, func)
|
||||||
|
|||||||
@@ -1,47 +1,50 @@
|
|||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing import AsyncGenerator, Generator, Optional
|
from typing import AsyncGenerator, Generator, Optional
|
||||||
|
|
||||||
from langchain_community.llms import Ollama
|
# from langchain_community.llms import Ollama
|
||||||
|
from langchain_ollama import OllamaLLM
|
||||||
from langchain_core.output_parsers import StrOutputParser
|
from langchain_core.output_parsers import StrOutputParser
|
||||||
from langchain_core.prompts import ChatPromptTemplate
|
from langchain_core.prompts import ChatPromptTemplate
|
||||||
|
|
||||||
from chat_backend.models import Conversation, Prompt
|
from chat_backend.models import Conversation, Prompt
|
||||||
|
|
||||||
|
|
||||||
class LLMService(ABC):
|
class LLMService(ABC):
|
||||||
"""Abstract base class for LLM conversation services."""
|
"""Abstract base class for LLM conversation services."""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.llm = Ollama(
|
self.llm = OllamaLLM(
|
||||||
model="llama3.2",
|
model="llama3.2",
|
||||||
temperature=0.7,
|
temperature=0.7,
|
||||||
top_k=50,
|
top_k=50,
|
||||||
top_p=0.9,
|
top_p=0.9,
|
||||||
repeat_penalty=1.1,
|
repeat_penalty=1.1,
|
||||||
num_ctx=4096
|
num_ctx=4096,
|
||||||
)
|
)
|
||||||
self.output_parser = StrOutputParser()
|
self.output_parser = StrOutputParser()
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def generate_response(self, conversation: Conversation, query: str, **kwargs):
|
def generate_response(self, conversation: Conversation, query: str, **kwargs):
|
||||||
"""Generate a response to a query within a conversation context."""
|
"""Generate a response to a query within a conversation context."""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def _format_history(self, conversation: Conversation) -> str:
|
def _format_history(self, conversation: Conversation) -> str:
|
||||||
"""Format conversation history for the prompt."""
|
"""Format conversation history for the prompt."""
|
||||||
prompts = Prompt.objects.filter(conversation=conversation).order_by('created_at')
|
prompts = Prompt.objects.filter(conversation=conversation).order_by(
|
||||||
|
"created_at"
|
||||||
|
)
|
||||||
return "\n".join(
|
return "\n".join(
|
||||||
f"{'User' if prompt.is_user else 'AI'}: {prompt.text}"
|
f"{'User' if prompt.is_user else 'AI'}: {prompt.text}" for prompt in prompts
|
||||||
for prompt in prompts
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class SyncLLMService(LLMService):
|
class SyncLLMService(LLMService):
|
||||||
"""Synchronous LLM conversation service."""
|
"""Synchronous LLM conversation service."""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self._setup_chain()
|
self._setup_chain()
|
||||||
|
|
||||||
def _setup_chain(self):
|
def _setup_chain(self):
|
||||||
"""Setup the conversation chain."""
|
"""Setup the conversation chain."""
|
||||||
template = """Continue the conversation based on the following history:
|
template = """Continue the conversation based on the following history:
|
||||||
@@ -52,35 +55,36 @@ class SyncLLMService(LLMService):
|
|||||||
|
|
||||||
Response:"""
|
Response:"""
|
||||||
self.prompt = ChatPromptTemplate.from_template(template)
|
self.prompt = ChatPromptTemplate.from_template(template)
|
||||||
|
|
||||||
self.conversation_chain = (
|
self.conversation_chain = (
|
||||||
{
|
{
|
||||||
"history": lambda x: self._format_history(x["conversation"]),
|
"history": lambda x: self._format_history(x["conversation"]),
|
||||||
"query": lambda x: x["query"]
|
"query": lambda x: x["query"],
|
||||||
}
|
}
|
||||||
| self.prompt
|
| self.prompt
|
||||||
| self.llm
|
| self.llm
|
||||||
| self.output_parser
|
| self.output_parser
|
||||||
)
|
)
|
||||||
|
|
||||||
def generate_response(self, conversation: Conversation, query: str, **kwargs) -> Generator[str, None, None]:
|
def generate_response(
|
||||||
|
self, conversation: Conversation, query: str, **kwargs
|
||||||
|
) -> Generator[str, None, None]:
|
||||||
"""Generate response with streaming support."""
|
"""Generate response with streaming support."""
|
||||||
chain_input = {
|
chain_input = {"query": query, "conversation": conversation}
|
||||||
"query": query,
|
|
||||||
"conversation": conversation
|
print(f"chain_input:\n{chain_input}")
|
||||||
}
|
|
||||||
|
|
||||||
for chunk in self.conversation_chain.stream(chain_input):
|
for chunk in self.conversation_chain.stream(chain_input):
|
||||||
yield chunk
|
yield chunk
|
||||||
|
|
||||||
|
|
||||||
class AsyncLLMService(LLMService):
|
class AsyncLLMService(LLMService):
|
||||||
"""Asynchronous LLM conversation service."""
|
"""Asynchronous LLM conversation service."""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self._setup_chain()
|
self._setup_chain()
|
||||||
|
|
||||||
def _setup_chain(self):
|
def _setup_chain(self):
|
||||||
"""Setup the conversation chain."""
|
"""Setup the conversation chain."""
|
||||||
template = """Continue this conversation while maintaining context by providing a single helpful response.
|
template = """Continue this conversation while maintaining context by providing a single helpful response.
|
||||||
@@ -97,42 +101,50 @@ class AsyncLLMService(LLMService):
|
|||||||
- When asked to modify something, identify what's being modified
|
- When asked to modify something, identify what's being modified
|
||||||
|
|
||||||
Response:"""
|
Response:"""
|
||||||
|
|
||||||
self.prompt = ChatPromptTemplate.from_template(template)
|
self.prompt = ChatPromptTemplate.from_template(template)
|
||||||
|
|
||||||
self.conversation_chain = (
|
self.conversation_chain = (
|
||||||
{
|
{
|
||||||
"context": lambda x: self._format_history(x["conversation"]),
|
"context": lambda x: self._format_history(x["conversation"]),
|
||||||
"recent_history": lambda x: self._get_recent_messages(x["conversation"]),
|
"recent_history": lambda x: self._get_recent_messages(
|
||||||
"query": lambda x: x["query"]
|
x["conversation"]
|
||||||
|
),
|
||||||
|
"query": lambda x: x["query"],
|
||||||
}
|
}
|
||||||
| self.prompt
|
| self.prompt
|
||||||
| self.llm
|
| self.llm
|
||||||
| self.output_parser
|
| self.output_parser
|
||||||
)
|
)
|
||||||
|
|
||||||
async def _format_history(self, conversation: Conversation) -> str:
|
async def _format_history(self, conversation: Conversation) -> str:
|
||||||
"""Async version of format conversation history."""
|
"""Async version of format conversation history."""
|
||||||
prompts = await Prompt.objects.filter(conversation=conversation).order_by('created_at').alist()
|
prompts = (
|
||||||
|
await Prompt.objects.filter(conversation=conversation)
|
||||||
|
.order_by("created_at")
|
||||||
|
.alist()
|
||||||
|
)
|
||||||
return "\n".join(
|
return "\n".join(
|
||||||
f"{'User' if prompt.is_user else 'AI'}: {prompt.text}"
|
f"{'User' if prompt.is_user else 'AI'}: {prompt.text}" for prompt in prompts
|
||||||
for prompt in prompts
|
|
||||||
)
|
)
|
||||||
|
|
||||||
async def _get_recent_messages(self, conversation: Conversation) -> str:
|
async def _get_recent_messages(self, conversation: Conversation) -> str:
|
||||||
"""Async version of format conversation history."""
|
"""Async version of format conversation history."""
|
||||||
prompts = await Prompt.objects.filter(conversation=conversation).order_by('created_at').alist()[-3:]
|
prompts = (
|
||||||
return "\n".join(
|
await Prompt.objects.filter(conversation=conversation)
|
||||||
f"{'User' if prompt.is_user else 'AI'}: {prompt.text}"
|
.order_by("created_at")
|
||||||
for prompt in prompts
|
.alist()[-6:]
|
||||||
)
|
)
|
||||||
|
return "\n".join(
|
||||||
async def generate_response(self, conversation: Conversation, query: str, **kwargs) -> AsyncGenerator[str, None]:
|
f"{'User' if prompt.is_user else 'AI'}: {prompt.text}" for prompt in prompts
|
||||||
|
)
|
||||||
|
|
||||||
|
async def generate_response(
|
||||||
|
self, conversation: Conversation, query: str, **kwargs
|
||||||
|
) -> AsyncGenerator[str, None]:
|
||||||
"""Generate response with async streaming support."""
|
"""Generate response with async streaming support."""
|
||||||
chain_input = {
|
chain_input = {"query": query, "conversation": conversation}
|
||||||
"query": query,
|
print(f"LLM Chain:\n{chain_input}")
|
||||||
"conversation": conversation
|
|
||||||
}
|
|
||||||
|
|
||||||
async for chunk in self.conversation_chain.astream(chain_input):
|
async for chunk in self.conversation_chain.astream(chain_input):
|
||||||
yield chunk
|
yield chunk
|
||||||
|
|||||||
@@ -1,27 +1,34 @@
|
|||||||
from enum import Enum, auto
|
from enum import Enum, auto
|
||||||
from typing import Dict, Any
|
from typing import Dict, Any
|
||||||
from langchain_core.prompts import ChatPromptTemplate
|
from langchain_core.prompts import ChatPromptTemplate
|
||||||
from langchain_community.llms import Ollama
|
|
||||||
|
# from langchain_community.llms import Ollama
|
||||||
|
from langchain_ollama import OllamaLLM
|
||||||
|
|
||||||
|
|
||||||
class ModerationLabel(Enum):
|
class ModerationLabel(Enum):
|
||||||
NSFW = auto()
|
NSFW = auto()
|
||||||
FINE = auto()
|
FINE = auto()
|
||||||
|
|
||||||
|
|
||||||
class ModerationClassifier:
|
class ModerationClassifier:
|
||||||
"""
|
"""
|
||||||
Classifies prompts as NSFW or FINE (safe) content.
|
Classifies prompts as NSFW or FINE (safe) content.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.llm = Ollama(
|
self.llm = OllamaLLM(
|
||||||
model="llama3.2",
|
model="llama3.2",
|
||||||
temperature=0.1, # Very low for strict moderation
|
temperature=0.1, # Very low for strict moderation
|
||||||
top_k=10,
|
top_k=10,
|
||||||
num_ctx=2048
|
num_ctx=2048,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.moderation_prompt = ChatPromptTemplate.from_messages([
|
self.moderation_prompt = ChatPromptTemplate.from_messages(
|
||||||
("system", """You are a strict content moderator. Classify the following prompt as either NSFW or FINE.
|
[
|
||||||
|
(
|
||||||
|
"system",
|
||||||
|
"""You are a strict content moderator. Classify the following prompt as either NSFW or FINE.
|
||||||
|
|
||||||
NSFW includes:
|
NSFW includes:
|
||||||
- Sexual content
|
- Sexual content
|
||||||
@@ -44,12 +51,14 @@ Examples:
|
|||||||
- "Explicit sex scene" → NSFW
|
- "Explicit sex scene" → NSFW
|
||||||
- "Python tutorial" → FINE
|
- "Python tutorial" → FINE
|
||||||
|
|
||||||
Return ONLY "NSFW" or "FINE", nothing else."""),
|
Return ONLY "NSFW" or "FINE", nothing else.""",
|
||||||
("human", "{prompt}")
|
),
|
||||||
])
|
("human", "{prompt}"),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
self.chain = self.moderation_prompt | self.llm
|
self.chain = self.moderation_prompt | self.llm
|
||||||
|
|
||||||
async def classify_async(self, prompt: str) -> ModerationLabel:
|
async def classify_async(self, prompt: str) -> ModerationLabel:
|
||||||
"""Asynchronous classification"""
|
"""Asynchronous classification"""
|
||||||
try:
|
try:
|
||||||
@@ -58,7 +67,7 @@ Return ONLY "NSFW" or "FINE", nothing else."""),
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Moderation error: {e}")
|
print(f"Moderation error: {e}")
|
||||||
return ModerationLabel.NSFW # Fail-safe to NSFW
|
return ModerationLabel.NSFW # Fail-safe to NSFW
|
||||||
|
|
||||||
def classify(self, prompt: str) -> ModerationLabel:
|
def classify(self, prompt: str) -> ModerationLabel:
|
||||||
"""Synchronous classification"""
|
"""Synchronous classification"""
|
||||||
try:
|
try:
|
||||||
@@ -67,7 +76,7 @@ Return ONLY "NSFW" or "FINE", nothing else."""),
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Moderation error: {e}")
|
print(f"Moderation error: {e}")
|
||||||
return ModerationLabel.NSFW # Fail-safe to NSFW
|
return ModerationLabel.NSFW # Fail-safe to NSFW
|
||||||
|
|
||||||
def _parse_response(self, response: str) -> ModerationLabel:
|
def _parse_response(self, response: str) -> ModerationLabel:
|
||||||
"""Convert string response to ModerationLabel enum"""
|
"""Convert string response to ModerationLabel enum"""
|
||||||
if "NSFW" in response:
|
if "NSFW" in response:
|
||||||
@@ -76,4 +85,4 @@ Return ONLY "NSFW" or "FINE", nothing else."""),
|
|||||||
|
|
||||||
|
|
||||||
# Singleton instance
|
# Singleton instance
|
||||||
moderation_classifier = ModerationClassifier()
|
moderation_classifier = ModerationClassifier()
|
||||||
|
|||||||
@@ -1,7 +1,10 @@
|
|||||||
from enum import Enum, auto
|
from enum import Enum, auto
|
||||||
from typing import Dict, Any
|
from typing import Dict, Any
|
||||||
from langchain_core.prompts import ChatPromptTemplate
|
from langchain_core.prompts import ChatPromptTemplate
|
||||||
from langchain_community.llms import Ollama
|
|
||||||
|
# from langchain_community.llms import Ollama
|
||||||
|
from langchain_ollama import OllamaLLM
|
||||||
|
|
||||||
|
|
||||||
class PromptType(Enum):
|
class PromptType(Enum):
|
||||||
GENERAL_CHAT = auto()
|
GENERAL_CHAT = auto()
|
||||||
@@ -9,23 +12,26 @@ class PromptType(Enum):
|
|||||||
IMAGE_GENERATION = auto()
|
IMAGE_GENERATION = auto()
|
||||||
UNKNOWN = auto()
|
UNKNOWN = auto()
|
||||||
|
|
||||||
|
|
||||||
class PromptClassifier:
|
class PromptClassifier:
|
||||||
"""
|
"""
|
||||||
Classifies user prompts to determine which service should handle them.
|
Classifies user prompts to determine which service should handle them.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.llm = Ollama(
|
self.llm = OllamaLLM(
|
||||||
model="llama3",
|
model="llama3.2",
|
||||||
temperature=0.3, # Lower temp for more deterministic classification
|
temperature=0.3, # Lower temp for more deterministic classification
|
||||||
top_k=20,
|
top_k=20,
|
||||||
top_p=0.9,
|
top_p=0.9,
|
||||||
num_ctx=4096
|
num_ctx=4096,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.classification_prompt = ChatPromptTemplate.from_messages([
|
self.classification_prompt = ChatPromptTemplate.from_messages(
|
||||||
("system",
|
[
|
||||||
"""You are a precision prompt classifier. Strictly categorize prompts into:
|
(
|
||||||
|
"system",
|
||||||
|
"""You are a precision prompt classifier. Strictly categorize prompts into:
|
||||||
1. GENERAL_CHAT - Casual conversation, personal questions, or non-specific inquiries
|
1. GENERAL_CHAT - Casual conversation, personal questions, or non-specific inquiries
|
||||||
2. RAG - ONLY when explicitly requesting document/search-based knowledge
|
2. RAG - ONLY when explicitly requesting document/search-based knowledge
|
||||||
3. IMAGE_GENERATION - Specific requests to create/modify images
|
3. IMAGE_GENERATION - Specific requests to create/modify images
|
||||||
@@ -63,12 +69,14 @@ Examples:
|
|||||||
- "Explain quantum computing" (General knowledge)
|
- "Explain quantum computing" (General knowledge)
|
||||||
- "Summarize the meeting" (No doc reference)
|
- "Summarize the meeting" (No doc reference)
|
||||||
|
|
||||||
Return ONLY the label, no explanations."""),
|
Return ONLY the label, no explanations.""",
|
||||||
("human", "{prompt}")
|
),
|
||||||
])
|
("human", "{prompt}"),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
self.chain = self.classification_prompt | self.llm
|
self.chain = self.classification_prompt | self.llm
|
||||||
|
|
||||||
async def classify_async(self, prompt: str) -> PromptType:
|
async def classify_async(self, prompt: str) -> PromptType:
|
||||||
"""Asynchronously classify the prompt"""
|
"""Asynchronously classify the prompt"""
|
||||||
try:
|
try:
|
||||||
@@ -77,7 +85,7 @@ Return ONLY the label, no explanations."""),
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Classification error: {e}")
|
print(f"Classification error: {e}")
|
||||||
return PromptType.UNKNOWN
|
return PromptType.UNKNOWN
|
||||||
|
|
||||||
def classify(self, prompt: str) -> PromptType:
|
def classify(self, prompt: str) -> PromptType:
|
||||||
"""Synchronously classify the prompt"""
|
"""Synchronously classify the prompt"""
|
||||||
try:
|
try:
|
||||||
@@ -86,7 +94,7 @@ Return ONLY the label, no explanations."""),
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Classification error: {e}")
|
print(f"Classification error: {e}")
|
||||||
return PromptType.UNKNOWN
|
return PromptType.UNKNOWN
|
||||||
|
|
||||||
def _parse_response(self, response: str) -> PromptType:
|
def _parse_response(self, response: str) -> PromptType:
|
||||||
"""Convert string response to PromptType enum"""
|
"""Convert string response to PromptType enum"""
|
||||||
response = response.upper()
|
response = response.upper()
|
||||||
@@ -97,4 +105,4 @@ Return ONLY the label, no explanations."""),
|
|||||||
|
|
||||||
|
|
||||||
# Singleton instance for easy access
|
# Singleton instance for easy access
|
||||||
prompt_classifier = PromptClassifier()
|
prompt_classifier = PromptClassifier()
|
||||||
|
|||||||
@@ -3,7 +3,9 @@ from abc import ABC, abstractmethod
|
|||||||
from typing import List, Dict, Any, AsyncGenerator, Generator, Optional
|
from typing import List, Dict, Any, AsyncGenerator, Generator, Optional
|
||||||
from channels.db import database_sync_to_async
|
from channels.db import database_sync_to_async
|
||||||
from langchain_community.embeddings import OllamaEmbeddings
|
from langchain_community.embeddings import OllamaEmbeddings
|
||||||
from langchain_community.llms import Ollama
|
|
||||||
|
# from langchain_community.llms import Ollama
|
||||||
|
from langchain_ollama import OllamaLLM
|
||||||
from langchain_community.vectorstores import Chroma
|
from langchain_community.vectorstores import Chroma
|
||||||
from langchain_core.documents import Document as LangDocument
|
from langchain_core.documents import Document as LangDocument
|
||||||
from langchain_core.output_parsers import StrOutputParser
|
from langchain_core.output_parsers import StrOutputParser
|
||||||
@@ -14,139 +16,139 @@ from langchain_community.document_loaders import (
|
|||||||
PyPDFLoader,
|
PyPDFLoader,
|
||||||
Docx2txtLoader,
|
Docx2txtLoader,
|
||||||
TextLoader,
|
TextLoader,
|
||||||
UnstructuredFileLoader
|
UnstructuredFileLoader,
|
||||||
)
|
)
|
||||||
from django.core.files.uploadedfile import UploadedFile
|
from django.core.files.uploadedfile import UploadedFile
|
||||||
from chat_backend.models import Conversation, Prompt, DocumentWorkspace, Document
|
from chat_backend.models import Conversation, Prompt, DocumentWorkspace, Document
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
|
|
||||||
@database_sync_to_async
|
@database_sync_to_async
|
||||||
def get_documents(workspace: DocumentWorkspace | None = None):
|
def get_documents(workspace: DocumentWorkspace | None = None):
|
||||||
if workspace:
|
if workspace:
|
||||||
return [doc for doc in Document.objects.filter(workspace=workspace)]
|
return [doc for doc in Document.objects.filter(workspace=workspace)]
|
||||||
else:
|
else:
|
||||||
return [doc for doc in Document.objects.all()]
|
return [doc for doc in Document.objects.all()]
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class RAGService(ABC):
|
class RAGService(ABC):
|
||||||
"""Abstract base class for RAG services."""
|
"""Abstract base class for RAG services."""
|
||||||
|
|
||||||
_instance = None
|
_instance = None
|
||||||
|
|
||||||
def __new__(cls):
|
def __new__(cls):
|
||||||
if cls._instance is None:
|
if cls._instance is None:
|
||||||
cls._instance = super().__new__(cls)
|
cls._instance = super().__new__(cls)
|
||||||
cls._instance.__init__()
|
cls._instance.__init__()
|
||||||
return cls._instance
|
return cls._instance
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.embedding_model = OllamaEmbeddings(model="llama3.2")
|
self.embedding_model = OllamaEmbeddings(model="llama3.2")
|
||||||
self.llm = Ollama(
|
self.llm = OllamaLLM(
|
||||||
model="llama3.2",
|
model="llama3.2",
|
||||||
temperature=0.7,
|
temperature=0.7,
|
||||||
top_k=50,
|
top_k=50,
|
||||||
top_p=0.9,
|
top_p=0.9,
|
||||||
repeat_penalty=1.1,
|
repeat_penalty=1.1,
|
||||||
num_ctx=4096
|
num_ctx=4096,
|
||||||
)
|
)
|
||||||
self.text_splitter = RecursiveCharacterTextSplitter(
|
self.text_splitter = RecursiveCharacterTextSplitter(
|
||||||
chunk_size=1000,
|
chunk_size=1000, chunk_overlap=200
|
||||||
chunk_overlap=200
|
|
||||||
)
|
)
|
||||||
self.vector_store = self._initialize_vector_store()
|
self.vector_store = self._initialize_vector_store()
|
||||||
|
|
||||||
# Supported file types and their loaders
|
# Supported file types and their loaders
|
||||||
self.loader_mapping = {
|
self.loader_mapping = {
|
||||||
'.pdf': PyPDFLoader,
|
".pdf": PyPDFLoader,
|
||||||
'.docx': Docx2txtLoader,
|
".docx": Docx2txtLoader,
|
||||||
'.txt': TextLoader,
|
".txt": TextLoader,
|
||||||
# Fallback for other file types
|
# Fallback for other file types
|
||||||
'*': UnstructuredFileLoader,
|
"*": UnstructuredFileLoader,
|
||||||
}
|
}
|
||||||
|
|
||||||
def _initialize_vector_store(self) -> Chroma:
|
def _initialize_vector_store(self) -> Chroma:
|
||||||
"""Initialize and return the Chroma vector store."""
|
"""Initialize and return the Chroma vector store."""
|
||||||
persist_directory=f"./chroma_db/"
|
persist_directory = f"./chroma_db/"
|
||||||
vector_store = Chroma(
|
vector_store = Chroma(
|
||||||
embedding_function=self.embedding_model,
|
embedding_function=self.embedding_model, persist_directory=persist_directory
|
||||||
persist_directory=persist_directory
|
|
||||||
)
|
)
|
||||||
return vector_store
|
return vector_store
|
||||||
|
|
||||||
def clear_vector_store(self):
|
def clear_vector_store(self):
|
||||||
"""Clear all vectors from the store"""
|
"""Clear all vectors from the store"""
|
||||||
self.vector_store.delete_collection()
|
self.vector_store.delete_collection()
|
||||||
self.vector_store = self._initialize_vector_store()
|
self.vector_store = self._initialize_vector_store()
|
||||||
|
|
||||||
def _prepare_documents(self, documents: List[Document]) -> List[Document]:
|
def _prepare_documents(self, documents: List[Document]) -> List[Document]:
|
||||||
"""Process documents for ingestion into vector store."""
|
"""Process documents for ingestion into vector store."""
|
||||||
docs = []
|
docs = []
|
||||||
|
|
||||||
for doc in documents:
|
for doc in documents:
|
||||||
print(f"Processing: {doc.file.name}")
|
print(f"Processing: {doc.file.name}")
|
||||||
loader_class = self._get_file_loader( doc.file.name)
|
loader_class = self._get_file_loader(doc.file.name)
|
||||||
loader = loader_class(doc.file)
|
loader = loader_class(doc.file)
|
||||||
|
|
||||||
|
|
||||||
chunks = self._load_and_split_documents(doc.file.path)
|
chunks = self._load_and_split_documents(doc.file.path)
|
||||||
if chunks:
|
if chunks:
|
||||||
self.vector_store.add_documents(chunks)
|
self.vector_store.add_documents(chunks)
|
||||||
self.vector_store.persist()
|
self.vector_store.persist()
|
||||||
|
|
||||||
|
|
||||||
def ingest_documents(self, workspace: DocumentWorkspace | None = None) -> None:
|
def ingest_documents(self, workspace: DocumentWorkspace | None = None) -> None:
|
||||||
"""Ingest documents from a workspace into the vector store."""
|
"""Ingest documents from a workspace into the vector store."""
|
||||||
print(f"Getting the Document via the workspace: {workspace}")
|
print(f"Getting the Document via the workspace: {workspace}")
|
||||||
if workspace:
|
if workspace:
|
||||||
documents = [doc for doc in Document.objects.filter(workspace=workspace)]
|
documents = [doc for doc in Document.objects.filter(workspace=workspace)]
|
||||||
else:
|
else:
|
||||||
documents = [doc for doc in Document.objects.all()]
|
documents = [doc for doc in Document.objects.all()]
|
||||||
|
|
||||||
print(f"Processing the documents : {documents}")
|
print(f"Processing the documents : {documents}")
|
||||||
self._prepare_documents(documents)
|
self._prepare_documents(documents)
|
||||||
|
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def generate_response(self, conversation: Conversation, query: str, **kwargs):
|
def generate_response(self, conversation: Conversation, query: str, **kwargs):
|
||||||
"""Generate a response using RAG."""
|
"""Generate a response using RAG."""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def search_documents(self, query: str, workspace: Optional[DocumentWorkspace] = None, k: int = 4) -> List[Document]:
|
def search_documents(
|
||||||
|
self, query: str, workspace: Optional[DocumentWorkspace] = None, k: int = 4
|
||||||
|
) -> List[Document]:
|
||||||
"""Search relevant documents from the vector store."""
|
"""Search relevant documents from the vector store."""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def _get_file_loader(self, file_path: str):
|
def _get_file_loader(self, file_path: str):
|
||||||
"""Get appropriate loader for file type"""
|
"""Get appropriate loader for file type"""
|
||||||
ext = Path(file_path).suffix.lower()
|
ext = Path(file_path).suffix.lower()
|
||||||
return self.loader_mapping.get(ext, self.loader_mapping['*'])
|
return self.loader_mapping.get(ext, self.loader_mapping["*"])
|
||||||
|
|
||||||
def _sanitize_filename(self, filename: str) -> str:
|
def _sanitize_filename(self, filename: str) -> str:
|
||||||
"""Sanitize filename for safe storage"""
|
"""Sanitize filename for safe storage"""
|
||||||
return re.sub(r'[^\w\-_. ]', '_', filename)
|
return re.sub(r"[^\w\-_. ]", "_", filename)
|
||||||
|
|
||||||
def _save_uploaded_file(self, uploaded_file: UploadedFile, save_dir: str) -> str:
|
def _save_uploaded_file(self, uploaded_file: UploadedFile, save_dir: str) -> str:
|
||||||
"""Save uploaded file to disk"""
|
"""Save uploaded file to disk"""
|
||||||
os.makedirs(save_dir, exist_ok=True)
|
os.makedirs(save_dir, exist_ok=True)
|
||||||
sanitized_name = self._sanitize_filename(uploaded_file.name)
|
sanitized_name = self._sanitize_filename(uploaded_file.name)
|
||||||
file_path = os.path.join(save_dir, sanitized_name)
|
file_path = os.path.join(save_dir, sanitized_name)
|
||||||
|
|
||||||
with open(file_path, 'wb+') as destination:
|
with open(file_path, "wb+") as destination:
|
||||||
for chunk in uploaded_file.chunks():
|
for chunk in uploaded_file.chunks():
|
||||||
destination.write(chunk)
|
destination.write(chunk)
|
||||||
|
|
||||||
return file_path
|
return file_path
|
||||||
|
|
||||||
def _load_and_split_documents(self, file_path: str, metadata: dict = None) -> List[Document]:
|
def _load_and_split_documents(
|
||||||
|
self, file_path: str, metadata: dict = None
|
||||||
|
) -> List[Document]:
|
||||||
"""Load and split documents from file"""
|
"""Load and split documents from file"""
|
||||||
loader_class = self._get_file_loader(file_path)
|
loader_class = self._get_file_loader(file_path)
|
||||||
loader = loader_class(file_path)
|
loader = loader_class(file_path)
|
||||||
|
|
||||||
docs = loader.load()
|
docs = loader.load()
|
||||||
if metadata:
|
if metadata:
|
||||||
for doc in docs:
|
for doc in docs:
|
||||||
doc.metadata.update(metadata)
|
doc.metadata.update(metadata)
|
||||||
|
|
||||||
return self.text_splitter.split_documents(docs)
|
return self.text_splitter.split_documents(docs)
|
||||||
|
|
||||||
def add_files_to_store(
|
def add_files_to_store(
|
||||||
@@ -154,58 +156,51 @@ class RAGService(ABC):
|
|||||||
file_tupls: List[UploadedFile], # (file_path, name,workspace_id)
|
file_tupls: List[UploadedFile], # (file_path, name,workspace_id)
|
||||||
workspace_id: str,
|
workspace_id: str,
|
||||||
source: str = "upload",
|
source: str = "upload",
|
||||||
save_dir: str = "data/uploads"
|
save_dir: str = "data/uploads",
|
||||||
) -> Dict[str, Any]:
|
) -> Dict[str, Any]:
|
||||||
"""
|
"""
|
||||||
Process and add uploaded files to vector store
|
Process and add uploaded files to vector store
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
files: List of Django UploadedFile objects
|
files: List of Django UploadedFile objects
|
||||||
workspace_id: ID of the workspace these belong to
|
workspace_id: ID of the workspace these belong to
|
||||||
source: Source identifier for documents
|
source: Source identifier for documents
|
||||||
save_dir: Directory to save uploaded files
|
save_dir: Directory to save uploaded files
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Dictionary with processing results
|
Dictionary with processing results
|
||||||
"""
|
"""
|
||||||
results = {
|
results = {"total_added": 0, "failed_files": [], "processed_files": []}
|
||||||
'total_added': 0,
|
|
||||||
'failed_files': [],
|
|
||||||
'processed_files': []
|
|
||||||
}
|
|
||||||
|
|
||||||
for file_tuple in file_tupls:
|
for file_tuple in file_tupls:
|
||||||
try:
|
try:
|
||||||
# Save file to disk
|
# Save file to disk
|
||||||
|
|
||||||
|
|
||||||
# Prepare metadata
|
# Prepare metadata
|
||||||
metadata = {
|
metadata = {
|
||||||
'source': file_tuple[1],
|
"source": file_tuple[1],
|
||||||
'workspace_id': file_tuple[2],
|
"workspace_id": file_tuple[2],
|
||||||
'original_filename': file_tuple[1],
|
"original_filename": file_tuple[1],
|
||||||
'file_path': file_tuple[0],
|
"file_path": file_tuple[0],
|
||||||
}
|
}
|
||||||
|
|
||||||
# Load and split documents
|
# Load and split documents
|
||||||
docs = self._load_and_split_documents(file_path, metadata)
|
docs = self._load_and_split_documents(file_path, metadata)
|
||||||
|
|
||||||
# Add to vector store
|
# Add to vector store
|
||||||
if docs:
|
if docs:
|
||||||
self.vector_store.add_documents(docs)
|
self.vector_store.add_documents(docs)
|
||||||
results['total_added'] += len(docs)
|
results["total_added"] += len(docs)
|
||||||
results['processed_files'].append({
|
results["processed_files"].append(
|
||||||
'filename': file_tuple[1],
|
{"filename": file_tuple[1], "document_count": len(docs)}
|
||||||
'document_count': len(docs)
|
)
|
||||||
})
|
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
results['failed_files'].append({
|
results["failed_files"].append(
|
||||||
'filename': file_tuple[1],
|
{"filename": file_tuple[1], "error": str(e)}
|
||||||
'error': str(e)
|
)
|
||||||
})
|
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# Persist changes
|
# Persist changes
|
||||||
self.vector_store.persist()
|
self.vector_store.persist()
|
||||||
return results
|
return results
|
||||||
@@ -213,11 +208,11 @@ class RAGService(ABC):
|
|||||||
|
|
||||||
class SyncRAGService(RAGService):
|
class SyncRAGService(RAGService):
|
||||||
"""Synchronous RAG service implementation."""
|
"""Synchronous RAG service implementation."""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self._setup_chain()
|
self._setup_chain()
|
||||||
|
|
||||||
def _setup_chain(self):
|
def _setup_chain(self):
|
||||||
"""Setup the RAG chain."""
|
"""Setup the RAG chain."""
|
||||||
template = """Answer the question based only on the following context:
|
template = """Answer the question based only on the following context:
|
||||||
@@ -229,31 +224,32 @@ class SyncRAGService(RAGService):
|
|||||||
Question: {question}
|
Question: {question}
|
||||||
"""
|
"""
|
||||||
self.prompt = ChatPromptTemplate.from_template(template)
|
self.prompt = ChatPromptTemplate.from_template(template)
|
||||||
|
|
||||||
self.rag_chain = (
|
self.rag_chain = (
|
||||||
{
|
{
|
||||||
"context": self._retriever_with_history,
|
"context": self._retriever_with_history,
|
||||||
"history": lambda x: self._format_history(x["conversation"]),
|
"history": lambda x: self._format_history(x["conversation"]),
|
||||||
"question": lambda x: x["query"]
|
"question": lambda x: x["query"],
|
||||||
}
|
}
|
||||||
| self.prompt
|
| self.prompt
|
||||||
| self.llm
|
| self.llm
|
||||||
| StrOutputParser()
|
| StrOutputParser()
|
||||||
)
|
)
|
||||||
|
|
||||||
def _format_history(self, conversation: Conversation) -> str:
|
def _format_history(self, conversation: Conversation) -> str:
|
||||||
"""Format conversation history for the prompt."""
|
"""Format conversation history for the prompt."""
|
||||||
prompts = Prompt.objects.filter(conversation=conversation).order_by('created_at')
|
prompts = Prompt.objects.filter(conversation=conversation).order_by(
|
||||||
return "\n".join(
|
"created_at"
|
||||||
f"{'User' if prompt.is_user else 'AI'}: {prompt.text}"
|
|
||||||
for prompt in prompts
|
|
||||||
)
|
)
|
||||||
|
return "\n".join(
|
||||||
|
f"{'User' if prompt.is_user else 'AI'}: {prompt.text}" for prompt in prompts
|
||||||
|
)
|
||||||
|
|
||||||
def _retriever_with_history(self, input_dict: Dict[str, Any]) -> str:
|
def _retriever_with_history(self, input_dict: Dict[str, Any]) -> str:
|
||||||
"""Retrieve documents considering conversation history."""
|
"""Retrieve documents considering conversation history."""
|
||||||
query = input_dict["query"]
|
query = input_dict["query"]
|
||||||
conversation = input_dict["conversation"]
|
conversation = input_dict["conversation"]
|
||||||
|
|
||||||
# You could enhance this to consider historical context in retrieval
|
# You could enhance this to consider historical context in retrieval
|
||||||
relevant_docs = self.search_documents(query, conversation.workspace)
|
relevant_docs = self.search_documents(query, conversation.workspace)
|
||||||
if not relevant_docs:
|
if not relevant_docs:
|
||||||
@@ -262,8 +258,9 @@ class SyncRAGService(RAGService):
|
|||||||
else:
|
else:
|
||||||
return relevant_docs
|
return relevant_docs
|
||||||
|
|
||||||
|
def search_documents(
|
||||||
def search_documents(self, query: str, workspace: Optional[DocumentWorkspace] = None, k: int = 4) -> List[Document]:
|
self, query: str, workspace: Optional[DocumentWorkspace] = None, k: int = 4
|
||||||
|
) -> List[Document]:
|
||||||
"""Search relevant documents from the vector store."""
|
"""Search relevant documents from the vector store."""
|
||||||
filter_dict = {}
|
filter_dict = {}
|
||||||
if workspace:
|
if workspace:
|
||||||
@@ -271,31 +268,27 @@ class SyncRAGService(RAGService):
|
|||||||
print(f"search_kwargs: {search_kwargs}")
|
print(f"search_kwargs: {search_kwargs}")
|
||||||
retriever = self.vector_store.as_retriever(
|
retriever = self.vector_store.as_retriever(
|
||||||
search_type="similarity",
|
search_type="similarity",
|
||||||
search_kwargs={
|
search_kwargs={"k": k, "filter": filter_dict if filter_dict else None},
|
||||||
"k": k,
|
|
||||||
"filter": filter_dict if filter_dict else None
|
|
||||||
}
|
|
||||||
)
|
)
|
||||||
return retriever.get_relevant_documents(query)
|
return retriever.get_relevant_documents(query)
|
||||||
|
|
||||||
def generate_response(self, conversation: Conversation, query: str, **kwargs) -> Generator[str, None, None]:
|
def generate_response(
|
||||||
|
self, conversation: Conversation, query: str, **kwargs
|
||||||
|
) -> Generator[str, None, None]:
|
||||||
"""Generate response with streaming support."""
|
"""Generate response with streaming support."""
|
||||||
chain_input = {
|
chain_input = {"query": query, "conversation": conversation}
|
||||||
"query": query,
|
|
||||||
"conversation": conversation
|
|
||||||
}
|
|
||||||
|
|
||||||
for chunk in self.rag_chain.stream(chain_input):
|
for chunk in self.rag_chain.stream(chain_input):
|
||||||
yield chunk
|
yield chunk
|
||||||
|
|
||||||
|
|
||||||
class AsyncRAGService(RAGService):
|
class AsyncRAGService(RAGService):
|
||||||
"""Asynchronous RAG service implementation."""
|
"""Asynchronous RAG service implementation."""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self._setup_chain()
|
self._setup_chain()
|
||||||
|
|
||||||
def _setup_chain(self):
|
def _setup_chain(self):
|
||||||
"""Setup the RAG chain."""
|
"""Setup the RAG chain."""
|
||||||
template = """Answer the question based only on the following context:
|
template = """Answer the question based only on the following context:
|
||||||
@@ -307,72 +300,76 @@ class AsyncRAGService(RAGService):
|
|||||||
Question: {question}
|
Question: {question}
|
||||||
"""
|
"""
|
||||||
self.prompt = ChatPromptTemplate.from_template(template)
|
self.prompt = ChatPromptTemplate.from_template(template)
|
||||||
|
|
||||||
self.rag_chain = (
|
self.rag_chain = (
|
||||||
{
|
{
|
||||||
"context": self._retriever_with_history,
|
"context": self._retriever_with_history,
|
||||||
"history": lambda x: self._format_history(x["conversation"]),
|
"history": lambda x: self._format_history(x["conversation"]),
|
||||||
"question": lambda x: x["query"]
|
"question": lambda x: x["query"],
|
||||||
}
|
}
|
||||||
| self.prompt
|
| self.prompt
|
||||||
| self.llm
|
| self.llm
|
||||||
| StrOutputParser()
|
| StrOutputParser()
|
||||||
)
|
)
|
||||||
|
|
||||||
async def _format_history(self, conversation: Conversation) -> str:
|
async def _format_history(self, conversation: Conversation) -> str:
|
||||||
"""Format conversation history for the prompt."""
|
"""Format conversation history for the prompt."""
|
||||||
prompts = await Prompt.objects.filter(conversation=conversation).order_by('created_at').alist()
|
prompts = (
|
||||||
|
await Prompt.objects.filter(conversation=conversation)
|
||||||
|
.order_by("created_at")
|
||||||
|
.alist()
|
||||||
|
)
|
||||||
print(f"prompts that we are seeding with are: {prompts}")
|
print(f"prompts that we are seeding with are: {prompts}")
|
||||||
return "\n".join(
|
return "\n".join(
|
||||||
f"{'User' if prompt.is_user else 'AI'}: {prompt.text}"
|
f"{'User' if prompt.is_user else 'AI'}: {prompt.text}" for prompt in prompts
|
||||||
for prompt in prompts
|
|
||||||
)
|
)
|
||||||
|
|
||||||
async def _retriever_with_history(self, input_dict: Dict[str, Any]) -> str:
|
async def _retriever_with_history(self, input_dict: Dict[str, Any]) -> str:
|
||||||
"""Retrieve documents considering conversation history."""
|
"""Retrieve documents considering conversation history."""
|
||||||
print(f"Retrieving history with input: {input_dict}")
|
print(f"Retrieving history with input: {input_dict}")
|
||||||
query = input_dict["query"]
|
query = input_dict["query"]
|
||||||
conversation = input_dict["conversation"]
|
conversation = input_dict["conversation"]
|
||||||
workspace = input_dict["workspace"]
|
workspace = input_dict["workspace"]
|
||||||
|
|
||||||
# You could enhance this to consider historical context in retrieval
|
# You could enhance this to consider historical context in retrieval
|
||||||
docs= await self.search_documents(query, workspace)
|
docs = await self.search_documents(query, workspace)
|
||||||
|
|
||||||
if not docs:
|
if not docs:
|
||||||
print("Didn't find any relevant docs")
|
print("Didn't find any relevant docs")
|
||||||
|
|
||||||
print("\n\n".join(doc.page_content for doc in docs))
|
print("\n\n".join(doc.page_content for doc in docs))
|
||||||
return "\n\n".join(doc.page_content for doc in docs)
|
return "\n\n".join(doc.page_content for doc in docs)
|
||||||
|
|
||||||
|
async def search_documents(
|
||||||
async def search_documents(self, query: str, workspace: Optional[DocumentWorkspace] = None, k: int = 4) -> List[Document]:
|
self, query: str, workspace: Optional[DocumentWorkspace] = None, k: int = 4
|
||||||
|
) -> List[Document]:
|
||||||
"""Search relevant documents from the vector store."""
|
"""Search relevant documents from the vector store."""
|
||||||
filter_dict = {}
|
filter_dict = {}
|
||||||
print(f"Do we have a workspace: {workspace}")
|
print(f"Do we have a workspace: {workspace}")
|
||||||
if workspace:
|
if workspace:
|
||||||
filter_dict["workspace_id"] = workspace.id
|
filter_dict["workspace_id"] = workspace.id
|
||||||
search_kwargs={
|
search_kwargs = {"k": k, "filter": filter_dict if filter_dict else None}
|
||||||
"k": k,
|
|
||||||
"filter": filter_dict if filter_dict else None
|
|
||||||
}
|
|
||||||
print(f"search_kwargs: {search_kwargs}")
|
print(f"search_kwargs: {search_kwargs}")
|
||||||
|
|
||||||
retriever = self.vector_store.as_retriever(
|
retriever = self.vector_store.as_retriever(
|
||||||
search_type="mmr",
|
search_type="mmr",
|
||||||
search_kwargs={
|
search_kwargs={"k": k, "filter": filter_dict if filter_dict else None},
|
||||||
"k": k,
|
|
||||||
"filter": filter_dict if filter_dict else None
|
|
||||||
}
|
|
||||||
)
|
)
|
||||||
return await retriever.aget_relevant_documents(query)
|
return await retriever.aget_relevant_documents(query)
|
||||||
|
|
||||||
async def generate_response(self, conversation: Conversation, query: str, workspace: DocumentWorkspace, **kwargs) -> AsyncGenerator[str, None]:
|
async def generate_response(
|
||||||
|
self,
|
||||||
|
conversation: Conversation,
|
||||||
|
query: str,
|
||||||
|
workspace: DocumentWorkspace,
|
||||||
|
**kwargs,
|
||||||
|
) -> AsyncGenerator[str, None]:
|
||||||
"""Generate response with streaming support."""
|
"""Generate response with streaming support."""
|
||||||
chain_input = {
|
chain_input = {
|
||||||
"query": query,
|
"query": query,
|
||||||
"conversation": conversation,
|
"conversation": conversation,
|
||||||
"workspace": workspace,
|
"workspace": workspace,
|
||||||
}
|
}
|
||||||
|
|
||||||
async for chunk in self.rag_chain.astream(chain_input):
|
async for chunk in self.rag_chain.astream(chain_input):
|
||||||
yield chunk
|
yield chunk
|
||||||
|
|||||||
@@ -5,9 +5,14 @@ from typing import List, Dict, Any
|
|||||||
|
|
||||||
from django.test import TestCase as DjangoTestCase
|
from django.test import TestCase as DjangoTestCase
|
||||||
|
|
||||||
from chat_backend.services.rag_services import RAGService, SyncRAGService, AsyncRAGService
|
from chat_backend.services.rag_services import (
|
||||||
|
RAGService,
|
||||||
|
SyncRAGService,
|
||||||
|
AsyncRAGService,
|
||||||
|
)
|
||||||
from chat_backend.models import Conversation, Prompt, DocumentWorkspace, Document
|
from chat_backend.models import Conversation, Prompt, DocumentWorkspace, Document
|
||||||
|
|
||||||
|
|
||||||
class TestRAGService(TestCase):
|
class TestRAGService(TestCase):
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
self.rag_service = RAGService()
|
self.rag_service = RAGService()
|
||||||
@@ -16,18 +21,20 @@ class TestRAGService(TestCase):
|
|||||||
self.rag_service.text_splitter = MagicMock()
|
self.rag_service.text_splitter = MagicMock()
|
||||||
|
|
||||||
def test_initialize_vector_store(self):
|
def test_initialize_vector_store(self):
|
||||||
with patch('os.path.exists', return_value=False), \
|
with patch("os.path.exists", return_value=False), patch(
|
||||||
patch('os.makedirs') as mock_makedirs, \
|
"os.makedirs"
|
||||||
patch('langchain_community.vectorstores.Chroma') as mock_chroma:
|
) as mock_makedirs, patch(
|
||||||
|
"langchain_community.vectorstores.Chroma"
|
||||||
|
) as mock_chroma:
|
||||||
|
|
||||||
# Reset the vector store to test initialization
|
# Reset the vector store to test initialization
|
||||||
self.rag_service.vector_store = None
|
self.rag_service.vector_store = None
|
||||||
result = self.rag_service._initialize_vector_store()
|
result = self.rag_service._initialize_vector_store()
|
||||||
|
|
||||||
mock_makedirs.assert_called_once_with("chroma_db")
|
mock_makedirs.assert_called_once_with("chroma_db")
|
||||||
mock_chroma.assert_called_once_with(
|
mock_chroma.assert_called_once_with(
|
||||||
embedding_function=self.rag_service.embedding_model,
|
embedding_function=self.rag_service.embedding_model,
|
||||||
persist_directory="chroma_db"
|
persist_directory="chroma_db",
|
||||||
)
|
)
|
||||||
self.assertIsNotNone(result)
|
self.assertIsNotNone(result)
|
||||||
|
|
||||||
@@ -40,11 +47,13 @@ class TestRAGService(TestCase):
|
|||||||
mock_doc1.id = 1
|
mock_doc1.id = 1
|
||||||
|
|
||||||
self.rag_service.text_splitter.split_text.return_value = ["chunk1", "chunk2"]
|
self.rag_service.text_splitter.split_text.return_value = ["chunk1", "chunk2"]
|
||||||
|
|
||||||
result = self.rag_service._prepare_documents([mock_doc1])
|
result = self.rag_service._prepare_documents([mock_doc1])
|
||||||
|
|
||||||
self.assertEqual(len(result), 2)
|
self.assertEqual(len(result), 2)
|
||||||
self.rag_service.text_splitter.split_text.assert_called_once_with("Test content")
|
self.rag_service.text_splitter.split_text.assert_called_once_with(
|
||||||
|
"Test content"
|
||||||
|
)
|
||||||
self.assertEqual(result[0].page_content, "chunk1")
|
self.assertEqual(result[0].page_content, "chunk1")
|
||||||
self.assertEqual(result[0].metadata["source"], "test_source")
|
self.assertEqual(result[0].metadata["source"], "test_source")
|
||||||
|
|
||||||
@@ -52,13 +61,19 @@ class TestRAGService(TestCase):
|
|||||||
mock_workspace = MagicMock()
|
mock_workspace = MagicMock()
|
||||||
mock_document = MagicMock()
|
mock_document = MagicMock()
|
||||||
mock_documents = [mock_document]
|
mock_documents = [mock_document]
|
||||||
|
|
||||||
with patch('services.rag_services.Document.objects.filter', return_value=mock_documents):
|
with patch(
|
||||||
self.rag_service._prepare_documents = MagicMock(return_value=["processed_doc"])
|
"services.rag_services.Document.objects.filter", return_value=mock_documents
|
||||||
|
):
|
||||||
|
self.rag_service._prepare_documents = MagicMock(
|
||||||
|
return_value=["processed_doc"]
|
||||||
|
)
|
||||||
|
|
||||||
self.rag_service.ingest_documents(mock_workspace)
|
self.rag_service.ingest_documents(mock_workspace)
|
||||||
|
|
||||||
self.rag_service.vector_store.add_documents.assert_called_once_with(["processed_doc"])
|
self.rag_service.vector_store.add_documents.assert_called_once_with(
|
||||||
|
["processed_doc"]
|
||||||
|
)
|
||||||
self.rag_service.vector_store.persist.assert_called_once()
|
self.rag_service.vector_store.persist.assert_called_once()
|
||||||
|
|
||||||
|
|
||||||
@@ -71,40 +86,39 @@ class TestSyncRAGService(DjangoTestCase):
|
|||||||
|
|
||||||
self.mock_conversation = MagicMock(spec=Conversation)
|
self.mock_conversation = MagicMock(spec=Conversation)
|
||||||
self.mock_conversation.workspace = MagicMock()
|
self.mock_conversation.workspace = MagicMock()
|
||||||
|
|
||||||
self.mock_prompt1 = MagicMock(spec=Prompt)
|
self.mock_prompt1 = MagicMock(spec=Prompt)
|
||||||
self.mock_prompt1.is_user = True
|
self.mock_prompt1.is_user = True
|
||||||
self.mock_prompt1.text = "User question"
|
self.mock_prompt1.text = "User question"
|
||||||
self.mock_prompt1.created_at = "2023-01-01"
|
self.mock_prompt1.created_at = "2023-01-01"
|
||||||
|
|
||||||
self.mock_prompt2 = MagicMock(spec=Prompt)
|
self.mock_prompt2 = MagicMock(spec=Prompt)
|
||||||
self.mock_prompt2.is_user = False
|
self.mock_prompt2.is_user = False
|
||||||
self.mock_prompt2.text = "AI response"
|
self.mock_prompt2.text = "AI response"
|
||||||
self.mock_prompt2.created_at = "2023-01-02"
|
self.mock_prompt2.created_at = "2023-01-02"
|
||||||
|
|
||||||
def test_format_history(self):
|
def test_format_history(self):
|
||||||
with patch('services.rag_services.Prompt.objects.filter') as mock_filter:
|
with patch("services.rag_services.Prompt.objects.filter") as mock_filter:
|
||||||
mock_filter.return_value.order_by.return_value = [self.mock_prompt1, self.mock_prompt2]
|
mock_filter.return_value.order_by.return_value = [
|
||||||
|
self.mock_prompt1,
|
||||||
|
self.mock_prompt2,
|
||||||
|
]
|
||||||
|
|
||||||
result = self.sync_service._format_history(self.mock_conversation)
|
result = self.sync_service._format_history(self.mock_conversation)
|
||||||
|
|
||||||
expected = "User: User question\nAI: AI response"
|
expected = "User: User question\nAI: AI response"
|
||||||
self.assertEqual(result, expected)
|
self.assertEqual(result, expected)
|
||||||
mock_filter.assert_called_once_with(conversation=self.mock_conversation)
|
mock_filter.assert_called_once_with(conversation=self.mock_conversation)
|
||||||
|
|
||||||
def test_retriever_with_history(self):
|
def test_retriever_with_history(self):
|
||||||
input_dict = {
|
input_dict = {"query": "test query", "conversation": self.mock_conversation}
|
||||||
"query": "test query",
|
|
||||||
"conversation": self.mock_conversation
|
|
||||||
}
|
|
||||||
|
|
||||||
self.sync_service.search_documents = MagicMock(return_value=["doc1", "doc2"])
|
self.sync_service.search_documents = MagicMock(return_value=["doc1", "doc2"])
|
||||||
|
|
||||||
result = self.sync_service._retriever_with_history(input_dict)
|
result = self.sync_service._retriever_with_history(input_dict)
|
||||||
|
|
||||||
self.sync_service.search_documents.assert_called_once_with(
|
self.sync_service.search_documents.assert_called_once_with(
|
||||||
"test query",
|
"test query", self.mock_conversation.workspace
|
||||||
self.mock_conversation.workspace
|
|
||||||
)
|
)
|
||||||
self.assertEqual(result, ["doc1", "doc2"])
|
self.assertEqual(result, ["doc1", "doc2"])
|
||||||
|
|
||||||
@@ -112,29 +126,30 @@ class TestSyncRAGService(DjangoTestCase):
|
|||||||
mock_retriever = MagicMock()
|
mock_retriever = MagicMock()
|
||||||
mock_retriever.get_relevant_documents.return_value = ["doc1", "doc2"]
|
mock_retriever.get_relevant_documents.return_value = ["doc1", "doc2"]
|
||||||
self.sync_service.vector_store.as_retriever.return_value = mock_retriever
|
self.sync_service.vector_store.as_retriever.return_value = mock_retriever
|
||||||
|
|
||||||
result = self.sync_service.search_documents("test query", self.mock_conversation.workspace)
|
result = self.sync_service.search_documents(
|
||||||
|
"test query", self.mock_conversation.workspace
|
||||||
|
)
|
||||||
|
|
||||||
self.sync_service.vector_store.as_retriever.assert_called_once_with(
|
self.sync_service.vector_store.as_retriever.assert_called_once_with(
|
||||||
search_type="similarity",
|
search_type="similarity",
|
||||||
search_kwargs={
|
search_kwargs={
|
||||||
"k": 4,
|
"k": 4,
|
||||||
"filter": {"workspace_id": self.mock_conversation.workspace.id}
|
"filter": {"workspace_id": self.mock_conversation.workspace.id},
|
||||||
}
|
},
|
||||||
)
|
)
|
||||||
self.assertEqual(result, ["doc1", "doc2"])
|
self.assertEqual(result, ["doc1", "doc2"])
|
||||||
|
|
||||||
def test_generate_response(self):
|
def test_generate_response(self):
|
||||||
chain_input = {
|
chain_input = {"query": "test query", "conversation": self.mock_conversation}
|
||||||
"query": "test query",
|
|
||||||
"conversation": self.mock_conversation
|
|
||||||
}
|
|
||||||
|
|
||||||
mock_stream = ["chunk1", "chunk2", "chunk3"]
|
mock_stream = ["chunk1", "chunk2", "chunk3"]
|
||||||
self.sync_service.rag_chain.stream.return_value = mock_stream
|
self.sync_service.rag_chain.stream.return_value = mock_stream
|
||||||
|
|
||||||
result = list(self.sync_service.generate_response(self.mock_conversation, "test query"))
|
result = list(
|
||||||
|
self.sync_service.generate_response(self.mock_conversation, "test query")
|
||||||
|
)
|
||||||
|
|
||||||
self.sync_service.rag_chain.stream.assert_called_once_with(chain_input)
|
self.sync_service.rag_chain.stream.assert_called_once_with(chain_input)
|
||||||
self.assertEqual(result, mock_stream)
|
self.assertEqual(result, mock_stream)
|
||||||
|
|
||||||
@@ -148,12 +163,12 @@ class TestAsyncRAGService(DjangoTestCase):
|
|||||||
|
|
||||||
self.mock_conversation = MagicMock(spec=Conversation)
|
self.mock_conversation = MagicMock(spec=Conversation)
|
||||||
self.mock_conversation.workspace = MagicMock()
|
self.mock_conversation.workspace = MagicMock()
|
||||||
|
|
||||||
self.mock_prompt1 = MagicMock(spec=Prompt)
|
self.mock_prompt1 = MagicMock(spec=Prompt)
|
||||||
self.mock_prompt1.is_user = True
|
self.mock_prompt1.is_user = True
|
||||||
self.mock_prompt1.text = "User question"
|
self.mock_prompt1.text = "User question"
|
||||||
self.mock_prompt1.created_at = "2023-01-01"
|
self.mock_prompt1.created_at = "2023-01-01"
|
||||||
|
|
||||||
self.mock_prompt2 = MagicMock(spec=Prompt)
|
self.mock_prompt2 = MagicMock(spec=Prompt)
|
||||||
self.mock_prompt2.is_user = False
|
self.mock_prompt2.is_user = False
|
||||||
self.mock_prompt2.text = "AI response"
|
self.mock_prompt2.text = "AI response"
|
||||||
@@ -161,28 +176,29 @@ class TestAsyncRAGService(DjangoTestCase):
|
|||||||
|
|
||||||
async def test_format_history(self):
|
async def test_format_history(self):
|
||||||
mock_manager = AsyncMock()
|
mock_manager = AsyncMock()
|
||||||
mock_manager.order_by.return_value.alist.return_value = [self.mock_prompt1, self.mock_prompt2]
|
mock_manager.order_by.return_value.alist.return_value = [
|
||||||
|
self.mock_prompt1,
|
||||||
with patch('services.rag_services.Prompt.objects.filter', return_value=mock_manager):
|
self.mock_prompt2,
|
||||||
|
]
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"services.rag_services.Prompt.objects.filter", return_value=mock_manager
|
||||||
|
):
|
||||||
result = await self.async_service._format_history(self.mock_conversation)
|
result = await self.async_service._format_history(self.mock_conversation)
|
||||||
|
|
||||||
expected = "User: User question\nAI: AI response"
|
expected = "User: User question\nAI: AI response"
|
||||||
self.assertEqual(result, expected)
|
self.assertEqual(result, expected)
|
||||||
mock_manager.order_by.assert_called_once_with('created_at')
|
mock_manager.order_by.assert_called_once_with("created_at")
|
||||||
|
|
||||||
async def test_retriever_with_history(self):
|
async def test_retriever_with_history(self):
|
||||||
input_dict = {
|
input_dict = {"query": "test query", "conversation": self.mock_conversation}
|
||||||
"query": "test query",
|
|
||||||
"conversation": self.mock_conversation
|
|
||||||
}
|
|
||||||
|
|
||||||
self.async_service.search_documents = AsyncMock(return_value=["doc1", "doc2"])
|
self.async_service.search_documents = AsyncMock(return_value=["doc1", "doc2"])
|
||||||
|
|
||||||
result = await self.async_service._retriever_with_history(input_dict)
|
result = await self.async_service._retriever_with_history(input_dict)
|
||||||
|
|
||||||
self.async_service.search_documents.assert_awaited_once_with(
|
self.async_service.search_documents.assert_awaited_once_with(
|
||||||
"test query",
|
"test query", self.mock_conversation.workspace
|
||||||
self.mock_conversation.workspace
|
|
||||||
)
|
)
|
||||||
self.assertEqual(result, ["doc1", "doc2"])
|
self.assertEqual(result, ["doc1", "doc2"])
|
||||||
|
|
||||||
@@ -190,30 +206,31 @@ class TestAsyncRAGService(DjangoTestCase):
|
|||||||
mock_retriever = AsyncMock()
|
mock_retriever = AsyncMock()
|
||||||
mock_retriever.aget_relevant_documents.return_value = ["doc1", "doc2"]
|
mock_retriever.aget_relevant_documents.return_value = ["doc1", "doc2"]
|
||||||
self.async_service.vector_store.as_retriever.return_value = mock_retriever
|
self.async_service.vector_store.as_retriever.return_value = mock_retriever
|
||||||
|
|
||||||
result = await self.async_service.search_documents("test query", self.mock_conversation.workspace)
|
result = await self.async_service.search_documents(
|
||||||
|
"test query", self.mock_conversation.workspace
|
||||||
|
)
|
||||||
|
|
||||||
self.async_service.vector_store.as_retriever.assert_called_once_with(
|
self.async_service.vector_store.as_retriever.assert_called_once_with(
|
||||||
search_type="similarity",
|
search_type="similarity",
|
||||||
search_kwargs={
|
search_kwargs={
|
||||||
"k": 4,
|
"k": 4,
|
||||||
"filter": {"workspace_id": self.mock_conversation.workspace.id}
|
"filter": {"workspace_id": self.mock_conversation.workspace.id},
|
||||||
}
|
},
|
||||||
)
|
)
|
||||||
self.assertEqual(result, ["doc1", "doc2"])
|
self.assertEqual(result, ["doc1", "doc2"])
|
||||||
|
|
||||||
async def test_generate_response(self):
|
async def test_generate_response(self):
|
||||||
chain_input = {
|
chain_input = {"query": "test query", "conversation": self.mock_conversation}
|
||||||
"query": "test query",
|
|
||||||
"conversation": self.mock_conversation
|
|
||||||
}
|
|
||||||
|
|
||||||
mock_stream = ["chunk1", "chunk2", "chunk3"]
|
mock_stream = ["chunk1", "chunk2", "chunk3"]
|
||||||
self.async_service.rag_chain.astream.return_value = mock_stream
|
self.async_service.rag_chain.astream.return_value = mock_stream
|
||||||
|
|
||||||
chunks = []
|
chunks = []
|
||||||
async for chunk in self.async_service.generate_response(self.mock_conversation, "test query"):
|
async for chunk in self.async_service.generate_response(
|
||||||
|
self.mock_conversation, "test query"
|
||||||
|
):
|
||||||
chunks.append(chunk)
|
chunks.append(chunk)
|
||||||
|
|
||||||
self.async_service.rag_chain.astream.assert_awaited_once_with(chain_input)
|
self.async_service.rag_chain.astream.assert_awaited_once_with(chain_input)
|
||||||
self.assertEqual(chunks, mock_stream)
|
self.assertEqual(chunks, mock_stream)
|
||||||
|
|||||||
@@ -1,22 +1,28 @@
|
|||||||
from langchain_core.prompts import ChatPromptTemplate
|
from langchain_core.prompts import ChatPromptTemplate
|
||||||
from langchain_community.llms import Ollama
|
|
||||||
|
# from langchain_community.llms import Ollama
|
||||||
|
from langchain_ollama import OllamaLLM
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
|
|
||||||
class TitleGenerator:
|
class TitleGenerator:
|
||||||
"""
|
"""
|
||||||
Generates short, descriptive titles for conversations based on the first prompt.
|
Generates short, descriptive titles for conversations based on the first prompt.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.llm = Ollama(
|
self.llm = OllamaLLM(
|
||||||
model="llama3",
|
model="llama3.2",
|
||||||
temperature=0.5, # Slightly creative but not too random
|
temperature=0.5, # Slightly creative but not too random
|
||||||
top_k=20,
|
top_k=20,
|
||||||
num_ctx=2048 # Shorter context needed for titles
|
num_ctx=2048, # Shorter context needed for titles
|
||||||
)
|
)
|
||||||
|
|
||||||
self.title_prompt = ChatPromptTemplate.from_messages([
|
self.title_prompt = ChatPromptTemplate.from_messages(
|
||||||
("system", """You are a conversation title generator. Create a very short (2-5 word) title based on the user's first message.
|
[
|
||||||
|
(
|
||||||
|
"system",
|
||||||
|
"""You are a conversation title generator. Create a very short (2-5 word) title based on the user's first message.
|
||||||
|
|
||||||
Rules:
|
Rules:
|
||||||
1. Keep it extremely concise
|
1. Keep it extremely concise
|
||||||
@@ -31,12 +37,14 @@ Examples:
|
|||||||
- "Generate an image of a dragon" → "Dragon Image Generation"
|
- "Generate an image of a dragon" → "Dragon Image Generation"
|
||||||
- "Find our company's privacy policy" → "Privacy Policy Search"
|
- "Find our company's privacy policy" → "Privacy Policy Search"
|
||||||
|
|
||||||
Return ONLY the title, nothing else."""),
|
Return ONLY the title, nothing else.""",
|
||||||
("human", "{prompt}")
|
),
|
||||||
])
|
("human", "{prompt}"),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
self.chain = self.title_prompt | self.llm
|
self.chain = self.title_prompt | self.llm
|
||||||
|
|
||||||
async def generate_async(self, prompt: str) -> str:
|
async def generate_async(self, prompt: str) -> str:
|
||||||
"""Generate title asynchronously"""
|
"""Generate title asynchronously"""
|
||||||
try:
|
try:
|
||||||
@@ -45,7 +53,7 @@ Return ONLY the title, nothing else."""),
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Title generation error: {e}")
|
print(f"Title generation error: {e}")
|
||||||
return "Conversation"
|
return "Conversation"
|
||||||
|
|
||||||
def generate(self, prompt: str) -> str:
|
def generate(self, prompt: str) -> str:
|
||||||
"""Generate title synchronously"""
|
"""Generate title synchronously"""
|
||||||
try:
|
try:
|
||||||
@@ -54,14 +62,14 @@ Return ONLY the title, nothing else."""),
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Title generation error: {e}")
|
print(f"Title generation error: {e}")
|
||||||
return "Conversation"
|
return "Conversation"
|
||||||
|
|
||||||
def _clean_response(self, response: str) -> str:
|
def _clean_response(self, response: str) -> str:
|
||||||
"""Clean and format the LLM response"""
|
"""Clean and format the LLM response"""
|
||||||
# Remove any quotes or punctuation
|
# Remove any quotes or punctuation
|
||||||
response = response.strip('"\'.!? \n\t')
|
response = response.strip("\"'.!? \n\t")
|
||||||
# Ensure title case and trim
|
# Ensure title case and trim
|
||||||
return response.title()[:50] # Hard limit for safety
|
return response.title()[:50] # Hard limit for safety
|
||||||
|
|
||||||
|
|
||||||
# Singleton instance
|
# Singleton instance
|
||||||
title_generator = TitleGenerator()
|
title_generator = TitleGenerator()
|
||||||
|
|||||||
@@ -3,16 +3,18 @@ from django.dispatch import receiver
|
|||||||
from chat_backend.models import Document
|
from chat_backend.models import Document
|
||||||
from .services.rag_services import AsyncRAGService
|
from .services.rag_services import AsyncRAGService
|
||||||
|
|
||||||
|
|
||||||
@receiver(post_save, sender=Document)
|
@receiver(post_save, sender=Document)
|
||||||
def update_vector_on_save(sender, instance, **kwargs):
|
def update_vector_on_save(sender, instance, **kwargs):
|
||||||
"""Update vector store when documents are saved"""
|
"""Update vector store when documents are saved"""
|
||||||
|
|
||||||
if kwargs.get('created', False):
|
if kwargs.get("created", False):
|
||||||
rag_service = AsyncRAGService()
|
rag_service = AsyncRAGService()
|
||||||
rag_service.ingest_documents()
|
rag_service.ingest_documents()
|
||||||
|
|
||||||
|
|
||||||
@receiver(post_delete, sender=Document)
|
@receiver(post_delete, sender=Document)
|
||||||
def delete_vector_on_remove(sender, instance, **kwargs):
|
def delete_vector_on_remove(sender, instance, **kwargs):
|
||||||
"""Handle document deletion by re-indexing the whole workspace"""
|
"""Handle document deletion by re-indexing the whole workspace"""
|
||||||
rag_service = AsyncRAGService()
|
rag_service = AsyncRAGService()
|
||||||
rag_service.ingest_documents()
|
rag_service.ingest_documents()
|
||||||
|
|||||||
@@ -13,80 +13,75 @@ from django.core.files.uploadedfile import SimpleUploadedFile
|
|||||||
|
|
||||||
# Minimal valid PDF bytes
|
# Minimal valid PDF bytes
|
||||||
VALID_PDF_BYTES = (
|
VALID_PDF_BYTES = (
|
||||||
b'%PDF-1.3\n'
|
b"%PDF-1.3\n"
|
||||||
b'1 0 obj\n'
|
b"1 0 obj\n"
|
||||||
b'<< /Type /Catalog /Pages 2 0 R >>\n'
|
b"<< /Type /Catalog /Pages 2 0 R >>\n"
|
||||||
b'endobj\n'
|
b"endobj\n"
|
||||||
b'2 0 obj\n'
|
b"2 0 obj\n"
|
||||||
b'<< /Type /Pages /Kids [3 0 R] /Count 1 >>\n'
|
b"<< /Type /Pages /Kids [3 0 R] /Count 1 >>\n"
|
||||||
b'endobj\n'
|
b"endobj\n"
|
||||||
b'3 0 obj\n'
|
b"3 0 obj\n"
|
||||||
b'<< /Type /Page /Parent 2 0 R /Resources << >> /MediaBox [0 0 612 792] /Contents 4 0 R >>\n'
|
b"<< /Type /Page /Parent 2 0 R /Resources << >> /MediaBox [0 0 612 792] /Contents 4 0 R >>\n"
|
||||||
b'endobj\n'
|
b"endobj\n"
|
||||||
b'4 0 obj\n'
|
b"4 0 obj\n"
|
||||||
b'<< /Length 44 >>\n'
|
b"<< /Length 44 >>\n"
|
||||||
b'stream\n'
|
b"stream\n"
|
||||||
b'BT /F1 12 Tf 72 720 Td (Test PDF) Tj ET\n'
|
b"BT /F1 12 Tf 72 720 Td (Test PDF) Tj ET\n"
|
||||||
b'endstream\n'
|
b"endstream\n"
|
||||||
b'endobj\n'
|
b"endobj\n"
|
||||||
b'xref\n'
|
b"xref\n"
|
||||||
b'0 5\n'
|
b"0 5\n"
|
||||||
b'0000000000 65535 f \n'
|
b"0000000000 65535 f \n"
|
||||||
b'0000000009 00000 n \n'
|
b"0000000009 00000 n \n"
|
||||||
b'0000000058 00000 n \n'
|
b"0000000058 00000 n \n"
|
||||||
b'0000000117 00000 n \n'
|
b"0000000117 00000 n \n"
|
||||||
b'0000000223 00000 n \n'
|
b"0000000223 00000 n \n"
|
||||||
b'trailer\n'
|
b"trailer\n"
|
||||||
b'<< /Size 5 /Root 1 0 R >>\n'
|
b"<< /Size 5 /Root 1 0 R >>\n"
|
||||||
b'startxref\n'
|
b"startxref\n"
|
||||||
b'317\n'
|
b"317\n"
|
||||||
b'%%EOF'
|
b"%%EOF"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class DocumentWorkspaceViewsTestCase(APITestCase):
|
class DocumentWorkspaceViewsTestCase(APITestCase):
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
self.company = Company.objects.create(
|
self.company = Company.objects.create(
|
||||||
name="test",
|
name="test", state="IL", zipcode="60189", address="1968 Greensboro Dr"
|
||||||
state="IL",
|
|
||||||
zipcode="60189",
|
|
||||||
address="1968 Greensboro Dr"
|
|
||||||
)
|
)
|
||||||
self.user = get_user_model().objects.create_user(
|
self.user = get_user_model().objects.create_user(
|
||||||
company=self.company,
|
company=self.company,
|
||||||
username='testuser',
|
username="testuser",
|
||||||
password='testpass123',
|
password="testpass123",
|
||||||
email="test@test.com",
|
email="test@test.com",
|
||||||
)
|
)
|
||||||
|
|
||||||
self.client = APIClient()
|
self.client = APIClient()
|
||||||
self.client.force_authenticate(user=self.user)
|
self.client.force_authenticate(user=self.user)
|
||||||
|
|
||||||
self.workspace = DocumentWorkspace.objects.create(
|
self.workspace = DocumentWorkspace.objects.create(
|
||||||
company = self.user.company,
|
company=self.user.company, name="Test Workspace"
|
||||||
name='Test Workspace'
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_list_workspaces(self):
|
def test_list_workspaces(self):
|
||||||
url = reverse('document_workspaces')
|
url = reverse("document_workspaces")
|
||||||
response = self.client.get(url)
|
response = self.client.get(url)
|
||||||
self.assertEqual(response.status_code, status.HTTP_200_OK)
|
self.assertEqual(response.status_code, status.HTTP_200_OK)
|
||||||
self.assertEqual(len(response.data), 1)
|
self.assertEqual(len(response.data), 1)
|
||||||
self.assertEqual(response.data[0]['name'], 'Test Workspace')
|
self.assertEqual(response.data[0]["name"], "Test Workspace")
|
||||||
|
|
||||||
def test_create_workspace(self):
|
def test_create_workspace(self):
|
||||||
url = reverse('document_workspaces')
|
url = reverse("document_workspaces")
|
||||||
data = {
|
data = {"name": "New Workspace"}
|
||||||
'name': 'New Workspace'
|
response = self.client.post(url, data, format="json")
|
||||||
}
|
|
||||||
response = self.client.post(url, data, format='json')
|
|
||||||
self.assertEqual(response.status_code, status.HTTP_201_CREATED)
|
self.assertEqual(response.status_code, status.HTTP_201_CREATED)
|
||||||
self.assertEqual(DocumentWorkspace.objects.count(), 2)
|
self.assertEqual(DocumentWorkspace.objects.count(), 2)
|
||||||
|
|
||||||
def test_retrieve_workspace(self):
|
def test_retrieve_workspace(self):
|
||||||
url = reverse('document_workspaces')
|
url = reverse("document_workspaces")
|
||||||
response = self.client.get(url)
|
response = self.client.get(url)
|
||||||
self.assertEqual(response.status_code, status.HTTP_200_OK)
|
self.assertEqual(response.status_code, status.HTTP_200_OK)
|
||||||
self.assertEqual(response.data[0]['name'], 'Test Workspace')
|
self.assertEqual(response.data[0]["name"], "Test Workspace")
|
||||||
|
|
||||||
# def test_update_workspace(self):
|
# def test_update_workspace(self):
|
||||||
# url = reverse('document_workspaces')
|
# url = reverse('document_workspaces')
|
||||||
@@ -104,107 +99,89 @@ class DocumentWorkspaceViewsTestCase(APITestCase):
|
|||||||
# self.assertEqual(response.status_code, status.HTTP_204_NO_CONTENT)
|
# self.assertEqual(response.status_code, status.HTTP_204_NO_CONTENT)
|
||||||
# self.assertEqual(DocumentWorkspace.objects.count(), 0)
|
# self.assertEqual(DocumentWorkspace.objects.count(), 0)
|
||||||
|
|
||||||
|
|
||||||
class DocumentViewsTestCase(APITestCase):
|
class DocumentViewsTestCase(APITestCase):
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
self.company = Company.objects.create(
|
self.company = Company.objects.create(
|
||||||
name="test",
|
name="test", state="IL", zipcode="60189", address="1968 Greensboro Dr"
|
||||||
state="IL",
|
|
||||||
zipcode="60189",
|
|
||||||
address="1968 Greensboro Dr"
|
|
||||||
)
|
)
|
||||||
self.user = get_user_model().objects.create_user(
|
self.user = get_user_model().objects.create_user(
|
||||||
company=self.company,
|
company=self.company,
|
||||||
username='testuser',
|
username="testuser",
|
||||||
password='testpass123',
|
password="testpass123",
|
||||||
email="test@test.com",
|
email="test@test.com",
|
||||||
)
|
)
|
||||||
|
|
||||||
self.client = APIClient()
|
self.client = APIClient()
|
||||||
self.client.force_authenticate(user=self.user)
|
self.client.force_authenticate(user=self.user)
|
||||||
|
|
||||||
self.workspace = DocumentWorkspace.objects.create(
|
self.workspace = DocumentWorkspace.objects.create(
|
||||||
company=self.user.company,
|
company=self.user.company, name="Test Workspace"
|
||||||
name='Test Workspace'
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Create a test file
|
# Create a test file
|
||||||
self.test_file = SimpleUploadedFile(
|
self.test_file = SimpleUploadedFile(
|
||||||
"test.pdf",
|
"test.pdf", VALID_PDF_BYTES, content_type="application/pdf"
|
||||||
VALID_PDF_BYTES,
|
|
||||||
content_type="application/pdf"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_upload_document(self):
|
def test_upload_document(self):
|
||||||
url = reverse('documents')
|
url = reverse("documents")
|
||||||
data = {
|
data = {"file": self.test_file}
|
||||||
'file': self.test_file
|
response = self.client.post(url, data, format="multipart")
|
||||||
}
|
|
||||||
response = self.client.post(url, data, format='multipart')
|
|
||||||
self.assertEqual(response.status_code, status.HTTP_201_CREATED)
|
self.assertEqual(response.status_code, status.HTTP_201_CREATED)
|
||||||
self.assertEqual(Document.objects.count(), 1)
|
self.assertEqual(Document.objects.count(), 1)
|
||||||
|
|
||||||
document = Document.objects.first()
|
document = Document.objects.first()
|
||||||
self.assertEqual(document.workspace.id, self.workspace.id)
|
self.assertEqual(document.workspace.id, self.workspace.id)
|
||||||
self.assertTrue(document.processed) # Should be False initially
|
self.assertTrue(document.processed) # Should be False initially
|
||||||
|
|
||||||
def test_list_documents(self):
|
def test_list_documents(self):
|
||||||
# First create a document
|
# First create a document
|
||||||
Document.objects.create(
|
Document.objects.create(workspace=self.workspace, file=self.test_file)
|
||||||
workspace=self.workspace,
|
|
||||||
file=self.test_file
|
url = reverse("documents")
|
||||||
)
|
|
||||||
|
|
||||||
url = reverse('documents')
|
|
||||||
response = self.client.get(url)
|
response = self.client.get(url)
|
||||||
self.assertEqual(response.status_code, status.HTTP_200_OK)
|
self.assertEqual(response.status_code, status.HTTP_200_OK)
|
||||||
self.assertEqual(len(response.data), 1)
|
self.assertEqual(len(response.data), 1)
|
||||||
self.assertIn('test', response.data[0]['file'])
|
self.assertIn("test", response.data[0]["file"])
|
||||||
self.assertIn('pdf', response.data[0]['file'])
|
self.assertIn("pdf", response.data[0]["file"])
|
||||||
|
|
||||||
# def test_delete_document(self):
|
# def test_delete_document(self):
|
||||||
# document = Document.objects.create(
|
# document = Document.objects.create(
|
||||||
# workspace=self.workspace,
|
# workspace=self.workspace,
|
||||||
# file=self.test_file
|
# file=self.test_file
|
||||||
# )
|
# )
|
||||||
|
|
||||||
# url = reverse('document-detail', args=[document.id])
|
# url = reverse('document-detail', args=[document.id])
|
||||||
# response = self.client.delete(url)
|
# response = self.client.delete(url)
|
||||||
# self.assertEqual(response.status_code, status.HTTP_204_NO_CONTENT)
|
# self.assertEqual(response.status_code, status.HTTP_204_NO_CONTENT)
|
||||||
# self.assertEqual(Document.objects.count(), 0)
|
# self.assertEqual(Document.objects.count(), 0)
|
||||||
|
|
||||||
def test_upload_invalid_file(self):
|
def test_upload_invalid_file(self):
|
||||||
url = reverse('documents')
|
url = reverse("documents")
|
||||||
data = {
|
data = {"file": "not a file"}
|
||||||
'file': 'not a file'
|
response = self.client.post(url, data, format="multipart")
|
||||||
}
|
|
||||||
response = self.client.post(url, data, format='multipart')
|
|
||||||
self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
|
self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
|
||||||
|
|
||||||
def test_access_other_users_documents(self):
|
def test_access_other_users_documents(self):
|
||||||
# Create another user
|
# Create another user
|
||||||
other_company = Company.objects.create(
|
other_company = Company.objects.create(
|
||||||
name="test2",
|
name="test2", state="IL", zipcode="60189", address="1968 Greensboro Dr"
|
||||||
state="IL",
|
|
||||||
zipcode="60189",
|
|
||||||
address="1968 Greensboro Dr"
|
|
||||||
)
|
)
|
||||||
other_user = get_user_model().objects.create_user(
|
other_user = get_user_model().objects.create_user(
|
||||||
company=other_company,
|
company=other_company,
|
||||||
username='otheruser',
|
username="otheruser",
|
||||||
password='otherpass123',
|
password="otherpass123",
|
||||||
email="testing2@test.com"
|
email="testing2@test.com",
|
||||||
)
|
)
|
||||||
other_workspace = DocumentWorkspace.objects.create(
|
other_workspace = DocumentWorkspace.objects.create(
|
||||||
company = other_user.company,
|
company=other_user.company, name="Other Workspace"
|
||||||
name='Other Workspace'
|
|
||||||
)
|
)
|
||||||
other_document = Document.objects.create(
|
other_document = Document.objects.create(
|
||||||
workspace=other_workspace,
|
workspace=other_workspace, file=self.test_file
|
||||||
file=self.test_file
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Try to access the other user's document
|
# Try to access the other user's document
|
||||||
url = reverse('documents_details', kwargs={"document_id":other_document.id})
|
url = reverse("documents_details", kwargs={"document_id": other_document.id})
|
||||||
response = self.client.get(url)
|
response = self.client.get(url)
|
||||||
self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND)
|
self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND)
|
||||||
|
|
||||||
@@ -23,8 +23,7 @@ from .views import (
|
|||||||
reset_password,
|
reset_password,
|
||||||
DocumentWorkspaceView,
|
DocumentWorkspaceView,
|
||||||
DocumentUploadView,
|
DocumentUploadView,
|
||||||
DocumentDetailView
|
DocumentDetailView,
|
||||||
|
|
||||||
)
|
)
|
||||||
from rest_framework.routers import DefaultRouter
|
from rest_framework.routers import DefaultRouter
|
||||||
|
|
||||||
@@ -80,11 +79,16 @@ urlpatterns = [
|
|||||||
name="analytics_company_usage",
|
name="analytics_company_usage",
|
||||||
),
|
),
|
||||||
path("analytics/admin/", AdminAnalytics.as_view(), name="analytics_admin"),
|
path("analytics/admin/", AdminAnalytics.as_view(), name="analytics_admin"),
|
||||||
|
|
||||||
# document urls
|
# document urls
|
||||||
path("document_workspaces/", DocumentWorkspaceView.as_view(), name="document_workspaces"),
|
path(
|
||||||
|
"document_workspaces/",
|
||||||
|
DocumentWorkspaceView.as_view(),
|
||||||
|
name="document_workspaces",
|
||||||
|
),
|
||||||
path("documents/", DocumentUploadView.as_view(), name="documents"),
|
path("documents/", DocumentUploadView.as_view(), name="documents"),
|
||||||
path("documents_details/<int:document_id>", DocumentDetailView.as_view(), name="documents_details"),
|
path(
|
||||||
|
"documents_details/<int:document_id>",
|
||||||
|
DocumentDetailView.as_view(),
|
||||||
|
name="documents_details",
|
||||||
|
),
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|||||||
@@ -13,7 +13,7 @@ from .serializers import (
|
|||||||
PromptSerializer,
|
PromptSerializer,
|
||||||
FeedbackSerializer,
|
FeedbackSerializer,
|
||||||
DocumentWorkspaceSerializer,
|
DocumentWorkspaceSerializer,
|
||||||
DocumentSerializer
|
DocumentSerializer,
|
||||||
)
|
)
|
||||||
from rest_framework.views import APIView
|
from rest_framework.views import APIView
|
||||||
from rest_framework.response import Response
|
from rest_framework.response import Response
|
||||||
@@ -25,7 +25,7 @@ from .models import (
|
|||||||
Feedback,
|
Feedback,
|
||||||
PromptMetric,
|
PromptMetric,
|
||||||
DocumentWorkspace,
|
DocumentWorkspace,
|
||||||
Document
|
Document,
|
||||||
)
|
)
|
||||||
from django.views.decorators.cache import never_cache
|
from django.views.decorators.cache import never_cache
|
||||||
from django.http import JsonResponse
|
from django.http import JsonResponse
|
||||||
@@ -99,8 +99,8 @@ class CustomUserCreate(APIView):
|
|||||||
|
|
||||||
def send_invite_email(slug, email_to_invite):
|
def send_invite_email(slug, email_to_invite):
|
||||||
print("Sending invite email")
|
print("Sending invite email")
|
||||||
print(f"url : https://www.chat.aimloperations.com/set_password?slug={slug}")
|
print(f"url : https://chat.aimloperations.com/set_password?slug={slug}")
|
||||||
url = f"https://www.chat.aimloperations.com/set_password?slug={slug}"
|
url = f"https://chat.aimloperations.com/set_password?slug={slug}"
|
||||||
subject = "Welcome to AI ML Operations, LLC Chat Services"
|
subject = "Welcome to AI ML Operations, LLC Chat Services"
|
||||||
from_email = "ryan@aimloperations.com"
|
from_email = "ryan@aimloperations.com"
|
||||||
to = email_to_invite
|
to = email_to_invite
|
||||||
@@ -113,6 +113,22 @@ def send_invite_email(slug, email_to_invite):
|
|||||||
msg.send(fail_silently=True)
|
msg.send(fail_silently=True)
|
||||||
|
|
||||||
|
|
||||||
|
def send_password_reset_email(slug, email_to_invite):
|
||||||
|
print("Sending reset email")
|
||||||
|
print(f"url : https://www.chat.aimloperations.com/set_password?slug={slug}")
|
||||||
|
url = f"https://www.chat.aimloperations.com/set_password?slug={slug}"
|
||||||
|
subject = "Password reset for AI ML Operations, LLC Chat Services"
|
||||||
|
from_email = "ryan@aimloperations.com"
|
||||||
|
to = email_to_invite
|
||||||
|
d = {"url": url}
|
||||||
|
html_content = get_template(r"emails/reset_email.html").render(d)
|
||||||
|
text_content = get_template(r"emails/reset_email.txt").render(d)
|
||||||
|
|
||||||
|
msg = EmailMultiAlternatives(subject, text_content, from_email, [to])
|
||||||
|
msg.attach_alternative(html_content, "text/html")
|
||||||
|
msg.send(fail_silently=True)
|
||||||
|
|
||||||
|
|
||||||
def send_feedback_email(feedback_obj):
|
def send_feedback_email(feedback_obj):
|
||||||
print("Sending feedback email")
|
print("Sending feedback email")
|
||||||
subject = "New Feedback for Chat by AI ML Operations, LLC"
|
subject = "New Feedback for Chat by AI ML Operations, LLC"
|
||||||
@@ -176,19 +192,22 @@ class CustomUserInvite(APIView):
|
|||||||
|
|
||||||
return Response(status=status.HTTP_201_CREATED)
|
return Response(status=status.HTTP_201_CREATED)
|
||||||
|
|
||||||
|
|
||||||
@csrf_exempt
|
@csrf_exempt
|
||||||
def reset_password(request):
|
def reset_password(request):
|
||||||
if request.method == "POST":
|
if request.method == "POST":
|
||||||
data = json.loads(request.body)
|
data = json.loads(request.body)
|
||||||
token = data.get('recaptchaToken')
|
token = data.get("recaptchaToken")
|
||||||
payload = {
|
payload = {
|
||||||
'secret': settings.CAPTCHA_SECRET_KEY,
|
"secret": settings.CAPTCHA_SECRET_KEY,
|
||||||
'response': token,
|
"response": token,
|
||||||
}
|
}
|
||||||
response = requests.post('https://www.google.com/recaptcha/api/siteverify', data=payload)
|
response = requests.post(
|
||||||
|
"https://www.google.com/recaptcha/api/siteverify", data=payload
|
||||||
|
)
|
||||||
result = response.json()
|
result = response.json()
|
||||||
if result.get('success') and result.get('score') >= 0.5:
|
if result.get("success") and result.get("score") >= 0.5:
|
||||||
email = data.get('email')
|
email = data.get("email")
|
||||||
user = CustomUser.objects.filter(email=email).first()
|
user = CustomUser.objects.filter(email=email).first()
|
||||||
if user:
|
if user:
|
||||||
user.set_unusable_password()
|
user.set_unusable_password()
|
||||||
@@ -197,10 +216,10 @@ def reset_password(request):
|
|||||||
# send the email
|
# send the email
|
||||||
send_password_reset_email(user.slug, email)
|
send_password_reset_email(user.slug, email)
|
||||||
JsonResponse(status=200)
|
JsonResponse(status=200)
|
||||||
|
|
||||||
|
|
||||||
JsonResponse(status=400)
|
JsonResponse(status=400)
|
||||||
|
|
||||||
|
|
||||||
class ResetUserPassword(APIView):
|
class ResetUserPassword(APIView):
|
||||||
http_method_names = [
|
http_method_names = [
|
||||||
"post",
|
"post",
|
||||||
@@ -214,14 +233,16 @@ class ResetUserPassword(APIView):
|
|||||||
Also disable the account
|
Also disable the account
|
||||||
"""
|
"""
|
||||||
print(f"Password reset for requests. {request.data}")
|
print(f"Password reset for requests. {request.data}")
|
||||||
token = request.data.get('recaptchaToken')
|
token = request.data.get("recaptchaToken")
|
||||||
payload = {
|
payload = {
|
||||||
'secret': settings.CAPTCHA_SECRET_KEY,
|
"secret": settings.CAPTCHA_SECRET_KEY,
|
||||||
'response': recaptchaToken,
|
"response": recaptchaToken,
|
||||||
}
|
}
|
||||||
response = requests.post('https://www.google.com/recaptcha/api/siteverify', data=payload)
|
response = requests.post(
|
||||||
|
"https://www.google.com/recaptcha/api/siteverify", data=payload
|
||||||
|
)
|
||||||
result = response.json()
|
result = response.json()
|
||||||
if result.get('success') and result.get('score') >= 0.5:
|
if result.get("success") and result.get("score") >= 0.5:
|
||||||
user = CustomUser.objects.filter(email=email).first()
|
user = CustomUser.objects.filter(email=email).first()
|
||||||
if user:
|
if user:
|
||||||
user.set_unusable_password()
|
user.set_unusable_password()
|
||||||
@@ -230,7 +251,7 @@ class ResetUserPassword(APIView):
|
|||||||
# send the email
|
# send the email
|
||||||
send_password_reset_email(user.slug, email)
|
send_password_reset_email(user.slug, email)
|
||||||
else:
|
else:
|
||||||
print('Captcha secret failed')
|
print("Captcha secret failed")
|
||||||
|
|
||||||
return Response(status=status.HTTP_200_OK)
|
return Response(status=status.HTTP_200_OK)
|
||||||
|
|
||||||
@@ -261,10 +282,16 @@ class CustomUserGet(APIView):
|
|||||||
|
|
||||||
email = request.user.email
|
email = request.user.email
|
||||||
username = request.user.username
|
username = request.user.username
|
||||||
user = CustomUser.objects.get(email=email)
|
user = CustomUser.objects.filter(email=email).last()
|
||||||
serializer = CustomUserSerializer(user)
|
print(f"Getting the user: {user}")
|
||||||
|
try:
|
||||||
return Response(serializer.data, status=status.HTTP_200_OK)
|
serializer = CustomUserSerializer(user)
|
||||||
|
print(f"serializer: {serializer}")
|
||||||
|
print(serializer.data)
|
||||||
|
return Response(serializer.data, status=status.HTTP_200_OK)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Exception: {e}")
|
||||||
|
return Response({}, status=status.HTTP_400_BAD_REQUEST)
|
||||||
|
|
||||||
|
|
||||||
class FeedbackView(APIView):
|
class FeedbackView(APIView):
|
||||||
@@ -706,6 +733,7 @@ def get_workspace(conversation_id):
|
|||||||
conversation = Conversation.objects.get(id=conversation_id)
|
conversation = Conversation.objects.get(id=conversation_id)
|
||||||
return DocumentWorkspace.objects.get(company=conversation.user.company)
|
return DocumentWorkspace.objects.get(company=conversation.user.company)
|
||||||
|
|
||||||
|
|
||||||
@database_sync_to_async
|
@database_sync_to_async
|
||||||
def get_messages(conversation_id, prompt, file_string: str = None, file_type: str = ""):
|
def get_messages(conversation_id, prompt, file_string: str = None, file_type: str = ""):
|
||||||
messages = []
|
messages = []
|
||||||
@@ -821,19 +849,21 @@ def finish_prompt_metric(prompt_metric, response_length):
|
|||||||
prompt_metric.save(update_fields=["end_time", "reponse_length", "event"])
|
prompt_metric.save(update_fields=["end_time", "reponse_length", "event"])
|
||||||
print("finish_prompt_metric saved")
|
print("finish_prompt_metric saved")
|
||||||
|
|
||||||
|
|
||||||
@database_sync_to_async
|
@database_sync_to_async
|
||||||
def get_retriever(conversation_id):
|
def get_retriever(conversation_id):
|
||||||
print(f'getting workspace from conversation: {conversation_id}')
|
print(f"getting workspace from conversation: {conversation_id}")
|
||||||
conversation = Conversation.objects.get(id=conversation_id)
|
conversation = Conversation.objects.get(id=conversation_id)
|
||||||
print(f'Got conversation: {conversation}')
|
print(f"Got conversation: {conversation}")
|
||||||
workspace = DocumentWorkspace.objects.get(company=conversation.user.company)
|
workspace = DocumentWorkspace.objects.get(company=conversation.user.company)
|
||||||
print(f'Got workspace: {conversation}')
|
print(f"Got workspace: {conversation}")
|
||||||
vectorstore = Chroma(
|
vectorstore = Chroma(
|
||||||
persist_directory=f"./chroma_db/",
|
persist_directory=f"./chroma_db/",
|
||||||
embedding=OllamaEmbeddings(model="llama3.2"),
|
embedding=OllamaEmbeddings(model="llama3.2"),
|
||||||
)
|
)
|
||||||
return vectorstore.as_retriever()
|
return vectorstore.as_retriever()
|
||||||
|
|
||||||
|
|
||||||
class ChatConsumerAgain(AsyncWebsocketConsumer):
|
class ChatConsumerAgain(AsyncWebsocketConsumer):
|
||||||
async def connect(self):
|
async def connect(self):
|
||||||
await self.accept()
|
await self.accept()
|
||||||
@@ -881,7 +911,10 @@ class ChatConsumerAgain(AsyncWebsocketConsumer):
|
|||||||
print(f'received: "{message}" for conversation {conversation_id}')
|
print(f'received: "{message}" for conversation {conversation_id}')
|
||||||
|
|
||||||
# check the moderation here
|
# check the moderation here
|
||||||
if await moderation_classifier.classify_async(message) == ModerationLabel.NSFW:
|
if (
|
||||||
|
await moderation_classifier.classify_async(message)
|
||||||
|
== ModerationLabel.NSFW
|
||||||
|
):
|
||||||
response = "Prompt has been marked as NSFW. If this is in error, submit a feedback with the prompt text."
|
response = "Prompt has been marked as NSFW. If this is in error, submit a feedback with the prompt text."
|
||||||
print("this prompt has been marked as NSFW")
|
print("this prompt has been marked as NSFW")
|
||||||
await self.send("CONVERSATION_ID")
|
await self.send("CONVERSATION_ID")
|
||||||
@@ -902,9 +935,8 @@ class ChatConsumerAgain(AsyncWebsocketConsumer):
|
|||||||
|
|
||||||
prompt_type = await prompt_classifier.classify_async(message)
|
prompt_type = await prompt_classifier.classify_async(message)
|
||||||
print(f"prompt_type: {prompt_type} for {message}")
|
print(f"prompt_type: {prompt_type} for {message}")
|
||||||
|
if file:
|
||||||
|
prompt_type = PromptType.GENERAL_CHAT
|
||||||
|
|
||||||
|
|
||||||
prompt_metric = await create_prompt_metric(
|
prompt_metric = await create_prompt_metric(
|
||||||
prompt.id,
|
prompt.id,
|
||||||
@@ -928,22 +960,25 @@ class ChatConsumerAgain(AsyncWebsocketConsumer):
|
|||||||
await self.send("START_OF_THE_STREAM_ENDER_GAME_42")
|
await self.send("START_OF_THE_STREAM_ENDER_GAME_42")
|
||||||
if prompt_type == PromptType.RAG:
|
if prompt_type == PromptType.RAG:
|
||||||
service = AsyncRAGService()
|
service = AsyncRAGService()
|
||||||
#await service.ingest_documents()
|
# await service.ingest_documents()
|
||||||
workspace = await get_workspace(conversation_id)
|
workspace = await get_workspace(conversation_id)
|
||||||
print('Time to get the rag response')
|
print("Time to get the rag response")
|
||||||
|
|
||||||
async for chunk in service.generate_response(messages, prompt.message, workspace):
|
async for chunk in service.generate_response(
|
||||||
print(f"chunk: {chunk}")
|
messages, prompt.message, workspace
|
||||||
|
):
|
||||||
response += chunk
|
response += chunk
|
||||||
await self.send(chunk)
|
await self.send(chunk)
|
||||||
elif prompt_type == PromptType.IMAGE_GENERATION:
|
elif prompt_type == PromptType.IMAGE_GENERATION:
|
||||||
response = "Image Generation is not supported at this time, but it will be soon."
|
response = "Image Generation is not supported at this time, but it will be soon."
|
||||||
await self.send(response)
|
await self.send(response)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
|
print(f"using the AsyncLLMService\n\n{messages}\n{prompt.message}")
|
||||||
service = AsyncLLMService()
|
service = AsyncLLMService()
|
||||||
async for chunk in service.generate_response(messages, prompt.message):
|
async for chunk in service.generate_response(
|
||||||
print(f"chunk: {chunk}")
|
messages, prompt.message
|
||||||
|
):
|
||||||
response += chunk
|
response += chunk
|
||||||
await self.send(chunk)
|
await self.send(chunk)
|
||||||
await self.send("END_OF_THE_STREAM_ENDER_GAME_42")
|
await self.send("END_OF_THE_STREAM_ENDER_GAME_42")
|
||||||
@@ -954,9 +989,10 @@ class ChatConsumerAgain(AsyncWebsocketConsumer):
|
|||||||
if bytes_data:
|
if bytes_data:
|
||||||
print("we have byte data")
|
print("we have byte data")
|
||||||
|
|
||||||
|
|
||||||
# Document Views
|
# Document Views
|
||||||
class DocumentWorkspaceView(APIView):
|
class DocumentWorkspaceView(APIView):
|
||||||
#permission_classes = [permissions.IsAuthenticated]
|
# permission_classes = [permissions.IsAuthenticated]
|
||||||
|
|
||||||
def get(self, request):
|
def get(self, request):
|
||||||
workspaces = DocumentWorkspace.objects.filter(company=request.user.company)
|
workspaces = DocumentWorkspace.objects.filter(company=request.user.company)
|
||||||
@@ -970,70 +1006,78 @@ class DocumentWorkspaceView(APIView):
|
|||||||
return Response(serializer.data, status=status.HTTP_201_CREATED)
|
return Response(serializer.data, status=status.HTTP_201_CREATED)
|
||||||
return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST)
|
return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST)
|
||||||
|
|
||||||
|
|
||||||
class DocumentUploadView(APIView):
|
class DocumentUploadView(APIView):
|
||||||
#permission_classes = [permissions.IsAuthenticated]Z
|
# permission_classes = [permissions.IsAuthenticated]Z
|
||||||
|
|
||||||
def get(self, request):
|
def get(self, request):
|
||||||
print(f'request_3: {request}')
|
print(f"request_3: {request}")
|
||||||
try:
|
try:
|
||||||
workspace = DocumentWorkspace.objects.get(company=request.user.company)
|
workspace = DocumentWorkspace.objects.get(company=request.user.company)
|
||||||
serializer = DocumentSerializer(Document.objects.filter(workspace=workspace), many=True)
|
serializer = DocumentSerializer(
|
||||||
|
Document.objects.filter(workspace=workspace), many=True
|
||||||
|
)
|
||||||
return Response(serializer.data, status=status.HTTP_200_OK)
|
return Response(serializer.data, status=status.HTTP_200_OK)
|
||||||
|
|
||||||
except:
|
except:
|
||||||
return Response({'error': "Workspace not found"}, status=status.HTTP_404_NOT_FOUND)
|
return Response(
|
||||||
|
{"error": "Workspace not found"}, status=status.HTTP_404_NOT_FOUND
|
||||||
|
)
|
||||||
|
|
||||||
def post(self, request):
|
def post(self, request):
|
||||||
print(f'request: {request}')
|
print(f"request: {request}")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
workspace = DocumentWorkspace.objects.get(company=request.user.company)
|
workspace = DocumentWorkspace.objects.get(company=request.user.company)
|
||||||
|
|
||||||
except:
|
except:
|
||||||
return Response({'error': "Workspace not found"}, status=status.HTTP_404_NOT_FOUND)
|
return Response(
|
||||||
|
{"error": "Workspace not found"}, status=status.HTTP_404_NOT_FOUND
|
||||||
|
)
|
||||||
|
|
||||||
print(request.FILES)
|
print(request.FILES)
|
||||||
file = request.FILES.get('file')
|
file = request.FILES.get("file")
|
||||||
if not file:
|
if not file:
|
||||||
return Response({"error":"No file provided"}, status=status.HTTP_400_BAD_REQUEST)
|
return Response(
|
||||||
|
{"error": "No file provided"}, status=status.HTTP_400_BAD_REQUEST
|
||||||
|
)
|
||||||
|
|
||||||
print("have the workspace and the file")
|
print("have the workspace and the file")
|
||||||
|
|
||||||
document = Document.objects.create(
|
document = Document.objects.create(workspace=workspace, file=file)
|
||||||
workspace=workspace,
|
|
||||||
file=file
|
|
||||||
)
|
|
||||||
|
|
||||||
# process the document inthe background
|
# process the document inthe background
|
||||||
self.process_document(document)
|
self.process_document(document)
|
||||||
|
|
||||||
serializer = DocumentSerializer(document)
|
serializer = DocumentSerializer(document)
|
||||||
return Response(serializer.data, status=status.HTTP_201_CREATED)
|
return Response(serializer.data, status=status.HTTP_201_CREATED)
|
||||||
|
|
||||||
|
|
||||||
def process_document(self, document):
|
def process_document(self, document):
|
||||||
file_path = os.path.join(settings.MEDIA_ROOT, document.file.name)
|
file_path = os.path.join(settings.MEDIA_ROOT, document.file.name)
|
||||||
|
|
||||||
document.processed = True
|
document.processed = True
|
||||||
document.active = True
|
document.active = True
|
||||||
document.save()
|
document.save()
|
||||||
service = AsyncRAGService()
|
service = AsyncRAGService()
|
||||||
service.add_files_to_store([(file_path, document.file.name, document.workspace_id)], workspace_id=document.workspace_id)
|
service.add_files_to_store(
|
||||||
|
[(file_path, document.file.name, document.workspace_id)],
|
||||||
|
workspace_id=document.workspace_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class DocumentDetailView(APIView):
|
class DocumentDetailView(APIView):
|
||||||
#permission_classes = [permissions.IsAuthenticated]
|
# permission_classes = [permissions.IsAuthenticated]
|
||||||
|
|
||||||
def get(self, request, document_id):
|
def get(self, request, document_id):
|
||||||
print(f'request: {request}')
|
print(f"request: {request}")
|
||||||
try:
|
try:
|
||||||
workspace = DocumentWorkspace.objects.get(company=request.user.company)
|
workspace = DocumentWorkspace.objects.get(company=request.user.company)
|
||||||
|
|
||||||
document = Document.objects.get(
|
document = Document.objects.get(workspace=workspace, id=document_id)
|
||||||
workspace=workspace,
|
|
||||||
id=document_id
|
|
||||||
)
|
|
||||||
except:
|
except:
|
||||||
return Response({'error': "Document not found"}, status=status.HTTP_404_NOT_FOUND)
|
return Response(
|
||||||
|
{"error": "Document not found"}, status=status.HTTP_404_NOT_FOUND
|
||||||
|
)
|
||||||
|
|
||||||
serializer = DocumentWorkspaceSerializer(workspaces, many=True)
|
serializer = DocumentWorkspaceSerializer(workspaces, many=True)
|
||||||
return Response(serializer.data)
|
return Response(serializer.data)
|
||||||
|
|||||||
@@ -9,13 +9,18 @@ https://docs.djangoproject.com/en/3.2/howto/deployment/asgi/
|
|||||||
|
|
||||||
import os
|
import os
|
||||||
|
|
||||||
|
import django
|
||||||
from django.core.asgi import get_asgi_application
|
from django.core.asgi import get_asgi_application
|
||||||
|
|
||||||
|
os.environ.setdefault("DJANGO_SETTINGS_MODULE", "llm_be.settings")
|
||||||
|
|
||||||
|
django.setup()
|
||||||
|
|
||||||
|
|
||||||
from channels.routing import ProtocolTypeRouter, URLRouter
|
from channels.routing import ProtocolTypeRouter, URLRouter
|
||||||
from channels.auth import AuthMiddlewareStack
|
from channels.auth import AuthMiddlewareStack
|
||||||
import chat_backend.routing
|
import chat_backend.routing
|
||||||
|
|
||||||
os.environ.setdefault("DJANGO_SETTINGS_MODULE", "llm_be.settings")
|
|
||||||
|
|
||||||
application = ProtocolTypeRouter(
|
application = ProtocolTypeRouter(
|
||||||
{
|
{
|
||||||
"http": get_asgi_application(),
|
"http": get_asgi_application(),
|
||||||
|
|||||||
@@ -161,7 +161,7 @@ REST_FRAMEWORK = {
|
|||||||
}
|
}
|
||||||
|
|
||||||
SIMPLE_JWT = {
|
SIMPLE_JWT = {
|
||||||
"ACCESS_TOKEN_LIFETIME": timedelta(hours=24),
|
"ACCESS_TOKEN_LIFETIME": timedelta(hours=5),
|
||||||
"REFRESH_TOKEN_LIFETIME": timedelta(days=14),
|
"REFRESH_TOKEN_LIFETIME": timedelta(days=14),
|
||||||
"ROTATE_REFRESH_TOKENS": True,
|
"ROTATE_REFRESH_TOKENS": True,
|
||||||
"BLACKLIST_AFTER_ROTATION": True,
|
"BLACKLIST_AFTER_ROTATION": True,
|
||||||
@@ -207,4 +207,4 @@ EMAIL_PORT = 2525
|
|||||||
EMAIL_USE_TLS = True
|
EMAIL_USE_TLS = True
|
||||||
|
|
||||||
# Captcha
|
# Captcha
|
||||||
CAPTCHA_SECRET_KEY = "6LfENu4qAAAAABdrj6JTviq-LfdPP5imhE-Os7h9"
|
CAPTCHA_SECRET_KEY = "6LfENu4qAAAAABdrj6JTviq-LfdPP5imhE-Os7h9"
|
||||||
|
|||||||
@@ -44,7 +44,7 @@ durationpy==0.9
|
|||||||
effdet==0.4.1
|
effdet==0.4.1
|
||||||
emoji==2.14.1
|
emoji==2.14.1
|
||||||
eval_type_backport==0.2.2
|
eval_type_backport==0.2.2
|
||||||
Faker==37.0.0
|
Faker
|
||||||
fastapi==0.115.9
|
fastapi==0.115.9
|
||||||
filelock==3.17.0
|
filelock==3.17.0
|
||||||
filetype==1.2.0
|
filetype==1.2.0
|
||||||
@@ -52,13 +52,13 @@ flatbuffers==25.2.10
|
|||||||
fonttools==4.56.0
|
fonttools==4.56.0
|
||||||
frozenlist==1.6.0
|
frozenlist==1.6.0
|
||||||
fsspec==2025.2.0
|
fsspec==2025.2.0
|
||||||
google-api-core==2.24.2
|
google-api-core
|
||||||
google-auth==2.39.0
|
google-auth
|
||||||
google-cloud-vision==3.10.1
|
google-cloud-vision
|
||||||
googleapis-common-protos==1.70.0
|
googleapis-common-protos
|
||||||
greenlet==3.1.1
|
greenlet==3.1.1
|
||||||
grpcio==1.72.0rc1
|
grpcio
|
||||||
grpcio-status==1.72.0rc1
|
grpcio-status
|
||||||
h11==0.14.0
|
h11==0.14.0
|
||||||
html5lib==1.1
|
html5lib==1.1
|
||||||
httpcore==1.0.7
|
httpcore==1.0.7
|
||||||
@@ -121,13 +121,13 @@ oauthlib==3.2.2
|
|||||||
olefile==0.47
|
olefile==0.47
|
||||||
ollama==0.4.7
|
ollama==0.4.7
|
||||||
omegaconf==2.3.0
|
omegaconf==2.3.0
|
||||||
onnx==1.18.0
|
onnx
|
||||||
onnxruntime==1.21.1
|
onnxruntime
|
||||||
openai==1.65.4
|
openai==1.65.4
|
||||||
opencv-python==4.11.0.86
|
opencv-python==4.11.0.86
|
||||||
opentelemetry-api==1.32.1
|
opentelemetry-api
|
||||||
opentelemetry-exporter-otlp-proto-common==1.32.1
|
opentelemetry-exporter-otlp-proto-common
|
||||||
opentelemetry-exporter-otlp-proto-grpc==1.32.1
|
opentelemetry-exporter-otlp-proto-grpc
|
||||||
opentelemetry-instrumentation==0.53b1
|
opentelemetry-instrumentation==0.53b1
|
||||||
opentelemetry-instrumentation-asgi==0.53b1
|
opentelemetry-instrumentation-asgi==0.53b1
|
||||||
opentelemetry-instrumentation-fastapi==0.53b1
|
opentelemetry-instrumentation-fastapi==0.53b1
|
||||||
@@ -138,8 +138,8 @@ opentelemetry-util-http==0.53b1
|
|||||||
orjson==3.10.15
|
orjson==3.10.15
|
||||||
overrides==7.7.0
|
overrides==7.7.0
|
||||||
packaging==24.2
|
packaging==24.2
|
||||||
pandas==2.2.3
|
pandas
|
||||||
pandasai==2.4.2
|
pandasai
|
||||||
pathspec==0.12.1
|
pathspec==0.12.1
|
||||||
pdf2image==1.17.0
|
pdf2image==1.17.0
|
||||||
pdfminer.six==20250506
|
pdfminer.six==20250506
|
||||||
@@ -150,7 +150,7 @@ platformdirs==4.3.6
|
|||||||
posthog==4.0.1
|
posthog==4.0.1
|
||||||
propcache==0.3.1
|
propcache==0.3.1
|
||||||
proto-plus==1.26.1
|
proto-plus==1.26.1
|
||||||
protobuf==6.31.0rc2
|
protobuf
|
||||||
psutil==7.0.0
|
psutil==7.0.0
|
||||||
pyasn1==0.6.1
|
pyasn1==0.6.1
|
||||||
pyasn1_modules==0.4.1
|
pyasn1_modules==0.4.1
|
||||||
@@ -160,7 +160,7 @@ pydantic==2.11.4
|
|||||||
pydantic-settings==2.9.1
|
pydantic-settings==2.9.1
|
||||||
pydantic_core==2.33.2
|
pydantic_core==2.33.2
|
||||||
Pygments==2.19.1
|
Pygments==2.19.1
|
||||||
PyJWT==2.10.1
|
PyJWT
|
||||||
pyOpenSSL==25.0.0
|
pyOpenSSL==25.0.0
|
||||||
pyparsing==3.2.1
|
pyparsing==3.2.1
|
||||||
pypdf==5.4.0
|
pypdf==5.4.0
|
||||||
|
|||||||
Reference in New Issue
Block a user