Offline Chat
-
- {% if current_config.processor and current_config.processor.conversation and current_config.processor.conversation.offline_chat.enable_offline_chat and not current_model_state.conversation_gpt4all %}
+
+ {% if current_model_state.enable_offline_model and not current_model_state.conversation_gpt4all %}
{% endif %}
@@ -266,12 +271,12 @@
Setup offline chat
-
+
-
+
diff --git a/src/khoj/routers/api.py b/src/khoj/routers/api.py
index d041fd76..b7ba66b6 100644
--- a/src/khoj/routers/api.py
+++ b/src/khoj/routers/api.py
@@ -8,21 +8,20 @@ from typing import List, Optional, Union, Any
import asyncio
# External Packages
-from fastapi import APIRouter, HTTPException, Header, Request, Depends
+from fastapi import APIRouter, HTTPException, Header, Request
from starlette.authentication import requires
from asgiref.sync import sync_to_async
# Internal Packages
-from khoj.configure import configure_processor, configure_server
+from khoj.configure import configure_server
from khoj.search_type import image_search, text_search
from khoj.search_filter.date_filter import DateFilter
from khoj.search_filter.file_filter import FileFilter
from khoj.search_filter.word_filter import WordFilter
-from khoj.utils.config import TextSearchModel
+from khoj.utils.config import TextSearchModel, GPT4AllProcessorModel
from khoj.utils.helpers import ConversationCommand, is_none_or_empty, timer, command_descriptions
from khoj.utils.rawconfig import (
FullConfig,
- ProcessorConfig,
SearchConfig,
SearchResponse,
TextContentConfig,
@@ -32,16 +31,16 @@ from khoj.utils.rawconfig import (
ConversationProcessorConfig,
OfflineChatProcessorConfig,
)
-from khoj.utils.helpers import resolve_absolute_path
from khoj.utils.state import SearchType
from khoj.utils import state, constants
-from khoj.utils.yaml import save_config_to_file_updated_state
+from khoj.utils.helpers import AsyncIteratorWrapper
from fastapi.responses import StreamingResponse, Response
from khoj.routers.helpers import (
get_conversation_command,
perform_chat_checks,
- generate_chat_response,
+ agenerate_chat_response,
update_telemetry_state,
+ is_ready_to_chat,
)
from khoj.processor.conversation.prompts import help_message
from khoj.processor.conversation.openai.gpt import extract_questions
@@ -49,7 +48,7 @@ from khoj.processor.conversation.gpt4all.chat_model import extract_questions_off
from fastapi.requests import Request
from database import adapters
-from database.adapters import EmbeddingsAdapters
+from database.adapters import EmbeddingsAdapters, ConversationAdapters
from database.models import LocalMarkdownConfig, LocalOrgConfig, LocalPdfConfig, LocalPlaintextConfig, KhojUser
@@ -114,6 +113,8 @@ async def map_config_to_db(config: FullConfig, user: KhojUser):
user=user,
token=config.content_type.notion.token,
)
+ if config.processor and config.processor.conversation:
+ ConversationAdapters.set_conversation_processor_config(user, config.processor.conversation)
# If it's a demo instance, prevent updating any of the configuration.
@@ -123,8 +124,6 @@ if not state.demo:
if state.config is None:
state.config = FullConfig()
state.config.search_type = SearchConfig.parse_obj(constants.default_config["search-type"])
- if state.processor_config is None:
- state.processor_config = configure_processor(state.config.processor)
@api.get("/config/data", response_model=FullConfig)
@requires(["authenticated"], redirect="login_page")
@@ -238,28 +237,24 @@ if not state.demo:
)
content_object = map_config_to_object(content_type)
+ if content_object is None:
+ raise ValueError(f"Invalid content type: {content_type}")
+
await content_object.objects.filter(user=user).adelete()
await sync_to_async(EmbeddingsAdapters.delete_all_embeddings)(user, content_type)
enabled_content = await sync_to_async(EmbeddingsAdapters.get_unique_file_types)(user)
-
return {"status": "ok"}
@api.post("/delete/config/data/processor/conversation/openai", status_code=200)
+ @requires(["authenticated"], redirect="login_page")
async def remove_processor_conversation_config_data(
request: Request,
client: Optional[str] = None,
):
- if (
- not state.config
- or not state.config.processor
- or not state.config.processor.conversation
- or not state.config.processor.conversation.openai
- ):
- return {"status": "ok"}
+ user = request.user.object
- state.config.processor.conversation.openai = None
- state.processor_config = configure_processor(state.config.processor, state.processor_config)
+ await sync_to_async(ConversationAdapters.clear_openai_conversation_config)(user)
update_telemetry_state(
request=request,
@@ -269,11 +264,7 @@ if not state.demo:
metadata={"processor_conversation_type": "openai"},
)
- try:
- save_config_to_file_updated_state()
- return {"status": "ok"}
- except Exception as e:
- return {"status": "error", "message": str(e)}
+ return {"status": "ok"}
@api.post("/config/data/content_type/{content_type}", status_code=200)
@requires(["authenticated"], redirect="login_page")
@@ -301,24 +292,17 @@ if not state.demo:
return {"status": "ok"}
@api.post("/config/data/processor/conversation/openai", status_code=200)
+ @requires(["authenticated"], redirect="login_page")
async def set_processor_openai_config_data(
request: Request,
updated_config: Union[OpenAIProcessorConfig, None],
client: Optional[str] = None,
):
- _initialize_config()
+ user = request.user.object
- if not state.config.processor or not state.config.processor.conversation:
- default_config = constants.default_config
- default_conversation_logfile = resolve_absolute_path(
- default_config["processor"]["conversation"]["conversation-logfile"] # type: ignore
- )
- conversation_logfile = resolve_absolute_path(default_conversation_logfile)
- state.config.processor = ProcessorConfig(conversation=ConversationProcessorConfig(conversation_logfile=conversation_logfile)) # type: ignore
+ conversation_config = ConversationProcessorConfig(openai=updated_config)
- assert state.config.processor.conversation is not None
- state.config.processor.conversation.openai = updated_config
- state.processor_config = configure_processor(state.config.processor, state.processor_config)
+ await sync_to_async(ConversationAdapters.set_conversation_processor_config)(user, conversation_config)
update_telemetry_state(
request=request,
@@ -328,11 +312,7 @@ if not state.demo:
metadata={"processor_conversation_type": "conversation"},
)
- try:
- save_config_to_file_updated_state()
- return {"status": "ok"}
- except Exception as e:
- return {"status": "error", "message": str(e)}
+ return {"status": "ok"}
@api.post("/config/data/processor/conversation/offline_chat", status_code=200)
async def set_processor_enable_offline_chat_config_data(
@@ -341,24 +321,26 @@ if not state.demo:
offline_chat_model: Optional[str] = None,
client: Optional[str] = None,
):
- _initialize_config()
+ user = request.user.object
- if not state.config.processor or not state.config.processor.conversation:
- default_config = constants.default_config
- default_conversation_logfile = resolve_absolute_path(
- default_config["processor"]["conversation"]["conversation-logfile"] # type: ignore
+ if enable_offline_chat:
+ conversation_config = ConversationProcessorConfig(
+ offline_chat=OfflineChatProcessorConfig(
+ enable_offline_chat=enable_offline_chat,
+ chat_model=offline_chat_model,
+ )
)
- conversation_logfile = resolve_absolute_path(default_conversation_logfile)
- state.config.processor = ProcessorConfig(conversation=ConversationProcessorConfig(conversation_logfile=conversation_logfile)) # type: ignore
- assert state.config.processor.conversation is not None
- if state.config.processor.conversation.offline_chat is None:
- state.config.processor.conversation.offline_chat = OfflineChatProcessorConfig()
+ await sync_to_async(ConversationAdapters.set_conversation_processor_config)(user, conversation_config)
- state.config.processor.conversation.offline_chat.enable_offline_chat = enable_offline_chat
- if offline_chat_model is not None:
- state.config.processor.conversation.offline_chat.chat_model = offline_chat_model
- state.processor_config = configure_processor(state.config.processor, state.processor_config)
+ offline_chat = await ConversationAdapters.get_offline_chat(user)
+ chat_model = offline_chat.chat_model
+ if state.gpt4all_processor_config is None:
+ state.gpt4all_processor_config = GPT4AllProcessorModel(chat_model=chat_model)
+
+ else:
+ await sync_to_async(ConversationAdapters.clear_offline_chat_conversation_config)(user)
+ state.gpt4all_processor_config = None
update_telemetry_state(
request=request,
@@ -368,11 +350,7 @@ if not state.demo:
metadata={"processor_conversation_type": f"{'enable' if enable_offline_chat else 'disable'}_local_llm"},
)
- try:
- save_config_to_file_updated_state()
- return {"status": "ok"}
- except Exception as e:
- return {"status": "error", "message": str(e)}
+ return {"status": "ok"}
# Create Routes
@@ -426,9 +404,6 @@ async def search(
if q is None or q == "":
logger.warning(f"No query param (q) passed in API call to initiate search")
return results
- if not state.search_models or not any(state.search_models.__dict__.values()):
- logger.warning(f"No search models loaded. Configure a search model before initiating search")
- return results
# initialize variables
user_query = q.strip()
@@ -565,8 +540,6 @@ def update(
components.append("Search models")
if state.content_index:
components.append("Content index")
- if state.processor_config:
- components.append("Conversation processor")
components_msg = ", ".join(components)
logger.info(f"📪 {components_msg} updated via API")
@@ -592,12 +565,11 @@ def chat_history(
referer: Optional[str] = Header(None),
host: Optional[str] = Header(None),
):
- perform_chat_checks()
+ user = request.user.object
+ perform_chat_checks(user)
# Load Conversation History
- meta_log = {}
- if state.processor_config.conversation:
- meta_log = state.processor_config.conversation.meta_log
+ meta_log = ConversationAdapters.get_conversation_by_user(user=user).conversation_log
update_telemetry_state(
request=request,
@@ -649,30 +621,35 @@ async def chat(
referer: Optional[str] = Header(None),
host: Optional[str] = Header(None),
) -> Response:
- perform_chat_checks()
+ user = request.user.object
+
+ await is_ready_to_chat(user)
conversation_command = get_conversation_command(query=q, any_references=True)
q = q.replace(f"/{conversation_command.value}", "").strip()
+ meta_log = (await ConversationAdapters.aget_conversation_by_user(user)).conversation_log
+
compiled_references, inferred_queries, defiltered_query = await extract_references_and_questions(
- request, q, (n or 5), conversation_command
+ request, meta_log, q, (n or 5), conversation_command
)
if conversation_command == ConversationCommand.Default and is_none_or_empty(compiled_references):
conversation_command = ConversationCommand.General
if conversation_command == ConversationCommand.Help:
- model_type = "offline" if state.processor_config.conversation.offline_chat.enable_offline_chat else "openai"
+ model_type = "offline" if await ConversationAdapters.has_offline_chat(user) else "openai"
formatted_help = help_message.format(model=model_type, version=state.khoj_version)
return StreamingResponse(iter([formatted_help]), media_type="text/event-stream", status_code=200)
# Get the (streamed) chat response from the LLM of choice.
- llm_response = generate_chat_response(
+ llm_response = await agenerate_chat_response(
defiltered_query,
- meta_log=state.processor_config.conversation.meta_log,
- compiled_references=compiled_references,
- inferred_queries=inferred_queries,
- conversation_command=conversation_command,
+ meta_log,
+ compiled_references,
+ inferred_queries,
+ conversation_command,
+ user,
)
if llm_response is None:
@@ -681,13 +658,14 @@ async def chat(
if stream:
return StreamingResponse(llm_response, media_type="text/event-stream", status_code=200)
+ iterator = AsyncIteratorWrapper(llm_response)
+
# Get the full response from the generator if the stream is not requested.
aggregated_gpt_response = ""
- while True:
- try:
- aggregated_gpt_response += next(llm_response)
- except StopIteration:
+ async for item in iterator:
+ if item is None:
break
+ aggregated_gpt_response += item
actual_response = aggregated_gpt_response.split("### compiled references:")[0]
@@ -708,44 +686,53 @@ async def chat(
async def extract_references_and_questions(
request: Request,
+ meta_log: dict,
q: str,
n: int,
conversation_type: ConversationCommand = ConversationCommand.Default,
):
user = request.user.object if request.user.is_authenticated else None
- # Load Conversation History
- meta_log = state.processor_config.conversation.meta_log
# Initialize Variables
compiled_references: List[Any] = []
inferred_queries: List[str] = []
- if not EmbeddingsAdapters.user_has_embeddings(user=user):
+ if conversation_type == ConversationCommand.General:
+ return compiled_references, inferred_queries, q
+
+ if not await EmbeddingsAdapters.user_has_embeddings(user=user):
logger.warning(
"No content index loaded, so cannot extract references from knowledge base. Please configure your data sources and update the index to chat with your notes."
)
return compiled_references, inferred_queries, q
- if conversation_type == ConversationCommand.General:
- return compiled_references, inferred_queries, q
-
# Extract filter terms from user message
defiltered_query = q
for filter in [DateFilter(), WordFilter(), FileFilter()]:
defiltered_query = filter.defilter(defiltered_query)
filters_in_query = q.replace(defiltered_query, "").strip()
+ using_offline_chat = False
+
# Infer search queries from user message
with timer("Extracting search queries took", logger):
# If we've reached here, either the user has enabled offline chat or the openai model is enabled.
- if state.processor_config.conversation.offline_chat.enable_offline_chat:
- loaded_model = state.processor_config.conversation.gpt4all_model.loaded_model
+ if await ConversationAdapters.has_offline_chat(user):
+ using_offline_chat = True
+ offline_chat = await ConversationAdapters.get_offline_chat(user)
+ chat_model = offline_chat.chat_model
+ if state.gpt4all_processor_config is None:
+ state.gpt4all_processor_config = GPT4AllProcessorModel(chat_model=chat_model)
+
+ loaded_model = state.gpt4all_processor_config.loaded_model
+
inferred_queries = extract_questions_offline(
defiltered_query, loaded_model=loaded_model, conversation_log=meta_log, should_extract_questions=False
)
- elif state.processor_config.conversation.openai_model:
- api_key = state.processor_config.conversation.openai_model.api_key
- chat_model = state.processor_config.conversation.openai_model.chat_model
+ elif await ConversationAdapters.has_openai_chat(user):
+ openai_chat = await ConversationAdapters.get_openai_chat(user)
+ api_key = openai_chat.api_key
+ chat_model = openai_chat.chat_model
inferred_queries = extract_questions(
defiltered_query, model=chat_model, api_key=api_key, conversation_log=meta_log
)
@@ -754,7 +741,7 @@ async def extract_references_and_questions(
with timer("Searching knowledge base took", logger):
result_list = []
for query in inferred_queries:
- n_items = min(n, 3) if state.processor_config.conversation.offline_chat.enable_offline_chat else n
+ n_items = min(n, 3) if using_offline_chat else n
result_list.extend(
await search(
f"{query} {filters_in_query}",
@@ -765,6 +752,8 @@ async def extract_references_and_questions(
dedupe=False,
)
)
+ # Dedupe the results again, as duplicates may be returned across queries.
+ result_list = text_search.deduplicated_search_responses(result_list)
compiled_references = [item.additional["compiled"] for item in result_list]
return compiled_references, inferred_queries, defiltered_query
diff --git a/src/khoj/routers/helpers.py b/src/khoj/routers/helpers.py
index be9e8700..8a9e53a7 100644
--- a/src/khoj/routers/helpers.py
+++ b/src/khoj/routers/helpers.py
@@ -1,34 +1,50 @@
import logging
+import asyncio
from datetime import datetime
from functools import partial
from typing import Iterator, List, Optional, Union
+from concurrent.futures import ThreadPoolExecutor
from fastapi import HTTPException, Request
from khoj.utils import state
-from khoj.utils.helpers import ConversationCommand, timer, log_telemetry
+from khoj.utils.config import GPT4AllProcessorModel
+from khoj.utils.helpers import ConversationCommand, log_telemetry
from khoj.processor.conversation.openai.gpt import converse
from khoj.processor.conversation.gpt4all.chat_model import converse_offline
-from khoj.processor.conversation.utils import reciprocal_conversation_to_chatml, message_to_log, ThreadedGenerator
+from khoj.processor.conversation.utils import message_to_log, ThreadedGenerator
from database.models import KhojUser
+from database.adapters import ConversationAdapters
logger = logging.getLogger(__name__)
+executor = ThreadPoolExecutor(max_workers=1)
-def perform_chat_checks():
- if (
- state.processor_config
- and state.processor_config.conversation
- and (
- state.processor_config.conversation.openai_model
- or state.processor_config.conversation.gpt4all_model.loaded_model
- )
- ):
+
+def perform_chat_checks(user: KhojUser):
+ if ConversationAdapters.has_valid_offline_conversation_config(
+ user
+ ) or ConversationAdapters.has_valid_openai_conversation_config(user):
return
- raise HTTPException(
- status_code=500, detail="Set your OpenAI API key or enable Local LLM via Khoj settings and restart it."
- )
+ raise HTTPException(status_code=500, detail="Set your OpenAI API key or enable Local LLM via Khoj settings.")
+
+
+async def is_ready_to_chat(user: KhojUser):
+ has_offline_config = await ConversationAdapters.has_offline_chat(user=user)
+ has_openai_config = await ConversationAdapters.has_openai_chat(user=user)
+
+ if has_offline_config:
+ offline_chat = await ConversationAdapters.get_offline_chat(user)
+ chat_model = offline_chat.chat_model
+ if state.gpt4all_processor_config is None:
+ state.gpt4all_processor_config = GPT4AllProcessorModel(chat_model=chat_model)
+ return True
+
+ ready = has_openai_config or has_offline_config
+
+ if not ready:
+ raise HTTPException(status_code=500, detail="Set your OpenAI API key or enable Local LLM via Khoj settings.")
def update_telemetry_state(
@@ -74,12 +90,22 @@ def get_conversation_command(query: str, any_references: bool = False) -> Conver
return ConversationCommand.Default
+async def construct_conversation_logs(user: KhojUser):
+ return (await ConversationAdapters.aget_conversation_by_user(user)).conversation_log
+
+
+async def agenerate_chat_response(*args):
+ loop = asyncio.get_event_loop()
+ return await loop.run_in_executor(executor, generate_chat_response, *args)
+
+
def generate_chat_response(
q: str,
meta_log: dict,
compiled_references: List[str] = [],
inferred_queries: List[str] = [],
conversation_command: ConversationCommand = ConversationCommand.Default,
+ user: KhojUser = None,
) -> Union[ThreadedGenerator, Iterator[str]]:
def _save_to_conversation_log(
q: str,
@@ -89,17 +115,14 @@ def generate_chat_response(
inferred_queries: List[str],
meta_log,
):
- state.processor_config.conversation.chat_session += reciprocal_conversation_to_chatml([q, chat_response])
- state.processor_config.conversation.meta_log["chat"] = message_to_log(
+ updated_conversation = message_to_log(
user_message=q,
chat_response=chat_response,
user_message_metadata={"created": user_message_time},
khoj_message_metadata={"context": compiled_references, "intent": {"inferred-queries": inferred_queries}},
conversation_log=meta_log.get("chat", []),
)
-
- # Load Conversation History
- meta_log = state.processor_config.conversation.meta_log
+ ConversationAdapters.save_conversation(user, {"chat": updated_conversation})
# Initialize Variables
user_message_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
@@ -116,8 +139,14 @@ def generate_chat_response(
meta_log=meta_log,
)
- if state.processor_config.conversation.offline_chat.enable_offline_chat:
- loaded_model = state.processor_config.conversation.gpt4all_model.loaded_model
+ offline_chat_config = ConversationAdapters.get_offline_chat_conversation_config(user=user)
+ conversation_config = ConversationAdapters.get_conversation_config(user)
+ openai_chat_config = ConversationAdapters.get_openai_conversation_config(user)
+ if offline_chat_config:
+ if state.gpt4all_processor_config.loaded_model is None:
+ state.gpt4all_processor_config = GPT4AllProcessorModel(offline_chat_config.chat_model)
+
+ loaded_model = state.gpt4all_processor_config.loaded_model
chat_response = converse_offline(
references=compiled_references,
user_query=q,
@@ -125,14 +154,14 @@ def generate_chat_response(
conversation_log=meta_log,
completion_func=partial_completion,
conversation_command=conversation_command,
- model=state.processor_config.conversation.offline_chat.chat_model,
- max_prompt_size=state.processor_config.conversation.max_prompt_size,
- tokenizer_name=state.processor_config.conversation.tokenizer,
+ model=offline_chat_config.chat_model,
+ max_prompt_size=conversation_config.max_prompt_size,
+ tokenizer_name=conversation_config.tokenizer,
)
- elif state.processor_config.conversation.openai_model:
- api_key = state.processor_config.conversation.openai_model.api_key
- chat_model = state.processor_config.conversation.openai_model.chat_model
+ elif openai_chat_config:
+ api_key = openai_chat_config.api_key
+ chat_model = openai_chat_config.chat_model
chat_response = converse(
compiled_references,
q,
@@ -141,8 +170,8 @@ def generate_chat_response(
api_key=api_key,
completion_func=partial_completion,
conversation_command=conversation_command,
- max_prompt_size=state.processor_config.conversation.max_prompt_size,
- tokenizer_name=state.processor_config.conversation.tokenizer,
+ max_prompt_size=conversation_config.max_prompt_size if conversation_config else None,
+ tokenizer_name=conversation_config.tokenizer if conversation_config else None,
)
except Exception as e:
diff --git a/src/khoj/routers/indexer.py b/src/khoj/routers/indexer.py
index 1e73c439..1125e653 100644
--- a/src/khoj/routers/indexer.py
+++ b/src/khoj/routers/indexer.py
@@ -92,7 +92,7 @@ async def update(
if dict_to_update is not None:
dict_to_update[file.filename] = (
- file.file.read().decode("utf-8") if encoding == "utf-8" else file.file.read()
+ file.file.read().decode("utf-8") if encoding == "utf-8" else file.file.read() # type: ignore
)
else:
logger.warning(f"Skipped indexing unsupported file type sent by {client} client: {file.filename}")
@@ -181,24 +181,25 @@ def configure_content(
files: Optional[dict[str, dict[str, str]]],
search_models: SearchModels,
regenerate: bool = False,
- t: Optional[Union[state.SearchType, str]] = None,
+ t: Optional[state.SearchType] = None,
full_corpus: bool = True,
user: KhojUser = None,
) -> Optional[ContentIndex]:
content_index = ContentIndex()
- if t in [type.value for type in state.SearchType]:
- t = state.SearchType(t).value
+ if t is not None and not t.value in [type.value for type in state.SearchType]:
+ logger.warning(f"🚨 Invalid search type: {t}")
+ return None
- assert type(t) == str or t == None, f"Invalid search type: {t}"
+ search_type = t.value if t else None
if files is None:
- logger.warning(f"🚨 No files to process for {t} search.")
+ logger.warning(f"🚨 No files to process for {search_type} search.")
return None
try:
# Initialize Org Notes Search
- if (t == None or t == state.SearchType.Org.value) and files["org"]:
+ if (search_type == None or search_type == state.SearchType.Org.value) and files["org"]:
logger.info("🦄 Setting up search for orgmode notes")
# Extract Entries, Generate Notes Embeddings
text_search.setup(
@@ -213,7 +214,7 @@ def configure_content(
try:
# Initialize Markdown Search
- if (t == None or t == state.SearchType.Markdown.value) and files["markdown"]:
+ if (search_type == None or search_type == state.SearchType.Markdown.value) and files["markdown"]:
logger.info("💎 Setting up search for markdown notes")
# Extract Entries, Generate Markdown Embeddings
text_search.setup(
@@ -229,7 +230,7 @@ def configure_content(
try:
# Initialize PDF Search
- if (t == None or t == state.SearchType.Pdf.value) and files["pdf"]:
+ if (search_type == None or search_type == state.SearchType.Pdf.value) and files["pdf"]:
logger.info("🖨️ Setting up search for pdf")
# Extract Entries, Generate PDF Embeddings
text_search.setup(
@@ -245,7 +246,7 @@ def configure_content(
try:
# Initialize Plaintext Search
- if (t == None or t == state.SearchType.Plaintext.value) and files["plaintext"]:
+ if (search_type == None or search_type == state.SearchType.Plaintext.value) and files["plaintext"]:
logger.info("📄 Setting up search for plaintext")
# Extract Entries, Generate Plaintext Embeddings
text_search.setup(
@@ -262,7 +263,7 @@ def configure_content(
try:
# Initialize Image Search
if (
- (t == None or t == state.SearchType.Image.value)
+ (search_type == None or search_type == state.SearchType.Image.value)
and content_config
and content_config.image
and search_models.image_search
@@ -278,7 +279,7 @@ def configure_content(
try:
github_config = GithubConfig.objects.filter(user=user).prefetch_related("githubrepoconfig").first()
- if (t == None or t == state.SearchType.Github.value) and github_config is not None:
+ if (search_type == None or search_type == state.SearchType.Github.value) and github_config is not None:
logger.info("🐙 Setting up search for github")
# Extract Entries, Generate Github Embeddings
text_search.setup(
@@ -296,7 +297,7 @@ def configure_content(
try:
# Initialize Notion Search
notion_config = NotionConfig.objects.filter(user=user).first()
- if (t == None or t in state.SearchType.Notion.value) and notion_config:
+ if (search_type == None or search_type in state.SearchType.Notion.value) and notion_config:
logger.info("🔌 Setting up search for notion")
text_search.setup(
NotionToJsonl,
diff --git a/src/khoj/routers/web_client.py b/src/khoj/routers/web_client.py
index 4122c6d0..333d89fa 100644
--- a/src/khoj/routers/web_client.py
+++ b/src/khoj/routers/web_client.py
@@ -19,7 +19,7 @@ from khoj.utils.rawconfig import (
# Internal Packages
from khoj.utils import constants, state
-from database.adapters import EmbeddingsAdapters, get_user_github_config, get_user_notion_config
+from database.adapters import EmbeddingsAdapters, get_user_github_config, get_user_notion_config, ConversationAdapters
from database.models import LocalOrgConfig, LocalMarkdownConfig, LocalPdfConfig, LocalPlaintextConfig
@@ -83,7 +83,7 @@ if not state.demo:
@web_client.get("/config", response_class=HTMLResponse)
@requires(["authenticated"], redirect="login_page")
def config_page(request: Request):
- user = request.user.object if request.user.is_authenticated else None
+ user = request.user.object
enabled_content = set(EmbeddingsAdapters.get_unique_file_types(user).all())
default_full_config = FullConfig(
content_type=None,
@@ -100,9 +100,6 @@ if not state.demo:
"github": ("github" in enabled_content),
"notion": ("notion" in enabled_content),
"plaintext": ("plaintext" in enabled_content),
- "enable_offline_model": False,
- "conversation_openai": False,
- "conversation_gpt4all": False,
}
if state.content_index:
@@ -112,13 +109,17 @@ if not state.demo:
}
)
- if state.processor_config and state.processor_config.conversation:
- successfully_configured.update(
- {
- "conversation_openai": state.processor_config.conversation.openai_model is not None,
- "conversation_gpt4all": state.processor_config.conversation.gpt4all_model.loaded_model is not None,
- }
- )
+ enabled_chat_config = ConversationAdapters.get_enabled_conversation_settings(user)
+
+ successfully_configured.update(
+ {
+ "conversation_openai": enabled_chat_config["openai"],
+ "enable_offline_model": enabled_chat_config["offline_chat"],
+ "conversation_gpt4all": state.gpt4all_processor_config.loaded_model is not None
+ if state.gpt4all_processor_config
+ else False,
+ }
+ )
return templates.TemplateResponse(
"config.html",
@@ -127,6 +128,7 @@ if not state.demo:
"current_config": current_config,
"current_model_state": successfully_configured,
"anonymous_mode": state.anonymous_mode,
+ "username": user.username if user else None,
},
)
@@ -204,22 +206,22 @@ if not state.demo:
)
@web_client.get("/config/processor/conversation/openai", response_class=HTMLResponse)
+ @requires(["authenticated"], redirect="login_page")
def conversation_processor_config_page(request: Request):
- default_copy = constants.default_config.copy()
- default_processor_config = default_copy["processor"]["conversation"]["openai"] # type: ignore
- default_openai_config = OpenAIProcessorConfig(
- api_key="",
- chat_model=default_processor_config["chat-model"],
- )
+ user = request.user.object
+ openai_config = ConversationAdapters.get_openai_conversation_config(user)
+
+ if openai_config:
+ current_processor_openai_config = OpenAIProcessorConfig(
+ api_key=openai_config.api_key,
+ chat_model=openai_config.chat_model,
+ )
+ else:
+ current_processor_openai_config = OpenAIProcessorConfig(
+ api_key="",
+ chat_model="gpt-3.5-turbo",
+ )
- current_processor_openai_config = (
- state.config.processor.conversation.openai
- if state.config
- and state.config.processor
- and state.config.processor.conversation
- and state.config.processor.conversation.openai
- else default_openai_config
- )
current_processor_openai_config = json.loads(current_processor_openai_config.json())
return templates.TemplateResponse(
diff --git a/src/khoj/search_type/image_search.py b/src/khoj/search_type/image_search.py
index 8b92d9db..d7f486af 100644
--- a/src/khoj/search_type/image_search.py
+++ b/src/khoj/search_type/image_search.py
@@ -236,6 +236,7 @@ def collate_results(hits, image_names, output_directory, image_files_url, count=
"image_score": f"{hit['image_score']:.9f}",
"metadata_score": f"{hit['metadata_score']:.9f}",
},
+ "corpus_id": hit["corpus_id"],
}
)
]
diff --git a/src/khoj/search_type/text_search.py b/src/khoj/search_type/text_search.py
index 36d6a791..dc6593f5 100644
--- a/src/khoj/search_type/text_search.py
+++ b/src/khoj/search_type/text_search.py
@@ -14,10 +14,9 @@ from asgiref.sync import sync_to_async
# Internal Packages
from khoj.utils import state
from khoj.utils.helpers import get_absolute_path, resolve_absolute_path, load_model, timer
-from khoj.utils.config import TextSearchModel
from khoj.utils.models import BaseEncoder
from khoj.utils.state import SearchType
-from khoj.utils.rawconfig import SearchResponse, TextSearchConfig, Entry
+from khoj.utils.rawconfig import SearchResponse, Entry
from khoj.utils.jsonl import load_jsonl
from khoj.processor.text_to_jsonl import TextEmbeddings
from database.adapters import EmbeddingsAdapters
@@ -36,36 +35,6 @@ search_type_to_embeddings_type = {
}
-def initialize_model(search_config: TextSearchConfig):
- "Initialize model for semantic search on text"
- torch.set_num_threads(4)
-
- # If model directory is configured
- if search_config.model_directory:
- # Convert model directory to absolute path
- search_config.model_directory = resolve_absolute_path(search_config.model_directory)
- # Create model directory if it doesn't exist
- search_config.model_directory.parent.mkdir(parents=True, exist_ok=True)
-
- # The bi-encoder encodes all entries to use for semantic search
- bi_encoder = load_model(
- model_dir=search_config.model_directory,
- model_name=search_config.encoder,
- model_type=search_config.encoder_type or SentenceTransformer,
- device=f"{state.device}",
- )
-
- # The cross-encoder re-ranks the results to improve quality
- cross_encoder = load_model(
- model_dir=search_config.model_directory,
- model_name=search_config.cross_encoder,
- model_type=CrossEncoder,
- device=f"{state.device}",
- )
-
- return TextSearchModel(bi_encoder, cross_encoder)
-
-
def extract_entries(jsonl_file) -> List[Entry]:
"Load entries from compressed jsonl"
return list(map(Entry.from_dict, load_jsonl(jsonl_file)))
@@ -176,6 +145,7 @@ def collate_results(hits, dedupe=True):
{
"entry": hit.raw,
"score": hit.distance,
+ "corpus_id": str(hit.corpus_id),
"additional": {
"file": hit.file_path,
"compiled": hit.compiled,
@@ -185,6 +155,28 @@ def collate_results(hits, dedupe=True):
)
+def deduplicated_search_responses(hits: List[SearchResponse]):
+ hit_ids = set()
+ for hit in hits:
+ if hit.corpus_id in hit_ids:
+ continue
+
+ else:
+ hit_ids.add(hit.corpus_id)
+ yield SearchResponse.parse_obj(
+ {
+ "entry": hit.entry,
+ "score": hit.score,
+ "corpus_id": hit.corpus_id,
+ "additional": {
+ "file": hit.additional["file"],
+ "compiled": hit.additional["compiled"],
+ "heading": hit.additional["heading"],
+ },
+ }
+ )
+
+
def rerank_and_sort_results(hits, query):
# Score all retrieved entries using the cross-encoder
hits = cross_encoder_score(query, hits)
diff --git a/src/khoj/utils/config.py b/src/khoj/utils/config.py
index ee5b4f9f..3c084c4f 100644
--- a/src/khoj/utils/config.py
+++ b/src/khoj/utils/config.py
@@ -5,8 +5,7 @@ from enum import Enum
import logging
from dataclasses import dataclass
-from pathlib import Path
-from typing import TYPE_CHECKING, Dict, List, Optional, Union, Any
+from typing import TYPE_CHECKING, List, Optional, Union, Any
from khoj.processor.conversation.gpt4all.utils import download_model
# External Packages
@@ -19,9 +18,7 @@ logger = logging.getLogger(__name__)
# Internal Packages
if TYPE_CHECKING:
from sentence_transformers import CrossEncoder
- from khoj.search_filter.base_filter import BaseFilter
from khoj.utils.models import BaseEncoder
- from khoj.utils.rawconfig import ConversationProcessorConfig, Entry, OpenAIProcessorConfig
class SearchType(str, Enum):
@@ -79,31 +76,15 @@ class GPT4AllProcessorConfig:
loaded_model: Union[Any, None] = None
-class ConversationProcessorConfigModel:
+class GPT4AllProcessorModel:
def __init__(
self,
- conversation_config: ConversationProcessorConfig,
+ chat_model: str = "llama-2-7b-chat.ggmlv3.q4_0.bin",
):
- self.openai_model = conversation_config.openai
- self.gpt4all_model = GPT4AllProcessorConfig()
- self.offline_chat = conversation_config.offline_chat or OfflineChatProcessorConfig()
- self.max_prompt_size = conversation_config.max_prompt_size
- self.tokenizer = conversation_config.tokenizer
- self.conversation_logfile = Path(conversation_config.conversation_logfile)
- self.chat_session: List[str] = []
- self.meta_log: dict = {}
-
- if self.offline_chat.enable_offline_chat:
- try:
- self.gpt4all_model.loaded_model = download_model(self.offline_chat.chat_model)
- except Exception as e:
- self.offline_chat.enable_offline_chat = False
- self.gpt4all_model.loaded_model = None
- logger.error(f"Error while loading offline chat model: {e}", exc_info=True)
- else:
- self.gpt4all_model.loaded_model = None
-
-
-@dataclass
-class ProcessorConfigModel:
- conversation: Union[ConversationProcessorConfigModel, None] = None
+ self.chat_model = chat_model
+ self.loaded_model = None
+ try:
+ self.loaded_model = download_model(self.chat_model)
+ except ValueError as e:
+ self.loaded_model = None
+ logger.error(f"Error while loading offline chat model: {e}", exc_info=True)
diff --git a/src/khoj/utils/constants.py b/src/khoj/utils/constants.py
index 181dee04..e9d431c6 100644
--- a/src/khoj/utils/constants.py
+++ b/src/khoj/utils/constants.py
@@ -8,136 +8,14 @@ telemetry_server = "https://khoj.beta.haletic.com/v1/telemetry"
content_directory = "~/.khoj/content/"
empty_config = {
- "content-type": {
- "org": {
- "input-files": None,
- "input-filter": None,
- "compressed-jsonl": "~/.khoj/content/org/org.jsonl.gz",
- "embeddings-file": "~/.khoj/content/org/org_embeddings.pt",
- "index-heading-entries": False,
- },
- "markdown": {
- "input-files": None,
- "input-filter": None,
- "compressed-jsonl": "~/.khoj/content/markdown/markdown.jsonl.gz",
- "embeddings-file": "~/.khoj/content/markdown/markdown_embeddings.pt",
- },
- "pdf": {
- "input-files": None,
- "input-filter": None,
- "compressed-jsonl": "~/.khoj/content/pdf/pdf.jsonl.gz",
- "embeddings-file": "~/.khoj/content/pdf/pdf_embeddings.pt",
- },
- "plaintext": {
- "input-files": None,
- "input-filter": None,
- "compressed-jsonl": "~/.khoj/content/plaintext/plaintext.jsonl.gz",
- "embeddings-file": "~/.khoj/content/plaintext/plaintext_embeddings.pt",
- },
- },
"search-type": {
- "symmetric": {
- "encoder": "sentence-transformers/all-MiniLM-L6-v2",
- "cross-encoder": "cross-encoder/ms-marco-MiniLM-L-6-v2",
- "model_directory": "~/.khoj/search/symmetric/",
- },
- "asymmetric": {
- "encoder": "sentence-transformers/multi-qa-MiniLM-L6-cos-v1",
- "cross-encoder": "cross-encoder/ms-marco-MiniLM-L-6-v2",
- "model_directory": "~/.khoj/search/asymmetric/",
- },
"image": {"encoder": "sentence-transformers/clip-ViT-B-32", "model_directory": "~/.khoj/search/image/"},
},
- "processor": {
- "conversation": {
- "openai": {
- "api-key": None,
- "chat-model": "gpt-3.5-turbo",
- },
- "offline-chat": {
- "enable-offline-chat": False,
- "chat-model": "llama-2-7b-chat.ggmlv3.q4_0.bin",
- },
- "tokenizer": None,
- "max-prompt-size": None,
- "conversation-logfile": "~/.khoj/processor/conversation/conversation_logs.json",
- }
- },
}
# default app config to use
default_config = {
- "content-type": {
- "org": {
- "input-files": None,
- "input-filter": None,
- "compressed-jsonl": "~/.khoj/content/org/org.jsonl.gz",
- "embeddings-file": "~/.khoj/content/org/org_embeddings.pt",
- "index-heading-entries": False,
- },
- "markdown": {
- "input-files": None,
- "input-filter": None,
- "compressed-jsonl": "~/.khoj/content/markdown/markdown.jsonl.gz",
- "embeddings-file": "~/.khoj/content/markdown/markdown_embeddings.pt",
- },
- "pdf": {
- "input-files": None,
- "input-filter": None,
- "compressed-jsonl": "~/.khoj/content/pdf/pdf.jsonl.gz",
- "embeddings-file": "~/.khoj/content/pdf/pdf_embeddings.pt",
- },
- "image": {
- "input-directories": None,
- "input-filter": None,
- "embeddings-file": "~/.khoj/content/image/image_embeddings.pt",
- "batch-size": 50,
- "use-xmp-metadata": False,
- },
- "github": {
- "pat-token": None,
- "repos": [],
- "compressed-jsonl": "~/.khoj/content/github/github.jsonl.gz",
- "embeddings-file": "~/.khoj/content/github/github_embeddings.pt",
- },
- "notion": {
- "token": None,
- "compressed-jsonl": "~/.khoj/content/notion/notion.jsonl.gz",
- "embeddings-file": "~/.khoj/content/notion/notion_embeddings.pt",
- },
- "plaintext": {
- "input-files": None,
- "input-filter": None,
- "compressed-jsonl": "~/.khoj/content/plaintext/plaintext.jsonl.gz",
- "embeddings-file": "~/.khoj/content/plaintext/plaintext_embeddings.pt",
- },
- },
"search-type": {
- "symmetric": {
- "encoder": "sentence-transformers/all-MiniLM-L6-v2",
- "cross-encoder": "cross-encoder/ms-marco-MiniLM-L-6-v2",
- "model_directory": "~/.khoj/search/symmetric/",
- },
- "asymmetric": {
- "encoder": "sentence-transformers/multi-qa-MiniLM-L6-cos-v1",
- "cross-encoder": "cross-encoder/ms-marco-MiniLM-L-6-v2",
- "model_directory": "~/.khoj/search/asymmetric/",
- },
"image": {"encoder": "sentence-transformers/clip-ViT-B-32", "model_directory": "~/.khoj/search/image/"},
},
- "processor": {
- "conversation": {
- "openai": {
- "api-key": None,
- "chat-model": "gpt-3.5-turbo",
- },
- "offline-chat": {
- "enable-offline-chat": False,
- "chat-model": "llama-2-7b-chat.ggmlv3.q4_0.bin",
- },
- "tokenizer": None,
- "max-prompt-size": None,
- "conversation-logfile": "~/.khoj/processor/conversation/conversation_logs.json",
- }
- },
}
diff --git a/src/khoj/utils/helpers.py b/src/khoj/utils/helpers.py
index e41791f9..0269a9e9 100644
--- a/src/khoj/utils/helpers.py
+++ b/src/khoj/utils/helpers.py
@@ -15,6 +15,7 @@ from time import perf_counter
import torch
from typing import Optional, Union, TYPE_CHECKING
import uuid
+from asgiref.sync import sync_to_async
# Internal Packages
from khoj.utils import constants
@@ -29,6 +30,28 @@ if TYPE_CHECKING:
from khoj.utils.rawconfig import AppConfig
+class AsyncIteratorWrapper:
+ def __init__(self, obj):
+ self._it = iter(obj)
+
+ def __aiter__(self):
+ return self
+
+ async def __anext__(self):
+ try:
+ value = await self.next_async()
+ except StopAsyncIteration:
+ return
+ return value
+
+ @sync_to_async
+ def next_async(self):
+ try:
+ return next(self._it)
+ except StopIteration:
+ raise StopAsyncIteration
+
+
def is_none_or_empty(item):
return item == None or (hasattr(item, "__iter__") and len(item) == 0) or item == ""
diff --git a/src/khoj/utils/rawconfig.py b/src/khoj/utils/rawconfig.py
index 5d2b3ce4..a469951f 100644
--- a/src/khoj/utils/rawconfig.py
+++ b/src/khoj/utils/rawconfig.py
@@ -67,13 +67,6 @@ class ContentConfig(ConfigBase):
notion: Optional[NotionContentConfig]
-class TextSearchConfig(ConfigBase):
- encoder: str
- cross_encoder: str
- encoder_type: Optional[str]
- model_directory: Optional[Path]
-
-
class ImageSearchConfig(ConfigBase):
encoder: str
encoder_type: Optional[str]
@@ -81,8 +74,6 @@ class ImageSearchConfig(ConfigBase):
class SearchConfig(ConfigBase):
- asymmetric: Optional[TextSearchConfig]
- symmetric: Optional[TextSearchConfig]
image: Optional[ImageSearchConfig]
@@ -97,11 +88,10 @@ class OfflineChatProcessorConfig(ConfigBase):
class ConversationProcessorConfig(ConfigBase):
- conversation_logfile: Path
- openai: Optional[OpenAIProcessorConfig]
- offline_chat: Optional[OfflineChatProcessorConfig]
- max_prompt_size: Optional[int]
- tokenizer: Optional[str]
+ openai: Optional[OpenAIProcessorConfig] = None
+ offline_chat: Optional[OfflineChatProcessorConfig] = None
+ max_prompt_size: Optional[int] = None
+ tokenizer: Optional[str] = None
class ProcessorConfig(ConfigBase):
@@ -125,6 +115,7 @@ class SearchResponse(ConfigBase):
score: float
cross_score: Optional[float]
additional: Optional[dict]
+ corpus_id: str
class Entry:
diff --git a/src/khoj/utils/state.py b/src/khoj/utils/state.py
index d6169d2a..40806c51 100644
--- a/src/khoj/utils/state.py
+++ b/src/khoj/utils/state.py
@@ -10,7 +10,7 @@ from pathlib import Path
# Internal Packages
from khoj.utils import config as utils_config
-from khoj.utils.config import ContentIndex, SearchModels, ProcessorConfigModel
+from khoj.utils.config import ContentIndex, SearchModels, GPT4AllProcessorModel
from khoj.utils.helpers import LRU
from khoj.utils.rawconfig import FullConfig
from khoj.processor.embeddings import EmbeddingsModel, CrossEncoderModel
@@ -21,7 +21,7 @@ search_models = SearchModels()
embeddings_model = EmbeddingsModel()
cross_encoder_model = CrossEncoderModel()
content_index = ContentIndex()
-processor_config = ProcessorConfigModel()
+gpt4all_processor_config: GPT4AllProcessorModel = None
config_file: Path = None
verbose: int = 0
host: str = None
diff --git a/tests/conftest.py b/tests/conftest.py
index ee4b9e57..12ac4f7b 100644
--- a/tests/conftest.py
+++ b/tests/conftest.py
@@ -5,7 +5,6 @@ from pathlib import Path
import pytest
from fastapi.staticfiles import StaticFiles
from fastapi import FastAPI
-import factory
import os
from fastapi import FastAPI
@@ -13,7 +12,7 @@ app = FastAPI()
# Internal Packages
-from khoj.configure import configure_processor, configure_routes, configure_search_types, configure_middleware
+from khoj.configure import configure_routes, configure_search_types, configure_middleware
from khoj.processor.plaintext.plaintext_to_jsonl import PlaintextToJsonl
from khoj.search_type import image_search, text_search
from khoj.utils.config import SearchModels
@@ -21,13 +20,8 @@ from khoj.utils.constants import web_directory
from khoj.utils.helpers import resolve_absolute_path
from khoj.utils.rawconfig import (
ContentConfig,
- ConversationProcessorConfig,
- OfflineChatProcessorConfig,
- OpenAIProcessorConfig,
- ProcessorConfig,
ImageContentConfig,
SearchConfig,
- TextSearchConfig,
ImageSearchConfig,
)
from khoj.utils import state, fs_syncer
@@ -42,42 +36,25 @@ from database.models import (
GithubRepoConfig,
)
+from tests.helpers import (
+ UserFactory,
+ ConversationProcessorConfigFactory,
+ OpenAIProcessorConversationConfigFactory,
+ OfflineChatProcessorConversationConfigFactory,
+)
+
@pytest.fixture(autouse=True)
def enable_db_access_for_all_tests(db):
pass
-class UserFactory(factory.django.DjangoModelFactory):
- class Meta:
- model = KhojUser
-
- username = factory.Faker("name")
- email = factory.Faker("email")
- password = factory.Faker("password")
- uuid = factory.Faker("uuid4")
-
-
@pytest.fixture(scope="session")
def search_config() -> SearchConfig:
model_dir = resolve_absolute_path("~/.khoj/search")
model_dir.mkdir(parents=True, exist_ok=True)
search_config = SearchConfig()
- search_config.symmetric = TextSearchConfig(
- encoder="sentence-transformers/all-MiniLM-L6-v2",
- cross_encoder="cross-encoder/ms-marco-MiniLM-L-6-v2",
- model_directory=model_dir / "symmetric/",
- encoder_type=None,
- )
-
- search_config.asymmetric = TextSearchConfig(
- encoder="sentence-transformers/multi-qa-MiniLM-L6-cos-v1",
- cross_encoder="cross-encoder/ms-marco-MiniLM-L-6-v2",
- model_directory=model_dir / "asymmetric/",
- encoder_type=None,
- )
-
search_config.image = ImageSearchConfig(
encoder="sentence-transformers/clip-ViT-B-32",
model_directory=model_dir / "image/",
@@ -177,55 +154,48 @@ def md_content_config():
return markdown_config
-@pytest.fixture(scope="session")
-def processor_config(tmp_path_factory):
- openai_api_key = os.getenv("OPENAI_API_KEY")
- processor_dir = tmp_path_factory.mktemp("processor")
-
- # The conversation processor is the only configured processor
- # It needs an OpenAI API key to work.
- if not openai_api_key:
- return
-
- # Setup conversation processor, if OpenAI API key is set
- processor_config = ProcessorConfig()
- processor_config.conversation = ConversationProcessorConfig(
- openai=OpenAIProcessorConfig(api_key=openai_api_key),
- conversation_logfile=processor_dir.joinpath("conversation_logs.json"),
- )
-
- return processor_config
-
-
-@pytest.fixture(scope="session")
-def processor_config_offline_chat(tmp_path_factory):
- processor_dir = tmp_path_factory.mktemp("processor")
-
- # Setup conversation processor
- processor_config = ProcessorConfig()
- offline_chat = OfflineChatProcessorConfig(enable_offline_chat=True)
- processor_config.conversation = ConversationProcessorConfig(
- offline_chat=offline_chat,
- conversation_logfile=processor_dir.joinpath("conversation_logs.json"),
- )
-
- return processor_config
-
-
-@pytest.fixture(scope="session")
-def chat_client(md_content_config: ContentConfig, search_config: SearchConfig, processor_config: ProcessorConfig):
+@pytest.fixture(scope="function")
+def chat_client(search_config: SearchConfig, default_user2: KhojUser):
# Initialize app state
state.config.search_type = search_config
state.SearchType = configure_search_types(state.config)
+ LocalMarkdownConfig.objects.create(
+ input_files=None,
+ input_filter=["tests/data/markdown/*.markdown"],
+ user=default_user2,
+ )
+
# Index Markdown Content for Search
- all_files = fs_syncer.collect_files()
+ all_files = fs_syncer.collect_files(user=default_user2)
state.content_index = configure_content(
- state.content_index, state.config.content_type, all_files, state.search_models
+ state.content_index, state.config.content_type, all_files, state.search_models, user=default_user2
)
# Initialize Processor from Config
- state.processor_config = configure_processor(processor_config)
+ if os.getenv("OPENAI_API_KEY"):
+ OpenAIProcessorConversationConfigFactory(user=default_user2)
+
+ state.anonymous_mode = True
+
+ app = FastAPI()
+
+ configure_routes(app)
+ configure_middleware(app)
+ app.mount("/static", StaticFiles(directory=web_directory), name="static")
+ return TestClient(app)
+
+
+@pytest.fixture(scope="function")
+def chat_client_no_background(search_config: SearchConfig, default_user2: KhojUser):
+ # Initialize app state
+ state.config.search_type = search_config
+ state.SearchType = configure_search_types(state.config)
+
+ # Initialize Processor from Config
+ if os.getenv("OPENAI_API_KEY"):
+ OpenAIProcessorConversationConfigFactory(user=default_user2)
+
state.anonymous_mode = True
app = FastAPI()
@@ -249,7 +219,6 @@ def fastapi_app():
def client(
content_config: ContentConfig,
search_config: SearchConfig,
- processor_config: ProcessorConfig,
default_user: KhojUser,
):
state.config.content_type = content_config
@@ -274,7 +243,7 @@ def client(
user=default_user,
)
- state.processor_config = configure_processor(processor_config)
+ ConversationProcessorConfigFactory(user=default_user)
state.anonymous_mode = True
configure_routes(app)
@@ -286,25 +255,32 @@ def client(
@pytest.fixture(scope="function")
def client_offline_chat(
search_config: SearchConfig,
- processor_config_offline_chat: ProcessorConfig,
content_config: ContentConfig,
- md_content_config,
+ default_user2: KhojUser,
):
# Initialize app state
state.config.content_type = md_content_config
state.config.search_type = search_config
state.SearchType = configure_search_types(state.config)
+ LocalMarkdownConfig.objects.create(
+ input_files=None,
+ input_filter=["tests/data/markdown/*.markdown"],
+ user=default_user2,
+ )
+
# Index Markdown Content for Search
state.search_models.image_search = image_search.initialize_model(search_config.image)
- all_files = fs_syncer.collect_files(state.config.content_type)
- state.content_index = configure_content(
- state.content_index, state.config.content_type, all_files, state.search_models
+ all_files = fs_syncer.collect_files(user=default_user2)
+ configure_content(
+ state.content_index, state.config.content_type, all_files, state.search_models, user=default_user2
)
# Initialize Processor from Config
- state.processor_config = configure_processor(processor_config_offline_chat)
+ ConversationProcessorConfigFactory(user=default_user2)
+ OfflineChatProcessorConversationConfigFactory(user=default_user2)
+
state.anonymous_mode = True
configure_routes(app)
diff --git a/tests/helpers.py b/tests/helpers.py
new file mode 100644
index 00000000..655c4435
--- /dev/null
+++ b/tests/helpers.py
@@ -0,0 +1,51 @@
+import factory
+import os
+
+from database.models import (
+ KhojUser,
+ ConversationProcessorConfig,
+ OfflineChatProcessorConversationConfig,
+ OpenAIProcessorConversationConfig,
+ Conversation,
+)
+
+
+class UserFactory(factory.django.DjangoModelFactory):
+ class Meta:
+ model = KhojUser
+
+ username = factory.Faker("name")
+ email = factory.Faker("email")
+ password = factory.Faker("password")
+ uuid = factory.Faker("uuid4")
+
+
+class ConversationProcessorConfigFactory(factory.django.DjangoModelFactory):
+ class Meta:
+ model = ConversationProcessorConfig
+
+ max_prompt_size = 2000
+ tokenizer = None
+
+
+class OfflineChatProcessorConversationConfigFactory(factory.django.DjangoModelFactory):
+ class Meta:
+ model = OfflineChatProcessorConversationConfig
+
+ enable_offline_chat = True
+ chat_model = "llama-2-7b-chat.ggmlv3.q4_0.bin"
+
+
+class OpenAIProcessorConversationConfigFactory(factory.django.DjangoModelFactory):
+ class Meta:
+ model = OpenAIProcessorConversationConfig
+
+ api_key = os.getenv("OPENAI_API_KEY")
+ chat_model = "gpt-3.5-turbo"
+
+
+class ConversationFactory(factory.django.DjangoModelFactory):
+ class Meta:
+ model = Conversation
+
+ user = factory.SubFactory(UserFactory)
diff --git a/tests/test_client.py b/tests/test_client.py
index b77ba07d..1a6b1346 100644
--- a/tests/test_client.py
+++ b/tests/test_client.py
@@ -119,7 +119,12 @@ def test_get_configured_types_via_api(client, sample_org_data):
def test_get_api_config_types(client, search_config: SearchConfig, sample_org_data, default_user2: KhojUser):
# Arrange
text_search.setup(OrgToJsonl, sample_org_data, regenerate=False, user=default_user2)
+
+ # Act
response = client.get(f"/api/config/types")
+
+ # Assert
+ assert response.status_code == 200
assert response.json() == ["all", "org", "image"]
diff --git a/tests/test_gpt4all_chat_director.py b/tests/test_gpt4all_chat_director.py
index 3e72a7e2..d978fc99 100644
--- a/tests/test_gpt4all_chat_director.py
+++ b/tests/test_gpt4all_chat_director.py
@@ -9,8 +9,7 @@ from faker import Faker
# Internal Packages
from khoj.processor.conversation import prompts
from khoj.processor.conversation.utils import message_to_log
-from khoj.utils import state
-
+from tests.helpers import ConversationFactory
SKIP_TESTS = True
pytestmark = pytest.mark.skipif(
@@ -23,7 +22,7 @@ fake = Faker()
# Helpers
# ----------------------------------------------------------------------------------------------------
-def populate_chat_history(message_list):
+def populate_chat_history(message_list, user):
# Generate conversation logs
conversation_log = {"chat": []}
for user_message, llm_message, context in message_list:
@@ -33,14 +32,15 @@ def populate_chat_history(message_list):
{"context": context, "intent": {"query": user_message, "inferred-queries": f'["{user_message}"]'}},
)
- # Update Conversation Metadata Logs in Application State
- state.processor_config.conversation.meta_log = conversation_log
+ # Update Conversation Metadata Logs in Database
+ ConversationFactory(user=user, conversation_log=conversation_log)
# Tests
# ----------------------------------------------------------------------------------------------------
@pytest.mark.xfail(AssertionError, reason="Chat director not capable of answering this question yet")
@pytest.mark.chatquality
+@pytest.mark.django_db(transaction=True)
def test_chat_with_no_chat_history_or_retrieved_content_gpt4all(client_offline_chat):
# Act
response = client_offline_chat.get(f'/api/chat?q="Hello, my name is Testatron. Who are you?"&stream=true')
@@ -56,13 +56,14 @@ def test_chat_with_no_chat_history_or_retrieved_content_gpt4all(client_offline_c
# ----------------------------------------------------------------------------------------------------
@pytest.mark.chatquality
-def test_answer_from_chat_history(client_offline_chat):
+@pytest.mark.django_db(transaction=True)
+def test_answer_from_chat_history(client_offline_chat, default_user2):
# Arrange
message_list = [
("Hello, my name is Testatron. Who are you?", "Hi, I am Khoj, a personal assistant. How can I help?", []),
("When was I born?", "You were born on 1st April 1984.", []),
]
- populate_chat_history(message_list)
+ populate_chat_history(message_list, default_user2)
# Act
response = client_offline_chat.get(f'/api/chat?q="What is my name?"&stream=true')
@@ -78,7 +79,8 @@ def test_answer_from_chat_history(client_offline_chat):
# ----------------------------------------------------------------------------------------------------
@pytest.mark.chatquality
-def test_answer_from_currently_retrieved_content(client_offline_chat):
+@pytest.mark.django_db(transaction=True)
+def test_answer_from_currently_retrieved_content(client_offline_chat, default_user2):
# Arrange
message_list = [
("Hello, my name is Testatron. Who are you?", "Hi, I am Khoj, a personal assistant. How can I help?", []),
@@ -88,7 +90,7 @@ def test_answer_from_currently_retrieved_content(client_offline_chat):
["Testatron was born on 1st April 1984 in Testville."],
),
]
- populate_chat_history(message_list)
+ populate_chat_history(message_list, default_user2)
# Act
response = client_offline_chat.get(f'/api/chat?q="Where was Xi Li born?"')
@@ -101,7 +103,8 @@ def test_answer_from_currently_retrieved_content(client_offline_chat):
# ----------------------------------------------------------------------------------------------------
@pytest.mark.chatquality
-def test_answer_from_chat_history_and_previously_retrieved_content(client_offline_chat):
+@pytest.mark.django_db(transaction=True)
+def test_answer_from_chat_history_and_previously_retrieved_content(client_offline_chat, default_user2):
# Arrange
message_list = [
("Hello, my name is Testatron. Who are you?", "Hi, I am Khoj, a personal assistant. How can I help?", []),
@@ -111,7 +114,7 @@ def test_answer_from_chat_history_and_previously_retrieved_content(client_offlin
["Testatron was born on 1st April 1984 in Testville."],
),
]
- populate_chat_history(message_list)
+ populate_chat_history(message_list, default_user2)
# Act
response = client_offline_chat.get(f'/api/chat?q="Where was I born?"')
@@ -130,13 +133,14 @@ def test_answer_from_chat_history_and_previously_retrieved_content(client_offlin
reason="Chat director not capable of answering this question yet because it requires extract_questions",
)
@pytest.mark.chatquality
-def test_answer_from_chat_history_and_currently_retrieved_content(client_offline_chat):
+@pytest.mark.django_db(transaction=True)
+def test_answer_from_chat_history_and_currently_retrieved_content(client_offline_chat, default_user2):
# Arrange
message_list = [
("Hello, my name is Xi Li. Who are you?", "Hi, I am Khoj, a personal assistant. How can I help?", []),
("When was I born?", "You were born on 1st April 1984.", []),
]
- populate_chat_history(message_list)
+ populate_chat_history(message_list, default_user2)
# Act
response = client_offline_chat.get(f'/api/chat?q="Where was I born?"')
@@ -154,14 +158,15 @@ def test_answer_from_chat_history_and_currently_retrieved_content(client_offline
# ----------------------------------------------------------------------------------------------------
@pytest.mark.xfail(AssertionError, reason="Chat director not capable of answering this question yet")
@pytest.mark.chatquality
-def test_no_answer_in_chat_history_or_retrieved_content(client_offline_chat):
+@pytest.mark.django_db(transaction=True)
+def test_no_answer_in_chat_history_or_retrieved_content(client_offline_chat, default_user2):
"Chat director should say don't know as not enough contexts in chat history or retrieved to answer question"
# Arrange
message_list = [
("Hello, my name is Testatron. Who are you?", "Hi, I am Khoj, a personal assistant. How can I help?", []),
("When was I born?", "You were born on 1st April 1984.", []),
]
- populate_chat_history(message_list)
+ populate_chat_history(message_list, default_user2)
# Act
response = client_offline_chat.get(f'/api/chat?q="Where was I born?"&stream=true')
@@ -177,11 +182,12 @@ def test_no_answer_in_chat_history_or_retrieved_content(client_offline_chat):
# ----------------------------------------------------------------------------------------------------
@pytest.mark.chatquality
-def test_answer_using_general_command(client_offline_chat):
+@pytest.mark.django_db(transaction=True)
+def test_answer_using_general_command(client_offline_chat, default_user2):
# Arrange
query = urllib.parse.quote("/general Where was Xi Li born?")
message_list = []
- populate_chat_history(message_list)
+ populate_chat_history(message_list, default_user2)
# Act
response = client_offline_chat.get(f"/api/chat?q={query}&stream=true")
@@ -194,11 +200,12 @@ def test_answer_using_general_command(client_offline_chat):
# ----------------------------------------------------------------------------------------------------
@pytest.mark.chatquality
-def test_answer_from_retrieved_content_using_notes_command(client_offline_chat):
+@pytest.mark.django_db(transaction=True)
+def test_answer_from_retrieved_content_using_notes_command(client_offline_chat, default_user2):
# Arrange
query = urllib.parse.quote("/notes Where was Xi Li born?")
message_list = []
- populate_chat_history(message_list)
+ populate_chat_history(message_list, default_user2)
# Act
response = client_offline_chat.get(f"/api/chat?q={query}&stream=true")
@@ -211,12 +218,13 @@ def test_answer_from_retrieved_content_using_notes_command(client_offline_chat):
# ----------------------------------------------------------------------------------------------------
@pytest.mark.chatquality
-def test_answer_using_file_filter(client_offline_chat):
+@pytest.mark.django_db(transaction=True)
+def test_answer_using_file_filter(client_offline_chat, default_user2):
# Arrange
no_answer_query = urllib.parse.quote('Where was Xi Li born? file:"Namita.markdown"')
answer_query = urllib.parse.quote('Where was Xi Li born? file:"Xi Li.markdown"')
message_list = []
- populate_chat_history(message_list)
+ populate_chat_history(message_list, default_user2)
# Act
no_answer_response = client_offline_chat.get(f"/api/chat?q={no_answer_query}&stream=true").content.decode("utf-8")
@@ -229,11 +237,12 @@ def test_answer_using_file_filter(client_offline_chat):
# ----------------------------------------------------------------------------------------------------
@pytest.mark.chatquality
-def test_answer_not_known_using_notes_command(client_offline_chat):
+@pytest.mark.django_db(transaction=True)
+def test_answer_not_known_using_notes_command(client_offline_chat, default_user2):
# Arrange
query = urllib.parse.quote("/notes Where was Testatron born?")
message_list = []
- populate_chat_history(message_list)
+ populate_chat_history(message_list, default_user2)
# Act
response = client_offline_chat.get(f"/api/chat?q={query}&stream=true")
@@ -247,6 +256,7 @@ def test_answer_not_known_using_notes_command(client_offline_chat):
# ----------------------------------------------------------------------------------------------------
@pytest.mark.xfail(AssertionError, reason="Chat director not capable of answering time aware questions yet")
@pytest.mark.chatquality
+@pytest.mark.django_db(transaction=True)
@freeze_time("2023-04-01")
def test_answer_requires_current_date_awareness(client_offline_chat):
"Chat actor should be able to answer questions relative to current date using provided notes"
@@ -265,6 +275,7 @@ def test_answer_requires_current_date_awareness(client_offline_chat):
# ----------------------------------------------------------------------------------------------------
@pytest.mark.xfail(AssertionError, reason="Chat director not capable of answering this question yet")
@pytest.mark.chatquality
+@pytest.mark.django_db(transaction=True)
@freeze_time("2023-04-01")
def test_answer_requires_date_aware_aggregation_across_provided_notes(client_offline_chat):
"Chat director should be able to answer questions that require date aware aggregation across multiple notes"
@@ -280,14 +291,15 @@ def test_answer_requires_date_aware_aggregation_across_provided_notes(client_off
# ----------------------------------------------------------------------------------------------------
@pytest.mark.xfail(AssertionError, reason="Chat director not capable of answering this question yet")
@pytest.mark.chatquality
-def test_answer_general_question_not_in_chat_history_or_retrieved_content(client_offline_chat):
+@pytest.mark.django_db(transaction=True)
+def test_answer_general_question_not_in_chat_history_or_retrieved_content(client_offline_chat, default_user2):
# Arrange
message_list = [
("Hello, my name is Testatron. Who are you?", "Hi, I am Khoj, a personal assistant. How can I help?", []),
("When was I born?", "You were born on 1st April 1984.", []),
("Where was I born?", "You were born Testville.", []),
]
- populate_chat_history(message_list)
+ populate_chat_history(message_list, default_user2)
# Act
response = client_offline_chat.get(
@@ -307,7 +319,8 @@ def test_answer_general_question_not_in_chat_history_or_retrieved_content(client
# ----------------------------------------------------------------------------------------------------
@pytest.mark.xfail(reason="Chat director not consistently capable of asking for clarification yet.")
@pytest.mark.chatquality
-def test_ask_for_clarification_if_not_enough_context_in_question(client_offline_chat):
+@pytest.mark.django_db(transaction=True)
+def test_ask_for_clarification_if_not_enough_context_in_question(client_offline_chat, default_user2):
# Act
response = client_offline_chat.get(f'/api/chat?q="What is the name of Namitas older son"&stream=true')
response_message = response.content.decode("utf-8")
@@ -328,14 +341,15 @@ def test_ask_for_clarification_if_not_enough_context_in_question(client_offline_
# ----------------------------------------------------------------------------------------------------
@pytest.mark.xfail(reason="Chat director not capable of answering this question yet")
@pytest.mark.chatquality
-def test_answer_in_chat_history_beyond_lookback_window(client_offline_chat):
+@pytest.mark.django_db(transaction=True)
+def test_answer_in_chat_history_beyond_lookback_window(client_offline_chat, default_user2):
# Arrange
message_list = [
("Hello, my name is Testatron. Who are you?", "Hi, I am Khoj, a personal assistant. How can I help?", []),
("When was I born?", "You were born on 1st April 1984.", []),
("Where was I born?", "You were born Testville.", []),
]
- populate_chat_history(message_list)
+ populate_chat_history(message_list, default_user2)
# Act
response = client_offline_chat.get(f'/api/chat?q="What is my name?"&stream=true')
@@ -350,11 +364,12 @@ def test_answer_in_chat_history_beyond_lookback_window(client_offline_chat):
@pytest.mark.chatquality
-def test_answer_chat_history_very_long(client_offline_chat):
+@pytest.mark.django_db(transaction=True)
+def test_answer_chat_history_very_long(client_offline_chat, default_user2):
# Arrange
message_list = [(" ".join([fake.paragraph() for _ in range(50)]), fake.sentence(), []) for _ in range(10)]
- populate_chat_history(message_list)
+ populate_chat_history(message_list, default_user2)
# Act
response = client_offline_chat.get(f'/api/chat?q="What is my name?"&stream=true')
@@ -368,6 +383,7 @@ def test_answer_chat_history_very_long(client_offline_chat):
# ----------------------------------------------------------------------------------------------------
@pytest.mark.xfail(AssertionError, reason="Chat director not capable of answering this question yet")
@pytest.mark.chatquality
+@pytest.mark.django_db(transaction=True)
def test_answer_requires_multiple_independent_searches(client_offline_chat):
"Chat director should be able to answer by doing multiple independent searches for required information"
# Act
diff --git a/tests/test_openai_chat_director.py b/tests/test_openai_chat_director.py
index abbd1831..14a73f15 100644
--- a/tests/test_openai_chat_director.py
+++ b/tests/test_openai_chat_director.py
@@ -9,8 +9,8 @@ from khoj.processor.conversation import prompts
# Internal Packages
from khoj.processor.conversation.utils import message_to_log
-from khoj.utils import state
-
+from tests.helpers import ConversationFactory
+from database.models import KhojUser
# Initialize variables for tests
api_key = os.getenv("OPENAI_API_KEY")
@@ -23,7 +23,7 @@ if api_key is None:
# Helpers
# ----------------------------------------------------------------------------------------------------
-def populate_chat_history(message_list):
+def populate_chat_history(message_list, user=None):
# Generate conversation logs
conversation_log = {"chat": []}
for user_message, gpt_message, context in message_list:
@@ -33,13 +33,14 @@ def populate_chat_history(message_list):
{"context": context, "intent": {"query": user_message, "inferred-queries": f'["{user_message}"]'}},
)
- # Update Conversation Metadata Logs in Application State
- state.processor_config.conversation.meta_log = conversation_log
+ # Update Conversation Metadata Logs in Database
+ ConversationFactory(user=user, conversation_log=conversation_log)
# Tests
# ----------------------------------------------------------------------------------------------------
@pytest.mark.chatquality
+@pytest.mark.django_db(transaction=True)
def test_chat_with_no_chat_history_or_retrieved_content(chat_client):
# Act
response = chat_client.get(f'/api/chat?q="Hello, my name is Testatron. Who are you?"&stream=true')
@@ -54,14 +55,15 @@ def test_chat_with_no_chat_history_or_retrieved_content(chat_client):
# ----------------------------------------------------------------------------------------------------
+@pytest.mark.django_db(transaction=True)
@pytest.mark.chatquality
-def test_answer_from_chat_history(chat_client):
+def test_answer_from_chat_history(chat_client, default_user2: KhojUser):
# Arrange
message_list = [
("Hello, my name is Testatron. Who are you?", "Hi, I am Khoj, a personal assistant. How can I help?", []),
("When was I born?", "You were born on 1st April 1984.", []),
]
- populate_chat_history(message_list)
+ populate_chat_history(message_list, default_user2)
# Act
response = chat_client.get(f'/api/chat?q="What is my name?"&stream=true')
@@ -76,8 +78,9 @@ def test_answer_from_chat_history(chat_client):
# ----------------------------------------------------------------------------------------------------
+@pytest.mark.django_db(transaction=True)
@pytest.mark.chatquality
-def test_answer_from_currently_retrieved_content(chat_client):
+def test_answer_from_currently_retrieved_content(chat_client, default_user2: KhojUser):
# Arrange
message_list = [
("Hello, my name is Testatron. Who are you?", "Hi, I am Khoj, a personal assistant. How can I help?", []),
@@ -87,7 +90,7 @@ def test_answer_from_currently_retrieved_content(chat_client):
["Testatron was born on 1st April 1984 in Testville."],
),
]
- populate_chat_history(message_list)
+ populate_chat_history(message_list, default_user2)
# Act
response = chat_client.get(f'/api/chat?q="Where was Xi Li born?"')
@@ -99,8 +102,9 @@ def test_answer_from_currently_retrieved_content(chat_client):
# ----------------------------------------------------------------------------------------------------
+@pytest.mark.django_db(transaction=True)
@pytest.mark.chatquality
-def test_answer_from_chat_history_and_previously_retrieved_content(chat_client):
+def test_answer_from_chat_history_and_previously_retrieved_content(chat_client_no_background, default_user2: KhojUser):
# Arrange
message_list = [
("Hello, my name is Testatron. Who are you?", "Hi, I am Khoj, a personal assistant. How can I help?", []),
@@ -110,10 +114,10 @@ def test_answer_from_chat_history_and_previously_retrieved_content(chat_client):
["Testatron was born on 1st April 1984 in Testville."],
),
]
- populate_chat_history(message_list)
+ populate_chat_history(message_list, default_user2)
# Act
- response = chat_client.get(f'/api/chat?q="Where was I born?"')
+ response = chat_client_no_background.get(f'/api/chat?q="Where was I born?"')
response_message = response.content.decode("utf-8")
# Assert
@@ -125,14 +129,15 @@ def test_answer_from_chat_history_and_previously_retrieved_content(chat_client):
# ----------------------------------------------------------------------------------------------------
@pytest.mark.xfail(AssertionError, reason="Chat director not capable of answering this question yet")
+@pytest.mark.django_db(transaction=True)
@pytest.mark.chatquality
-def test_answer_from_chat_history_and_currently_retrieved_content(chat_client):
+def test_answer_from_chat_history_and_currently_retrieved_content(chat_client, default_user2: KhojUser):
# Arrange
message_list = [
("Hello, my name is Xi Li. Who are you?", "Hi, I am Khoj, a personal assistant. How can I help?", []),
("When was I born?", "You were born on 1st April 1984.", []),
]
- populate_chat_history(message_list)
+ populate_chat_history(message_list, default_user2)
# Act
response = chat_client.get(f'/api/chat?q="Where was I born?"')
@@ -148,15 +153,16 @@ def test_answer_from_chat_history_and_currently_retrieved_content(chat_client):
# ----------------------------------------------------------------------------------------------------
+@pytest.mark.django_db(transaction=True)
@pytest.mark.chatquality
-def test_no_answer_in_chat_history_or_retrieved_content(chat_client):
+def test_no_answer_in_chat_history_or_retrieved_content(chat_client, default_user2: KhojUser):
"Chat director should say don't know as not enough contexts in chat history or retrieved to answer question"
# Arrange
message_list = [
("Hello, my name is Testatron. Who are you?", "Hi, I am Khoj, a personal assistant. How can I help?", []),
("When was I born?", "You were born on 1st April 1984.", []),
]
- populate_chat_history(message_list)
+ populate_chat_history(message_list, default_user2)
# Act
response = chat_client.get(f'/api/chat?q="Where was I born?"&stream=true')
@@ -171,12 +177,13 @@ def test_no_answer_in_chat_history_or_retrieved_content(chat_client):
# ----------------------------------------------------------------------------------------------------
+@pytest.mark.django_db(transaction=True)
@pytest.mark.chatquality
-def test_answer_using_general_command(chat_client):
+def test_answer_using_general_command(chat_client, default_user2: KhojUser):
# Arrange
query = urllib.parse.quote("/general Where was Xi Li born?")
message_list = []
- populate_chat_history(message_list)
+ populate_chat_history(message_list, default_user2)
# Act
response = chat_client.get(f"/api/chat?q={query}&stream=true")
@@ -188,12 +195,13 @@ def test_answer_using_general_command(chat_client):
# ----------------------------------------------------------------------------------------------------
+@pytest.mark.django_db(transaction=True)
@pytest.mark.chatquality
-def test_answer_from_retrieved_content_using_notes_command(chat_client):
+def test_answer_from_retrieved_content_using_notes_command(chat_client, default_user2: KhojUser):
# Arrange
query = urllib.parse.quote("/notes Where was Xi Li born?")
message_list = []
- populate_chat_history(message_list)
+ populate_chat_history(message_list, default_user2)
# Act
response = chat_client.get(f"/api/chat?q={query}&stream=true")
@@ -205,15 +213,16 @@ def test_answer_from_retrieved_content_using_notes_command(chat_client):
# ----------------------------------------------------------------------------------------------------
+@pytest.mark.django_db(transaction=True)
@pytest.mark.chatquality
-def test_answer_not_known_using_notes_command(chat_client):
+def test_answer_not_known_using_notes_command(chat_client_no_background, default_user2: KhojUser):
# Arrange
query = urllib.parse.quote("/notes Where was Testatron born?")
message_list = []
- populate_chat_history(message_list)
+ populate_chat_history(message_list, default_user2)
# Act
- response = chat_client.get(f"/api/chat?q={query}&stream=true")
+ response = chat_client_no_background.get(f"/api/chat?q={query}&stream=true")
response_message = response.content.decode("utf-8")
# Assert
@@ -223,6 +232,7 @@ def test_answer_not_known_using_notes_command(chat_client):
# ----------------------------------------------------------------------------------------------------
@pytest.mark.xfail(AssertionError, reason="Chat director not capable of answering time aware questions yet")
+@pytest.mark.django_db(transaction=True)
@pytest.mark.chatquality
@freeze_time("2023-04-01")
def test_answer_requires_current_date_awareness(chat_client):
@@ -240,11 +250,13 @@ def test_answer_requires_current_date_awareness(chat_client):
# ----------------------------------------------------------------------------------------------------
+@pytest.mark.django_db(transaction=True)
@pytest.mark.chatquality
@freeze_time("2023-04-01")
def test_answer_requires_date_aware_aggregation_across_provided_notes(chat_client):
"Chat director should be able to answer questions that require date aware aggregation across multiple notes"
# Act
+
response = chat_client.get(f'/api/chat?q="How much did I spend on dining this year?"&stream=true')
response_message = response.content.decode("utf-8")
@@ -254,15 +266,16 @@ def test_answer_requires_date_aware_aggregation_across_provided_notes(chat_clien
# ----------------------------------------------------------------------------------------------------
+@pytest.mark.django_db(transaction=True)
@pytest.mark.chatquality
-def test_answer_general_question_not_in_chat_history_or_retrieved_content(chat_client):
+def test_answer_general_question_not_in_chat_history_or_retrieved_content(chat_client, default_user2: KhojUser):
# Arrange
message_list = [
("Hello, my name is Testatron. Who are you?", "Hi, I am Khoj, a personal assistant. How can I help?", []),
("When was I born?", "You were born on 1st April 1984.", []),
("Where was I born?", "You were born Testville.", []),
]
- populate_chat_history(message_list)
+ populate_chat_history(message_list, default_user2)
# Act
response = chat_client.get(
@@ -280,10 +293,12 @@ def test_answer_general_question_not_in_chat_history_or_retrieved_content(chat_c
# ----------------------------------------------------------------------------------------------------
+@pytest.mark.django_db(transaction=True)
@pytest.mark.chatquality
-def test_ask_for_clarification_if_not_enough_context_in_question(chat_client):
+def test_ask_for_clarification_if_not_enough_context_in_question(chat_client_no_background):
# Act
- response = chat_client.get(f'/api/chat?q="What is the name of Namitas older son"&stream=true')
+
+ response = chat_client_no_background.get(f'/api/chat?q="What is the name of Namitas older son"&stream=true')
response_message = response.content.decode("utf-8")
# Assert
@@ -301,15 +316,16 @@ def test_ask_for_clarification_if_not_enough_context_in_question(chat_client):
# ----------------------------------------------------------------------------------------------------
@pytest.mark.xfail(reason="Chat director not capable of answering this question yet")
+@pytest.mark.django_db(transaction=True)
@pytest.mark.chatquality
-def test_answer_in_chat_history_beyond_lookback_window(chat_client):
+def test_answer_in_chat_history_beyond_lookback_window(chat_client, default_user2: KhojUser):
# Arrange
message_list = [
("Hello, my name is Testatron. Who are you?", "Hi, I am Khoj, a personal assistant. How can I help?", []),
("When was I born?", "You were born on 1st April 1984.", []),
("Where was I born?", "You were born Testville.", []),
]
- populate_chat_history(message_list)
+ populate_chat_history(message_list, default_user2)
# Act
response = chat_client.get(f'/api/chat?q="What is my name?"&stream=true')
@@ -324,6 +340,7 @@ def test_answer_in_chat_history_beyond_lookback_window(chat_client):
# ----------------------------------------------------------------------------------------------------
+@pytest.mark.django_db(transaction=True)
@pytest.mark.chatquality
def test_answer_requires_multiple_independent_searches(chat_client):
"Chat director should be able to answer by doing multiple independent searches for required information"
@@ -340,10 +357,12 @@ def test_answer_requires_multiple_independent_searches(chat_client):
# ----------------------------------------------------------------------------------------------------
+@pytest.mark.django_db(transaction=True)
def test_answer_using_file_filter(chat_client):
"Chat should be able to use search filters in the query"
# Act
query = urllib.parse.quote('Is Xi older than Namita? file:"Namita.markdown" file:"Xi Li.markdown"')
+
response = chat_client.get(f"/api/chat?q={query}&stream=true")
response_message = response.content.decode("utf-8")
diff --git a/tests/test_text_search.py b/tests/test_text_search.py
index af47ffe5..ec8034ef 100644
--- a/tests/test_text_search.py
+++ b/tests/test_text_search.py
@@ -13,12 +13,11 @@ from khoj.search_type import text_search
from khoj.utils.rawconfig import ContentConfig, SearchConfig
from khoj.processor.org_mode.org_to_jsonl import OrgToJsonl
from khoj.processor.github.github_to_jsonl import GithubToJsonl
-from khoj.utils.config import SearchModels
-from khoj.utils.fs_syncer import get_org_files, collect_files
+from khoj.utils.fs_syncer import collect_files, get_org_files
from database.models import LocalOrgConfig, KhojUser, Embeddings, GithubConfig
logger = logging.getLogger(__name__)
-from khoj.utils.rawconfig import ContentConfig, SearchConfig, TextContentConfig
+from khoj.utils.rawconfig import ContentConfig, SearchConfig
# Test