fixed chat service

This commit is contained in:
2025-05-28 03:25:14 -05:00
parent a85f1222eb
commit 951a58f2fa
10 changed files with 374 additions and 300 deletions

View File

@@ -0,0 +1,17 @@
from abc import ABC, abstractmethod
from langchain_ollama import OllamaLLM
from langchain_core.output_parsers import StrOutputParser
class BaseService(ABC):
"""Abstract base class for LLM conversation services."""
def __init__(self, temperature=0.7):
self.llm = OllamaLLM(
model="llama3.2",
temperature=0.7,
top_k=50,
top_p=0.9,
repeat_penalty=1.1,
num_ctx=4096,
)
self.output_parser = StrOutputParser()

View File

@@ -72,8 +72,6 @@ class SyncLLMService(LLMService):
"""Generate response with streaming support."""
chain_input = {"query": query, "conversation": conversation}
print(f"chain_input:\n{chain_input}")
for chunk in self.conversation_chain.stream(chain_input):
yield chunk
@@ -106,10 +104,8 @@ class AsyncLLMService(LLMService):
self.conversation_chain = (
{
"context": lambda x: self._format_history(x["conversation"]),
"recent_history": lambda x: self._get_recent_messages(
x["conversation"]
),
"context":lambda x: x["conversation"],
"recent_history":lambda x: x['recent_conversation'],
"query": lambda x: x["query"],
}
| self.prompt
@@ -117,34 +113,39 @@ class AsyncLLMService(LLMService):
| self.output_parser
)
async def _format_history(self, conversation: Conversation) -> str:
async def _format_history(self, conversation: list) -> str:
"""Async version of format conversation history."""
prompts = (
await Prompt.objects.filter(conversation=conversation)
.order_by("created_at")
.alist()
)
return "\n".join(
f"{'User' if prompt.is_user else 'AI'}: {prompt.text}" for prompt in prompts
)
# prompts = list(
# await Prompt.objects.filter(conversation_id=conversation_id)
# .order_by("created")
# )
# return "\n".join(
# f"{'User' if prompt.is_user else 'AI'}: {prompt.text}" for prompt in prompts
# )
return "\n".join([f"{"User" if prompt.type=="human" else "AI"}: {prompt.text()}" for prompt in conversation])
async def _get_recent_messages(self, conversation: Conversation) -> str:
async def _get_recent_messages(self, conversation: list) -> str:
"""Async version of format conversation history."""
prompts = (
await Prompt.objects.filter(conversation=conversation)
.order_by("created_at")
.alist()[-6:]
)
return "\n".join(
f"{'User' if prompt.is_user else 'AI'}: {prompt.text}" for prompt in prompts
)
# prompts = list(
# await Prompt.objects.filter(conversation_id=conversation_id)
# .order_by("created")
# [-6:]
# )
# return "\n".join(
# f"{'User' if prompt.is_user else 'AI'}: {prompt.text}" for prompt in prompts
# )
return "\n".join([f"{"User" if prompt.type=="human" else "AI"}: {prompt.text()}" for prompt in conversation])
async def generate_response(
self, conversation: Conversation, query: str, **kwargs
self, conversation: Conversation, query: str, conversation_id: int, **kwargs
) -> AsyncGenerator[str, None]:
"""Generate response with async streaming support."""
chain_input = {"query": query, "conversation": conversation}
print(f"LLM Chain:\n{chain_input}")
chain_input = {
"query": query,
"conversation": await self._format_history(conversation),
"recent_conversation": await self._get_recent_messages(conversation[-6:])}
async for chunk in self.conversation_chain.astream(chain_input):
yield chunk

View File

