fixed chat service
This commit is contained in:
17
llm_be/chat_backend/services/base_service.py
Normal file
17
llm_be/chat_backend/services/base_service.py
Normal file
@@ -0,0 +1,17 @@
|
|||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from langchain_ollama import OllamaLLM
|
||||||
|
from langchain_core.output_parsers import StrOutputParser
|
||||||
|
|
||||||
|
class BaseService(ABC):
|
||||||
|
"""Abstract base class for LLM conversation services."""
|
||||||
|
|
||||||
|
def __init__(self, temperature=0.7):
|
||||||
|
self.llm = OllamaLLM(
|
||||||
|
model="llama3.2",
|
||||||
|
temperature=0.7,
|
||||||
|
top_k=50,
|
||||||
|
top_p=0.9,
|
||||||
|
repeat_penalty=1.1,
|
||||||
|
num_ctx=4096,
|
||||||
|
)
|
||||||
|
self.output_parser = StrOutputParser()
|
||||||
@@ -72,8 +72,6 @@ class SyncLLMService(LLMService):
|
|||||||
"""Generate response with streaming support."""
|
"""Generate response with streaming support."""
|
||||||
chain_input = {"query": query, "conversation": conversation}
|
chain_input = {"query": query, "conversation": conversation}
|
||||||
|
|
||||||
print(f"chain_input:\n{chain_input}")
|
|
||||||
|
|
||||||
for chunk in self.conversation_chain.stream(chain_input):
|
for chunk in self.conversation_chain.stream(chain_input):
|
||||||
yield chunk
|
yield chunk
|
||||||
|
|
||||||
@@ -106,10 +104,8 @@ class AsyncLLMService(LLMService):
|
|||||||
|
|
||||||
self.conversation_chain = (
|
self.conversation_chain = (
|
||||||
{
|
{
|
||||||
"context": lambda x: self._format_history(x["conversation"]),
|
"context":lambda x: x["conversation"],
|
||||||
"recent_history": lambda x: self._get_recent_messages(
|
"recent_history":lambda x: x['recent_conversation'],
|
||||||
x["conversation"]
|
|
||||||
),
|
|
||||||
"query": lambda x: x["query"],
|
"query": lambda x: x["query"],
|
||||||
}
|
}
|
||||||
| self.prompt
|
| self.prompt
|
||||||
@@ -117,34 +113,39 @@ class AsyncLLMService(LLMService):
|
|||||||
| self.output_parser
|
| self.output_parser
|
||||||
)
|
)
|
||||||
|
|
||||||
async def _format_history(self, conversation: Conversation) -> str:
|
async def _format_history(self, conversation: list) -> str:
|
||||||
"""Async version of format conversation history."""
|
"""Async version of format conversation history."""
|
||||||
prompts = (
|
# prompts = list(
|
||||||
await Prompt.objects.filter(conversation=conversation)
|
# await Prompt.objects.filter(conversation_id=conversation_id)
|
||||||
.order_by("created_at")
|
# .order_by("created")
|
||||||
.alist()
|
|
||||||
)
|
# )
|
||||||
return "\n".join(
|
# return "\n".join(
|
||||||
f"{'User' if prompt.is_user else 'AI'}: {prompt.text}" for prompt in prompts
|
# f"{'User' if prompt.is_user else 'AI'}: {prompt.text}" for prompt in prompts
|
||||||
)
|
# )
|
||||||
|
return "\n".join([f"{"User" if prompt.type=="human" else "AI"}: {prompt.text()}" for prompt in conversation])
|
||||||
|
|
||||||
async def _get_recent_messages(self, conversation: Conversation) -> str:
|
async def _get_recent_messages(self, conversation: list) -> str:
|
||||||
"""Async version of format conversation history."""
|
"""Async version of format conversation history."""
|
||||||
prompts = (
|
|
||||||
await Prompt.objects.filter(conversation=conversation)
|
# prompts = list(
|
||||||
.order_by("created_at")
|
# await Prompt.objects.filter(conversation_id=conversation_id)
|
||||||
.alist()[-6:]
|
# .order_by("created")
|
||||||
)
|
# [-6:]
|
||||||
return "\n".join(
|
# )
|
||||||
f"{'User' if prompt.is_user else 'AI'}: {prompt.text}" for prompt in prompts
|
# return "\n".join(
|
||||||
)
|
# f"{'User' if prompt.is_user else 'AI'}: {prompt.text}" for prompt in prompts
|
||||||
|
# )
|
||||||
|
return "\n".join([f"{"User" if prompt.type=="human" else "AI"}: {prompt.text()}" for prompt in conversation])
|
||||||
|
|
||||||
async def generate_response(
|
async def generate_response(
|
||||||
self, conversation: Conversation, query: str, **kwargs
|
self, conversation: Conversation, query: str, conversation_id: int, **kwargs
|
||||||
) -> AsyncGenerator[str, None]:
|
) -> AsyncGenerator[str, None]:
|
||||||
"""Generate response with async streaming support."""
|
"""Generate response with async streaming support."""
|
||||||
chain_input = {"query": query, "conversation": conversation}
|
chain_input = {
|
||||||
print(f"LLM Chain:\n{chain_input}")
|
"query": query,
|
||||||
|
"conversation": await self._format_history(conversation),
|
||||||
|
"recent_conversation": await self._get_recent_messages(conversation[-6:])}
|
||||||
|
|
||||||
async for chunk in self.conversation_chain.astream(chain_input):
|
async for chunk in self.conversation_chain.astream(chain_input):
|
||||||
yield chunk
|
yield chunk
|
||||||
|
|||||||
@@ -2,8 +2,7 @@ from enum import Enum, auto
|
|||||||
from typing import Dict, Any
|
from typing import Dict, Any
|
||||||
from langchain_core.prompts import ChatPromptTemplate
|
from langchain_core.prompts import ChatPromptTemplate
|
||||||
|
|
||||||
# from langchain_community.llms import Ollama
|
from chat_backend.services.base_service import BaseService
|
||||||
from langchain_ollama import OllamaLLM
|
|
||||||
|
|
||||||
|
|
||||||
class ModerationLabel(Enum):
|
class ModerationLabel(Enum):
|
||||||
@@ -11,18 +10,19 @@ class ModerationLabel(Enum):
|
|||||||
FINE = auto()
|
FINE = auto()
|
||||||
|
|
||||||
|
|
||||||
class ModerationClassifier:
|
class ModerationClassifier(BaseService):
|
||||||
"""
|
"""
|
||||||
Classifies prompts as NSFW or FINE (safe) content.
|
Classifies prompts as NSFW or FINE (safe) content.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.llm = OllamaLLM(
|
super().__init__(temperature=0.1)
|
||||||
model="llama3.2",
|
# self.llm = OllamaLLM(
|
||||||
temperature=0.1, # Very low for strict moderation
|
# model="llama3.2",
|
||||||
top_k=10,
|
# temperature=0.1, # Very low for strict moderation
|
||||||
num_ctx=2048,
|
# top_k=10,
|
||||||
)
|
# num_ctx=2048,
|
||||||
|
# )
|
||||||
|
|
||||||
self.moderation_prompt = ChatPromptTemplate.from_messages(
|
self.moderation_prompt = ChatPromptTemplate.from_messages(
|
||||||
[
|
[
|
||||||
|
|||||||
@@ -2,8 +2,7 @@ from enum import Enum, auto
|
|||||||
from typing import Dict, Any
|
from typing import Dict, Any
|
||||||
from langchain_core.prompts import ChatPromptTemplate
|
from langchain_core.prompts import ChatPromptTemplate
|
||||||
|
|
||||||
# from langchain_community.llms import Ollama
|
from chat_backend.services.base_service import BaseService
|
||||||
from langchain_ollama import OllamaLLM
|
|
||||||
|
|
||||||
|
|
||||||
class PromptType(Enum):
|
class PromptType(Enum):
|
||||||
@@ -13,19 +12,20 @@ class PromptType(Enum):
|
|||||||
UNKNOWN = auto()
|
UNKNOWN = auto()
|
||||||
|
|
||||||
|
|
||||||
class PromptClassifier:
|
class PromptClassifier(BaseService):
|
||||||
"""
|
"""
|
||||||
Classifies user prompts to determine which service should handle them.
|
Classifies user prompts to determine which service should handle them.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.llm = OllamaLLM(
|
super().__init__(temperature=0.1)
|
||||||
model="llama3.2",
|
# self.llm = OllamaLLM(
|
||||||
temperature=0.3, # Lower temp for more deterministic classification
|
# model="llama3.2",
|
||||||
top_k=20,
|
# temperature=0.3, # Lower temp for more deterministic classification
|
||||||
top_p=0.9,
|
# top_k=20,
|
||||||
num_ctx=4096,
|
# top_p=0.9,
|
||||||
)
|
# num_ctx=4096,
|
||||||
|
# )
|
||||||
|
|
||||||
self.classification_prompt = ChatPromptTemplate.from_messages(
|
self.classification_prompt = ChatPromptTemplate.from_messages(
|
||||||
[
|
[
|
||||||
@@ -69,6 +69,10 @@ Examples:
|
|||||||
- "Explain quantum computing" (General knowledge)
|
- "Explain quantum computing" (General knowledge)
|
||||||
- "Summarize the meeting" (No doc reference)
|
- "Summarize the meeting" (No doc reference)
|
||||||
|
|
||||||
|
[Definitely NOT IMAGE_GENERATION]
|
||||||
|
- "Great, can you make it about a duck now"
|
||||||
|
- "highlight the features of the backyard playset if they were to choose us and make the language more long form"
|
||||||
|
|
||||||
Return ONLY the label, no explanations.""",
|
Return ONLY the label, no explanations.""",
|
||||||
),
|
),
|
||||||
("human", "{prompt}"),
|
("human", "{prompt}"),
|
||||||
31
llm_be/chat_backend/services/prompt_classifier/tests.py
Normal file
31
llm_be/chat_backend/services/prompt_classifier/tests.py
Normal file
@@ -0,0 +1,31 @@
|
|||||||
|
import os
|
||||||
|
from unittest import TestCase, mock
|
||||||
|
from unittest.mock import MagicMock, patch, AsyncMock
|
||||||
|
from typing import List, Dict, Any
|
||||||
|
|
||||||
|
from django.test import TestCase as DjangoTestCase
|
||||||
|
|
||||||
|
from chat_backend.services.rag_services import (
|
||||||
|
RAGService,
|
||||||
|
SyncRAGService,
|
||||||
|
AsyncRAGService,
|
||||||
|
)
|
||||||
|
from chat_backend.models import Conversation, Prompt, DocumentWorkspace, Document
|
||||||
|
from chat_backend.services.prompt_classifier.prompt_classifier import PromptClassifier, PromptType
|
||||||
|
from parameterized import parameterized
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class PromptClassifierTestCase(TestCase):
|
||||||
|
def setUp(self):
|
||||||
|
self.service = PromptClassifier()
|
||||||
|
|
||||||
|
@parameterized.expand([
|
||||||
|
["Tell me a joke",PromptType.GENERAL_CHAT],
|
||||||
|
["Create an image of a dog for me",PromptType.IMAGE_GENERATION],
|
||||||
|
["highlight the features of the backyard playset if they were to choose us and make the language more long form",PromptType.GENERAL_CHAT],
|
||||||
|
["Great, can you make it about a duck now", PromptType.IMAGE_GENERATION],
|
||||||
|
])
|
||||||
|
def test_prompt_classification(self, prompt, expected_output):
|
||||||
|
result = self.service.classify(prompt)
|
||||||
|
self.assertEqual(result, expected_output)
|
||||||
@@ -21,6 +21,7 @@ from langchain_community.document_loaders import (
|
|||||||
from django.core.files.uploadedfile import UploadedFile
|
from django.core.files.uploadedfile import UploadedFile
|
||||||
from chat_backend.models import Conversation, Prompt, DocumentWorkspace, Document
|
from chat_backend.models import Conversation, Prompt, DocumentWorkspace, Document
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
from chat_backend.services.base_service import BaseService
|
||||||
|
|
||||||
|
|
||||||
@database_sync_to_async
|
@database_sync_to_async
|
||||||
@@ -31,7 +32,7 @@ def get_documents(workspace: DocumentWorkspace | None = None):
|
|||||||
return [doc for doc in Document.objects.all()]
|
return [doc for doc in Document.objects.all()]
|
||||||
|
|
||||||
|
|
||||||
class RAGService(ABC):
|
class RAGService(BaseService):
|
||||||
"""Abstract base class for RAG services."""
|
"""Abstract base class for RAG services."""
|
||||||
|
|
||||||
_instance = None
|
_instance = None
|
||||||
@@ -44,14 +45,7 @@ class RAGService(ABC):
|
|||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.embedding_model = OllamaEmbeddings(model="llama3.2")
|
self.embedding_model = OllamaEmbeddings(model="llama3.2")
|
||||||
self.llm = OllamaLLM(
|
super().__init__()
|
||||||
model="llama3.2",
|
|
||||||
temperature=0.7,
|
|
||||||
top_k=50,
|
|
||||||
top_p=0.9,
|
|
||||||
repeat_penalty=1.1,
|
|
||||||
num_ctx=4096,
|
|
||||||
)
|
|
||||||
self.text_splitter = RecursiveCharacterTextSplitter(
|
self.text_splitter = RecursiveCharacterTextSplitter(
|
||||||
chunk_size=1000, chunk_overlap=200
|
chunk_size=1000, chunk_overlap=200
|
||||||
)
|
)
|
||||||
@@ -104,17 +98,17 @@ class RAGService(ABC):
|
|||||||
print(f"Processing the documents : {documents}")
|
print(f"Processing the documents : {documents}")
|
||||||
self._prepare_documents(documents)
|
self._prepare_documents(documents)
|
||||||
|
|
||||||
@abstractmethod
|
# @abstractmethod
|
||||||
def generate_response(self, conversation: Conversation, query: str, **kwargs):
|
# def generate_response(self, conversation: Conversation, query: str, **kwargs):
|
||||||
"""Generate a response using RAG."""
|
# """Generate a response using RAG."""
|
||||||
pass
|
# pass
|
||||||
|
|
||||||
@abstractmethod
|
# @abstractmethod
|
||||||
def search_documents(
|
# def search_documents(
|
||||||
self, query: str, workspace: Optional[DocumentWorkspace] = None, k: int = 4
|
# self, query: str, workspace: Optional[DocumentWorkspace] = None, k: int = 4
|
||||||
) -> List[Document]:
|
# ) -> List[Document]:
|
||||||
"""Search relevant documents from the vector store."""
|
# """Search relevant documents from the vector store."""
|
||||||
pass
|
# pass
|
||||||
|
|
||||||
def _get_file_loader(self, file_path: str):
|
def _get_file_loader(self, file_path: str):
|
||||||
"""Get appropriate loader for file type"""
|
"""Get appropriate loader for file type"""
|
||||||
@@ -304,7 +298,7 @@ class AsyncRAGService(RAGService):
|
|||||||
self.rag_chain = (
|
self.rag_chain = (
|
||||||
{
|
{
|
||||||
"context": self._retriever_with_history,
|
"context": self._retriever_with_history,
|
||||||
"history": lambda x: self._format_history(x["conversation"]),
|
"history": lambda x: x['recent_conversation'], #self._format_history(x["conversation"]),
|
||||||
"question": lambda x: x["query"],
|
"question": lambda x: x["query"],
|
||||||
}
|
}
|
||||||
| self.prompt
|
| self.prompt
|
||||||
@@ -314,15 +308,16 @@ class AsyncRAGService(RAGService):
|
|||||||
|
|
||||||
async def _format_history(self, conversation: Conversation) -> str:
|
async def _format_history(self, conversation: Conversation) -> str:
|
||||||
"""Format conversation history for the prompt."""
|
"""Format conversation history for the prompt."""
|
||||||
prompts = (
|
# prompts = (
|
||||||
await Prompt.objects.filter(conversation=conversation)
|
# await Prompt.objects.filter(conversation=conversation)
|
||||||
.order_by("created_at")
|
# .order_by("created_at")
|
||||||
.alist()
|
# .alist()
|
||||||
)
|
# )
|
||||||
print(f"prompts that we are seeding with are: {prompts}")
|
# print(f"prompts that we are seeding with are: {prompts}")
|
||||||
return "\n".join(
|
# return "\n".join(
|
||||||
f"{'User' if prompt.is_user else 'AI'}: {prompt.text}" for prompt in prompts
|
# f"{'User' if prompt.is_user else 'AI'}: {prompt.text}" for prompt in prompts
|
||||||
)
|
# )
|
||||||
|
return "\n".join([f"{"User" if prompt.type=="human" else "AI"}: {prompt.text()}" for prompt in conversation])
|
||||||
|
|
||||||
async def _retriever_with_history(self, input_dict: Dict[str, Any]) -> str:
|
async def _retriever_with_history(self, input_dict: Dict[str, Any]) -> str:
|
||||||
"""Retrieve documents considering conversation history."""
|
"""Retrieve documents considering conversation history."""
|
||||||
@@ -369,6 +364,7 @@ class AsyncRAGService(RAGService):
|
|||||||
"query": query,
|
"query": query,
|
||||||
"conversation": conversation,
|
"conversation": conversation,
|
||||||
"workspace": workspace,
|
"workspace": workspace,
|
||||||
|
"recent_conversation": await self._format_history(conversation),
|
||||||
}
|
}
|
||||||
|
|
||||||
async for chunk in self.rag_chain.astream(chain_input):
|
async for chunk in self.rag_chain.astream(chain_input):
|
||||||
|
|||||||
@@ -11,226 +11,241 @@ from chat_backend.services.rag_services import (
|
|||||||
AsyncRAGService,
|
AsyncRAGService,
|
||||||
)
|
)
|
||||||
from chat_backend.models import Conversation, Prompt, DocumentWorkspace, Document
|
from chat_backend.models import Conversation, Prompt, DocumentWorkspace, Document
|
||||||
|
from chat_backend.services.prompt_classifier import PromptClassifier, PromptType
|
||||||
|
from parameterized import parameterized
|
||||||
|
|
||||||
|
|
||||||
class TestRAGService(TestCase):
|
# class TestRAGService(TestCase):
|
||||||
|
# def setUp(self):
|
||||||
|
# self.rag_service = RAGService()
|
||||||
|
# self.rag_service.vector_store = MagicMock()
|
||||||
|
# self.rag_service.embedding_model = MagicMock()
|
||||||
|
# self.rag_service.text_splitter = MagicMock()
|
||||||
|
|
||||||
|
# def test_initialize_vector_store(self):
|
||||||
|
# with patch("os.path.exists", return_value=False), patch(
|
||||||
|
# "os.makedirs"
|
||||||
|
# ) as mock_makedirs, patch(
|
||||||
|
# "langchain_community.vectorstores.Chroma"
|
||||||
|
# ) as mock_chroma:
|
||||||
|
|
||||||
|
# # Reset the vector store to test initialization
|
||||||
|
# self.rag_service.vector_store = None
|
||||||
|
# result = self.rag_service._initialize_vector_store()
|
||||||
|
|
||||||
|
# mock_makedirs.assert_called_once_with("chroma_db")
|
||||||
|
# mock_chroma.assert_called_once_with(
|
||||||
|
# embedding_function=self.rag_service.embedding_model,
|
||||||
|
# persist_directory="chroma_db",
|
||||||
|
# )
|
||||||
|
# self.assertIsNotNone(result)
|
||||||
|
|
||||||
|
# def test_prepare_documents(self):
|
||||||
|
# mock_doc1 = MagicMock(spec=Document)
|
||||||
|
# mock_doc1.content = "Test content"
|
||||||
|
# mock_doc1.source = "test_source"
|
||||||
|
# mock_doc1.workspace = MagicMock()
|
||||||
|
# mock_doc1.workspace.id = 1
|
||||||
|
# mock_doc1.id = 1
|
||||||
|
|
||||||
|
# self.rag_service.text_splitter.split_text.return_value = ["chunk1", "chunk2"]
|
||||||
|
|
||||||
|
# result = self.rag_service._prepare_documents([mock_doc1])
|
||||||
|
|
||||||
|
# self.assertEqual(len(result), 2)
|
||||||
|
# self.rag_service.text_splitter.split_text.assert_called_once_with(
|
||||||
|
# "Test content"
|
||||||
|
# )
|
||||||
|
# self.assertEqual(result[0].page_content, "chunk1")
|
||||||
|
# self.assertEqual(result[0].metadata["source"], "test_source")
|
||||||
|
|
||||||
|
# def test_ingest_documents(self):
|
||||||
|
# mock_workspace = MagicMock()
|
||||||
|
# mock_document = MagicMock()
|
||||||
|
# mock_documents = [mock_document]
|
||||||
|
|
||||||
|
# with patch(
|
||||||
|
# "services.rag_services.Document.objects.filter", return_value=mock_documents
|
||||||
|
# ):
|
||||||
|
# self.rag_service._prepare_documents = MagicMock(
|
||||||
|
# return_value=["processed_doc"]
|
||||||
|
# )
|
||||||
|
|
||||||
|
# self.rag_service.ingest_documents(mock_workspace)
|
||||||
|
|
||||||
|
# self.rag_service.vector_store.add_documents.assert_called_once_with(
|
||||||
|
# ["processed_doc"]
|
||||||
|
# )
|
||||||
|
# self.rag_service.vector_store.persist.assert_called_once()
|
||||||
|
|
||||||
|
|
||||||
|
# class TestSyncRAGService(DjangoTestCase):
|
||||||
|
# def setUp(self):
|
||||||
|
# self.sync_service = SyncRAGService()
|
||||||
|
# self.sync_service.vector_store = MagicMock()
|
||||||
|
# self.sync_service.llm = MagicMock()
|
||||||
|
# self.sync_service.rag_chain = MagicMock()
|
||||||
|
|
||||||
|
# self.mock_conversation = MagicMock(spec=Conversation)
|
||||||
|
# self.mock_conversation.workspace = MagicMock()
|
||||||
|
|
||||||
|
# self.mock_prompt1 = MagicMock(spec=Prompt)
|
||||||
|
# self.mock_prompt1.is_user = True
|
||||||
|
# self.mock_prompt1.text = "User question"
|
||||||
|
# self.mock_prompt1.created_at = "2023-01-01"
|
||||||
|
|
||||||
|
# self.mock_prompt2 = MagicMock(spec=Prompt)
|
||||||
|
# self.mock_prompt2.is_user = False
|
||||||
|
# self.mock_prompt2.text = "AI response"
|
||||||
|
# self.mock_prompt2.created_at = "2023-01-02"
|
||||||
|
|
||||||
|
# def test_format_history(self):
|
||||||
|
# with patch("services.rag_services.Prompt.objects.filter") as mock_filter:
|
||||||
|
# mock_filter.return_value.order_by.return_value = [
|
||||||
|
# self.mock_prompt1,
|
||||||
|
# self.mock_prompt2,
|
||||||
|
# ]
|
||||||
|
|
||||||
|
# result = self.sync_service._format_history(self.mock_conversation)
|
||||||
|
|
||||||
|
# expected = "User: User question\nAI: AI response"
|
||||||
|
# self.assertEqual(result, expected)
|
||||||
|
# mock_filter.assert_called_once_with(conversation=self.mock_conversation)
|
||||||
|
|
||||||
|
# def test_retriever_with_history(self):
|
||||||
|
# input_dict = {"query": "test query", "conversation": self.mock_conversation}
|
||||||
|
|
||||||
|
# self.sync_service.search_documents = MagicMock(return_value=["doc1", "doc2"])
|
||||||
|
|
||||||
|
# result = self.sync_service._retriever_with_history(input_dict)
|
||||||
|
|
||||||
|
# self.sync_service.search_documents.assert_called_once_with(
|
||||||
|
# "test query", self.mock_conversation.workspace
|
||||||
|
# )
|
||||||
|
# self.assertEqual(result, ["doc1", "doc2"])
|
||||||
|
|
||||||
|
# def test_search_documents(self):
|
||||||
|
# mock_retriever = MagicMock()
|
||||||
|
# mock_retriever.get_relevant_documents.return_value = ["doc1", "doc2"]
|
||||||
|
# self.sync_service.vector_store.as_retriever.return_value = mock_retriever
|
||||||
|
|
||||||
|
# result = self.sync_service.search_documents(
|
||||||
|
# "test query", self.mock_conversation.workspace
|
||||||
|
# )
|
||||||
|
|
||||||
|
# self.sync_service.vector_store.as_retriever.assert_called_once_with(
|
||||||
|
# search_type="similarity",
|
||||||
|
# search_kwargs={
|
||||||
|
# "k": 4,
|
||||||
|
# "filter": {"workspace_id": self.mock_conversation.workspace.id},
|
||||||
|
# },
|
||||||
|
# )
|
||||||
|
# self.assertEqual(result, ["doc1", "doc2"])
|
||||||
|
|
||||||
|
# def test_generate_response(self):
|
||||||
|
# chain_input = {"query": "test query", "conversation": self.mock_conversation}
|
||||||
|
|
||||||
|
# mock_stream = ["chunk1", "chunk2", "chunk3"]
|
||||||
|
# self.sync_service.rag_chain.stream.return_value = mock_stream
|
||||||
|
|
||||||
|
# result = list(
|
||||||
|
# self.sync_service.generate_response(self.mock_conversation, "test query")
|
||||||
|
# )
|
||||||
|
|
||||||
|
# self.sync_service.rag_chain.stream.assert_called_once_with(chain_input)
|
||||||
|
# self.assertEqual(result, mock_stream)
|
||||||
|
|
||||||
|
|
||||||
|
# class TestAsyncRAGService(DjangoTestCase):
|
||||||
|
# def setUp(self):
|
||||||
|
# self.async_service = AsyncRAGService()
|
||||||
|
# self.async_service.vector_store = MagicMock()
|
||||||
|
# self.async_service.llm = MagicMock()
|
||||||
|
# self.async_service.rag_chain = AsyncMock()
|
||||||
|
|
||||||
|
# self.mock_conversation = MagicMock(spec=Conversation)
|
||||||
|
# self.mock_conversation.workspace = MagicMock()
|
||||||
|
|
||||||
|
# self.mock_prompt1 = MagicMock(spec=Prompt)
|
||||||
|
# self.mock_prompt1.is_user = True
|
||||||
|
# self.mock_prompt1.text = "User question"
|
||||||
|
# self.mock_prompt1.created_at = "2023-01-01"
|
||||||
|
|
||||||
|
# self.mock_prompt2 = MagicMock(spec=Prompt)
|
||||||
|
# self.mock_prompt2.is_user = False
|
||||||
|
# self.mock_prompt2.text = "AI response"
|
||||||
|
# self.mock_prompt2.created_at = "2023-01-02"
|
||||||
|
|
||||||
|
# async def test_format_history(self):
|
||||||
|
# mock_manager = AsyncMock()
|
||||||
|
# mock_manager.order_by.return_value.alist.return_value = [
|
||||||
|
# self.mock_prompt1,
|
||||||
|
# self.mock_prompt2,
|
||||||
|
# ]
|
||||||
|
|
||||||
|
# with patch(
|
||||||
|
# "services.rag_services.Prompt.objects.filter", return_value=mock_manager
|
||||||
|
# ):
|
||||||
|
# result = await self.async_service._format_history(self.mock_conversation)
|
||||||
|
|
||||||
|
# expected = "User: User question\nAI: AI response"
|
||||||
|
# self.assertEqual(result, expected)
|
||||||
|
# mock_manager.order_by.assert_called_once_with("created_at")
|
||||||
|
|
||||||
|
# async def test_retriever_with_history(self):
|
||||||
|
# input_dict = {"query": "test query", "conversation": self.mock_conversation}
|
||||||
|
|
||||||
|
# self.async_service.search_documents = AsyncMock(return_value=["doc1", "doc2"])
|
||||||
|
|
||||||
|
# result = await self.async_service._retriever_with_history(input_dict)
|
||||||
|
|
||||||
|
# self.async_service.search_documents.assert_awaited_once_with(
|
||||||
|
# "test query", self.mock_conversation.workspace
|
||||||
|
# )
|
||||||
|
# self.assertEqual(result, ["doc1", "doc2"])
|
||||||
|
|
||||||
|
# async def test_search_documents(self):
|
||||||
|
# mock_retriever = AsyncMock()
|
||||||
|
# mock_retriever.aget_relevant_documents.return_value = ["doc1", "doc2"]
|
||||||
|
# self.async_service.vector_store.as_retriever.return_value = mock_retriever
|
||||||
|
|
||||||
|
# result = await self.async_service.search_documents(
|
||||||
|
# "test query", self.mock_conversation.workspace
|
||||||
|
# )
|
||||||
|
|
||||||
|
# self.async_service.vector_store.as_retriever.assert_called_once_with(
|
||||||
|
# search_type="similarity",
|
||||||
|
# search_kwargs={
|
||||||
|
# "k": 4,
|
||||||
|
# "filter": {"workspace_id": self.mock_conversation.workspace.id},
|
||||||
|
# },
|
||||||
|
# )
|
||||||
|
# self.assertEqual(result, ["doc1", "doc2"])
|
||||||
|
|
||||||
|
# async def test_generate_response(self):
|
||||||
|
# chain_input = {"query": "test query", "conversation": self.mock_conversation}
|
||||||
|
|
||||||
|
# mock_stream = ["chunk1", "chunk2", "chunk3"]
|
||||||
|
# self.async_service.rag_chain.astream.return_value = mock_stream
|
||||||
|
|
||||||
|
# chunks = []
|
||||||
|
# async for chunk in self.async_service.generate_response(
|
||||||
|
# self.mock_conversation, "test query"
|
||||||
|
# ):
|
||||||
|
# chunks.append(chunk)
|
||||||
|
|
||||||
|
# self.async_service.rag_chain.astream.assert_awaited_once_with(chain_input)
|
||||||
|
# self.assertEqual(chunks, mock_stream)
|
||||||
|
|
||||||
|
class PromptClassifierTestCase(TestCase):
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
self.rag_service = RAGService()
|
self.service = PromptClassifier()
|
||||||
self.rag_service.vector_store = MagicMock()
|
|
||||||
self.rag_service.embedding_model = MagicMock()
|
|
||||||
self.rag_service.text_splitter = MagicMock()
|
|
||||||
|
|
||||||
def test_initialize_vector_store(self):
|
@parameterized.expand([
|
||||||
with patch("os.path.exists", return_value=False), patch(
|
["Tell me a joke",PromptType.GENERAL_CHAT],
|
||||||
"os.makedirs"
|
["Create an image of a dog for me",PromptType.IMAGE_GENERATION],
|
||||||
) as mock_makedirs, patch(
|
["highlight the features of the backyard playset if they were to choose us and make the language more long form",PromptType.GENERAL_CHAT],
|
||||||
"langchain_community.vectorstores.Chroma"
|
])
|
||||||
) as mock_chroma:
|
def test_prompt_classification(self, prompt, expected_output):
|
||||||
|
result = self.service.classify(prompt)
|
||||||
# Reset the vector store to test initialization
|
self.assertEqual(result, expected_output)
|
||||||
self.rag_service.vector_store = None
|
|
||||||
result = self.rag_service._initialize_vector_store()
|
|
||||||
|
|
||||||
mock_makedirs.assert_called_once_with("chroma_db")
|
|
||||||
mock_chroma.assert_called_once_with(
|
|
||||||
embedding_function=self.rag_service.embedding_model,
|
|
||||||
persist_directory="chroma_db",
|
|
||||||
)
|
|
||||||
self.assertIsNotNone(result)
|
|
||||||
|
|
||||||
def test_prepare_documents(self):
|
|
||||||
mock_doc1 = MagicMock(spec=Document)
|
|
||||||
mock_doc1.content = "Test content"
|
|
||||||
mock_doc1.source = "test_source"
|
|
||||||
mock_doc1.workspace = MagicMock()
|
|
||||||
mock_doc1.workspace.id = 1
|
|
||||||
mock_doc1.id = 1
|
|
||||||
|
|
||||||
self.rag_service.text_splitter.split_text.return_value = ["chunk1", "chunk2"]
|
|
||||||
|
|
||||||
result = self.rag_service._prepare_documents([mock_doc1])
|
|
||||||
|
|
||||||
self.assertEqual(len(result), 2)
|
|
||||||
self.rag_service.text_splitter.split_text.assert_called_once_with(
|
|
||||||
"Test content"
|
|
||||||
)
|
|
||||||
self.assertEqual(result[0].page_content, "chunk1")
|
|
||||||
self.assertEqual(result[0].metadata["source"], "test_source")
|
|
||||||
|
|
||||||
def test_ingest_documents(self):
|
|
||||||
mock_workspace = MagicMock()
|
|
||||||
mock_document = MagicMock()
|
|
||||||
mock_documents = [mock_document]
|
|
||||||
|
|
||||||
with patch(
|
|
||||||
"services.rag_services.Document.objects.filter", return_value=mock_documents
|
|
||||||
):
|
|
||||||
self.rag_service._prepare_documents = MagicMock(
|
|
||||||
return_value=["processed_doc"]
|
|
||||||
)
|
|
||||||
|
|
||||||
self.rag_service.ingest_documents(mock_workspace)
|
|
||||||
|
|
||||||
self.rag_service.vector_store.add_documents.assert_called_once_with(
|
|
||||||
["processed_doc"]
|
|
||||||
)
|
|
||||||
self.rag_service.vector_store.persist.assert_called_once()
|
|
||||||
|
|
||||||
|
|
||||||
class TestSyncRAGService(DjangoTestCase):
|
|
||||||
def setUp(self):
|
|
||||||
self.sync_service = SyncRAGService()
|
|
||||||
self.sync_service.vector_store = MagicMock()
|
|
||||||
self.sync_service.llm = MagicMock()
|
|
||||||
self.sync_service.rag_chain = MagicMock()
|
|
||||||
|
|
||||||
self.mock_conversation = MagicMock(spec=Conversation)
|
|
||||||
self.mock_conversation.workspace = MagicMock()
|
|
||||||
|
|
||||||
self.mock_prompt1 = MagicMock(spec=Prompt)
|
|
||||||
self.mock_prompt1.is_user = True
|
|
||||||
self.mock_prompt1.text = "User question"
|
|
||||||
self.mock_prompt1.created_at = "2023-01-01"
|
|
||||||
|
|
||||||
self.mock_prompt2 = MagicMock(spec=Prompt)
|
|
||||||
self.mock_prompt2.is_user = False
|
|
||||||
self.mock_prompt2.text = "AI response"
|
|
||||||
self.mock_prompt2.created_at = "2023-01-02"
|
|
||||||
|
|
||||||
def test_format_history(self):
|
|
||||||
with patch("services.rag_services.Prompt.objects.filter") as mock_filter:
|
|
||||||
mock_filter.return_value.order_by.return_value = [
|
|
||||||
self.mock_prompt1,
|
|
||||||
self.mock_prompt2,
|
|
||||||
]
|
|
||||||
|
|
||||||
result = self.sync_service._format_history(self.mock_conversation)
|
|
||||||
|
|
||||||
expected = "User: User question\nAI: AI response"
|
|
||||||
self.assertEqual(result, expected)
|
|
||||||
mock_filter.assert_called_once_with(conversation=self.mock_conversation)
|
|
||||||
|
|
||||||
def test_retriever_with_history(self):
|
|
||||||
input_dict = {"query": "test query", "conversation": self.mock_conversation}
|
|
||||||
|
|
||||||
self.sync_service.search_documents = MagicMock(return_value=["doc1", "doc2"])
|
|
||||||
|
|
||||||
result = self.sync_service._retriever_with_history(input_dict)
|
|
||||||
|
|
||||||
self.sync_service.search_documents.assert_called_once_with(
|
|
||||||
"test query", self.mock_conversation.workspace
|
|
||||||
)
|
|
||||||
self.assertEqual(result, ["doc1", "doc2"])
|
|
||||||
|
|
||||||
def test_search_documents(self):
|
|
||||||
mock_retriever = MagicMock()
|
|
||||||
mock_retriever.get_relevant_documents.return_value = ["doc1", "doc2"]
|
|
||||||
self.sync_service.vector_store.as_retriever.return_value = mock_retriever
|
|
||||||
|
|
||||||
result = self.sync_service.search_documents(
|
|
||||||
"test query", self.mock_conversation.workspace
|
|
||||||
)
|
|
||||||
|
|
||||||
self.sync_service.vector_store.as_retriever.assert_called_once_with(
|
|
||||||
search_type="similarity",
|
|
||||||
search_kwargs={
|
|
||||||
"k": 4,
|
|
||||||
"filter": {"workspace_id": self.mock_conversation.workspace.id},
|
|
||||||
},
|
|
||||||
)
|
|
||||||
self.assertEqual(result, ["doc1", "doc2"])
|
|
||||||
|
|
||||||
def test_generate_response(self):
|
|
||||||
chain_input = {"query": "test query", "conversation": self.mock_conversation}
|
|
||||||
|
|
||||||
mock_stream = ["chunk1", "chunk2", "chunk3"]
|
|
||||||
self.sync_service.rag_chain.stream.return_value = mock_stream
|
|
||||||
|
|
||||||
result = list(
|
|
||||||
self.sync_service.generate_response(self.mock_conversation, "test query")
|
|
||||||
)
|
|
||||||
|
|
||||||
self.sync_service.rag_chain.stream.assert_called_once_with(chain_input)
|
|
||||||
self.assertEqual(result, mock_stream)
|
|
||||||
|
|
||||||
|
|
||||||
class TestAsyncRAGService(DjangoTestCase):
|
|
||||||
def setUp(self):
|
|
||||||
self.async_service = AsyncRAGService()
|
|
||||||
self.async_service.vector_store = MagicMock()
|
|
||||||
self.async_service.llm = MagicMock()
|
|
||||||
self.async_service.rag_chain = AsyncMock()
|
|
||||||
|
|
||||||
self.mock_conversation = MagicMock(spec=Conversation)
|
|
||||||
self.mock_conversation.workspace = MagicMock()
|
|
||||||
|
|
||||||
self.mock_prompt1 = MagicMock(spec=Prompt)
|
|
||||||
self.mock_prompt1.is_user = True
|
|
||||||
self.mock_prompt1.text = "User question"
|
|
||||||
self.mock_prompt1.created_at = "2023-01-01"
|
|
||||||
|
|
||||||
self.mock_prompt2 = MagicMock(spec=Prompt)
|
|
||||||
self.mock_prompt2.is_user = False
|
|
||||||
self.mock_prompt2.text = "AI response"
|
|
||||||
self.mock_prompt2.created_at = "2023-01-02"
|
|
||||||
|
|
||||||
async def test_format_history(self):
|
|
||||||
mock_manager = AsyncMock()
|
|
||||||
mock_manager.order_by.return_value.alist.return_value = [
|
|
||||||
self.mock_prompt1,
|
|
||||||
self.mock_prompt2,
|
|
||||||
]
|
|
||||||
|
|
||||||
with patch(
|
|
||||||
"services.rag_services.Prompt.objects.filter", return_value=mock_manager
|
|
||||||
):
|
|
||||||
result = await self.async_service._format_history(self.mock_conversation)
|
|
||||||
|
|
||||||
expected = "User: User question\nAI: AI response"
|
|
||||||
self.assertEqual(result, expected)
|
|
||||||
mock_manager.order_by.assert_called_once_with("created_at")
|
|
||||||
|
|
||||||
async def test_retriever_with_history(self):
|
|
||||||
input_dict = {"query": "test query", "conversation": self.mock_conversation}
|
|
||||||
|
|
||||||
self.async_service.search_documents = AsyncMock(return_value=["doc1", "doc2"])
|
|
||||||
|
|
||||||
result = await self.async_service._retriever_with_history(input_dict)
|
|
||||||
|
|
||||||
self.async_service.search_documents.assert_awaited_once_with(
|
|
||||||
"test query", self.mock_conversation.workspace
|
|
||||||
)
|
|
||||||
self.assertEqual(result, ["doc1", "doc2"])
|
|
||||||
|
|
||||||
async def test_search_documents(self):
|
|
||||||
mock_retriever = AsyncMock()
|
|
||||||
mock_retriever.aget_relevant_documents.return_value = ["doc1", "doc2"]
|
|
||||||
self.async_service.vector_store.as_retriever.return_value = mock_retriever
|
|
||||||
|
|
||||||
result = await self.async_service.search_documents(
|
|
||||||
"test query", self.mock_conversation.workspace
|
|
||||||
)
|
|
||||||
|
|
||||||
self.async_service.vector_store.as_retriever.assert_called_once_with(
|
|
||||||
search_type="similarity",
|
|
||||||
search_kwargs={
|
|
||||||
"k": 4,
|
|
||||||
"filter": {"workspace_id": self.mock_conversation.workspace.id},
|
|
||||||
},
|
|
||||||
)
|
|
||||||
self.assertEqual(result, ["doc1", "doc2"])
|
|
||||||
|
|
||||||
async def test_generate_response(self):
|
|
||||||
chain_input = {"query": "test query", "conversation": self.mock_conversation}
|
|
||||||
|
|
||||||
mock_stream = ["chunk1", "chunk2", "chunk3"]
|
|
||||||
self.async_service.rag_chain.astream.return_value = mock_stream
|
|
||||||
|
|
||||||
chunks = []
|
|
||||||
async for chunk in self.async_service.generate_response(
|
|
||||||
self.mock_conversation, "test query"
|
|
||||||
):
|
|
||||||
chunks.append(chunk)
|
|
||||||
|
|
||||||
self.async_service.rag_chain.astream.assert_awaited_once_with(chain_input)
|
|
||||||
self.assertEqual(chunks, mock_stream)
|
|
||||||
@@ -10,6 +10,8 @@ from .models import DocumentWorkspace, Document, Company
|
|||||||
from django.contrib.auth import get_user_model
|
from django.contrib.auth import get_user_model
|
||||||
import tempfile
|
import tempfile
|
||||||
from django.core.files.uploadedfile import SimpleUploadedFile
|
from django.core.files.uploadedfile import SimpleUploadedFile
|
||||||
|
from parameterized import parameterized
|
||||||
|
|
||||||
|
|
||||||
# Minimal valid PDF bytes
|
# Minimal valid PDF bytes
|
||||||
VALID_PDF_BYTES = (
|
VALID_PDF_BYTES = (
|
||||||
@@ -185,3 +187,8 @@ class DocumentViewsTestCase(APITestCase):
|
|||||||
url = reverse("documents_details", kwargs={"document_id": other_document.id})
|
url = reverse("documents_details", kwargs={"document_id": other_document.id})
|
||||||
response = self.client.get(url)
|
response = self.client.get(url)
|
||||||
self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND)
|
self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -15,6 +15,7 @@ from .serializers import (
|
|||||||
DocumentWorkspaceSerializer,
|
DocumentWorkspaceSerializer,
|
||||||
DocumentSerializer,
|
DocumentSerializer,
|
||||||
)
|
)
|
||||||
|
|
||||||
from rest_framework.views import APIView
|
from rest_framework.views import APIView
|
||||||
from rest_framework.response import Response
|
from rest_framework.response import Response
|
||||||
from .models import (
|
from .models import (
|
||||||
@@ -35,7 +36,7 @@ from asgiref.sync import sync_to_async, async_to_sync
|
|||||||
from channels.generic.websocket import AsyncWebsocketConsumer
|
from channels.generic.websocket import AsyncWebsocketConsumer
|
||||||
from langchain_ollama.llms import OllamaLLM
|
from langchain_ollama.llms import OllamaLLM
|
||||||
from langchain_core.prompts import ChatPromptTemplate
|
from langchain_core.prompts import ChatPromptTemplate
|
||||||
from langchain_core.messages import HumanMessage, SystemMessage
|
from langchain_core.messages import HumanMessage, AIMessage
|
||||||
from langchain.chains import RetrievalQA
|
from langchain.chains import RetrievalQA
|
||||||
import re
|
import re
|
||||||
import os
|
import os
|
||||||
@@ -66,7 +67,7 @@ from .services.llm_service import AsyncLLMService
|
|||||||
from .services.rag_services import AsyncRAGService
|
from .services.rag_services import AsyncRAGService
|
||||||
from .services.title_generator import title_generator
|
from .services.title_generator import title_generator
|
||||||
from .services.moderation_classifier import moderation_classifier, ModerationLabel
|
from .services.moderation_classifier import moderation_classifier, ModerationLabel
|
||||||
from .services.prompt_classifier import prompt_classifier, PromptType
|
from .services.prompt_classifier.prompt_classifier import PromptClassifier, PromptType
|
||||||
|
|
||||||
|
|
||||||
from langchain.chains import create_retrieval_chain
|
from langchain.chains import create_retrieval_chain
|
||||||
@@ -74,7 +75,7 @@ from langchain.chains.combine_documents import create_stuff_documents_chain
|
|||||||
from langchain_ollama import ChatOllama
|
from langchain_ollama import ChatOllama
|
||||||
|
|
||||||
CHANNEL_NAME: str = "llm_messages"
|
CHANNEL_NAME: str = "llm_messages"
|
||||||
MODEL_NAME: str = "llama3"
|
MODEL_NAME: str = "llama3.2"
|
||||||
|
|
||||||
|
|
||||||
# Create your views here.
|
# Create your views here.
|
||||||
@@ -790,7 +791,7 @@ def get_messages(conversation_id, prompt, file_string: str = None, file_type: st
|
|||||||
altered_message = message["content"]
|
altered_message = message["content"]
|
||||||
|
|
||||||
transformed_message = (
|
transformed_message = (
|
||||||
SystemMessage(content=altered_message)
|
AIMessage(content=altered_message)
|
||||||
if message["role"] == "assistant"
|
if message["role"] == "assistant"
|
||||||
else HumanMessage(content=altered_message)
|
else HumanMessage(content=altered_message)
|
||||||
)
|
)
|
||||||
@@ -863,6 +864,7 @@ def get_retriever(conversation_id):
|
|||||||
)
|
)
|
||||||
return vectorstore.as_retriever()
|
return vectorstore.as_retriever()
|
||||||
|
|
||||||
|
PROMPT_CLASSIFIER = PromptClassifier()
|
||||||
|
|
||||||
class ChatConsumerAgain(AsyncWebsocketConsumer):
|
class ChatConsumerAgain(AsyncWebsocketConsumer):
|
||||||
async def connect(self):
|
async def connect(self):
|
||||||
@@ -932,8 +934,9 @@ class ChatConsumerAgain(AsyncWebsocketConsumer):
|
|||||||
messages, prompt = await get_messages(
|
messages, prompt = await get_messages(
|
||||||
conversation_id, message, decoded_file, file_type
|
conversation_id, message, decoded_file, file_type
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
prompt_type = await prompt_classifier.classify_async(message)
|
prompt_type = await PROMPT_CLASSIFIER.classify_async(message)
|
||||||
print(f"prompt_type: {prompt_type} for {message}")
|
print(f"prompt_type: {prompt_type} for {message}")
|
||||||
if file:
|
if file:
|
||||||
prompt_type = PromptType.GENERAL_CHAT
|
prompt_type = PromptType.GENERAL_CHAT
|
||||||
@@ -977,7 +980,7 @@ class ChatConsumerAgain(AsyncWebsocketConsumer):
|
|||||||
print(f"using the AsyncLLMService\n\n{messages}\n{prompt.message}")
|
print(f"using the AsyncLLMService\n\n{messages}\n{prompt.message}")
|
||||||
service = AsyncLLMService()
|
service = AsyncLLMService()
|
||||||
async for chunk in service.generate_response(
|
async for chunk in service.generate_response(
|
||||||
messages, prompt.message
|
messages, prompt.message, conversation_id
|
||||||
):
|
):
|
||||||
response += chunk
|
response += chunk
|
||||||
await self.send(chunk)
|
await self.send(chunk)
|
||||||
|
|||||||
Reference in New Issue
Block a user