fixed chat service
This commit is contained in:
17
llm_be/chat_backend/services/base_service.py
Normal file
17
llm_be/chat_backend/services/base_service.py
Normal file
@@ -0,0 +1,17 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from langchain_ollama import OllamaLLM
|
||||
from langchain_core.output_parsers import StrOutputParser
|
||||
|
||||
class BaseService(ABC):
|
||||
"""Abstract base class for LLM conversation services."""
|
||||
|
||||
def __init__(self, temperature=0.7):
|
||||
self.llm = OllamaLLM(
|
||||
model="llama3.2",
|
||||
temperature=0.7,
|
||||
top_k=50,
|
||||
top_p=0.9,
|
||||
repeat_penalty=1.1,
|
||||
num_ctx=4096,
|
||||
)
|
||||
self.output_parser = StrOutputParser()
|
||||
@@ -72,8 +72,6 @@ class SyncLLMService(LLMService):
|
||||
"""Generate response with streaming support."""
|
||||
chain_input = {"query": query, "conversation": conversation}
|
||||
|
||||
print(f"chain_input:\n{chain_input}")
|
||||
|
||||
for chunk in self.conversation_chain.stream(chain_input):
|
||||
yield chunk
|
||||
|
||||
@@ -106,10 +104,8 @@ class AsyncLLMService(LLMService):
|
||||
|
||||
self.conversation_chain = (
|
||||
{
|
||||
"context": lambda x: self._format_history(x["conversation"]),
|
||||
"recent_history": lambda x: self._get_recent_messages(
|
||||
x["conversation"]
|
||||
),
|
||||
"context":lambda x: x["conversation"],
|
||||
"recent_history":lambda x: x['recent_conversation'],
|
||||
"query": lambda x: x["query"],
|
||||
}
|
||||
| self.prompt
|
||||
@@ -117,34 +113,39 @@ class AsyncLLMService(LLMService):
|
||||
| self.output_parser
|
||||
)
|
||||
|
||||
async def _format_history(self, conversation: Conversation) -> str:
|
||||
async def _format_history(self, conversation: list) -> str:
|
||||
"""Async version of format conversation history."""
|
||||
prompts = (
|
||||
await Prompt.objects.filter(conversation=conversation)
|
||||
.order_by("created_at")
|
||||
.alist()
|
||||
)
|
||||
return "\n".join(
|
||||
f"{'User' if prompt.is_user else 'AI'}: {prompt.text}" for prompt in prompts
|
||||
)
|
||||
# prompts = list(
|
||||
# await Prompt.objects.filter(conversation_id=conversation_id)
|
||||
# .order_by("created")
|
||||
|
||||
# )
|
||||
# return "\n".join(
|
||||
# f"{'User' if prompt.is_user else 'AI'}: {prompt.text}" for prompt in prompts
|
||||
# )
|
||||
return "\n".join([f"{"User" if prompt.type=="human" else "AI"}: {prompt.text()}" for prompt in conversation])
|
||||
|
||||
async def _get_recent_messages(self, conversation: Conversation) -> str:
|
||||
async def _get_recent_messages(self, conversation: list) -> str:
|
||||
"""Async version of format conversation history."""
|
||||
prompts = (
|
||||
await Prompt.objects.filter(conversation=conversation)
|
||||
.order_by("created_at")
|
||||
.alist()[-6:]
|
||||
)
|
||||
return "\n".join(
|
||||
f"{'User' if prompt.is_user else 'AI'}: {prompt.text}" for prompt in prompts
|
||||
)
|
||||
|
||||
# prompts = list(
|
||||
# await Prompt.objects.filter(conversation_id=conversation_id)
|
||||
# .order_by("created")
|
||||
# [-6:]
|
||||
# )
|
||||
# return "\n".join(
|
||||
# f"{'User' if prompt.is_user else 'AI'}: {prompt.text}" for prompt in prompts
|
||||
# )
|
||||
return "\n".join([f"{"User" if prompt.type=="human" else "AI"}: {prompt.text()}" for prompt in conversation])
|
||||
|
||||
async def generate_response(
|
||||
self, conversation: Conversation, query: str, **kwargs
|
||||
self, conversation: Conversation, query: str, conversation_id: int, **kwargs
|
||||
) -> AsyncGenerator[str, None]:
|
||||
"""Generate response with async streaming support."""
|
||||
chain_input = {"query": query, "conversation": conversation}
|
||||
print(f"LLM Chain:\n{chain_input}")
|
||||
chain_input = {
|
||||
"query": query,
|
||||
"conversation": await self._format_history(conversation),
|
||||
"recent_conversation": await self._get_recent_messages(conversation[-6:])}
|
||||
|
||||
async for chunk in self.conversation_chain.astream(chain_input):
|
||||
yield chunk
|
||||
|
||||
@@ -2,8 +2,7 @@ from enum import Enum, auto
|
||||
from typing import Dict, Any
|
||||
from langchain_core.prompts import ChatPromptTemplate
|
||||
|
||||
# from langchain_community.llms import Ollama
|
||||
from langchain_ollama import OllamaLLM
|
||||
from chat_backend.services.base_service import BaseService
|
||||
|
||||
|
||||
class ModerationLabel(Enum):
|
||||
@@ -11,18 +10,19 @@ class ModerationLabel(Enum):
|
||||
FINE = auto()
|
||||
|
||||
|
||||
class ModerationClassifier:
|
||||
class ModerationClassifier(BaseService):
|
||||
"""
|
||||
Classifies prompts as NSFW or FINE (safe) content.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.llm = OllamaLLM(
|
||||
model="llama3.2",
|
||||
temperature=0.1, # Very low for strict moderation
|
||||
top_k=10,
|
||||
num_ctx=2048,
|
||||
)
|
||||
super().__init__(temperature=0.1)
|
||||
# self.llm = OllamaLLM(
|
||||
# model="llama3.2",
|
||||
# temperature=0.1, # Very low for strict moderation
|
||||
# top_k=10,
|
||||
# num_ctx=2048,
|
||||
# )
|
||||
|
||||
self.moderation_prompt = ChatPromptTemplate.from_messages(
|
||||
[
|
||||
|
||||
@@ -2,8 +2,7 @@ from enum import Enum, auto
|
||||
from typing import Dict, Any
|
||||
from langchain_core.prompts import ChatPromptTemplate
|
||||
|
||||
# from langchain_community.llms import Ollama
|
||||
from langchain_ollama import OllamaLLM
|
||||
from chat_backend.services.base_service import BaseService
|
||||
|
||||
|
||||
class PromptType(Enum):
|
||||
@@ -13,19 +12,20 @@ class PromptType(Enum):
|
||||
UNKNOWN = auto()
|
||||
|
||||
|
||||
class PromptClassifier:
|
||||
class PromptClassifier(BaseService):
|
||||
"""
|
||||
Classifies user prompts to determine which service should handle them.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.llm = OllamaLLM(
|
||||
model="llama3.2",
|
||||
temperature=0.3, # Lower temp for more deterministic classification
|
||||
top_k=20,
|
||||
top_p=0.9,
|
||||
num_ctx=4096,
|
||||
)
|
||||
super().__init__(temperature=0.1)
|
||||
# self.llm = OllamaLLM(
|
||||
# model="llama3.2",
|
||||
# temperature=0.3, # Lower temp for more deterministic classification
|
||||
# top_k=20,
|
||||
# top_p=0.9,
|
||||
# num_ctx=4096,
|
||||
# )
|
||||
|
||||
self.classification_prompt = ChatPromptTemplate.from_messages(
|
||||
[
|
||||
@@ -69,6 +69,10 @@ Examples:
|
||||
- "Explain quantum computing" (General knowledge)
|
||||
- "Summarize the meeting" (No doc reference)
|
||||
|
||||
[Definitely NOT IMAGE_GENERATION]
|
||||
- "Great, can you make it about a duck now"
|
||||
- "highlight the features of the backyard playset if they were to choose us and make the language more long form"
|
||||
|
||||
Return ONLY the label, no explanations.""",
|
||||
),
|
||||
("human", "{prompt}"),
|
||||
31
llm_be/chat_backend/services/prompt_classifier/tests.py
Normal file
31
llm_be/chat_backend/services/prompt_classifier/tests.py
Normal file
@@ -0,0 +1,31 @@
|
||||
import os
|
||||
from unittest import TestCase, mock
|
||||
from unittest.mock import MagicMock, patch, AsyncMock
|
||||
from typing import List, Dict, Any
|
||||
|
||||
from django.test import TestCase as DjangoTestCase
|
||||
|
||||
from chat_backend.services.rag_services import (
|
||||
RAGService,
|
||||
SyncRAGService,
|
||||
AsyncRAGService,
|
||||
)
|
||||
from chat_backend.models import Conversation, Prompt, DocumentWorkspace, Document
|
||||
from chat_backend.services.prompt_classifier.prompt_classifier import PromptClassifier, PromptType
|
||||
from parameterized import parameterized
|
||||
|
||||
|
||||
|
||||
class PromptClassifierTestCase(TestCase):
|
||||
def setUp(self):
|
||||
self.service = PromptClassifier()
|
||||
|
||||
@parameterized.expand([
|
||||
["Tell me a joke",PromptType.GENERAL_CHAT],
|
||||
["Create an image of a dog for me",PromptType.IMAGE_GENERATION],
|
||||
["highlight the features of the backyard playset if they were to choose us and make the language more long form",PromptType.GENERAL_CHAT],
|
||||
["Great, can you make it about a duck now", PromptType.IMAGE_GENERATION],
|
||||
])
|
||||
def test_prompt_classification(self, prompt, expected_output):
|
||||
result = self.service.classify(prompt)
|
||||
self.assertEqual(result, expected_output)
|
||||
@@ -21,6 +21,7 @@ from langchain_community.document_loaders import (
|
||||
from django.core.files.uploadedfile import UploadedFile
|
||||
from chat_backend.models import Conversation, Prompt, DocumentWorkspace, Document
|
||||
from pathlib import Path
|
||||
from chat_backend.services.base_service import BaseService
|
||||
|
||||
|
||||
@database_sync_to_async
|
||||
@@ -31,7 +32,7 @@ def get_documents(workspace: DocumentWorkspace | None = None):
|
||||
return [doc for doc in Document.objects.all()]
|
||||
|
||||
|
||||
class RAGService(ABC):
|
||||
class RAGService(BaseService):
|
||||
"""Abstract base class for RAG services."""
|
||||
|
||||
_instance = None
|
||||
@@ -44,14 +45,7 @@ class RAGService(ABC):
|
||||
|
||||
def __init__(self):
|
||||
self.embedding_model = OllamaEmbeddings(model="llama3.2")
|
||||
self.llm = OllamaLLM(
|
||||
model="llama3.2",
|
||||
temperature=0.7,
|
||||
top_k=50,
|
||||
top_p=0.9,
|
||||
repeat_penalty=1.1,
|
||||
num_ctx=4096,
|
||||
)
|
||||
super().__init__()
|
||||
self.text_splitter = RecursiveCharacterTextSplitter(
|
||||
chunk_size=1000, chunk_overlap=200
|
||||
)
|
||||
@@ -104,17 +98,17 @@ class RAGService(ABC):
|
||||
print(f"Processing the documents : {documents}")
|
||||
self._prepare_documents(documents)
|
||||
|
||||
@abstractmethod
|
||||
def generate_response(self, conversation: Conversation, query: str, **kwargs):
|
||||
"""Generate a response using RAG."""
|
||||
pass
|
||||
# @abstractmethod
|
||||
# def generate_response(self, conversation: Conversation, query: str, **kwargs):
|
||||
# """Generate a response using RAG."""
|
||||
# pass
|
||||
|
||||
@abstractmethod
|
||||
def search_documents(
|
||||
self, query: str, workspace: Optional[DocumentWorkspace] = None, k: int = 4
|
||||
) -> List[Document]:
|
||||
"""Search relevant documents from the vector store."""
|
||||
pass
|
||||
# @abstractmethod
|
||||
# def search_documents(
|
||||
# self, query: str, workspace: Optional[DocumentWorkspace] = None, k: int = 4
|
||||
# ) -> List[Document]:
|
||||
# """Search relevant documents from the vector store."""
|
||||
# pass
|
||||
|
||||
def _get_file_loader(self, file_path: str):
|
||||
"""Get appropriate loader for file type"""
|
||||
@@ -304,7 +298,7 @@ class AsyncRAGService(RAGService):
|
||||
self.rag_chain = (
|
||||
{
|
||||
"context": self._retriever_with_history,
|
||||
"history": lambda x: self._format_history(x["conversation"]),
|
||||
"history": lambda x: x['recent_conversation'], #self._format_history(x["conversation"]),
|
||||
"question": lambda x: x["query"],
|
||||
}
|
||||
| self.prompt
|
||||
@@ -314,15 +308,16 @@ class AsyncRAGService(RAGService):
|
||||
|
||||
async def _format_history(self, conversation: Conversation) -> str:
|
||||
"""Format conversation history for the prompt."""
|
||||
prompts = (
|
||||
await Prompt.objects.filter(conversation=conversation)
|
||||
.order_by("created_at")
|
||||
.alist()
|
||||
)
|
||||
print(f"prompts that we are seeding with are: {prompts}")
|
||||
return "\n".join(
|
||||
f"{'User' if prompt.is_user else 'AI'}: {prompt.text}" for prompt in prompts
|
||||
)
|
||||
# prompts = (
|
||||
# await Prompt.objects.filter(conversation=conversation)
|
||||
# .order_by("created_at")
|
||||
# .alist()
|
||||
# )
|
||||
# print(f"prompts that we are seeding with are: {prompts}")
|
||||
# return "\n".join(
|
||||
# f"{'User' if prompt.is_user else 'AI'}: {prompt.text}" for prompt in prompts
|
||||
# )
|
||||
return "\n".join([f"{"User" if prompt.type=="human" else "AI"}: {prompt.text()}" for prompt in conversation])
|
||||
|
||||
async def _retriever_with_history(self, input_dict: Dict[str, Any]) -> str:
|
||||
"""Retrieve documents considering conversation history."""
|
||||
@@ -369,6 +364,7 @@ class AsyncRAGService(RAGService):
|
||||
"query": query,
|
||||
"conversation": conversation,
|
||||
"workspace": workspace,
|
||||
"recent_conversation": await self._format_history(conversation),
|
||||
}
|
||||
|
||||
async for chunk in self.rag_chain.astream(chain_input):
|
||||
|
||||
@@ -11,226 +11,241 @@ from chat_backend.services.rag_services import (
|
||||
AsyncRAGService,
|
||||
)
|
||||
from chat_backend.models import Conversation, Prompt, DocumentWorkspace, Document
|
||||
from chat_backend.services.prompt_classifier import PromptClassifier, PromptType
|
||||
from parameterized import parameterized
|
||||
|
||||
|
||||
class TestRAGService(TestCase):
|
||||
# class TestRAGService(TestCase):
|
||||
# def setUp(self):
|
||||
# self.rag_service = RAGService()
|
||||
# self.rag_service.vector_store = MagicMock()
|
||||
# self.rag_service.embedding_model = MagicMock()
|
||||
# self.rag_service.text_splitter = MagicMock()
|
||||
|
||||
# def test_initialize_vector_store(self):
|
||||
# with patch("os.path.exists", return_value=False), patch(
|
||||
# "os.makedirs"
|
||||
# ) as mock_makedirs, patch(
|
||||
# "langchain_community.vectorstores.Chroma"
|
||||
# ) as mock_chroma:
|
||||
|
||||
# # Reset the vector store to test initialization
|
||||
# self.rag_service.vector_store = None
|
||||
# result = self.rag_service._initialize_vector_store()
|
||||
|
||||
# mock_makedirs.assert_called_once_with("chroma_db")
|
||||
# mock_chroma.assert_called_once_with(
|
||||
# embedding_function=self.rag_service.embedding_model,
|
||||
# persist_directory="chroma_db",
|
||||
# )
|
||||
# self.assertIsNotNone(result)
|
||||
|
||||
# def test_prepare_documents(self):
|
||||
# mock_doc1 = MagicMock(spec=Document)
|
||||
# mock_doc1.content = "Test content"
|
||||
# mock_doc1.source = "test_source"
|
||||
# mock_doc1.workspace = MagicMock()
|
||||
# mock_doc1.workspace.id = 1
|
||||
# mock_doc1.id = 1
|
||||
|
||||
# self.rag_service.text_splitter.split_text.return_value = ["chunk1", "chunk2"]
|
||||
|
||||
# result = self.rag_service._prepare_documents([mock_doc1])
|
||||
|
||||
# self.assertEqual(len(result), 2)
|
||||
# self.rag_service.text_splitter.split_text.assert_called_once_with(
|
||||
# "Test content"
|
||||
# )
|
||||
# self.assertEqual(result[0].page_content, "chunk1")
|
||||
# self.assertEqual(result[0].metadata["source"], "test_source")
|
||||
|
||||
# def test_ingest_documents(self):
|
||||
# mock_workspace = MagicMock()
|
||||
# mock_document = MagicMock()
|
||||
# mock_documents = [mock_document]
|
||||
|
||||
# with patch(
|
||||
# "services.rag_services.Document.objects.filter", return_value=mock_documents
|
||||
# ):
|
||||
# self.rag_service._prepare_documents = MagicMock(
|
||||
# return_value=["processed_doc"]
|
||||
# )
|
||||
|
||||
# self.rag_service.ingest_documents(mock_workspace)
|
||||
|
||||
# self.rag_service.vector_store.add_documents.assert_called_once_with(
|
||||
# ["processed_doc"]
|
||||
# )
|
||||
# self.rag_service.vector_store.persist.assert_called_once()
|
||||
|
||||
|
||||
# class TestSyncRAGService(DjangoTestCase):
|
||||
# def setUp(self):
|
||||
# self.sync_service = SyncRAGService()
|
||||
# self.sync_service.vector_store = MagicMock()
|
||||
# self.sync_service.llm = MagicMock()
|
||||
# self.sync_service.rag_chain = MagicMock()
|
||||
|
||||
# self.mock_conversation = MagicMock(spec=Conversation)
|
||||
# self.mock_conversation.workspace = MagicMock()
|
||||
|
||||
# self.mock_prompt1 = MagicMock(spec=Prompt)
|
||||
# self.mock_prompt1.is_user = True
|
||||
# self.mock_prompt1.text = "User question"
|
||||
# self.mock_prompt1.created_at = "2023-01-01"
|
||||
|
||||
# self.mock_prompt2 = MagicMock(spec=Prompt)
|
||||
# self.mock_prompt2.is_user = False
|
||||
# self.mock_prompt2.text = "AI response"
|
||||
# self.mock_prompt2.created_at = "2023-01-02"
|
||||
|
||||
# def test_format_history(self):
|
||||
# with patch("services.rag_services.Prompt.objects.filter") as mock_filter:
|
||||
# mock_filter.return_value.order_by.return_value = [
|
||||
# self.mock_prompt1,
|
||||
# self.mock_prompt2,
|
||||
# ]
|
||||
|
||||
# result = self.sync_service._format_history(self.mock_conversation)
|
||||
|
||||
# expected = "User: User question\nAI: AI response"
|
||||
# self.assertEqual(result, expected)
|
||||
# mock_filter.assert_called_once_with(conversation=self.mock_conversation)
|
||||
|
||||
# def test_retriever_with_history(self):
|
||||
# input_dict = {"query": "test query", "conversation": self.mock_conversation}
|
||||
|
||||
# self.sync_service.search_documents = MagicMock(return_value=["doc1", "doc2"])
|
||||
|
||||
# result = self.sync_service._retriever_with_history(input_dict)
|
||||
|
||||
# self.sync_service.search_documents.assert_called_once_with(
|
||||
# "test query", self.mock_conversation.workspace
|
||||
# )
|
||||
# self.assertEqual(result, ["doc1", "doc2"])
|
||||
|
||||
# def test_search_documents(self):
|
||||
# mock_retriever = MagicMock()
|
||||
# mock_retriever.get_relevant_documents.return_value = ["doc1", "doc2"]
|
||||
# self.sync_service.vector_store.as_retriever.return_value = mock_retriever
|
||||
|
||||
# result = self.sync_service.search_documents(
|
||||
# "test query", self.mock_conversation.workspace
|
||||
# )
|
||||
|
||||
# self.sync_service.vector_store.as_retriever.assert_called_once_with(
|
||||
# search_type="similarity",
|
||||
# search_kwargs={
|
||||
# "k": 4,
|
||||
# "filter": {"workspace_id": self.mock_conversation.workspace.id},
|
||||
# },
|
||||
# )
|
||||
# self.assertEqual(result, ["doc1", "doc2"])
|
||||
|
||||
# def test_generate_response(self):
|
||||
# chain_input = {"query": "test query", "conversation": self.mock_conversation}
|
||||
|
||||
# mock_stream = ["chunk1", "chunk2", "chunk3"]
|
||||
# self.sync_service.rag_chain.stream.return_value = mock_stream
|
||||
|
||||
# result = list(
|
||||
# self.sync_service.generate_response(self.mock_conversation, "test query")
|
||||
# )
|
||||
|
||||
# self.sync_service.rag_chain.stream.assert_called_once_with(chain_input)
|
||||
# self.assertEqual(result, mock_stream)
|
||||
|
||||
|
||||
# class TestAsyncRAGService(DjangoTestCase):
|
||||
# def setUp(self):
|
||||
# self.async_service = AsyncRAGService()
|
||||
# self.async_service.vector_store = MagicMock()
|
||||
# self.async_service.llm = MagicMock()
|
||||
# self.async_service.rag_chain = AsyncMock()
|
||||
|
||||
# self.mock_conversation = MagicMock(spec=Conversation)
|
||||
# self.mock_conversation.workspace = MagicMock()
|
||||
|
||||
# self.mock_prompt1 = MagicMock(spec=Prompt)
|
||||
# self.mock_prompt1.is_user = True
|
||||
# self.mock_prompt1.text = "User question"
|
||||
# self.mock_prompt1.created_at = "2023-01-01"
|
||||
|
||||
# self.mock_prompt2 = MagicMock(spec=Prompt)
|
||||
# self.mock_prompt2.is_user = False
|
||||
# self.mock_prompt2.text = "AI response"
|
||||
# self.mock_prompt2.created_at = "2023-01-02"
|
||||
|
||||
# async def test_format_history(self):
|
||||
# mock_manager = AsyncMock()
|
||||
# mock_manager.order_by.return_value.alist.return_value = [
|
||||
# self.mock_prompt1,
|
||||
# self.mock_prompt2,
|
||||
# ]
|
||||
|
||||
# with patch(
|
||||
# "services.rag_services.Prompt.objects.filter", return_value=mock_manager
|
||||
# ):
|
||||
# result = await self.async_service._format_history(self.mock_conversation)
|
||||
|
||||
# expected = "User: User question\nAI: AI response"
|
||||
# self.assertEqual(result, expected)
|
||||
# mock_manager.order_by.assert_called_once_with("created_at")
|
||||
|
||||
# async def test_retriever_with_history(self):
|
||||
# input_dict = {"query": "test query", "conversation": self.mock_conversation}
|
||||
|
||||
# self.async_service.search_documents = AsyncMock(return_value=["doc1", "doc2"])
|
||||
|
||||
# result = await self.async_service._retriever_with_history(input_dict)
|
||||
|
||||
# self.async_service.search_documents.assert_awaited_once_with(
|
||||
# "test query", self.mock_conversation.workspace
|
||||
# )
|
||||
# self.assertEqual(result, ["doc1", "doc2"])
|
||||
|
||||
# async def test_search_documents(self):
|
||||
# mock_retriever = AsyncMock()
|
||||
# mock_retriever.aget_relevant_documents.return_value = ["doc1", "doc2"]
|
||||
# self.async_service.vector_store.as_retriever.return_value = mock_retriever
|
||||
|
||||
# result = await self.async_service.search_documents(
|
||||
# "test query", self.mock_conversation.workspace
|
||||
# )
|
||||
|
||||
# self.async_service.vector_store.as_retriever.assert_called_once_with(
|
||||
# search_type="similarity",
|
||||
# search_kwargs={
|
||||
# "k": 4,
|
||||
# "filter": {"workspace_id": self.mock_conversation.workspace.id},
|
||||
# },
|
||||
# )
|
||||
# self.assertEqual(result, ["doc1", "doc2"])
|
||||
|
||||
# async def test_generate_response(self):
|
||||
# chain_input = {"query": "test query", "conversation": self.mock_conversation}
|
||||
|
||||
# mock_stream = ["chunk1", "chunk2", "chunk3"]
|
||||
# self.async_service.rag_chain.astream.return_value = mock_stream
|
||||
|
||||
# chunks = []
|
||||
# async for chunk in self.async_service.generate_response(
|
||||
# self.mock_conversation, "test query"
|
||||
# ):
|
||||
# chunks.append(chunk)
|
||||
|
||||
# self.async_service.rag_chain.astream.assert_awaited_once_with(chain_input)
|
||||
# self.assertEqual(chunks, mock_stream)
|
||||
|
||||
class PromptClassifierTestCase(TestCase):
|
||||
def setUp(self):
|
||||
self.rag_service = RAGService()
|
||||
self.rag_service.vector_store = MagicMock()
|
||||
self.rag_service.embedding_model = MagicMock()
|
||||
self.rag_service.text_splitter = MagicMock()
|
||||
self.service = PromptClassifier()
|
||||
|
||||
def test_initialize_vector_store(self):
|
||||
with patch("os.path.exists", return_value=False), patch(
|
||||
"os.makedirs"
|
||||
) as mock_makedirs, patch(
|
||||
"langchain_community.vectorstores.Chroma"
|
||||
) as mock_chroma:
|
||||
|
||||
# Reset the vector store to test initialization
|
||||
self.rag_service.vector_store = None
|
||||
result = self.rag_service._initialize_vector_store()
|
||||
|
||||
mock_makedirs.assert_called_once_with("chroma_db")
|
||||
mock_chroma.assert_called_once_with(
|
||||
embedding_function=self.rag_service.embedding_model,
|
||||
persist_directory="chroma_db",
|
||||
)
|
||||
self.assertIsNotNone(result)
|
||||
|
||||
def test_prepare_documents(self):
|
||||
mock_doc1 = MagicMock(spec=Document)
|
||||
mock_doc1.content = "Test content"
|
||||
mock_doc1.source = "test_source"
|
||||
mock_doc1.workspace = MagicMock()
|
||||
mock_doc1.workspace.id = 1
|
||||
mock_doc1.id = 1
|
||||
|
||||
self.rag_service.text_splitter.split_text.return_value = ["chunk1", "chunk2"]
|
||||
|
||||
result = self.rag_service._prepare_documents([mock_doc1])
|
||||
|
||||
self.assertEqual(len(result), 2)
|
||||
self.rag_service.text_splitter.split_text.assert_called_once_with(
|
||||
"Test content"
|
||||
)
|
||||
self.assertEqual(result[0].page_content, "chunk1")
|
||||
self.assertEqual(result[0].metadata["source"], "test_source")
|
||||
|
||||
def test_ingest_documents(self):
|
||||
mock_workspace = MagicMock()
|
||||
mock_document = MagicMock()
|
||||
mock_documents = [mock_document]
|
||||
|
||||
with patch(
|
||||
"services.rag_services.Document.objects.filter", return_value=mock_documents
|
||||
):
|
||||
self.rag_service._prepare_documents = MagicMock(
|
||||
return_value=["processed_doc"]
|
||||
)
|
||||
|
||||
self.rag_service.ingest_documents(mock_workspace)
|
||||
|
||||
self.rag_service.vector_store.add_documents.assert_called_once_with(
|
||||
["processed_doc"]
|
||||
)
|
||||
self.rag_service.vector_store.persist.assert_called_once()
|
||||
|
||||
|
||||
class TestSyncRAGService(DjangoTestCase):
|
||||
def setUp(self):
|
||||
self.sync_service = SyncRAGService()
|
||||
self.sync_service.vector_store = MagicMock()
|
||||
self.sync_service.llm = MagicMock()
|
||||
self.sync_service.rag_chain = MagicMock()
|
||||
|
||||
self.mock_conversation = MagicMock(spec=Conversation)
|
||||
self.mock_conversation.workspace = MagicMock()
|
||||
|
||||
self.mock_prompt1 = MagicMock(spec=Prompt)
|
||||
self.mock_prompt1.is_user = True
|
||||
self.mock_prompt1.text = "User question"
|
||||
self.mock_prompt1.created_at = "2023-01-01"
|
||||
|
||||
self.mock_prompt2 = MagicMock(spec=Prompt)
|
||||
self.mock_prompt2.is_user = False
|
||||
self.mock_prompt2.text = "AI response"
|
||||
self.mock_prompt2.created_at = "2023-01-02"
|
||||
|
||||
def test_format_history(self):
|
||||
with patch("services.rag_services.Prompt.objects.filter") as mock_filter:
|
||||
mock_filter.return_value.order_by.return_value = [
|
||||
self.mock_prompt1,
|
||||
self.mock_prompt2,
|
||||
]
|
||||
|
||||
result = self.sync_service._format_history(self.mock_conversation)
|
||||
|
||||
expected = "User: User question\nAI: AI response"
|
||||
self.assertEqual(result, expected)
|
||||
mock_filter.assert_called_once_with(conversation=self.mock_conversation)
|
||||
|
||||
def test_retriever_with_history(self):
|
||||
input_dict = {"query": "test query", "conversation": self.mock_conversation}
|
||||
|
||||
self.sync_service.search_documents = MagicMock(return_value=["doc1", "doc2"])
|
||||
|
||||
result = self.sync_service._retriever_with_history(input_dict)
|
||||
|
||||
self.sync_service.search_documents.assert_called_once_with(
|
||||
"test query", self.mock_conversation.workspace
|
||||
)
|
||||
self.assertEqual(result, ["doc1", "doc2"])
|
||||
|
||||
def test_search_documents(self):
|
||||
mock_retriever = MagicMock()
|
||||
mock_retriever.get_relevant_documents.return_value = ["doc1", "doc2"]
|
||||
self.sync_service.vector_store.as_retriever.return_value = mock_retriever
|
||||
|
||||
result = self.sync_service.search_documents(
|
||||
"test query", self.mock_conversation.workspace
|
||||
)
|
||||
|
||||
self.sync_service.vector_store.as_retriever.assert_called_once_with(
|
||||
search_type="similarity",
|
||||
search_kwargs={
|
||||
"k": 4,
|
||||
"filter": {"workspace_id": self.mock_conversation.workspace.id},
|
||||
},
|
||||
)
|
||||
self.assertEqual(result, ["doc1", "doc2"])
|
||||
|
||||
def test_generate_response(self):
|
||||
chain_input = {"query": "test query", "conversation": self.mock_conversation}
|
||||
|
||||
mock_stream = ["chunk1", "chunk2", "chunk3"]
|
||||
self.sync_service.rag_chain.stream.return_value = mock_stream
|
||||
|
||||
result = list(
|
||||
self.sync_service.generate_response(self.mock_conversation, "test query")
|
||||
)
|
||||
|
||||
self.sync_service.rag_chain.stream.assert_called_once_with(chain_input)
|
||||
self.assertEqual(result, mock_stream)
|
||||
|
||||
|
||||
class TestAsyncRAGService(DjangoTestCase):
|
||||
def setUp(self):
|
||||
self.async_service = AsyncRAGService()
|
||||
self.async_service.vector_store = MagicMock()
|
||||
self.async_service.llm = MagicMock()
|
||||
self.async_service.rag_chain = AsyncMock()
|
||||
|
||||
self.mock_conversation = MagicMock(spec=Conversation)
|
||||
self.mock_conversation.workspace = MagicMock()
|
||||
|
||||
self.mock_prompt1 = MagicMock(spec=Prompt)
|
||||
self.mock_prompt1.is_user = True
|
||||
self.mock_prompt1.text = "User question"
|
||||
self.mock_prompt1.created_at = "2023-01-01"
|
||||
|
||||
self.mock_prompt2 = MagicMock(spec=Prompt)
|
||||
self.mock_prompt2.is_user = False
|
||||
self.mock_prompt2.text = "AI response"
|
||||
self.mock_prompt2.created_at = "2023-01-02"
|
||||
|
||||
async def test_format_history(self):
|
||||
mock_manager = AsyncMock()
|
||||
mock_manager.order_by.return_value.alist.return_value = [
|
||||
self.mock_prompt1,
|
||||
self.mock_prompt2,
|
||||
]
|
||||
|
||||
with patch(
|
||||
"services.rag_services.Prompt.objects.filter", return_value=mock_manager
|
||||
):
|
||||
result = await self.async_service._format_history(self.mock_conversation)
|
||||
|
||||
expected = "User: User question\nAI: AI response"
|
||||
self.assertEqual(result, expected)
|
||||
mock_manager.order_by.assert_called_once_with("created_at")
|
||||
|
||||
async def test_retriever_with_history(self):
|
||||
input_dict = {"query": "test query", "conversation": self.mock_conversation}
|
||||
|
||||
self.async_service.search_documents = AsyncMock(return_value=["doc1", "doc2"])
|
||||
|
||||
result = await self.async_service._retriever_with_history(input_dict)
|
||||
|
||||
self.async_service.search_documents.assert_awaited_once_with(
|
||||
"test query", self.mock_conversation.workspace
|
||||
)
|
||||
self.assertEqual(result, ["doc1", "doc2"])
|
||||
|
||||
async def test_search_documents(self):
|
||||
mock_retriever = AsyncMock()
|
||||
mock_retriever.aget_relevant_documents.return_value = ["doc1", "doc2"]
|
||||
self.async_service.vector_store.as_retriever.return_value = mock_retriever
|
||||
|
||||
result = await self.async_service.search_documents(
|
||||
"test query", self.mock_conversation.workspace
|
||||
)
|
||||
|
||||
self.async_service.vector_store.as_retriever.assert_called_once_with(
|
||||
search_type="similarity",
|
||||
search_kwargs={
|
||||
"k": 4,
|
||||
"filter": {"workspace_id": self.mock_conversation.workspace.id},
|
||||
},
|
||||
)
|
||||
self.assertEqual(result, ["doc1", "doc2"])
|
||||
|
||||
async def test_generate_response(self):
|
||||
chain_input = {"query": "test query", "conversation": self.mock_conversation}
|
||||
|
||||
mock_stream = ["chunk1", "chunk2", "chunk3"]
|
||||
self.async_service.rag_chain.astream.return_value = mock_stream
|
||||
|
||||
chunks = []
|
||||
async for chunk in self.async_service.generate_response(
|
||||
self.mock_conversation, "test query"
|
||||
):
|
||||
chunks.append(chunk)
|
||||
|
||||
self.async_service.rag_chain.astream.assert_awaited_once_with(chain_input)
|
||||
self.assertEqual(chunks, mock_stream)
|
||||
@parameterized.expand([
|
||||
["Tell me a joke",PromptType.GENERAL_CHAT],
|
||||
["Create an image of a dog for me",PromptType.IMAGE_GENERATION],
|
||||
["highlight the features of the backyard playset if they were to choose us and make the language more long form",PromptType.GENERAL_CHAT],
|
||||
])
|
||||
def test_prompt_classification(self, prompt, expected_output):
|
||||
result = self.service.classify(prompt)
|
||||
self.assertEqual(result, expected_output)
|
||||
@@ -10,6 +10,8 @@ from .models import DocumentWorkspace, Document, Company
|
||||
from django.contrib.auth import get_user_model
|
||||
import tempfile
|
||||
from django.core.files.uploadedfile import SimpleUploadedFile
|
||||
from parameterized import parameterized
|
||||
|
||||
|
||||
# Minimal valid PDF bytes
|
||||
VALID_PDF_BYTES = (
|
||||
@@ -185,3 +187,8 @@ class DocumentViewsTestCase(APITestCase):
|
||||
url = reverse("documents_details", kwargs={"document_id": other_document.id})
|
||||
response = self.client.get(url)
|
||||
self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND)
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
@@ -15,6 +15,7 @@ from .serializers import (
|
||||
DocumentWorkspaceSerializer,
|
||||
DocumentSerializer,
|
||||
)
|
||||
|
||||
from rest_framework.views import APIView
|
||||
from rest_framework.response import Response
|
||||
from .models import (
|
||||
@@ -35,7 +36,7 @@ from asgiref.sync import sync_to_async, async_to_sync
|
||||
from channels.generic.websocket import AsyncWebsocketConsumer
|
||||
from langchain_ollama.llms import OllamaLLM
|
||||
from langchain_core.prompts import ChatPromptTemplate
|
||||
from langchain_core.messages import HumanMessage, SystemMessage
|
||||
from langchain_core.messages import HumanMessage, AIMessage
|
||||
from langchain.chains import RetrievalQA
|
||||
import re
|
||||
import os
|
||||
@@ -66,7 +67,7 @@ from .services.llm_service import AsyncLLMService
|
||||
from .services.rag_services import AsyncRAGService
|
||||
from .services.title_generator import title_generator
|
||||
from .services.moderation_classifier import moderation_classifier, ModerationLabel
|
||||
from .services.prompt_classifier import prompt_classifier, PromptType
|
||||
from .services.prompt_classifier.prompt_classifier import PromptClassifier, PromptType
|
||||
|
||||
|
||||
from langchain.chains import create_retrieval_chain
|
||||
@@ -74,7 +75,7 @@ from langchain.chains.combine_documents import create_stuff_documents_chain
|
||||
from langchain_ollama import ChatOllama
|
||||
|
||||
CHANNEL_NAME: str = "llm_messages"
|
||||
MODEL_NAME: str = "llama3"
|
||||
MODEL_NAME: str = "llama3.2"
|
||||
|
||||
|
||||
# Create your views here.
|
||||
@@ -790,7 +791,7 @@ def get_messages(conversation_id, prompt, file_string: str = None, file_type: st
|
||||
altered_message = message["content"]
|
||||
|
||||
transformed_message = (
|
||||
SystemMessage(content=altered_message)
|
||||
AIMessage(content=altered_message)
|
||||
if message["role"] == "assistant"
|
||||
else HumanMessage(content=altered_message)
|
||||
)
|
||||
@@ -863,6 +864,7 @@ def get_retriever(conversation_id):
|
||||
)
|
||||
return vectorstore.as_retriever()
|
||||
|
||||
PROMPT_CLASSIFIER = PromptClassifier()
|
||||
|
||||
class ChatConsumerAgain(AsyncWebsocketConsumer):
|
||||
async def connect(self):
|
||||
@@ -932,8 +934,9 @@ class ChatConsumerAgain(AsyncWebsocketConsumer):
|
||||
messages, prompt = await get_messages(
|
||||
conversation_id, message, decoded_file, file_type
|
||||
)
|
||||
|
||||
|
||||
prompt_type = await prompt_classifier.classify_async(message)
|
||||
prompt_type = await PROMPT_CLASSIFIER.classify_async(message)
|
||||
print(f"prompt_type: {prompt_type} for {message}")
|
||||
if file:
|
||||
prompt_type = PromptType.GENERAL_CHAT
|
||||
@@ -977,7 +980,7 @@ class ChatConsumerAgain(AsyncWebsocketConsumer):
|
||||
print(f"using the AsyncLLMService\n\n{messages}\n{prompt.message}")
|
||||
service = AsyncLLMService()
|
||||
async for chunk in service.generate_response(
|
||||
messages, prompt.message
|
||||
messages, prompt.message, conversation_id
|
||||
):
|
||||
response += chunk
|
||||
await self.send(chunk)
|
||||
|
||||
Reference in New Issue
Block a user