@@ -2,8 +2,7 @@ from enum import Enum, auto
from typing import Dict, Any
from langchain_core.prompts import ChatPromptTemplate
# from langchain_community.llms import Ollama
from langchain_ollama import OllamaLLM
from chat_backend.services.base_service import BaseService
class ModerationLabel(Enum):
@@ -11,18 +10,19 @@ class ModerationLabel(Enum):
FINE = auto()
class ModerationClassifier:
class ModerationClassifier(BaseService):
"""
Classifies prompts as NSFW or FINE (safe) content.
"""
def __init__(self):
self.llm = OllamaLLM(
model="llama3.2",
temperature=0.1, # Very low for strict moderation
top_k=10,
num_ctx=2048,
)
super().__init__(temperature=0.1)
# self.llm = OllamaLLM(
# model="llama3.2",
# temperature=0.1, # Very low for strict moderation
# top_k=10,
# num_ctx=2048,
# )
self.moderation_prompt = ChatPromptTemplate.from_messages(
[

View File

@@ -2,8 +2,7 @@ from enum import Enum, auto
from typing import Dict, Any
from langchain_core.prompts import ChatPromptTemplate
# from langchain_community.llms import Ollama
from langchain_ollama import OllamaLLM
from chat_backend.services.base_service import BaseService
class PromptType(Enum):
@@ -13,19 +12,20 @@ class PromptType(Enum):
UNKNOWN = auto()
class PromptClassifier:
class PromptClassifier(BaseService):
"""
Classifies user prompts to determine which service should handle them.
"""
def __init__(self):
self.llm = OllamaLLM(
model="llama3.2",
temperature=0.3, # Lower temp for more deterministic classification
top_k=20,
top_p=0.9,
num_ctx=4096,
)
super().__init__(temperature=0.1)
# self.llm = OllamaLLM(
# model="llama3.2",
# temperature=0.3, # Lower temp for more deterministic classification
# top_k=20,
# top_p=0.9,
# num_ctx=4096,
# )
self.classification_prompt = ChatPromptTemplate.from_messages(
[
@@ -69,6 +69,10 @@ Examples:
- "Explain quantum computing" (General knowledge)
- "Summarize the meeting" (No doc reference)
[Definitely NOT IMAGE_GENERATION]
- "Great, can you make it about a duck now"
- "highlight the features of the backyard playset if they were to choose us and make the language more long form"
Return ONLY the label, no explanations.""",
),
("human", "{prompt}"),

View File

@@ -0,0 +1,31 @@
import os
from unittest import TestCase, mock
from unittest.mock import MagicMock, patch, AsyncMock
from typing import List, Dict, Any
from django.test import TestCase as DjangoTestCase
from chat_backend.services.rag_services import (
RAGService,
SyncRAGService,
AsyncRAGService,
)
from chat_backend.models import Conversation, Prompt, DocumentWorkspace, Document
from chat_backend.services.prompt_classifier.prompt_classifier import PromptClassifier, PromptType
from parameterized import parameterized
class PromptClassifierTestCase(TestCase):
def setUp(self):
self.service = PromptClassifier()
@parameterized.expand([
["Tell me a joke",PromptType.GENERAL_CHAT],
["Create an image of a dog for me",PromptType.IMAGE_GENERATION],
["highlight the features of the backyard playset if they were to choose us and make the language more long form",PromptType.GENERAL_CHAT],
["Great, can you make it about a duck now", PromptType.IMAGE_GENERATION],
])
def test_prompt_classification(self, prompt, expected_output):
result = self.service.classify(prompt)
self.assertEqual(result, expected_output)

View File

@@ -21,6 +21,7 @@ from langchain_community.document_loaders import (
from django.core.files.uploadedfile import UploadedFile
from chat_backend.models import Conversation, Prompt, DocumentWorkspace, Document
from pathlib import Path
from chat_backend.services.base_service import BaseService
@database_sync_to_async
@@ -31,7 +32,7 @@ def get_documents(workspace: DocumentWorkspace | None = None):
return [doc for doc in Document.objects.all()]
class RAGService(ABC):
class RAGService(BaseService):
"""Abstract base class for RAG services."""
_instance = None
@@ -44,14 +45,7 @@ class RAGService(ABC):
def __init__(self):
self.embedding_model = OllamaEmbeddings(model="llama3.2")
self.llm = OllamaLLM(
model="llama3.2",
temperature=0.7,
top_k=50,
top_p=0.9,
repeat_penalty=1.1,
num_ctx=4096,
)
super().__init__()
self.text_splitter = RecursiveCharacterTextSplitter(
chunk_size=1000, chunk_overlap=200
)
@@ -104,17 +98,17 @@ class RAGService(ABC):
print(f"Processing the documents : {documents}")
self._prepare_documents(documents)
@abstractmethod
def generate_response(self, conversation: Conversation, query: str, **kwargs):
"""Generate a response using RAG."""
pass
# @abstractmethod
# def generate_response(self, conversation: Conversation, query: str, **kwargs):
# """Generate a response using RAG."""
# pass
@abstractmethod
def search_documents(
self, query: str, workspace: Optional[DocumentWorkspace] = None, k: int = 4
) -> List[Document]:
"""Search relevant documents from the vector store."""
pass
# @abstractmethod
# def search_documents(
# self, query: str, workspace: Optional[DocumentWorkspace] = None, k: int = 4
# ) -> List[Document]:
# """Search relevant documents from the vector store."""
# pass
def _get_file_loader(self, file_path: str):
"""Get appropriate loader for file type"""
@@ -304,7 +298,7 @@ class AsyncRAGService(RAGService):
self.rag_chain = (
{
"context": self._retriever_with_history,
"history": lambda x: self._format_history(x["conversation"]),
"history": lambda x: x['recent_conversation'], #self._format_history(x["conversation"]),
"question": lambda x: x["query"],
}
| self.prompt
@@ -314,15 +308,16 @@ class AsyncRAGService(RAGService):
async def _format_history(self, conversation: Conversation) -> str:
"""Format conversation history for the prompt."""
prompts = (
await Prompt.objects.filter(conversation=conversation)
.order_by("created_at")
.alist()
)
print(f"prompts that we are seeding with are: {prompts}")
return "\n".join(
f"{'User' if prompt.is_user else 'AI'}: {prompt.text}" for prompt in prompts
)
# prompts = (
# await Prompt.objects.filter(conversation=conversation)
# .order_by("created_at")
# .alist()
# )
# print(f"prompts that we are seeding with are: {prompts}")
# return "\n".join(
# f"{'User' if prompt.is_user else 'AI'}: {prompt.text}" for prompt in prompts
# )
return "\n".join([f"{"User" if prompt.type=="human" else "AI"}: {prompt.text()}" for prompt in conversation])
async def _retriever_with_history(self, input_dict: Dict[str, Any]) -> str:
"""Retrieve documents considering conversation history."""
@@ -369,6 +364,7 @@ class AsyncRAGService(RAGService):
"query": query,
"conversation": conversation,
"workspace": workspace,
"recent_conversation": await self._format_history(conversation),
}
async for chunk in self.rag_chain.astream(chain_input):

View File

@@ -11,226 +11,241 @@ from chat_backend.services.rag_services import (
AsyncRAGService,
)
from chat_backend.models import Conversation, Prompt, DocumentWorkspace, Document
from chat_backend.services.prompt_classifier import PromptClassifier, PromptType
from parameterized import parameterized
class TestRAGService(TestCase):
# class TestRAGService(TestCase):
# def setUp(self):
# self.rag_service = RAGService()
# self.rag_service.vector_store = MagicMock()
# self.rag_service.embedding_model = MagicMock()
# self.rag_service.text_splitter = MagicMock()
# def test_initialize_vector_store(self):
# with patch("os.path.exists", return_value=False), patch(
# "os.makedirs"
# ) as mock_makedirs, patch(
# "langchain_community.vectorstores.Chroma"
# ) as mock_chroma:
# # Reset the vector store to test initialization
# self.rag_service.vector_store = None
# result = self.rag_service._initialize_vector_store()
# mock_makedirs.assert_called_once_with("chroma_db")
# mock_chroma.assert_called_once_with(
# embedding_function=self.rag_service.embedding_model,
# persist_directory="chroma_db",
# )
# self.assertIsNotNone(result)
# def test_prepare_documents(self):
# mock_doc1 = MagicMock(spec=Document)
# mock_doc1.content = "Test content"
# mock_doc1.source = "test_source"
# mock_doc1.workspace = MagicMock()
# mock_doc1.workspace.id = 1
# mock_doc1.id = 1
# self.rag_service.text_splitter.split_text.return_value = ["chunk1", "chunk2"]
# result = self.rag_service._prepare_documents([mock_doc1])
# self.assertEqual(len(result), 2)
# self.rag_service.text_splitter.split_text.assert_called_once_with(
# "Test content"
# )
# self.assertEqual(result[0].page_content, "chunk1")
# self.assertEqual(result[0].metadata["source"], "test_source")
# def test_ingest_documents(self):
# mock_workspace = MagicMock()
# mock_document = MagicMock()
# mock_documents = [mock_document]
# with patch(
# "services.rag_services.Document.objects.filter", return_value=mock_documents
# ):
# self.rag_service._prepare_documents = MagicMock(
# return_value=["processed_doc"]
# )
# self.rag_service.ingest_documents(mock_workspace)
# self.rag_service.vector_store.add_documents.assert_called_once_with(
# ["processed_doc"]
# )
# self.rag_service.vector_store.persist.assert_called_once()
# class TestSyncRAGService(DjangoTestCase):
# def setUp(self):
# self.sync_service = SyncRAGService()
# self.sync_service.vector_store = MagicMock()
# self.sync_service.llm = MagicMock()
# self.sync_service.rag_chain = MagicMock()
# self.mock_conversation = MagicMock(spec=Conversation)
# self.mock_conversation.workspace = MagicMock()
# self.mock_prompt1 = MagicMock(spec=Prompt)
# self.mock_prompt1.is_user = True
# self.mock_prompt1.text = "User question"
# self.mock_prompt1.created_at = "2023-01-01"
# self.mock_prompt2 = MagicMock(spec=Prompt)
# self.mock_prompt2.is_user = False
# self.mock_prompt2.text = "AI response"
# self.mock_prompt2.created_at = "2023-01-02"
# def test_format_history(self):
# with patch("services.rag_services.Prompt.objects.filter") as mock_filter:
# mock_filter.return_value.order_by.return_value = [
# self.mock_prompt1,
# self.mock_prompt2,
# ]
# result = self.sync_service._format_history(self.mock_conversation)
# expected = "User: User question\nAI: AI response"
# self.assertEqual(result, expected)
# mock_filter.assert_called_once_with(conversation=self.mock_conversation)
# def test_retriever_with_history(self):
# input_dict = {"query": "test query", "conversation": self.mock_conversation}
# self.sync_service.search_documents = MagicMock(return_value=["doc1", "doc2"])
# result = self.sync_service._retriever_with_history(input_dict)
# self.sync_service.search_documents.assert_called_once_with(
# "test query", self.mock_conversation.workspace
# )
# self.assertEqual(result, ["doc1", "doc2"])
# def test_search_documents(self):
# mock_retriever = MagicMock()
# mock_retriever.get_relevant_documents.return_value = ["doc1", "doc2"]
# self.sync_service.vector_store.as_retriever.return_value = mock_retriever
# result = self.sync_service.search_documents(
# "test query", self.mock_conversation.workspace
# )
# self.sync_service.vector_store.as_retriever.assert_called_once_with(
# search_type="similarity",
# search_kwargs={
# "k": 4,
# "filter": {"workspace_id": self.mock_conversation.workspace.id},
# },
# )
# self.assertEqual(result, ["doc1", "doc2"])
# def test_generate_response(self):
# chain_input = {"query": "test query", "conversation": self.mock_conversation}
# mock_stream = ["chunk1", "chunk2", "chunk3"]
# self.sync_service.rag_chain.stream.return_value = mock_stream
# result = list(
# self.sync_service.generate_response(self.mock_conversation, "test query")
# )
# self.sync_service.rag_chain.stream.assert_called_once_with(chain_input)
# self.assertEqual(result, mock_stream)
# class TestAsyncRAGService(DjangoTestCase):
# def setUp(self):
# self.async_service = AsyncRAGService()
# self.async_service.vector_store = MagicMock()
# self.async_service.llm = MagicMock()
# self.async_service.rag_chain = AsyncMock()
# self.mock_conversation = MagicMock(spec=Conversation)
# self.mock_conversation.workspace = MagicMock()
# self.mock_prompt1 = MagicMock(spec=Prompt)
# self.mock_prompt1.is_user = True
# self.mock_prompt1.text = "User question"
# self.mock_prompt1.created_at = "2023-01-01"
# self.mock_prompt2 = MagicMock(spec=Prompt)
# self.mock_prompt2.is_user = False
# self.mock_prompt2.text = "AI response"
# self.mock_prompt2.created_at = "2023-01-02"
# async def test_format_history(self):
# mock_manager = AsyncMock()
# mock_manager.order_by.return_value.alist.return_value = [
# self.mock_prompt1,
# self.mock_prompt2,
# ]
# with patch(
# "services.rag_services.Prompt.objects.filter", return_value=mock_manager
# ):
# result = await self.async_service._format_history(self.mock_conversation)
# expected = "User: User question\nAI: AI response"
# self.assertEqual(result, expected)
# mock_manager.order_by.assert_called_once_with("created_at")
# async def test_retriever_with_history(self):
# input_dict = {"query": "test query", "conversation": self.mock_conversation}
# self.async_service.search_documents = AsyncMock(return_value=["doc1", "doc2"])
# result = await self.async_service._retriever_with_history(input_dict)
# self.async_service.search_documents.assert_awaited_once_with(
# "test query", self.mock_conversation.workspace
# )
# self.assertEqual(result, ["doc1", "doc2"])
# async def test_search_documents(self):
# mock_retriever = AsyncMock()
# mock_retriever.aget_relevant_documents.return_value = ["doc1", "doc2"]
# self.async_service.vector_store.as_retriever.return_value = mock_retriever
# result = await self.async_service.search_documents(
# "test query", self.mock_conversation.workspace
# )
# self.async_service.vector_store.as_retriever.assert_called_once_with(
# search_type="similarity",
# search_kwargs={
# "k": 4,
# "filter": {"workspace_id": self.mock_conversation.workspace.id},
# },
# )
# self.assertEqual(result, ["doc1", "doc2"])
# async def test_generate_response(self):
# chain_input = {"query": "test query", "conversation": self.mock_conversation}
# mock_stream = ["chunk1", "chunk2", "chunk3"]
# self.async_service.rag_chain.astream.return_value = mock_stream
# chunks = []
# async for chunk in self.async_service.generate_response(
# self.mock_conversation, "test query"
# ):
# chunks.append(chunk)
# self.async_service.rag_chain.astream.assert_awaited_once_with(chain_input)
# self.assertEqual(chunks, mock_stream)
class PromptClassifierTestCase(TestCase):
def setUp(self):
self.rag_service = RAGService()
self.rag_service.vector_store = MagicMock()
self.rag_service.embedding_model = MagicMock()
self.rag_service.text_splitter = MagicMock()
self.service = PromptClassifier()
def test_initialize_vector_store(self):
with patch("os.path.exists", return_value=False), patch(
"os.makedirs"
) as mock_makedirs, patch(
"langchain_community.vectorstores.Chroma"
) as mock_chroma:
# Reset the vector store to test initialization
self.rag_service.vector_store = None
result = self.rag_service._initialize_vector_store()
mock_makedirs.assert_called_once_with("chroma_db")
mock_chroma.assert_called_once_with(
embedding_function=self.rag_service.embedding_model,
persist_directory="chroma_db",
)
self.assertIsNotNone(result)
def test_prepare_documents(self):
mock_doc1 = MagicMock(spec=Document)
mock_doc1.content = "Test content"
mock_doc1.source = "test_source"
mock_doc1.workspace = MagicMock()
mock_doc1.workspace.id = 1
mock_doc1.id = 1
self.rag_service.text_splitter.split_text.return_value = ["chunk1", "chunk2"]
result = self.rag_service._prepare_documents([mock_doc1])
self.assertEqual(len(result), 2)
self.rag_service.text_splitter.split_text.assert_called_once_with(
"Test content"
)
self.assertEqual(result[0].page_content, "chunk1")
self.assertEqual(result[0].metadata["source"], "test_source")
def test_ingest_documents(self):
mock_workspace = MagicMock()
mock_document = MagicMock()
mock_documents = [mock_document]
with patch(
"services.rag_services.Document.objects.filter", return_value=mock_documents
):
self.rag_service._prepare_documents = MagicMock(
return_value=["processed_doc"]
)
self.rag_service.ingest_documents(mock_workspace)
self.rag_service.vector_store.add_documents.assert_called_once_with(
["processed_doc"]
)
self.rag_service.vector_store.persist.assert_called_once()
class TestSyncRAGService(DjangoTestCase):
def setUp(self):
self.sync_service = SyncRAGService()
self.sync_service.vector_store = MagicMock()
self.sync_service.llm = MagicMock()
self.sync_service.rag_chain = MagicMock()
self.mock_conversation = MagicMock(spec=Conversation)
self.mock_conversation.workspace = MagicMock()
self.mock_prompt1 = MagicMock(spec=Prompt)
self.mock_prompt1.is_user = True
self.mock_prompt1.text = "User question"
self.mock_prompt1.created_at = "2023-01-01"
self.mock_prompt2 = MagicMock(spec=Prompt)
self.mock_prompt2.is_user = False
self.mock_prompt2.text = "AI response"
self.mock_prompt2.created_at = "2023-01-02"
def test_format_history(self):
with patch("services.rag_services.Prompt.objects.filter") as mock_filter:
mock_filter.return_value.order_by.return_value = [
self.mock_prompt1,
self.mock_prompt2,
]
result = self.sync_service._format_history(self.mock_conversation)
expected = "User: User question\nAI: AI response"
self.assertEqual(result, expected)
mock_filter.assert_called_once_with(conversation=self.mock_conversation)
def test_retriever_with_history(self):
input_dict = {"query": "test query", "conversation": self.mock_conversation}
self.sync_service.search_documents = MagicMock(return_value=["doc1", "doc2"])
result = self.sync_service._retriever_with_history(input_dict)
self.sync_service.search_documents.assert_called_once_with(
"test query", self.mock_conversation.workspace
)
self.assertEqual(result, ["doc1", "doc2"])
def test_search_documents(self):
mock_retriever = MagicMock()
mock_retriever.get_relevant_documents.return_value = ["doc1", "doc2"]
self.sync_service.vector_store.as_retriever.return_value = mock_retriever
result = self.sync_service.search_documents(
"test query", self.mock_conversation.workspace
)
self.sync_service.vector_store.as_retriever.assert_called_once_with(
search_type="similarity",
search_kwargs={
"k": 4,
"filter": {"workspace_id": self.mock_conversation.workspace.id},
},
)
self.assertEqual(result, ["doc1", "doc2"])
def test_generate_response(self):
chain_input = {"query": "test query", "conversation": self.mock_conversation}
mock_stream = ["chunk1", "chunk2", "chunk3"]
self.sync_service.rag_chain.stream.return_value = mock_stream
result = list(
self.sync_service.generate_response(self.mock_conversation, "test query")
)
self.sync_service.rag_chain.stream.assert_called_once_with(chain_input)
self.assertEqual(result, mock_stream)
class TestAsyncRAGService(DjangoTestCase):
def setUp(self):
self.async_service = AsyncRAGService()
self.async_service.vector_store = MagicMock()
self.async_service.llm = MagicMock()
self.async_service.rag_chain = AsyncMock()
self.mock_conversation = MagicMock(spec=Conversation)
self.mock_conversation.workspace = MagicMock()
self.mock_prompt1 = MagicMock(spec=Prompt)
self.mock_prompt1.is_user = True
self.mock_prompt1.text = "User question"
self.mock_prompt1.created_at = "2023-01-01"
self.mock_prompt2 = MagicMock(spec=Prompt)
self.mock_prompt2.is_user = False
self.mock_prompt2.text = "AI response"
self.mock_prompt2.created_at = "2023-01-02"
async def test_format_history(self):
mock_manager = AsyncMock()
mock_manager.order_by.return_value.alist.return_value = [
self.mock_prompt1,
self.mock_prompt2,
]
with patch(
"services.rag_services.Prompt.objects.filter", return_value=mock_manager
):
result = await self.async_service._format_history(self.mock_conversation)
expected = "User: User question\nAI: AI response"
self.assertEqual(result, expected)
mock_manager.order_by.assert_called_once_with("created_at")
async def test_retriever_with_history(self):
input_dict = {"query": "test query", "conversation": self.mock_conversation}
self.async_service.search_documents = AsyncMock(return_value=["doc1", "doc2"])
result = await self.async_service._retriever_with_history(input_dict)
self.async_service.search_documents.assert_awaited_once_with(
"test query", self.mock_conversation.workspace
)
self.assertEqual(result, ["doc1", "doc2"])
async def test_search_documents(self):
mock_retriever = AsyncMock()
mock_retriever.aget_relevant_documents.return_value = ["doc1", "doc2"]
self.async_service.vector_store.as_retriever.return_value = mock_retriever
result = await self.async_service.search_documents(
"test query", self.mock_conversation.workspace
)
self.async_service.vector_store.as_retriever.assert_called_once_with(
search_type="similarity",
search_kwargs={
"k": 4,
"filter": {"workspace_id": self.mock_conversation.workspace.id},
},
)
self.assertEqual(result, ["doc1", "doc2"])
async def test_generate_response(self):
chain_input = {"query": "test query", "conversation": self.mock_conversation}
mock_stream = ["chunk1", "chunk2", "chunk3"]
self.async_service.rag_chain.astream.return_value = mock_stream
chunks = []
async for chunk in self.async_service.generate_response(
self.mock_conversation, "test query"
):
chunks.append(chunk)
self.async_service.rag_chain.astream.assert_awaited_once_with(chain_input)
self.assertEqual(chunks, mock_stream)
@parameterized.expand([
["Tell me a joke",PromptType.GENERAL_CHAT],
["Create an image of a dog for me",PromptType.IMAGE_GENERATION],
["highlight the features of the backyard playset if they were to choose us and make the language more long form",PromptType.GENERAL_CHAT],
])
def test_prompt_classification(self, prompt, expected_output):
result = self.service.classify(prompt)
self.assertEqual(result, expected_output)

View File

@@ -10,6 +10,8 @@ from .models import DocumentWorkspace, Document, Company
from django.contrib.auth import get_user_model
import tempfile
from django.core.files.uploadedfile import SimpleUploadedFile
from parameterized import parameterized
# Minimal valid PDF bytes
VALID_PDF_BYTES = (
@@ -185,3 +187,8 @@ class DocumentViewsTestCase(APITestCase):
url = reverse("documents_details", kwargs={"document_id": other_document.id})
response = self.client.get(url)
self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND)

View File

@@ -15,6 +15,7 @@ from .serializers import (
DocumentWorkspaceSerializer,
DocumentSerializer,
)
from rest_framework.views import APIView
from rest_framework.response import Response
from .models import (
@@ -35,7 +36,7 @@ from asgiref.sync import sync_to_async, async_to_sync
from channels.generic.websocket import AsyncWebsocketConsumer
from langchain_ollama.llms import OllamaLLM
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.messages import HumanMessage, SystemMessage
from langchain_core.messages import HumanMessage, AIMessage
from langchain.chains import RetrievalQA
import re
import os
@@ -66,7 +67,7 @@ from .services.llm_service import AsyncLLMService
from .services.rag_services import AsyncRAGService
from .services.title_generator import title_generator
from .services.moderation_classifier import moderation_classifier, ModerationLabel
from .services.prompt_classifier import prompt_classifier, PromptType
from .services.prompt_classifier.prompt_classifier import PromptClassifier, PromptType
from langchain.chains import create_retrieval_chain
@@ -74,7 +75,7 @@ from langchain.chains.combine_documents import create_stuff_documents_chain
from langchain_ollama import ChatOllama
CHANNEL_NAME: str = "llm_messages"
MODEL_NAME: str = "llama3"
MODEL_NAME: str = "llama3.2"
# Create your views here.
@@ -790,7 +791,7 @@ def get_messages(conversation_id, prompt, file_string: str = None, file_type: st
altered_message = message["content"]
transformed_message = (
SystemMessage(content=altered_message)
AIMessage(content=altered_message)
if message["role"] == "assistant"
else HumanMessage(content=altered_message)
)
@@ -863,6 +864,7 @@ def get_retriever(conversation_id):
)
return vectorstore.as_retriever()
PROMPT_CLASSIFIER = PromptClassifier()
class ChatConsumerAgain(AsyncWebsocketConsumer):
async def connect(self):
@@ -932,8 +934,9 @@ class ChatConsumerAgain(AsyncWebsocketConsumer):
messages, prompt = await get_messages(
conversation_id, message, decoded_file, file_type
)
prompt_type = await prompt_classifier.classify_async(message)
prompt_type = await PROMPT_CLASSIFIER.classify_async(message)
print(f"prompt_type: {prompt_type} for {message}")
if file:
prompt_type = PromptType.GENERAL_CHAT
@@ -977,7 +980,7 @@ class ChatConsumerAgain(AsyncWebsocketConsumer):
print(f"using the AsyncLLMService\n\n{messages}\n{prompt.message}")
service = AsyncLLMService()
async for chunk in service.generate_response(
messages, prompt.message
messages, prompt.message, conversation_id
):
response += chunk
await self.send(chunk)